In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
import os
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Config, 
    GPT2LMHeadModel, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling
)
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from utils_gpt import *

In [0]:
import torch

# Force CUDA initialization at the start
if torch.cuda.is_available():
    torch.zeros(1).cuda()  # Triggers CUDA context init
    torch.cuda.manual_seed_all(42)  # Now safe to set seeds

# customized tokenizer

In [0]:
import os
import json
import torch

class SimpleTokenizer:
    def __init__(self, token_to_id, id_to_token):
        self.token_to_id = token_to_id
        self.id_to_token = id_to_token
        self.pad_token_id = token_to_id["<pad>"]
        self.bos_token_id = token_to_id["<s>"]
        self.eos_token_id = token_to_id["</s>"]
        self.unk_token_id = token_to_id["<unk>"]
        self.vocab_size = len(token_to_id)
        
        # Special token attributes expected by HF transformers
        self.all_special_ids = [self.pad_token_id, self.bos_token_id, self.eos_token_id, self.unk_token_id]
        self.model_max_length = MAX_SEQ_LENGTH

        # Special token properties that HF expects
        self.pad_token = "<pad>"
        self.bos_token = "<s>"
        self.eos_token = "</s>"
        self.unk_token = "<unk>"
        
        # Map IDs to special tokens
        self.special_ids_to_tokens = {
            self.pad_token_id: self.pad_token,
            self.bos_token_id: self.bos_token,
            self.eos_token_id: self.eos_token,
            self.unk_token_id: self.unk_token
        }

    def tokenize(self, text):
        """Split text into tokens"""
        if isinstance(text, str):
            return text.split()
        return text

    def __call__(self, sequence, max_length=None, padding=False, truncation=False, return_tensors=None):
        """Make the tokenizer callable like HF tokenizers"""
        return self.encode(sequence, max_length, padding, truncation, return_tensors)
    

    def encode(self, sequence, max_length=None, padding=False, truncation=False, return_tensors=None):
        """Convert a sequence or batch of sequences to token IDs"""
        # Check if it's a batch (list of strings)
        if isinstance(sequence, list) and all(isinstance(s, str) for s in sequence):
            # Process batch
            batch_ids = []
            batch_attention_masks = []
            
            # First pass: tokenize and truncate
            for seq in sequence:
                seq_tokens = seq.split()
                
                # Convert tokens to ids
                ids = [self.token_to_id.get(token, self.unk_token_id) for token in seq_tokens]
                
                # Apply truncation if needed
                if truncation and max_length and len(ids) > max_length:
                    ids = ids[:max_length]
                
                batch_ids.append(ids)
            
            # Determine the padding length
            if padding:
                if max_length is None:
                    # Pad to the longest sequence in the batch
                    max_length = max(len(ids) for ids in batch_ids)
                
                # Second pass: pad all sequences to max_length
                for i, ids in enumerate(batch_ids):
                    original_length = len(ids)
                    attention_mask = [1] * original_length
                    
                    padding_length = max_length - original_length
                    if padding_length > 0:
                        ids = ids + [self.pad_token_id] * padding_length
                        attention_mask = attention_mask + [0] * padding_length
                    
                    batch_ids[i] = ids
                    batch_attention_masks.append(attention_mask)
            else:
                # No padding - just create attention masks
                for ids in batch_ids:
                    attention_mask = [1] * len(ids)
                    batch_attention_masks.append(attention_mask)
            
            # Return tensors if requested
            if return_tensors == "pt":
                return {
                    "input_ids": torch.tensor(batch_ids),
                    "attention_mask": torch.tensor(batch_attention_masks)
                }
            else:
                return {
                    "input_ids": batch_ids,
                    "attention_mask": batch_attention_masks
                }
        else:
            # Single sequence processing (original code)
            if isinstance(sequence, str):
                sequence = sequence.split()
                
            # Convert tokens to ids
            ids = [self.token_to_id.get(token, self.unk_token_id) for token in sequence]
            
            # Apply truncation if needed
            if truncation and max_length and len(ids) > max_length:
                ids = ids[:max_length]
                
            # Apply padding if needed
            attention_mask = [1] * len(ids)
            if padding and max_length:
                padding_length = max_length - len(ids)
                ids = ids + [self.pad_token_id] * padding_length
                attention_mask = attention_mask + [0] * padding_length
            
            # Return tensors if requested
            if return_tensors == "pt":
                return {
                    "input_ids": torch.tensor([ids]),
                    "attention_mask": torch.tensor([attention_mask])
                }
            else:
                return {
                    "input_ids": ids,
                    "attention_mask": attention_mask
                }
    
    def decode(self, ids, skip_special_tokens=False):
        """Convert token IDs back to a sequence"""
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
            
        tokens = []
        for id in ids:
            if skip_special_tokens and id in self.all_special_ids:
                continue
            tokens.append(self.id_to_token.get(id, "<unk>"))
            
        return " ".join(tokens)
    
    def save_pretrained(self, output_dir):
        """Save tokenizer to disk"""
        os.makedirs(output_dir, exist_ok=True)
        
        # Save the vocabulary
        with open(os.path.join(output_dir, "vocab.json"), "w") as f:
            # Convert int keys to strings for JSON serialization
            token_to_id_serializable = {k: v for k, v in self.token_to_id.items()}
            id_to_token_serializable = {str(k): v for k, v in self.id_to_token.items()}
            json.dump({
                "token_to_id": token_to_id_serializable,
                "id_to_token": id_to_token_serializable
            }, f)
            
    @classmethod
    def from_pretrained(cls, input_dir):
        """Load tokenizer from disk"""
        with open(os.path.join(input_dir, "vocab.json"), "r") as f:
            data = json.load(f)
            token_to_id = data["token_to_id"]
            # Convert string keys back to integers
            id_to_token = {int(k): v for k, v in data["id_to_token"].items()}
            
        return cls(token_to_id, id_to_token)

In [0]:
MAX_SEQ_LENGTH = 150
tokenizer_path = "models/simple_taxa_tokenizer"

# load mapping dictionaries
import json
with open('../../data/token_to_id.json', 'r') as f:
    token_to_id = json.load(f)
id_to_token = {idx: token for token, idx in token_to_id.items()}

# Add special tokens
special_tokens = {"<pad>": 0, "<s>":  1, 
                    "</s>":  2, "<unk>": 3}
token_to_id.update(special_tokens)
id_to_token.update({v: k for k, v in special_tokens.items()})

tokenizer = SimpleTokenizer(token_to_id, id_to_token)
tokenizer.save_pretrained(tokenizer_path)

In [0]:
SAMPLE_TEXTS = [
        # "This is a simple test sentence.",
        # "Let's test some domain-specific content that your model might see.",
        "Roseburia Ruminococcus Streptococcus Dorea Faecalibacterium Anaerostipes Bifidobacterium Blautia Anaerobutyricum Agathobaculum Collinsella Klebsiella Fusicatenibacter Bacteroides Eubacterium Gemmiger Adlercreutzia Phocaeicola Alistipes Barnesiella Firmicutes",
        "Ruminococcus Phocaeicola Bacteroides Faecalibacterium Eubacterium Roseburia Alistipes"
    ]
    

In [0]:
tokenizer.encode(SAMPLE_TEXTS, padding=True, truncation=True,
                  return_tensors='pt')

# dataset and dataloader classes


In [0]:
# custom dataset class
class TaxaSequenceDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length=128):
        """
        Args:
            sequences: List of sequences, where each sequence is a list of string tokens
            tokenizer: Simple tokenizer instance
            max_length: Maximum sequence length
        """
        self.inputs = []
        
        for seq in sequences:
            encoded = tokenizer.encode(
                seq,
                max_length=max_length,
                padding="max_length",
                truncation=True, 
                return_tensors="pt"
            )
            self.inputs.append(encoded)    

            # self.inputs.append({
            #     "input_ids": encoded["input_ids"],
            #     "attention_mask": encoded["attention_mask"]
            # })
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        # Get raw tensors
        input_ids = self.inputs[idx]["input_ids"]
        attention_mask = self.inputs[idx]["attention_mask"]
        
        # Ensure both tensors are 1D (flatten if needed)
        if input_ids.dim() > 1:
            input_ids = input_ids.view(-1)
        if attention_mask.dim() > 1:
            attention_mask = attention_mask.view(-1)
            
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

# configure and initialize the model

model chosen: GPT-2 Small (gpt2) 

- It's autoregressive, suited for generation task (taxa completion)
- practical to extract embeddings for downstream supervised learning tasks
- It handles variable sequence lengths well
- With 355 unique values, no worry for vocabulary size issues

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
def create_gpt2_model(vocab_size):
    """Create a GPT-2 model with custom vocab size"""
    config = GPT2Config(
        vocab_size=vocab_size,
        n_positions=MAX_SEQ_LENGTH,
        n_embd=64,  # Smaller embedding size
        n_layer=6,   # Fewer layers for faster training
        n_head=4,    # Fewer attention heads

        bos_token_id=1,  # <s>
        eos_token_id=2,  # </s>
        pad_token_id=0,   # <pad>
        unk_token_id =3,  # <unk>

        attn_pdrop = 0.1,  # Attention dropout
        embd_pdrop = 0.1,  # Embedding dropout
        resid_pdrop = 0.1,  # Residual dropout

        # optimizations for small dataset
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        use_cache=False  # Disable during training
    )
    
    model = GPT2LMHeadModel(config)
    return model

model = create_gpt2_model(tokenizer.vocab_size).to(device)

In [0]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

Create a SimpleDataCollator class to replace the HF DataCollatorForLanguageModeling, to Properly handles padding and creates proper language modeling labels

In [0]:

class SimpleDataCollator:
    """Simple data collator for language modeling"""
    def __init__(self, tokenizer, mlm=False, max_seq_length=None):
        self.tokenizer = tokenizer
        self.mlm = mlm
        self.max_seq_length = max_seq_length
        
    def __call__(self, features):
        # Ensure consistent tensor shapes
        input_ids = [f["input_ids"] for f in features]
        attention_mask = [f["attention_mask"] for f in features]
        
        # Get max length
        max_len = max(len(ids) for ids in input_ids)
        if self.max_seq_length:
            max_len = min(max_len, self.max_seq_length)
        
        
        # Pad all tensors to max length
        padded_input_ids = []
        padded_attention_mask = []
        
        for ids, mask in zip(input_ids, attention_mask):
            # Padding needed
            padding_len = max_len - len(ids)
            
            if padding_len > 0:
                # Pad with pad_token_id
                padded_ids = torch.cat([
                    ids, 
                    torch.full((padding_len,), self.tokenizer.pad_token_id, dtype=torch.long)
                ])
                padded_mask = torch.cat([
                    mask,
                    torch.zeros(padding_len, dtype=torch.long)
                ])
            else:
                padded_ids = ids
                padded_mask = mask
                
            padded_input_ids.append(padded_ids)
            padded_attention_mask.append(padded_mask)
        
        # Stack into batches
        batch = {
            "input_ids": torch.stack(padded_input_ids),
            "attention_mask": torch.stack(padded_attention_mask)
        }
        
        # For causal language modeling
        labels = batch["input_ids"].clone()
        # Mark padding as -100 to ignore in loss calculation
        labels[batch["input_ids"] == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        def debug_batch(batch):
            print("Input IDs shape:", batch["input_ids"].shape)
            print("Attention mask shape:", batch["attention_mask"].shape)
            print("Labels shape:", batch["labels"].shape)
            # Check if all tensors are on the same device
            print("Input IDs device:", batch["input_ids"].device)
            print("Attention mask device:", batch["attention_mask"].device)
            print("Labels device:", batch["labels"].device)
            # Check for any NaN values
            print("Any NaN in input_ids:", torch.isnan(batch["input_ids"]).any())
            print("Any NaN in attention_mask:", torch.isnan(batch["attention_mask"]).any())
            print("Any NaN in labels:", torch.isnan(batch["labels"]).any())
            # Print some values
            print("First sequence input_ids:", batch["input_ids"][0][:10])
            print("First sequence attention_mask:", batch["attention_mask"][0][:10])
            print("First sequence labels:", batch["labels"][0][:10])

        # Use this before model forward pass
        # debug_batch(batch)
        
        return batch

# data loading


In [0]:
from sklearn.model_selection import train_test_split

# Load sequences
with open('../../data/taxa_sequences.txt', 'r') as file:
    sequences = file.readlines()

# Split sequences into train and test sets
train_sequences, test_sequences = train_test_split(sequences, test_size=0.2, random_state=42)

In [0]:
len(sequences)

# training

In [0]:
output_dir="./gpt2_taxa_seq_model"
train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
eval_dataset = TaxaSequenceDataset(test_sequences, tokenizer)

In [0]:
def log_diversity_metrics(model, tokenizer, num_samples=100):
    """Calculate and log various diversity metrics to MLflow"""
    """help to monitor if the model is learning diverse patterns or just memorizing training data."""
    generated_sequences = []
    all_tokens = []
    
    model.eval()
    with torch.no_grad():
        for _ in range(num_samples):
            input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(model.device)
            
            generated = model.generate(
                input_ids,
                max_length=25,
                do_sample=True,
                temperature=1.0,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
            
            # Decode for sequence diversity
            sequence = tokenizer.decode(generated[0], skip_special_tokens=True)
            generated_sequences.append(sequence)
            
            # Keep token IDs for token diversity
            all_tokens.extend(generated[0].tolist())
    
    # Sequence-level diversity
    unique_sequences = len(set(generated_sequences))
    sequence_diversity = unique_sequences / len(generated_sequences)
    
    # Token-level diversity
    unique_tokens = len(set(all_tokens))
    vocab_usage = unique_tokens / tokenizer.vocab_size
    
    # Log metrics
    mlflow.log_metric("sequence_diversity", sequence_diversity)
    mlflow.log_metric("unique_sequences", unique_sequences)
    mlflow.log_metric("vocab_usage_ratio", vocab_usage)
    
    print(f"Diversity Metrics:")
    print(f"  Sequence diversity: {sequence_diversity:.3f}")
    print(f"  Unique sequences: {unique_sequences}/{num_samples}")
    print(f"  Vocabulary usage: {vocab_usage:.3f} ({unique_tokens}/{tokenizer.vocab_size} tokens)")
    
    model.train()
    return sequence_diversity

In [0]:
import mlflow
import torch
from transformers import Trainer, TrainerCallback
from transformers.integrations import MLflowCallback
import numpy as np



class TaxaModelCallback(TrainerCallback):
    """Simplified callback for Databricks with MLflow logging"""
    
    def __init__(self, tokenizer, generation_length=50):
        self.tokenizer = tokenizer
        self.generation_length = generation_length
        self.epoch_samples = []
    
    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        """Generate samples and log to MLflow"""
        if state.global_step > 0:
            model.eval()
            
            # Generate sample sequence
            input_ids = torch.tensor([[self.tokenizer.bos_token_id]]).to(model.device)
            
            with torch.no_grad():
                generated = model.generate(
                    input_ids,
                    max_length=self.generation_length,
                    do_sample=True,
                    temperature=0.8,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )
            
            generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
            
            # Log to MLflow
            mlflow.log_text(generated_text, f"samples/epoch_{state.epoch}_sample.txt")
            mlflow.log_metric("epoch", state.epoch, step=state.global_step)
            
            # Calculate and log perplexity if eval loss exists
            if hasattr(state, 'log_history') and state.log_history:
                last_log = state.log_history[-1]
                if 'eval_loss' in last_log:
                    perplexity = np.exp(last_log['eval_loss'])
                    mlflow.log_metric("perplexity", perplexity, step=state.global_step)
            
            # Display in notebook
            print(f"\nEpoch {state.epoch}: Generated sample:")
            print(generated_text[:100] + "..." if len(generated_text) > 100 else generated_text)
            
            model.train()


class SimpleTaxaTrainer(Trainer):
    """Minimal custom trainer for taxa data"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """Standard loss computation"""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Log batch metrics to MLflow
        if self.state.global_step % 100 == 0:
            mlflow.log_metric("batch_loss", loss.item(), step=self.state.global_step)
        
        return (loss, outputs) if return_outputs else loss

In [0]:
def train_model(model, tokenizer, train_sequences, eval_sequences=None, 
                output_dir="./gpt2_taxa_seq_model", num_epochs=200):
    """
    Train GPT-2 model on taxa sequences - simplified for Databricks
    """
    # Create datasets
    train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
    eval_dataset = None
    if eval_sequences:
        eval_dataset = TaxaSequenceDataset(eval_sequences, tokenizer)
    
    # Data collator
    data_collator = SimpleDataCollator(
        tokenizer=tokenizer,
        mlm=False,
        max_seq_length=MAX_SEQ_LENGTH
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=2,
        fp16=True,
        dataloader_pin_memory=True,
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        save_total_limit=3,
        load_best_model_at_end=True if eval_dataset else False,
        metric_for_best_model="eval_loss" if eval_dataset else None,
        greater_is_better=False,
        learning_rate=5e-4,
        warmup_steps=1000,
        weight_decay=0.01,
        max_grad_norm=1.0,
        logging_steps=100,
        seed=42,
    )
    
    # Only add custom callbacks - Databricks handles MLflow automatically
    callbacks = []
    
    # Add taxa-specific callback if you want sample generation
    if eval_dataset:
        callbacks.append(TaxaModelCallback(tokenizer=tokenizer, generation_length=50))
        callbacks.append(EarlyStoppingCallback(
            early_stopping_patience=20,
            early_stopping_threshold=0.001
        ))
    
    # Initialize trainer - Databricks will handle MLflow logging automatically
    trainer = Trainer(  # Use standard Trainer
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        callbacks=callbacks if callbacks else None
    )
    
    # Train the model - MLflow tracking happens automatically
    print(f"Starting training for {num_epochs} epochs...")
    trainer.train()
    
    # Save the trained model
    model.save_pretrained(output_dir)
    
    # Generate a few final samples to verify quality
    model.eval()
    print("\n=== Final Generated Samples ===")
    for i in range(3):
        input_ids = torch.tensor([[tokenizer.bos_token_id]]).to(model.device)
        with torch.no_grad():
            generated = model.generate(
                input_ids,
                max_length=50,
                do_sample=True,
                temperature=0.8,
            )
        sequence = tokenizer.decode(generated[0], skip_special_tokens=True)
        print(f"Sample {i+1}: {sequence}")
    
    print(f"Training complete! Model saved to {output_dir}")
    return model, tokenizer

In [0]:
mlflow.end_run()

In [0]:
# Ensure MLflow autologging is enabled

model, tokenizer = train_model(model, tokenizer, train_sequences, test_sequences, "./gpt2_taxa_seq_model")
    

# evaluate on sequence completion task

In [0]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("./gpt2_taxa_seq_model")


In [0]:
%pip install python-Levenshtein

In [0]:
import pickle
import numpy as np
from Levenshtein import distance

# First, generate and save sequences once
def generate_and_save_sequences(model, test_sequences, save_path='generated_sequences.pkl', 
                               prefix_ratio=0.3, num_samples=None):
    """Generate sequences once and save them for later metric calculation"""
    generated_data = []
    
    # Use subset if specified
    sequences_to_process = test_sequences[:num_samples] if num_samples else test_sequences
    
    for idx, test_seq in enumerate(sequences_to_process):
        if idx % 100 == 0:
            print(f"Processing sequence {idx}/{len(sequences_to_process)}")
        
        seq_tokens = test_seq.split()
        prefix_len = max(3, int(len(seq_tokens) * prefix_ratio))  # Use at least 3 tokens as prefix
        prefix_tokens = seq_tokens[:prefix_len]
        target_tokens = seq_tokens[prefix_len:]

        prefix = " ".join(prefix_tokens)
        target = " ".join(target_tokens)
        
        # Generate completion using the prefix
        generated = generate_from_seed(model, tokenizer, prefix, max_length=len(seq_tokens) + 5)
        
        # Extract the generated continuation (after the prefix)
        generated_completion = " ".join(generated.split()[prefix_len:])



        # # Generate sequence
        # generated_tokens = model.generate(prefix_tokens, max_new_tokens=len(target_tokens))
        
        # Save all relevant data
        generated_data.append({
            'prefix_tokens': prefix_tokens,
            'target_tokens': target_tokens,
            'generated_tokens': generated_completion.split()
        })
    
    # Save to file
    with open(save_path, 'wb') as f:
        pickle.dump(generated_data, f)
    
    print(f"Saved {len(generated_data)} generated sequences to {save_path}")
    return generated_data

# Then calculate metrics from saved data
def calculate_metrics_from_saved(save_path='generated_sequences.pkl'):
    """Calculate multiple metrics from saved generated sequences"""
    
    # Load saved data
    with open(save_path, 'rb') as f:
        generated_data = pickle.load(f)
    
    # Initialize metric lists
    jaccard_scores = []
    f1_scores = []
    weighted_jaccard_scores = []
    top_k_scores = []
    normalized_edit_scores = []
    
    # Your original accuracy for comparison
    accuracies = []
    
    for data in generated_data:
        target_tokens = data['target_tokens']
        generated_tokens = data['generated_tokens']
        
        # accuracy calculation (for comparison)
        matches = sum(1 for t, g in zip(target_tokens, generated_tokens) if t == g)
        accuracy = matches / len(target_tokens) if target_tokens else 1.0
        accuracies.append(accuracy)
        
        # Set-based metrics
        set_target = set(target_tokens)
        set_generated = set(generated_tokens)
        
        # Jaccard similarity
        jaccard = len(set_target & set_generated) / len(set_target | set_generated) if set_target | set_generated else 0
        jaccard_scores.append(jaccard)
        
        # F1 score
        precision = len(set_target & set_generated) / len(set_generated) if set_generated else 0
        recall = len(set_target & set_generated) / len(set_target) if set_target else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0
        f1_scores.append(f1)
        
        # Position-weighted Jaccard
        def calculate_weighted_jaccard(target, generated, position_weight=0.9):
            score = 0
            for i, taxon in enumerate(generated):
                if taxon in target:
                    target_pos = target.index(taxon)
                    weight = position_weight ** abs(i - target_pos)
                    score += weight
            return score / len(target) if target else 0
        
        weighted_jaccard = calculate_weighted_jaccard(target_tokens, generated_tokens)
        weighted_jaccard_scores.append(weighted_jaccard)
        
        # Top-k accuracy (top-10)
        k = min(10, len(target_tokens))  # Handle cases where sequence is shorter than 10
        top_k_accuracy = len(set(target_tokens[:k]) & set(generated_tokens[:k])) / k if k > 0 else 0
        top_k_scores.append(top_k_accuracy)
        
        # Normalized edit distance
        edit_dist = distance(target_tokens, generated_tokens)
        max_len = max(len(target_tokens), len(generated_tokens))
        normalized_edit = 1 - (edit_dist / max_len) if max_len > 0 else 1
        normalized_edit_scores.append(normalized_edit)
    
    # Calculate averages
    results = {
        'accuracy': np.mean(accuracies),
        'jaccard_similarity': np.mean(jaccard_scores),
        'f1_score': np.mean(f1_scores),
        'weighted_jaccard': np.mean(weighted_jaccard_scores),
        'top_k_accuracy': np.mean(top_k_scores),
        'normalized_edit_distance': np.mean(normalized_edit_scores),
        # Also include standard deviations
        'accuracy_std': np.std(accuracies),
        'jaccard_similarity_std': np.std(jaccard_scores),
        'f1_score_std': np.std(f1_scores),
        'weighted_jaccard_std': np.std(weighted_jaccard_scores),
        'top_k_accuracy_std': np.std(top_k_scores),
        'normalized_edit_distance_std': np.std(normalized_edit_scores)
    }
    
    return results


In [0]:

# Generate once and save
generated_data = generate_and_save_sequences(
    model, 
    test_sequences, 
    save_path='generated_sequences_test.pkl',
    num_samples=500
)

# Calculate metrics from saved data
metrics = calculate_metrics_from_saved('generated_sequences_test.pkl')

# Print results
print("\nEvaluation Results:")
print(f"Accuracy: {metrics['accuracy']:.4f} ± {metrics['accuracy_std']:.4f}")
print(f"Jaccard Similarity: {metrics['jaccard_similarity']:.4f} ± {metrics['jaccard_similarity_std']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f} ± {metrics['f1_score_std']:.4f}")
print(f"Weighted Jaccard: {metrics['weighted_jaccard']:.4f} ± {metrics['weighted_jaccard_std']:.4f}")
print(f"Top-10 Accuracy: {metrics['top_k_accuracy']:.4f} ± {metrics['top_k_accuracy_std']:.4f}")
print(f"Normalized Edit Distance: {metrics['normalized_edit_distance']:.4f} ± {metrics['normalized_edit_distance_std']:.4f}")

In [0]:
matthieu : distance entre les vrai et predicted order  à tester . 

In [0]:
matthieu: can we integrate aussi 

In [0]:
# def evaluate_sequence_completion(model, tokenizer, test_sequences, prefix_ratio=0.5):
#     model.eval()
#     accuracies = []
    
#     for sequence in test_sequences:
#         # Split the sequence into prefix and target
#         tokens = sequence.split()
#         prefix_len = max(3, int(len(tokens) * prefix_ratio))  # Use at least 3 tokens as prefix
        
#         prefix = " ".join(tokens[:prefix_len])
#         target = " ".join(tokens[prefix_len:])
        
#         # Generate completion using the prefix
#         generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
        
#         # Extract the generated continuation (after the prefix)
#         generated_completion = " ".join(generated.split()[prefix_len:])
        
#         # Calculate accuracy (exact match between tokens)
#         target_tokens = target.split()
#         generated_tokens = generated_completion.split()[:len(target_tokens)]  # Truncate to target length
        
#         # If generated is shorter, pad with dummy values that won't match
#         if len(generated_tokens) < len(target_tokens):
#             generated_tokens.extend(["DUMMY"] * (len(target_tokens) - len(generated_tokens)))
        
#         # Calculate token-level accuracy
#         matches = sum(1 for t, g in zip(target_tokens, generated_tokens) if t == g)
#         accuracy = matches / len(target_tokens) if target_tokens else 1.0
#         accuracies.append(accuracy)
        
#     return {
#         "mean_accuracy": np.mean(accuracies),
#         "individual_accuracies": accuracies
#     }


In [0]:
completion_results = evaluate_sequence_completion(model, tokenizer, test_sequences)
print(f"Mean sequence completion accuracy: {completion_results['mean_accuracy']:.4f}")
    

In [0]:
len(test_sequences)

In [0]:

# Example of sequence completion
example_idx = np.random.randint(0, len(test_sequences))
example_sequence = test_sequences[example_idx]
tokens = example_sequence.split()
prefix_len = max(3, int(len(tokens) * 0.5))
prefix = " ".join(tokens[:prefix_len])

print("\nExample sequence completion:")
print(f"Prefix: {prefix}")
print(f"Original completion: {' '.join(tokens[prefix_len:])}")
generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
print(f"Model completion: {' '.join(generated.split()[prefix_len:])}")



In [0]:
len(tokens), len(generated.split())


In [0]:
intersection = set(tokens).intersection(set(generated.split()))
percentage_intersection = (len(intersection) / len(tokens)) * 100
percentage_intersection

In [0]:
import torch

matthieu: extract embedding ! and compare effect diet 

In [0]:
# Extract embeddings for a few test sequences
print("\nExtracting embeddings for supervised learning...")
sample_sequences = test_sequences[:5]  # Just use a few sequences for demonstration
embeddings = extract_embeddings(model, tokenizer, sample_sequences)
print(f"Extracted embeddings shape: {embeddings.shape}")
print("These embeddings can now be used for classification or regression tasks.")


In [0]:
https://adb-7744086575777980.0.azuredatabricks.net/explore/data/onesource_eu_dev_rni/onebiome/mpa_2?o=7744086575777980