In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from collections import Counter

from helper_funcs import generate_sequences


In [2]:
# Load data
loaded_df = pd.read_hdf('./data/sample_otu_arrays.h5', key='df')

# Set random seed
np.random.seed(42)

# Split indices into train/test
train_idx, test_idx = train_test_split(loaded_df.index, test_size=0.2, random_state=42)

# Create train and test dataframes
train_df = loaded_df.loc[train_idx]
test_df = loaded_df.loc[test_idx]

print(f"Train size: {len(train_df)}")
print(f"Test size: {len(test_df)}")
print("\nFirst few training samples:")
print(train_df.head())

# Let's also look at array lengths
array_lengths = [len(x) for x in loaded_df['otu_arrays']]
print(f"\nMin array length: {min(array_lengths)}")
print(f"Max array length: {max(array_lengths)}")
print(f"Mean array length: {np.mean(array_lengths):.2f}")

Train size: 6486
Test size: 1622

First few training samples:
                                                            otu_arrays
Unnamed: 0                                                            
SRR044975.SRS011167  [30, 58, 82, 89, 93, 98, 99, 104, 117, 120, 12...
SRR049604.SRS049164  [9, 10, 11, 14, 15, 16, 17, 20, 28, 30, 31, 32...
SRR331714.SRS076947  [19, 30, 43, 58, 65, 70, 71, 74, 80, 90, 92, 9...
SRR089999.SRS077685  [12, 14, 18, 20, 22, 38, 45, 67, 68, 76, 88, 1...
SRR048091.SRS021563  [19, 30, 45, 52, 58, 60, 65, 70, 74, 80, 90, 9...

Min array length: 3
Max array length: 277
Mean array length: 69.10


In [36]:
import torch
from torch.utils.data import Dataset, DataLoader

class OTUDataset(Dataset):
   def __init__(self, df):
       self.df = df
       
       # Find max sequence length for padding
       self.max_len = max(len(x) for x in df['otu_arrays'])
       
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       # Get array for this sample
       array = self.df.iloc[idx]['otu_arrays']
       
       # Create padded tensor
       padded = torch.zeros(self.max_len, dtype=torch.long)
       padded[:len(array)] = torch.tensor(array)
       
       # Create mask (False where we have real tokens, True for padding)
       mask = torch.zeros(self.max_len, dtype=torch.bool)
       mask[len(array):] = True
       
       return padded, mask

# Create datasets
train_dataset = OTUDataset(train_df)
test_dataset = OTUDataset(test_df)

# Create dataloaders
batch_size = 11
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Verify shapes
for tokens, mask in train_loader:
   print(f"Batch tokens shape: {tokens.shape}")
   print(f"Batch mask shape: {mask.shape}")

   break

# Get vocab size (maximum token ID + 1 for padding)
vocab_size = max(max(x) for x in loaded_df['otu_arrays']) + 1
print(f"\nVocabulary size: {vocab_size}")

Batch tokens shape: torch.Size([11, 277])
Batch mask shape: torch.Size([11, 277])

Vocabulary size: 519


In [31]:
import model_arch
import helper_funcs
import importlib
from model_arch import CategoricalScoreDiffusion
from helper_funcs import generate_sequences
importlib.reload(model_arch)
importlib.reload(helper_funcs)

<module 'helper_funcs' from 'd:\\Projects\\UZH\\cdcd_hmp\\helper_funcs.py'>

In [40]:

class TrainingMetrics:
    def __init__(self):
        self.best_val_loss = float('inf')

        
    def update_best_metrics(self, val_loss):
        improved = False
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            improved = True
        return improved

def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)

    # Get clean embeddings
    x0 = model.embedding(tokens)
  
    
    # Add noise
    noise = model.get_noise(x0, t)

    xt = x0 + noise

    
    # Get model predictions
    logits = model(xt, mask, t)

    
    # Compute loss
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        model.update_time_warping(t, loss.detach())
        loss.backward()
        optimizer.step()
    
    return loss.item()

def validation_step(model, tokens, mask, device):
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)
    
    # Get clean embeddings
    x0 = model.embedding(tokens)
    
    # Add noise according to N(0, σt²)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    
    # Get model predictions
    logits = model(xt, mask, t)
    
    # Compute cross-entropy loss with padding handling
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0  # Assuming 0 is padding token
    )
    
    return loss.item()

def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        
    }
    torch.save(checkpoint, 'best_model.pt')

def log_metrics(metrics_dict, step_type='batch'):
    wandb.log(metrics_dict)

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    return train_loss / len(train_loader)

def validate_epoch(model, test_loader, device, epoch):
    model.eval()
    val_loss = 0
    val_bar = tqdm(test_loader, desc=f'Validation Epoch {epoch}')
    
    # Collect real sequences
    real_sequences = []
    with torch.no_grad():
        for tokens, mask in val_bar:
            tokens = tokens.to(device)
            mask = mask.to(device)
            
            loss = validation_step(model, tokens, mask, device)
            val_loss += loss
            val_bar.set_postfix({'loss': f'{loss:.4f}'})
            
            real_sequences.extend([seq[seq != 0].cpu().numpy() for seq in tokens])

    
    return val_loss / len(test_loader)



def train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device, use_lr_scheduling=True):
    metrics = TrainingMetrics()
    
    scheduler = None
    if use_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=3, factor=0.5, verbose=True
        )
    
    for epoch in range(num_epochs):
        # Training phase
        avg_train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        log_metrics({'train/epoch_loss': avg_train_loss, 'epoch': epoch})
         
        # Validation phase (every 5 epochs)
        if epoch % 1 == 0:
            avg_val_loss = validate_epoch(model, test_loader, device, epoch)
            
            log_metrics({
                'val/epoch_loss': avg_val_loss,
                'epoch': epoch
            })
            
            print(f'\nEpoch {epoch}:')
            print(f'Average Train Loss: {avg_train_loss:.4f}')
            print(f'Average Val Loss: {avg_val_loss:.4f}')
         
            
            if scheduler:
                scheduler.step(avg_val_loss)
            
            if metrics.update_best_metrics(avg_val_loss):
                save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss)
                log_metrics({
                    'best_model/val_loss': avg_val_loss,
                    'best_model/train_loss': avg_train_loss,
                    'best_model/epoch': epoch
                })
        else:
            print(f'\nEpoch {epoch}: Average Train Loss: {avg_train_loss:.4f}\n')

In [43]:
# Initialize model
embed_dim = 42 #8 
num_layers = 6 #5
num_heads = 6
dim_feedforward = 16 #32
num_fourier_features = 16# going from 4 to 8 destabilised the batch loss but seems o have resulted in a faster convergence and lower
model = CategoricalScoreDiffusion(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dim_feedforward=dim_feedforward,
    num_fourier_features=num_fourier_features
    
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Move model to device
model = model.to(device)


In [44]:
import wandb
num_epochs = 200
learning_rate = 0.0014763510861459355

wandb.finish()
wandb.init(
    project="diffusion-hmp",
    config={
        "learning_rate": learning_rate,
        "architecture": "restart",
        "dataset": "hmp",
        "epochs": num_epochs,
        "embed_dim": embed_dim,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "dim_feedforward": dim_feedforward,
        "vocab_size": vocab_size,
        "num_fourier_features":num_fourier_features
    }
)

In [39]:
# Training parameters
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Start training
train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device)

Training Epoch 0:  12%|█▏        | 73/590 [00:03<00:25, 19.94it/s, loss=5.7502]

In [117]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'model_args': {
        'vocab_size': vocab_size,
        'embed_dim': embed_dim,
        'num_layers': num_layers,
        'num_heads': num_heads,
        'dim_feedforward': dim_feedforward,
        'num_fourier_features': num_fourier_features
    }
}

torch.save(checkpoint, 'model_checkpoint_3.58pt')

In [None]:
from model_arch import CategoricalScoreDiffusion

checkpoint = torch.load('model_checkpoint_2.65.pt')
model = CategoricalScoreDiffusion(**checkpoint['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Access the learning rate
# Get the optimizer state dict
optimizer_state = checkpoint['optimizer_state_dict']
learning_rate = optimizer_state['param_groups'][0]['lr']
print(f"Learning rate: {learning_rate}")

Learning rate: 0.001


In [8]:
import time
from contextlib import contextmanager

@contextmanager
def timer(name):
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    print(f"{name}: {(end - start)*1000:.2f} ms")