In [0]:
%load_ext autoreload
%autoreload 2

# Step 1: First, isolate the issue

In [0]:
import logging
import traceback

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('gpt2_pretraining')

In [0]:


try:
    
    logger.info("Starting tokenization...")
    TOKENIZER_PATH = "models/simple_taxa_tokenizer"  # Set to None to use standard GPT-2
    SAMPLE_TEXTS = [
        "This is a simple test sentence.",
        "Let's test some domain-specific content that your model might see.",
        "Ruminococcus Phocaeicola Bacteroides Faecalibacterium Eubacterium Roseburia Alistipes"
    ]    
    success, message = test_tokenizer(TOKENIZER_PATH, SAMPLE_TEXTS)
    logger.info(f"Test result: {message}")
    logger.info("Tokenization complete")


    logger.info("Preparing dataset...")
    DATA_PATH = "../../data/taxa_sequences.txt"
    TOKENIZER_PATH = "gpt2"  # Update this
    success, message = test_data_loading(DATA_PATH, TOKENIZER_PATH)
    logger.info(f"Test result: {message}")
    logger.info("Dataset preparation complete")
    

    logger.info("Initializing model...")
    TOKENIZER_PATH = "gpt2"  # Update this
    
    # Custom config if needed
    custom_config = GPT2Config(
        vocab_size=50257,  # Update with your tokenizer's vocab size
        n_positions=512,
        n_ctx=512,
        n_embd=768,
        n_layer=6,  # Smaller than standard GPT-2 for faster testing
        n_head=12,
    )    
    success, message = test_model(TOKENIZER_PATH, config=custom_config)
    logger.info(f"Test result: {message}")
    logger.info("Model initialization complete")
    

    logger.info("Starting training...")
    TOKENIZER_PATH = "gpt2"  # Update this
    success, message = test_training(TOKENIZER_PATH)
    logger.info(f"Test result: {message}")
    logger.info("Training complete")


except Exception as e:
    logger.error(f"Error occurred: {str(e)}")
    logger.error(traceback.format_exc())

# step 2: test the tokenizer


In [0]:
import logging
from transformers import GPT2TokenizerFast
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("tokenizer_test")

def test_tokenizer(tokenizer_path=None, sample_texts=None):
    """Test if the tokenizer works correctly."""
    try:
        # If you're using a custom tokenizer
        if tokenizer_path and os.path.exists(tokenizer_path):
            logger.info(f"Loading tokenizer from {tokenizer_path}")
            tokenizer = SimpleTokenizer.from_pretrained(tokenizer_path)
        else:
            # Fallback to standard GPT-2 tokenizer for comparison
            logger.info("Loading standard GPT-2 tokenizer")
            tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        
        # Set padding token - THIS IS THE FIX
        logger.info("Setting pad_token to be the same as eos_token")
        # tokenizer.pad_token = tokenizer.eos_token
        
        # Test with sample texts
        if not sample_texts:
            sample_texts = [
                "This is a simple test sentence.",
                "Let's test some domain-specific content that your model might see."
            ]
        
        logger.info("Testing tokenization on sample texts:")
        for text in sample_texts:
            tokens = tokenizer.tokenize(text)
            token_ids = tokenizer.encode(text)
            decoded = tokenizer.decode(token_ids)
            
            logger.info(f"\nOriginal: {text}")
            logger.info(f"Tokenized: {tokens}")
            logger.info(f"Token IDs: {token_ids}")
            logger.info(f"Decoded: {decoded}")
            logger.info(f"Roundtrip successful: {text == decoded}")
        
        # Test batch encoding
        logger.info("\nTesting batch encoding:")
        batch_encoding = tokenizer(sample_texts, padding=True, truncation=True, return_tensors="pt")
        logger.info(f"Batch shape: {batch_encoding['input_ids'].shape}")
        
        return True, "Tokenizer test completed successfully"
    
    except Exception as e:
        logger.error(f"Tokenizer test failed: {str(e)}")
        logger.error(traceback.format_exc())
        return False, f"Error: {str(e)}"

# Execute the test
if __name__ == "__main__":
    # Path to your custom tokenizer if available
    TOKENIZER_PATH = "models/simple_taxa_tokenizer"  # Set to None to use standard GPT-2
    
    # Sample texts from your domain
    SAMPLE_TEXTS = [
        # "This is a simple test sentence.",
        # "Let's test some domain-specific content that your model might see.",
        "Roseburia Ruminococcus Streptococcus Dorea Faecalibacterium Anaerostipes Bifidobacterium Blautia Anaerobutyricum Agathobaculum Collinsella Klebsiella Fusicatenibacter Bacteroides Eubacterium Gemmiger Adlercreutzia Phocaeicola Alistipes Barnesiella Firmicutes",
        "Ruminococcus Phocaeicola Bacteroides Faecalibacterium Eubacterium Roseburia Alistipes"
    ]
    
    success, message = test_tokenizer(TOKENIZER_PATH, SAMPLE_TEXTS)
    logger.info(f"Test result: {message}")

In [0]:
message

# Step 3: Test data loading and preprocessing

In [0]:
import logging
import traceback
import pandas as pd
import numpy as np
from transformers import GPT2TokenizerFast
import torch
from torch.utils.data import Dataset, DataLoader

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("data_test")

class SequenceDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(texts, truncation=True, padding="max_length", 
                                  max_length=max_length, return_tensors="pt")
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = item["input_ids"].clone()
        return item
        
    def __len__(self):
        return len(self.encodings["input_ids"])

def test_data_loading(data_path, tokenizer_path, batch_size=4, max_samples=10):
    """Test if the data can be loaded and processed correctly."""
    try:
        logger.info(f"Loading tokenizer from {tokenizer_path}")
        tokenizer = SimpleTokenizer.from_pretrained(tokenizer_path)
        tokenizer.pad_token = tokenizer.eos_token
        
        logger.info(f"Loading data from {data_path}")
        # Modify this part based on your data format
        if data_path.endswith('.csv'):
            df = pd.read_csv(data_path)
            text_column = "text"  # Update this to your column name
            texts = df[text_column].tolist()[:max_samples]
        elif data_path.endswith('.txt'):
            with open(data_path, 'r') as f:
                texts = [line.strip() for line in f.readlines()[:max_samples]]
        else:
            # Add other data formats as needed
            raise ValueError(f"Unsupported data format: {data_path}")
            
        logger.info(f"Loaded {len(texts)} text samples")
        logger.info(f"Sample text: {texts[0][:100]}...")
        
        # Create dataset
        logger.info("Creating dataset")
        dataset = SequenceDataset(texts, tokenizer)
        
        # Create dataloader
        logger.info("Creating dataloader")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Test batch iteration
        logger.info("Testing batch iteration")
        for i, batch in enumerate(dataloader):
            logger.info(f"Batch {i+1} shape: {batch['input_ids'].shape}")
            if i >= 2:  # Just test a few batches
                break
                
        logger.info("Data loading and processing test completed successfully")
        return True, "Data test completed successfully"
        
    except Exception as e:
        logger.error(f"Data test failed: {str(e)}")
        logger.error(traceback.format_exc())
        return False, f"Error: {str(e)}"

# # Execute the test
# if __name__ == "__main__":
#     DATA_PATH = "../../data/taxa_sequences.txt"  # Update this
#     TOKENIZER_PATH = "gpt2"  # Update this
    
#     success, message = test_data_loading(DATA_PATH, TOKENIZER_PATH)
#     logger.info(f"Test result: {message}")

In [0]:
message

# step 4: test the model


In [0]:
import logging
import traceback
import torch
from transformers import GPT2Config, GPT2LMHeadModel, SimpleTokenizer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("model_test")

# In your model test
def test_model(tokenizer_path, config=None, sample_text="This is a test"):
    try:
        logger.info(f"Loading tokenizer from {tokenizer_path}")
        tokenizer = SimpleTokenizer.from_pretrained(tokenizer_path)
        
        # Add this line:
        tokenizer.pad_token = tokenizer.eos_token
        
        # Create or load model configuration
        if config is None:
            logger.info("Using default GPT-2 configuration")
            config = GPT2Config(
                vocab_size=len(tokenizer),
                n_positions=512,
                n_ctx=512,
                n_embd=768,
                n_layer=6,
                n_head=12,
            )
        
        # Initialize model
        logger.info("Initializing model")
        model = GPT2LMHeadModel(config)
        
        # Log model size
        total_params = sum(p.numel() for p in model.parameters())
        logger.info(f"Model initialized with {total_params:,} parameters")
        
        # Test forward pass
        logger.info("Testing forward pass")
        inputs = tokenizer(sample_text, return_tensors="pt")
        
        # Get available device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        
        # Move model and inputs to device
        model = model.to(device)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Add labels for loss calculation
        inputs["labels"] = inputs["input_ids"].clone()
        
        # Forward pass with loss calculation
        logger.info("Performing forward pass with loss calculation")
        with torch.no_grad():
            outputs = model(**inputs)
            
        logger.info(f"Forward pass successful, loss: {outputs.loss.item()}")
        
        # Test generation
        logger.info("Testing text generation")
        input_ids = inputs["input_ids"]
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids,
                max_length=50,
                num_return_sequences=1,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        logger.info(f"Generated text: {generated_text}")
        
        return True, "Model test completed successfully"
    
    except Exception as e:
        logger.error(f"Model test failed: {str(e)}")
        logger.error(traceback.format_exc())
        return False, f"Error: {str(e)}"

# Execute the test
# if __name__ == "__main__":
    # TOKENIZER_PATH = "gpt2"  # Update this
    
    # # Custom config if needed
    # custom_config = GPT2Config(
    #     vocab_size=50257,  # Update with your tokenizer's vocab size
    #     n_positions=512,
    #     n_ctx=512,
    #     n_embd=768,
    #     n_layer=6,  # Smaller than standard GPT-2 for faster testing
    #     n_head=12,
    # )
    
    # success, message = test_model(TOKENIZER_PATH, config=custom_config)
    # logger.info(f"Test result: {message}")

In [0]:
message

# Step 5: Test minimal training loop

In [0]:
import logging
import traceback
import torch
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast, AdamW
from torch.utils.data import Dataset, DataLoader

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("training_test")

class MinimalDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(texts, truncation=True, padding="max_length", 
                                  max_length=max_length, return_tensors="pt")
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = item["input_ids"].clone()
        return item
        
    def __len__(self):
        return len(self.encodings["input_ids"])

def test_training(tokenizer_path, num_steps=3):
    """Test a minimal training loop."""
    try:
        # Load tokenizer
        logger.info(f"Loading tokenizer from {tokenizer_path}")
        tokenizer = SimpleTokenizer.from_pretrained(tokenizer_path)
        tokenizer.pad_token = tokenizer.eos_token
        
        # Create sample data
        sample_texts = [
            "This is a test sentence for training GPT-2.",
            "Let's see if we can train the model without errors.",
            "Databricks runtime should be able to handle this small test.",
            "If this works, we can move on to actual training."
        ]
        
        # Create dataset and dataloader
        logger.info("Creating sample dataset")
        dataset = MinimalDataset(sample_texts, tokenizer)
        dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
        
        # Initialize model
        logger.info("Initializing model")
        config = GPT2Config(
            vocab_size=len(tokenizer),
            n_positions=512,
            n_ctx=512,
            n_embd=768,
            n_layer=4,  # Small model for testing
            n_head=12,
        )
        model = GPT2LMHeadModel(config)
        
        # Get device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")
        model = model.to(device)
        
        # Setup optimizer
        logger.info("Setting up optimizer")
        optimizer = AdamW(model.parameters(), lr=5e-5)
        
        # Training loop
        logger.info("Starting mini training loop")
        model.train()
        
        for step in range(num_steps):
            batch = next(iter(dataloader))
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(**batch)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            logger.info(f"Step {step+1}/{num_steps}, Loss: {loss.item()}")
        
        logger.info("Mini training loop completed successfully")
        return True, "Training test completed successfully"
    
    except Exception as e:
        logger.error(f"Training test failed: {str(e)}")
        logger.error(traceback.format_exc())
        return False, f"Error: {str(e)}"

# Execute the test
# if __name__ == "__main__":
#     TOKENIZER_PATH = "gpt2"  # Update this
    
#     success, message = test_training(TOKENIZER_PATH)
#     logger.info(f"Test result: {message}")

In [0]:
message

# Step 6: Test Databricks Runtime

In [0]:
import logging
import traceback
import sys
import torch
import psutil
import os
from transformers import __version__ as transformers_version

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("runtime_test")

def test_runtime():
    """Test Databricks runtime environment."""
    try:
        # Python version
        logger.info(f"Python version: {sys.version}")
        
        # PyTorch version and CUDA availability
        logger.info(f"PyTorch version: {torch.__version__}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            logger.info(f"CUDA version: {torch.version.cuda}")
            logger.info(f"GPU count: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
                
        # Transformers version
        logger.info(f"Transformers version: {transformers_version}")
        
        # Memory info
        memory = psutil.virtual_memory()
        logger.info(f"Total memory: {memory.total / (1024**3):.2f} GB")
        logger.info(f"Available memory: {memory.available / (1024**3):.2f} GB")
        
        # CPU info
        logger.info(f"CPU count: {psutil.cpu_count()}")
        
        # Disk space
        disk = psutil.disk_usage('/')
        logger.info(f"Disk total: {disk.total / (1024**3):.2f} GB")
        logger.info(f"Disk free: {disk.free / (1024**3):.2f} GB")
        
        # Environment variables
        logger.info("Relevant environment variables:")
        for var in ['CUDA_VISIBLE_DEVICES', 'PYTHONPATH', 'LD_LIBRARY_PATH']:
            if var in os.environ:
                logger.info(f"{var}: {os.environ[var]}")
        
        return True, "Runtime test completed successfully"
    
    except Exception as e:
        logger.error(f"Runtime test failed: {str(e)}")
        logger.error(traceback.format_exc())
        return False, f"Error: {str(e)}"

# Execute the test
if __name__ == "__main__":
    success, message = test_runtime()
    logger.info(f"Test result: {message}")

In [0]:
message

# start


In [0]:
import os
import pandas as pd
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Config, 
    GPT2LMHeadModel, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling
)
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from utils_gpt import *

In [0]:
import torch

# Force CUDA initialization at the start
if torch.cuda.is_available():
    torch.zeros(1).cuda()  # Triggers CUDA context init
    torch.cuda.manual_seed_all(42)  # Now safe to set seeds

# customized tokenizer

In [0]:
import os
import json
import torch

class SimpleTokenizer:
    def __init__(self, token_to_id, id_to_token):
        self.token_to_id = token_to_id
        self.id_to_token = id_to_token
        self.pad_token_id = token_to_id["<pad>"]
        self.bos_token_id = token_to_id["<s>"]
        self.eos_token_id = token_to_id["</s>"]
        self.unk_token_id = token_to_id["<unk>"]
        self.vocab_size = len(token_to_id)
        
        # Special token attributes expected by HF transformers
        self.all_special_ids = [self.pad_token_id, self.bos_token_id, self.eos_token_id, self.unk_token_id]
        self.model_max_length = MAX_SEQ_LENGTH

        # Special token properties that HF expects
        self.pad_token = "<pad>"
        self.bos_token = "<s>"
        self.eos_token = "</s>"
        self.unk_token = "<unk>"
        
        # Map IDs to special tokens
        self.special_ids_to_tokens = {
            self.pad_token_id: self.pad_token,
            self.bos_token_id: self.bos_token,
            self.eos_token_id: self.eos_token,
            self.unk_token_id: self.unk_token
        }

    def tokenize(self, text):
        """Split text into tokens"""
        if isinstance(text, str):
            return text.split()
        return text

    def __call__(self, sequence, max_length=None, padding=False, truncation=False, return_tensors=None):
        """Make the tokenizer callable like HF tokenizers"""
        return self.encode(sequence, max_length, padding, truncation, return_tensors)
    

    def encode(self, sequence, max_length=None, padding=False, truncation=False, return_tensors=None):
        """Convert a sequence or batch of sequences to token IDs"""
        # Check if it's a batch (list of strings)
        if isinstance(sequence, list) and all(isinstance(s, str) for s in sequence):
            # Process batch
            batch_ids = []
            batch_attention_masks = []
            
            # First pass: tokenize and truncate
            for seq in sequence:
                seq_tokens = seq.split()
                
                # Convert tokens to ids
                ids = [self.token_to_id.get(token, self.unk_token_id) for token in seq_tokens]
                
                # Apply truncation if needed
                if truncation and max_length and len(ids) > max_length:
                    ids = ids[:max_length]
                
                batch_ids.append(ids)
            
            # Determine the padding length
            if padding:
                if max_length is None:
                    # Pad to the longest sequence in the batch
                    max_length = max(len(ids) for ids in batch_ids)
                
                # Second pass: pad all sequences to max_length
                for i, ids in enumerate(batch_ids):
                    original_length = len(ids)
                    attention_mask = [1] * original_length
                    
                    padding_length = max_length - original_length
                    if padding_length > 0:
                        ids = ids + [self.pad_token_id] * padding_length
                        attention_mask = attention_mask + [0] * padding_length
                    
                    batch_ids[i] = ids
                    batch_attention_masks.append(attention_mask)
            else:
                # No padding - just create attention masks
                for ids in batch_ids:
                    attention_mask = [1] * len(ids)
                    batch_attention_masks.append(attention_mask)
            
            # Return tensors if requested
            if return_tensors == "pt":
                return {
                    "input_ids": torch.tensor(batch_ids),
                    "attention_mask": torch.tensor(batch_attention_masks)
                }
            else:
                return {
                    "input_ids": batch_ids,
                    "attention_mask": batch_attention_masks
                }
        else:
            # Single sequence processing (original code)
            if isinstance(sequence, str):
                sequence = sequence.split()
                
            # Convert tokens to ids
            ids = [self.token_to_id.get(token, self.unk_token_id) for token in sequence]
            
            # Apply truncation if needed
            if truncation and max_length and len(ids) > max_length:
                ids = ids[:max_length]
                
            # Apply padding if needed
            attention_mask = [1] * len(ids)
            if padding and max_length:
                padding_length = max_length - len(ids)
                ids = ids + [self.pad_token_id] * padding_length
                attention_mask = attention_mask + [0] * padding_length
            
            # Return tensors if requested
            if return_tensors == "pt":
                return {
                    "input_ids": torch.tensor([ids]),
                    "attention_mask": torch.tensor([attention_mask])
                }
            else:
                return {
                    "input_ids": ids,
                    "attention_mask": attention_mask
                }
    
    def decode(self, ids, skip_special_tokens=False):
        """Convert token IDs back to a sequence"""
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
            
        tokens = []
        for id in ids:
            if skip_special_tokens and id in self.all_special_ids:
                continue
            tokens.append(self.id_to_token.get(id, "<unk>"))
            
        return " ".join(tokens)
    
    def save_pretrained(self, output_dir):
        """Save tokenizer to disk"""
        os.makedirs(output_dir, exist_ok=True)
        
        # Save the vocabulary
        with open(os.path.join(output_dir, "vocab.json"), "w") as f:
            # Convert int keys to strings for JSON serialization
            token_to_id_serializable = {k: v for k, v in self.token_to_id.items()}
            id_to_token_serializable = {str(k): v for k, v in self.id_to_token.items()}
            json.dump({
                "token_to_id": token_to_id_serializable,
                "id_to_token": id_to_token_serializable
            }, f)
            
    @classmethod
    def from_pretrained(cls, input_dir):
        """Load tokenizer from disk"""
        with open(os.path.join(input_dir, "vocab.json"), "r") as f:
            data = json.load(f)
            token_to_id = data["token_to_id"]
            # Convert string keys back to integers
            id_to_token = {int(k): v for k, v in data["id_to_token"].items()}
            
        return cls(token_to_id, id_to_token)

In [0]:
MAX_SEQ_LENGTH = 150
tokenizer_path = "models/simple_taxa_tokenizer"

# load mapping dictionaries
import json
with open('../../data/token_to_id.json', 'r') as f:
    token_to_id = json.load(f)
id_to_token = {idx: token for token, idx in token_to_id.items()}

# Add special tokens
special_tokens = {"<pad>": 0, "<s>":  1, 
                    "</s>":  2, "<unk>": 3}
token_to_id.update(special_tokens)
id_to_token.update({v: k for k, v in special_tokens.items()})

tokenizer = SimpleTokenizer(token_to_id, id_to_token)
tokenizer.save_pretrained(tokenizer_path)

In [0]:
SAMPLE_TEXTS = [
        # "This is a simple test sentence.",
        # "Let's test some domain-specific content that your model might see.",
        "Roseburia Ruminococcus Streptococcus Dorea Faecalibacterium Anaerostipes Bifidobacterium Blautia Anaerobutyricum Agathobaculum Collinsella Klebsiella Fusicatenibacter Bacteroides Eubacterium Gemmiger Adlercreutzia Phocaeicola Alistipes Barnesiella Firmicutes",
        "Ruminococcus Phocaeicola Bacteroides Faecalibacterium Eubacterium Roseburia Alistipes"
    ]
    

In [0]:
tokenizer.encode(SAMPLE_TEXTS, padding=True, truncation=True,
                  return_tensors='pt')

In [0]:
# custom dataset class
class TaxaSequenceDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length=128):
        """
        Args:
            sequences: List of sequences, where each sequence is a list of string tokens
            tokenizer: Simple tokenizer instance
            max_length: Maximum sequence length
        """
        self.inputs = []
        
        for seq in sequences:
            encoded = tokenizer.encode(
                seq,
                max_length=max_length,
                padding="max_length",
                truncation=True, 
                return_tensors="pt"
            )
            self.inputs.append(encoded)    

            # self.inputs.append({
            #     "input_ids": encoded["input_ids"],
            #     "attention_mask": encoded["attention_mask"]
            # })
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        # Get raw tensors
        input_ids = self.inputs[idx]["input_ids"]
        attention_mask = self.inputs[idx]["attention_mask"]
        
        # Ensure both tensors are 1D (flatten if needed)
        if input_ids.dim() > 1:
            input_ids = input_ids.view(-1)
        if attention_mask.dim() > 1:
            attention_mask = attention_mask.view(-1)
            
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

# configure and initialize the model

model chosen: GPT-2 Small (gpt2) 

- It's autoregressive, suited for generation task (taxa completion)
- practical to extract embeddings for downstream supervised learning tasks
- It handles variable sequence lengths well
- With 355 unique values, no worry for vocabulary size issues

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
def create_gpt2_model(vocab_size):
    """Create a GPT-2 model with custom vocab size"""
    config = GPT2Config(
        vocab_size=vocab_size,
        n_positions=MAX_SEQ_LENGTH,
        n_ctx=MAX_SEQ_LENGTH,
        n_embd=64,  # Smaller embedding size
        n_layer=6,   # Fewer layers for faster training
        n_head=8,    # Fewer attention heads
        bos_token_id=1,  # <s>
        eos_token_id=2,  # </s>
        pad_token_id=0,   # <pad>
        attn_pdrop = 0.0,  # Attention dropout
        embd_pdrop = 0.0,  # Embedding dropout
        resid_pdrop = 0.0  # Residual dropout
    )
    
    model = GPT2LMHeadModel(config)
    return model

model = create_gpt2_model(tokenizer.vocab_size).to(device)

In [0]:
next(model.parameters()).device

Create a SimpleDataCollator class to replace the HF DataCollatorForLanguageModeling, to Properly handles padding and creates proper language modeling labels

In [0]:

class SimpleDataCollator:
    """Simple data collator for language modeling"""
    def __init__(self, tokenizer, mlm=False):
        self.tokenizer = tokenizer
        self.mlm = mlm
        
    def __call__(self, features):
        # Ensure consistent tensor shapes
        input_ids = [f["input_ids"] for f in features]
        attention_mask = [f["attention_mask"] for f in features]
        
        # Get max length
        max_len = max(len(ids) for ids in input_ids)
        
        # Pad all tensors to max length
        padded_input_ids = []
        padded_attention_mask = []
        
        for ids, mask in zip(input_ids, attention_mask):
            # Padding needed
            padding_len = max_len - len(ids)
            
            if padding_len > 0:
                # Pad with pad_token_id
                padded_ids = torch.cat([
                    ids, 
                    torch.full((padding_len,), self.tokenizer.pad_token_id, dtype=torch.long)
                ])
                padded_mask = torch.cat([
                    mask,
                    torch.zeros(padding_len, dtype=torch.long)
                ])
            else:
                padded_ids = ids
                padded_mask = mask
                
            padded_input_ids.append(padded_ids)
            padded_attention_mask.append(padded_mask)
        
        # Stack into batches
        batch = {
            "input_ids": torch.stack(padded_input_ids),
            "attention_mask": torch.stack(padded_attention_mask)
        }
        
        # For causal language modeling
        labels = batch["input_ids"].clone()
        # Mark padding as -100 to ignore in loss calculation
        labels[batch["input_ids"] == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels

        def debug_batch(batch):
            print("Input IDs shape:", batch["input_ids"].shape)
            print("Attention mask shape:", batch["attention_mask"].shape)
            print("Labels shape:", batch["labels"].shape)
            # Check if all tensors are on the same device
            print("Input IDs device:", batch["input_ids"].device)
            print("Attention mask device:", batch["attention_mask"].device)
            print("Labels device:", batch["labels"].device)
            # Check for any NaN values
            print("Any NaN in input_ids:", torch.isnan(batch["input_ids"]).any())
            print("Any NaN in attention_mask:", torch.isnan(batch["attention_mask"]).any())
            print("Any NaN in labels:", torch.isnan(batch["labels"]).any())
            # Print some values
            print("First sequence input_ids:", batch["input_ids"][0][:10])
            print("First sequence attention_mask:", batch["attention_mask"][0][:10])
            print("First sequence labels:", batch["labels"][0][:10])

        # Use this before model forward pass
        debug_batch(batch)
        
        return batch

# training

In [0]:
from sklearn.model_selection import train_test_split

# Load sequences
with open('../../data/taxa_sequences.txt', 'r') as file:
    sequences = file.readlines()

# Split sequences into train and test sets
train_sequences, test_sequences = train_test_split(sequences, test_size=0.2, random_state=42)

In [0]:
output_dir="./gpt2_taxa_seq_model"
train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
eval_dataset = TaxaSequenceDataset(test_sequences, tokenizer)

In [0]:
data_collator = SimpleDataCollator(
        tokenizer=tokenizer, 
        mlm=False
    )
# Set up training arguments
training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        fp16=True,
        dataloader_pin_memory=True,
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        save_total_limit=2,
        logging_dir=f"{output_dir}/logs",
        load_best_model_at_end=True if eval_dataset else False,
        full_determinism = False
    )

In [0]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"; os.environ["TORCH_USE_CUDA_DSA"] = "1"



trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
# Train the model
trainer.train()

# Save the trained model and tokenizer
model.save_pretrained(output_dir)

In [0]:

# 3. Training function (accept both training and evaluation datasets)
def train_model(model, tokenizer, train_sequences, eval_sequences=None, output_dir="./gpt2_taxa_seq_model"):
    train_dataset = TaxaSequenceDataset(train_sequences, tokenizer)
    
    # Prepare validation dataset if provided
    eval_dataset = None
    if eval_sequences:
        eval_dataset = TaxaSequenceDataset(eval_sequences, tokenizer)
    
    # Set up data collator with masked language modeling
    data_collator = SimpleDataCollator(
        tokenizer=tokenizer, 
        mlm=False
    )
    
    # Set up training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        evaluation_strategy="epoch" if eval_dataset else "no",
        save_strategy="epoch",
        save_total_limit=2,
        logging_dir=f"{output_dir}/logs",
        load_best_model_at_end=True if eval_dataset else False,
    )
    
    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    
    # Train the model
    trainer.train()
    
    # Save the trained model and tokenizer
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return model, tokenizer

In [0]:
model, tokenizer = train_model(model, tokenizer, train_sequences, test_sequences)
    

In [0]:
!nvidia-smi


# evaluate on sequence completion task

In [0]:
def evaluate_sequence_completion(model, tokenizer, test_sequences, prefix_ratio=0.5):
    model.eval()
    accuracies = []
    
    for sequence in test_sequences:
        # Split the sequence into prefix and target
        tokens = sequence.split()
        prefix_len = max(3, int(len(tokens) * prefix_ratio))  # Use at least 3 tokens as prefix
        
        prefix = " ".join(tokens[:prefix_len])
        target = " ".join(tokens[prefix_len:])
        
        # Generate completion using the prefix
        generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
        
        # Extract the generated continuation (after the prefix)
        generated_completion = " ".join(generated.split()[prefix_len:])
        
        # Calculate accuracy (exact match between tokens)
        target_tokens = target.split()
        generated_tokens = generated_completion.split()[:len(target_tokens)]  # Truncate to target length
        
        # If generated is shorter, pad with dummy values that won't match
        if len(generated_tokens) < len(target_tokens):
            generated_tokens.extend(["DUMMY"] * (len(target_tokens) - len(generated_tokens)))
        
        # Calculate token-level accuracy
        matches = sum(1 for t, g in zip(target_tokens, generated_tokens) if t == g)
        accuracy = matches / len(target_tokens) if target_tokens else 1.0
        accuracies.append(accuracy)
        
    return {
        "mean_accuracy": np.mean(accuracies),
        "individual_accuracies": accuracies
    }


In [0]:
completion_results = evaluate_sequence_completion(model, tokenizer, test_sequences)
print(f"Mean sequence completion accuracy: {completion_results['mean_accuracy']:.4f}")
    

In [0]:

# Example of sequence completion
example_idx = np.random.randint(0, len(test_sequences))
example_sequence = test_sequences[example_idx]
tokens = example_sequence.split()
prefix_len = max(3, int(len(tokens) * 0.5))
prefix = " ".join(tokens[:prefix_len])

print("\nExample sequence completion:")
print(f"Prefix: {prefix}")
print(f"Original completion: {' '.join(tokens[prefix_len:])}")
generated = generate_from_seed(model, tokenizer, prefix, max_length=len(tokens) + 5)
print(f"Model completion: {' '.join(generated.split()[prefix_len:])}")



In [0]:
# Extract embeddings for a few test sequences
print("\nExtracting embeddings for supervised learning...")
sample_sequences = test_sequences[:5]  # Just use a few sequences for demonstration
embeddings = extract_embeddings(model, tokenizer, sample_sequences)
print(f"Extracted embeddings shape: {embeddings.shape}")
print("These embeddings can now be used for classification or regression tasks.")
