In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import os
import pandas as pd

# data loading and preprocessing

**load the taxa abundance table**

In [0]:
df_genus = spark.read.table("onesource_eu_dev_rni.onebiome.mpa4_genus_level")
df_genus =  df_genus.toPandas().set_index('barcode')
print(df_genus.shape)

In [0]:
df_genus

In [0]:
df_genus.columns = df_genus.columns.str.lower()
df_genus.columns = [col.split('|')[-1] for col in df_genus.columns]


In [0]:
print([col for col in df_genus.columns if not col.startswith('g__')])


In [0]:
df_genus.columns = [col.split('|')[-1][3:] for col in df_genus.columns]
df_genus.columns = df_genus.columns.str.lower()
df_genus.head()

In [0]:
similar_cols = []

columns = df_genus.columns

for i, col1 in enumerate(columns):
    if len(col1)<1:
        print(col1)
        continue
    for col2 in columns[i+1:]:
        if len(col2)<1:
            print(col2)
            continue
        if col1[:-1] == col2[:-1] and col1[-1] != col2[-1]:
            similar_cols.extend([col1, col2])



In [0]:
similar_cols = sorted(set(similar_cols))
import seaborn as sns
df_genus_similar = df_genus[similar_cols]
sns.heatmap(df_genus_similar)

**generate the taxa sequences**

In [0]:
list_sequences = []
for _, row in df_genus.iterrows():
    ranked_vars = row[row != 0].sort_values(ascending=False).index.tolist()  
    list_sequences.append(ranked_vars) 


**check the sequence length distribution**

In [0]:
MAX_SEQ_LENGTH = max([len(seq) for seq in list_sequences])
print(MAX_SEQ_LENGTH)
import matplotlib.pyplot as plt
plt.hist([len(seq) for seq in list_sequences], bins=100)
plt.title("Sequence Length Distribution in Dicaprio MPA4")
plt.xlabel("Sequence Length")
plt.ylabel("Frequency");

In [0]:

sequences = [" ".join(seq) for seq in list_sequences]

In [0]:
sequences[2222]

In [0]:
len(sequences)

In [0]:
with open('../data/dicaprio_genus_sequences.txt', 'w') as file:
    for sequence in sequences:
        file.write(sequence + '\n')

# Extend the taxa id mapper 
to what was created based on curatedMetagenomics


In [0]:
UNIQUE_TAXA = df_genus.columns.sort_values()
UNIQUE_TAXA

In [0]:
import json

with open('../data/genus_token_to_id_cmg.json', 'r') as file:
    token_to_id = json.load(file)

In [0]:
token_to_id = {token: idx + 4 for idx, token in enumerate(UNIQUE_TAXA)}

import json
with open('../data/token_to_id.json', 'w') as file:
    json.dump(token_to_id, file)

In [0]:
tokenizer = create_simple_mapping_tokenizer(sequences, UNIQUE_TAXA) 

In [0]:
tokenizer.encode(list_sequences[0], return_tensors='pt')

In [0]:
tokenizer.encode(list_sequences[0], return_tensors='pt')['input_ids'].dim()

In [0]:
# custom dataset class
class TaxaSequenceDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length=128):
        """
        Args:
            sequences: List of sequences, where each sequence is a list of string tokens
            tokenizer: Simple tokenizer instance
            max_length: Maximum sequence length
        """
        self.inputs = []
        
        for seq in sequences:
            encoded = tokenizer.encode(
                seq,
                max_length=max_length,
                padding="max_length",
                truncation=True, 
                return_tensors="pt"
            )
            self.inputs.append(encoded)    

            # self.inputs.append({
            #     "input_ids": encoded["input_ids"],
            #     "attention_mask": encoded["attention_mask"]
            # })
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        # Get raw tensors
        input_ids = self.inputs[idx]["input_ids"]
        attention_mask = self.inputs[idx]["attention_mask"]
        
        # Ensure both tensors are 1D (flatten if needed)
        if input_ids.dim() > 1:
            input_ids = input_ids.view(-1)
        if attention_mask.dim() > 1:
            attention_mask = attention_mask.view(-1)
            
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

# configure and initialize the model

model chosen: GPT-2 Small (gpt2) 

- It's autoregressive, suited for generation task (taxa completion)
- practical to extract embeddings for downstream supervised learning tasks
- It handles variable sequence lengths well
- With 355 unique values, no worry for vocabulary size issues

In [0]:
def create_gpt2_model(vocab_size):
    """Create a GPT-2 model with custom vocab size"""
    config = GPT2Config(
        vocab_size=vocab_size,
        n_positions=MAX_SEQ_LENGTH,
        n_ctx=MAX_SEQ_LENGTH,
        n_embd=64,  # Smaller embedding size
        n_layer=6,   # Fewer layers for faster training
        n_head=8,    # Fewer attention heads
        bos_token_id=vocab_size - 3,  # <s>
        eos_token_id=vocab_size - 2,  # </s>
        pad_token_id=vocab_size - 4,   # <pad>
        attn_pdrop = 0.0,  # Attention dropout
        embd_pdrop = 0.0,  # Embedding dropout
        resid_pdrop = 0.0  # Residual dropout
    )
    
    model = GPT2LMHeadModel(config)
    return model

model = create_gpt2_model(tokenizer.vocab_size).to(device)

In [0]:
next(model.parameters()).device

Create a SimpleDataCollator class to replace the HF DataCollatorForLanguageModeling, to Properly handles padding and creates proper language modeling labels

In [0]:

class SimpleDataCollator:
    """Simple data collator for language modeling"""
    def __init__(self, tokenizer, mlm=False):
        self.tokenizer = tokenizer
        self.mlm = mlm
        
    def __call__(self, features):
        # Ensure consistent tensor shapes
        input_ids = [f["input_ids"] for f in features]
        attention_mask = [f["attention_mask"] for f in features]
        
        # Get max length
        max_len = max(len(ids) for ids in input_ids)
        
        # Pad all tensors to max length
        padded_input_ids = []
        padded_attention_mask = []
        
        for ids, mask in zip(input_ids, attention_mask):
            # Padding needed
            padding_len = max_len - len(ids)
            
            if padding_len > 0:
                # Pad with pad_token_id
                padded_ids = torch.cat([
                    ids, 
                    torch.full((padding_len,), self.tokenizer.pad_token_id, dtype=torch.long)
                ])
                padded_mask = torch.cat([
                    mask,
                    torch.zeros(padding_len, dtype=torch.long)
                ])
            else:
                padded_ids = ids
                padded_mask = mask
                
            padded_input_ids.append(padded_ids)
            padded_attention_mask.append(padded_mask)
        
        # Stack into batches
        batch = {
            "input_ids": torch.stack(padded_input_ids),
            "attention_mask": torch.stack(padded_attention_mask)
        }
        
        # For causal language modeling
        labels = batch["input_ids"].clone()
        # Mark padding as -100 to ignore in loss calculation
        labels[batch["input_ids"] == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        def debug_batch(batch):
            print("Input IDs shape:", batch["input_ids"].shape)
            print("Attention mask shape:", batch["attention_mask"].shape)
            print("Labels shape:", batch["labels"].shape)
            # Check if all tensors are on the same device
            print("Input IDs device:", batch["input_ids"].device)
            print("Attention mask device:", batch["attention_mask"].device)
            print("Labels device:", batch["labels"].device)
            # Check for any NaN values
            print("Any NaN in input_ids:", torch.isnan(batch["input_ids"]).any())
            print("Any NaN in attention_mask:", torch.isnan(batch["attention_mask"]).any())
            print("Any NaN in labels:", torch.isnan(batch["labels"]).any())
            # Print some values
            print("First sequence input_ids:", batch["input_ids"][0][:10])
            print("First sequence attention_mask:", batch["attention_mask"][0][:10])
            print("First sequence labels:", batch["labels"][0][:10])

        # Use this before model forward pass
        debug_batch(batch)
        
        return batch

# training

In [0]:
output_dir="./gpt2_taxa_seq_model"
train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
eval_dataset = TaxaSequenceDataset(test_sequences, tokenizer)

In [0]:
data_collator = SimpleDataCollator(
        tokenizer=tokenizer, 
        mlm=False
    )
# Set up training arguments
training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        fp16=True,
        dataloader_pin_memory=True,
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        save_total_limit=2,
        logging_dir=f"{output_dir}/logs",
        load_best_model_at_end=True if eval_dataset else False,
        full_determinism = False
    )

In [0]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"; os.environ["TORCH_USE_CUDA_DSA"] = "1"



trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
# Train the model
trainer.train()

# Save the trained model and tokenizer
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [0]:

# 3. Training function (accept both training and evaluation datasets)
def train_model(model, tokenizer, train_sequences, eval_sequences=None, output_dir="./gpt2_taxa_seq_model"):
    train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
    
    # Prepare validation dataset if provided
    eval_dataset = None
    if eval_sequences:
        eval_dataset = TaxaSequenceDataset(eval_sequences, tokenizer)
    
    # Set up data collator with masked language modeling
    data_collator = SimpleDataCollator(
        tokenizer=tokenizer, 
        mlm=False
    )
    
    # Set up training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        save_total_limit=2,
        logging_dir=f"{output_dir}/logs",
        load_best_model_at_end=True if eval_dataset else False,
    )
    
    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    
    # Train the model
    trainer.train()
    
    # Save the trained model and tokenizer
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return model, tokenizer

In [0]:
model, tokenizer = train_model(model, tokenizer, train_sequences, test_sequences)
    

# evaluate on sequence completion task

In [0]:
def evaluate_sequence_completion(model, tokenizer, test_sequences, prefix_ratio=0.5):
    model.eval()
    accuracies = []
    
    for sequence in test_sequences:
        # Split the sequence into prefix and target
        tokens = sequence.split()
        prefix_len = max(3, int(len(tokens) * prefix_ratio))  # Use at least 3 tokens as prefix
        
        prefix = " ".join(tokens[:prefix_len])
        target = " ".join(tokens[prefix_len:])
        
        # Generate completion using the prefix
        generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
        
        # Extract the generated continuation (after the prefix)
        generated_completion = " ".join(generated.split()[prefix_len:])
        
        # Calculate accuracy (exact match between tokens)
        target_tokens = target.split()
        generated_tokens = generated_completion.split()[:len(target_tokens)]  # Truncate to target length
        
        # If generated is shorter, pad with dummy values that won't match
        if len(generated_tokens) < len(target_tokens):
            generated_tokens.extend(["DUMMY"] * (len(target_tokens) - len(generated_tokens)))
        
        # Calculate token-level accuracy
        matches = sum(1 for t, g in zip(target_tokens, generated_tokens) if t == g)
        accuracy = matches / len(target_tokens) if target_tokens else 1.0
        accuracies.append(accuracy)
        
    return {
        "mean_accuracy": np.mean(accuracies),
        "individual_accuracies": accuracies
    }


In [0]:
completion_results = evaluate_sequence_completion(model, tokenizer, test_sequences)
print(f"Mean sequence completion accuracy: {completion_results['mean_accuracy']:.4f}")
    

In [0]:

# Example of sequence completion
example_idx = np.random.randint(0, len(test_sequences))
example_sequence = test_sequences[example_idx]
tokens = example_sequence.split()
prefix_len = max(3, int(len(tokens) * 0.5))
prefix = " ".join(tokens[:prefix_len])

print("\nExample sequence completion:")
print(f"Prefix: {prefix}")
print(f"Original completion: {' '.join(tokens[prefix_len:])}")
generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
print(f"Model completion: {' '.join(generated.split()[prefix_len:])}")



In [0]:
# Extract embeddings for a few test sequences
print("\nExtracting embeddings for supervised learning...")
sample_sequences = test_sequences[:5]  # Just use a few sequences for demonstration
embeddings = extract_embeddings(model, tokenizer, sample_sequences)
print(f"Extracted embeddings shape: {embeddings.shape}")
print("These embeddings can now be used for classification or regression tasks.")
