In [93]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from collections import Counter

from helper_funcs import generate_sequences


In [94]:
# Load data
loaded_df = pd.read_hdf('./data/sample_otu_arrays.h5', key='df')

# Set random seed
np.random.seed(42)

# Split indices into train/test
train_idx, test_idx = train_test_split(loaded_df.index, test_size=0.2, random_state=42)

# Create train and test dataframes
train_df = loaded_df.loc[train_idx]
test_df = loaded_df.loc[test_idx]

print(f"Train size: {len(train_df)}")
print(f"Test size: {len(test_df)}")
print("\nFirst few training samples:")
print(train_df.head())

# Let's also look at array lengths
array_lengths = [len(x) for x in loaded_df['otu_arrays']]
print(f"\nMin array length: {min(array_lengths)}")
print(f"Max array length: {max(array_lengths)}")
print(f"Mean array length: {np.mean(array_lengths):.2f}")

Train size: 6486
Test size: 1622

First few training samples:
                                                            otu_arrays
Unnamed: 0                                                            
SRR044975.SRS011167  [30, 58, 82, 89, 93, 98, 99, 104, 117, 120, 12...
SRR049604.SRS049164  [9, 10, 11, 14, 15, 16, 17, 20, 28, 30, 31, 32...
SRR331714.SRS076947  [19, 30, 43, 58, 65, 70, 71, 74, 80, 90, 92, 9...
SRR089999.SRS077685  [12, 14, 18, 20, 22, 38, 45, 67, 68, 76, 88, 1...
SRR048091.SRS021563  [19, 30, 45, 52, 58, 60, 65, 70, 74, 80, 90, 9...

Min array length: 3
Max array length: 277
Mean array length: 69.10


In [105]:
import torch
from torch.utils.data import Dataset, DataLoader

class OTUDataset(Dataset):
   def __init__(self, df):
       self.df = df
       
       # Find max sequence length for padding
       self.max_len = max(len(x) for x in df['otu_arrays'])
       
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       # Get array for this sample
       array = self.df.iloc[idx]['otu_arrays']
       
       # Create padded tensor
       padded = torch.zeros(self.max_len, dtype=torch.long)
       padded[:len(array)] = torch.tensor(array)
       
       # Create mask (False where we have real tokens, True for padding)
       mask = torch.zeros(self.max_len, dtype=torch.bool)
       mask[len(array):] = True
       
       return padded, mask

# Create datasets
train_dataset = OTUDataset(train_df)
test_dataset = OTUDataset(test_df)

# Create dataloaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Verify shapes
for tokens, mask in train_loader:
   print(f"Batch tokens shape: {tokens.shape}")
   print(f"Batch mask shape: {mask.shape}")

   break

# Get vocab size (maximum token ID + 1 for padding)
vocab_size = max(max(x) for x in loaded_df['otu_arrays']) + 1
print(f"\nVocabulary size: {vocab_size}")

Batch tokens shape: torch.Size([8, 277])
Batch mask shape: torch.Size([8, 277])

Vocabulary size: 519


In [107]:
import model_arch
import helper_funcs
import importlib
from model_arch import CategoricalScoreDiffusion
from helper_funcs import generate_sequences
importlib.reload(model_arch)
importlib.reload(helper_funcs)

<module 'helper_funcs' from '/Users/matteo/Documents/MATLAS/full_model_and_SAE/flow_matching/helper_funcs.py'>

In [108]:


def calculate_cooccurrence_correlation(sequences, num_otus, reference_coocur=None, overall=False):
   """Calculate co-occurrence correlation for a set of sequences"""
   matrix = np.zeros((len(sequences), num_otus-1))
   for i, seq in enumerate(sequences):
       unique_otus = set(otu for otu in seq if otu != 0)
       for otu in unique_otus:
           matrix[i, otu-1] = 1
   
   coocur = np.corrcoef(matrix.T)
   
   if reference_coocur is not None:
       if overall:
           mask = ~np.isnan(coocur) & ~np.isnan(reference_coocur)
           correlation = np.corrcoef(coocur[mask], reference_coocur[mask])[0,1]
       else:
           all_otus = []
           for seq in sequences:
               all_otus.extend([x for x in seq if x != 0])
           counts = Counter(all_otus)
           top_otus = [otu-1 for otu, _ in counts.most_common(50)]
           
           top_coocur = coocur[top_otus][:, top_otus]
           top_ref_coocur = reference_coocur[top_otus][:, top_otus]
           mask = ~np.isnan(top_coocur) & ~np.isnan(top_ref_coocur)
           correlation = np.corrcoef(top_coocur[mask], top_ref_coocur[mask])[0,1]
       return correlation
   
   return coocur

def ce_loss_simple(logits, target_tokens, temperature=0.1): #avg over seq then batch
    B, S, V = logits.shape
    logits_flat = logits.view(-1, V) / temperature
    targets_flat = target_tokens.view(-1)
    
    # Create mask for non-pad tokens
    mask = (targets_flat != 0).view(B, S)
    
    # Calculate CE loss per token
    token_losses = F.cross_entropy(
        logits_flat, 
        targets_flat, 
        reduction='none'
    ).view(B, S)
    
    # Average over sequence first (using mask)
    seq_lengths = mask.sum(dim=1)
    sequence_loss = (token_losses * mask).sum(dim=1) / seq_lengths
    
    # Average over batch
    loss = sequence_loss.mean()
    
    return loss

In [130]:
def calculate_rmsd(generated_sequences, real_sequences, vocab_size):
    def create_correlation_matrix(sequences):
        # Create presence/absence matrix
        matrix = np.zeros((len(sequences), vocab_size-1))
        for i, seq in enumerate(sequences):
            unique_otus = set(otu for otu in seq if otu != 0)
            for otu in unique_otus:
                matrix[i, otu-1] = 1
        
        # Calculate correlation matrix
        return np.corrcoef(matrix.T)
    
    # Get top OTUs from real sequences
    all_otus = []
    for seq in real_sequences:
        all_otus.extend([x for x in seq if x != 0])
    counts = Counter(all_otus)
    top_otus = [otu-1 for otu, _ in counts.most_common(50)]
    
    # Calculate correlation matrices
    real_matrix = create_correlation_matrix(real_sequences)
    gen_matrix = create_correlation_matrix(generated_sequences)
    
    # Calculate RMSD for top OTUs only
    real_subset = real_matrix[top_otus][:, top_otus]
    gen_subset = gen_matrix[top_otus][:, top_otus]
    return np.sqrt(np.mean((real_subset - gen_subset)**2))

class TrainingMetrics:
    def __init__(self):
        self.best_val_loss = float('inf')
        self.best_rmsd = float('inf')
        
    def update_best_metrics(self, val_loss):
        improved = False
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            improved = True
        # if rmsd < self.best_rmsd:
        #     self.best_rmsd = rmsd
        #     improved = True
        return improved

def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)

    # Get clean embeddings
    x0 = model.embedding(tokens)
  
    
    # Add noise
    noise = model.get_noise(x0, t)

    xt = x0 + noise

    
    # Get model predictions
    logits = model(xt, mask, t)

    
    # Compute loss
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        model.update_time_warping(t, loss.detach())
        loss.backward()
        optimizer.step()
    
    return loss.item()

def validation_step(model, tokens, mask, device):
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)
    
    # Get clean embeddings
    x0 = model.embedding(tokens)
    
    # Add noise according to N(0, σt²)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    
    # Get model predictions
    logits = model(xt, mask, t)
    
    # Compute cross-entropy loss with padding handling
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0  # Assuming 0 is padding token
    )
    
    return loss.item()

def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        
    }
    torch.save(checkpoint, 'best_model.pt')

def log_metrics(metrics_dict, step_type='batch'):
    wandb.log(metrics_dict)

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    return train_loss / len(train_loader)

def validate_epoch(model, test_loader, device, epoch):
    model.eval()
    val_loss = 0
    val_bar = tqdm(test_loader, desc=f'Validation Epoch {epoch}')
    
    # Collect real sequences
    real_sequences = []
    with torch.no_grad():
        for tokens, mask in val_bar:
            tokens = tokens.to(device)
            mask = mask.to(device)
            
            loss = validation_step(model, tokens, mask, device)
            val_loss += loss
            val_bar.set_postfix({'loss': f'{loss:.4f}'})
            
            real_sequences.extend([seq[seq != 0].cpu().numpy() for seq in tokens])
    
    # Generate sequences and calculate RMSD
    # generated_sequences = generate_sequences(model, num_sequences=500, temperature=1, num_steps=20)
    # rmsd = calculate_rmsd(generated_sequences, real_sequences, model.vocab_size)
    
    return val_loss / len(test_loader)



def train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device, use_lr_scheduling=True):
    metrics = TrainingMetrics()
    
    scheduler = None
    if use_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=3, factor=0.5, verbose=True
        )
    
    for epoch in range(num_epochs):
        # Training phase
        avg_train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        log_metrics({'train/epoch_loss': avg_train_loss, 'epoch': epoch})
         
        # Validation phase (every 5 epochs)
        if epoch % 5 == 0:
            avg_val_loss = validate_epoch(model, test_loader, device, epoch)
            
            log_metrics({
                'val/epoch_loss': avg_val_loss,
                'epoch': epoch
            })
            
            print(f'\nEpoch {epoch}:')
            print(f'Average Train Loss: {avg_train_loss:.4f}')
            print(f'Average Val Loss: {avg_val_loss:.4f}')
         
            
            if scheduler:
                scheduler.step(avg_val_loss)
            
            if metrics.update_best_metrics(avg_val_loss):
                save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss)
                log_metrics({
                    'best_model/val_loss': avg_val_loss,
                    'best_model/train_loss': avg_train_loss,
                    'best_model/epoch': epoch
                })
        else:
            print(f'\nEpoch {epoch}: Average Train Loss: {avg_train_loss:.4f}\n')

In [135]:
# Initialize model
embed_dim = 16 #8 
num_layers = 3 #5
num_heads = 4
dim_feedforward = 16 #32
num_fourier_features = 8# going from 4 to 8 destabilised the batch loss but seems o have resulted in a faster convergence and lower
model = CategoricalScoreDiffusion(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dim_feedforward=dim_feedforward,
    num_fourier_features=num_fourier_features
    
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Move model to device
model = model.to(device)


In [136]:
import wandb
num_epochs = 200
learning_rate = 1e-2

wandb.finish()
wandb.init(
    project="diffusion-hmp",
    config={
        "learning_rate": learning_rate,
        "architecture": "restart",
        "dataset": "hmp",
        "epochs": num_epochs,
        "embed_dim": embed_dim,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "dim_feedforward": dim_feedforward,
        "vocab_size": vocab_size,
        "num_fourier_features":num_fourier_features
    }
)

0,1
batch,▁▃▆▂▃▄▄▅▅▆▃▃▅▆▆▂▃▃▄▅▇▇▅▆▁▃▅█▂▂▄▆▆▇█▅▅▆▁▃
best_model/epoch,▁█
best_model/train_loss,█▁
best_model/val_loss,█▁
epoch,▁▁▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██
train/batch_loss,█▆▅▄▃▃▃▄▄▄▃▃▂▃▃▂▃▃▂▄▂▄▃▃▃▂▂▃▃▃▄▁▂▃▁▃▂▃▂▂
train/epoch_loss,█▃▂▂▂▁▁▁
train/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/epoch_loss,█▁

0,1
batch,390.0
best_model/epoch,5.0
best_model/train_loss,4.4706
best_model/val_loss,4.43708
epoch,8.0
train/batch_loss,4.07089
train/epoch_loss,4.4276
train/learning_rate,0.001
val/epoch_loss,4.43708


In [137]:
# Training parameters
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Start training
train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device)

Training Epoch 0: 100%|██████████| 811/811 [00:28<00:00, 28.54it/s, loss=4.3285]
Validation Epoch 0: 100%|██████████| 203/203 [00:02<00:00, 93.38it/s, loss=2.6559]



Epoch 0:
Average Train Loss: 4.1879
Average Val Loss: 3.9029


Training Epoch 1: 100%|██████████| 811/811 [00:27<00:00, 29.67it/s, loss=3.9660]



Epoch 1: Average Train Loss: 3.8453



Training Epoch 2: 100%|██████████| 811/811 [00:26<00:00, 30.26it/s, loss=3.3252]



Epoch 2: Average Train Loss: 3.8384



Training Epoch 3: 100%|██████████| 811/811 [00:27<00:00, 29.49it/s, loss=4.0613]



Epoch 3: Average Train Loss: 3.8035



Training Epoch 4: 100%|██████████| 811/811 [00:26<00:00, 30.39it/s, loss=4.3854]



Epoch 4: Average Train Loss: 3.8678



Training Epoch 5: 100%|██████████| 811/811 [00:26<00:00, 30.57it/s, loss=3.8581]
Validation Epoch 5: 100%|██████████| 203/203 [00:01<00:00, 104.51it/s, loss=4.5168]



Epoch 5:
Average Train Loss: 3.8411
Average Val Loss: 3.8843


Training Epoch 6: 100%|██████████| 811/811 [00:26<00:00, 30.45it/s, loss=4.3109]



Epoch 6: Average Train Loss: 3.8198



Training Epoch 7: 100%|██████████| 811/811 [00:26<00:00, 30.39it/s, loss=3.6409]



Epoch 7: Average Train Loss: 3.8367



Training Epoch 8: 100%|██████████| 811/811 [00:26<00:00, 30.36it/s, loss=3.0007]



Epoch 8: Average Train Loss: 3.7981



Training Epoch 9: 100%|██████████| 811/811 [00:26<00:00, 30.60it/s, loss=3.7241]



Epoch 9: Average Train Loss: 3.8294



Training Epoch 10: 100%|██████████| 811/811 [00:26<00:00, 30.57it/s, loss=4.5010]
Validation Epoch 10: 100%|██████████| 203/203 [00:01<00:00, 105.74it/s, loss=3.0820]



Epoch 10:
Average Train Loss: 3.7893
Average Val Loss: 3.6757


Training Epoch 11: 100%|██████████| 811/811 [00:26<00:00, 30.56it/s, loss=2.9555]



Epoch 11: Average Train Loss: 3.8750



Training Epoch 12: 100%|██████████| 811/811 [00:26<00:00, 30.57it/s, loss=4.1427]



Epoch 12: Average Train Loss: 3.8095



Training Epoch 13: 100%|██████████| 811/811 [00:26<00:00, 30.43it/s, loss=2.8313]



Epoch 13: Average Train Loss: 3.8243



Training Epoch 14: 100%|██████████| 811/811 [00:26<00:00, 30.56it/s, loss=3.5406]



Epoch 14: Average Train Loss: 3.7983



Training Epoch 15: 100%|██████████| 811/811 [00:26<00:00, 30.54it/s, loss=4.2893]
Validation Epoch 15: 100%|██████████| 203/203 [00:01<00:00, 106.02it/s, loss=3.7157]



Epoch 15:
Average Train Loss: 3.8074
Average Val Loss: 3.8003


Training Epoch 16: 100%|██████████| 811/811 [00:26<00:00, 30.47it/s, loss=4.1837]



Epoch 16: Average Train Loss: 3.7768



Training Epoch 17: 100%|██████████| 811/811 [00:26<00:00, 30.18it/s, loss=3.8816]



Epoch 17: Average Train Loss: 3.7923



Training Epoch 18: 100%|██████████| 811/811 [00:27<00:00, 29.32it/s, loss=3.7803]



Epoch 18: Average Train Loss: 3.7400



Training Epoch 19: 100%|██████████| 811/811 [00:26<00:00, 30.15it/s, loss=3.4114]



Epoch 19: Average Train Loss: 3.7644



Training Epoch 20: 100%|██████████| 811/811 [00:26<00:00, 30.28it/s, loss=4.1232]
Validation Epoch 20: 100%|██████████| 203/203 [00:01<00:00, 103.07it/s, loss=3.6387]



Epoch 20:
Average Train Loss: 3.7924
Average Val Loss: 3.8110


Training Epoch 21: 100%|██████████| 811/811 [00:26<00:00, 30.28it/s, loss=4.4074]



Epoch 21: Average Train Loss: 3.7467



Training Epoch 22: 100%|██████████| 811/811 [00:27<00:00, 29.47it/s, loss=2.8486]



Epoch 22: Average Train Loss: 3.7733



Training Epoch 23: 100%|██████████| 811/811 [00:26<00:00, 30.18it/s, loss=4.6399]



Epoch 23: Average Train Loss: 3.7489



Training Epoch 24: 100%|██████████| 811/811 [00:27<00:00, 29.59it/s, loss=2.9454]



Epoch 24: Average Train Loss: 3.7736



Training Epoch 25: 100%|██████████| 811/811 [00:28<00:00, 28.80it/s, loss=3.3630]
Validation Epoch 25: 100%|██████████| 203/203 [00:01<00:00, 103.95it/s, loss=4.3329]



Epoch 25:
Average Train Loss: 3.7805
Average Val Loss: 3.7512


Training Epoch 26: 100%|██████████| 811/811 [00:27<00:00, 29.38it/s, loss=3.3836]



Epoch 26: Average Train Loss: 3.7760



Training Epoch 27: 100%|██████████| 811/811 [00:27<00:00, 28.98it/s, loss=2.9961]



Epoch 27: Average Train Loss: 3.7691



Training Epoch 28: 100%|██████████| 811/811 [00:27<00:00, 29.43it/s, loss=3.5246]



Epoch 28: Average Train Loss: 3.7493



Training Epoch 29: 100%|██████████| 811/811 [00:26<00:00, 30.82it/s, loss=4.3527]



Epoch 29: Average Train Loss: 3.7478



Training Epoch 30: 100%|██████████| 811/811 [00:27<00:00, 29.68it/s, loss=2.8524]
Validation Epoch 30: 100%|██████████| 203/203 [00:02<00:00, 99.95it/s, loss=4.0479] 



Epoch 30:
Average Train Loss: 3.7154
Average Val Loss: 3.6752


Training Epoch 31: 100%|██████████| 811/811 [00:27<00:00, 29.98it/s, loss=4.3475]



Epoch 31: Average Train Loss: 3.7783



Training Epoch 32: 100%|██████████| 811/811 [00:27<00:00, 29.49it/s, loss=4.8974]



Epoch 32: Average Train Loss: 3.8022



Training Epoch 33: 100%|██████████| 811/811 [00:27<00:00, 29.35it/s, loss=4.4593]



Epoch 33: Average Train Loss: 3.7485



Training Epoch 34: 100%|██████████| 811/811 [00:26<00:00, 30.34it/s, loss=3.8310]



Epoch 34: Average Train Loss: 3.7151



Training Epoch 35: 100%|██████████| 811/811 [00:27<00:00, 30.01it/s, loss=3.1949]
Validation Epoch 35: 100%|██████████| 203/203 [00:01<00:00, 105.82it/s, loss=3.9745]



Epoch 35:
Average Train Loss: 3.7853
Average Val Loss: 3.6613


Training Epoch 36: 100%|██████████| 811/811 [00:26<00:00, 30.53it/s, loss=3.8548]



Epoch 36: Average Train Loss: 3.7752



Training Epoch 37: 100%|██████████| 811/811 [00:26<00:00, 30.19it/s, loss=1.4777]



Epoch 37: Average Train Loss: 3.7706



Training Epoch 38: 100%|██████████| 811/811 [00:27<00:00, 29.41it/s, loss=2.6082]



Epoch 38: Average Train Loss: 3.7580



Training Epoch 39: 100%|██████████| 811/811 [00:26<00:00, 30.14it/s, loss=4.7430]



Epoch 39: Average Train Loss: 3.7151



Training Epoch 40: 100%|██████████| 811/811 [00:26<00:00, 30.61it/s, loss=4.8232]
Validation Epoch 40: 100%|██████████| 203/203 [00:01<00:00, 104.97it/s, loss=3.0023]



Epoch 40:
Average Train Loss: 3.7620
Average Val Loss: 3.7332


Training Epoch 41: 100%|██████████| 811/811 [00:26<00:00, 30.55it/s, loss=3.5530]



Epoch 41: Average Train Loss: 3.7270



Training Epoch 42: 100%|██████████| 811/811 [00:26<00:00, 30.68it/s, loss=4.3687]



Epoch 42: Average Train Loss: 3.7591



Training Epoch 43: 100%|██████████| 811/811 [00:26<00:00, 30.50it/s, loss=4.5337]



Epoch 43: Average Train Loss: 3.7186



Training Epoch 44: 100%|██████████| 811/811 [00:26<00:00, 30.55it/s, loss=4.7582]



Epoch 44: Average Train Loss: 3.7288



Training Epoch 45: 100%|██████████| 811/811 [00:26<00:00, 30.55it/s, loss=2.3390]
Validation Epoch 45: 100%|██████████| 203/203 [00:01<00:00, 105.88it/s, loss=3.0325]



Epoch 45:
Average Train Loss: 3.7324
Average Val Loss: 3.7468


Training Epoch 46: 100%|██████████| 811/811 [00:27<00:00, 29.60it/s, loss=4.3324]



Epoch 46: Average Train Loss: 3.7409



Training Epoch 47: 100%|██████████| 811/811 [00:28<00:00, 28.10it/s, loss=4.3845]



Epoch 47: Average Train Loss: 3.7321



Training Epoch 48: 100%|██████████| 811/811 [00:33<00:00, 24.28it/s, loss=5.1703]



Epoch 48: Average Train Loss: 3.7429



Training Epoch 49: 100%|██████████| 811/811 [00:35<00:00, 23.03it/s, loss=4.3174]



Epoch 49: Average Train Loss: 3.7453



Training Epoch 50: 100%|██████████| 811/811 [00:31<00:00, 25.61it/s, loss=3.3565]
Validation Epoch 50: 100%|██████████| 203/203 [00:02<00:00, 98.19it/s, loss=3.8315] 



Epoch 50:
Average Train Loss: 3.7017
Average Val Loss: 3.7542


Training Epoch 51: 100%|██████████| 811/811 [00:26<00:00, 30.13it/s, loss=3.2753]



Epoch 51: Average Train Loss: 3.7778



Training Epoch 52: 100%|██████████| 811/811 [00:27<00:00, 29.73it/s, loss=3.8160]



Epoch 52: Average Train Loss: 3.7088



Training Epoch 53:  49%|████▉     | 396/811 [00:13<00:14, 29.43it/s, loss=3.2663]


KeyboardInterrupt: 

In [117]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'model_args': {
        'vocab_size': vocab_size,
        'embed_dim': embed_dim,
        'num_layers': num_layers,
        'num_heads': num_heads,
        'dim_feedforward': dim_feedforward,
        'num_fourier_features': num_fourier_features
    }
}

torch.save(checkpoint, 'model_checkpoint_3.58pt')

In [None]:
from model_arch import CategoricalScoreDiffusion

checkpoint = torch.load('model_checkpoint_2.65.pt')
model = CategoricalScoreDiffusion(**checkpoint['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Access the learning rate
# Get the optimizer state dict
optimizer_state = checkpoint['optimizer_state_dict']
learning_rate = optimizer_state['param_groups'][0]['lr']
print(f"Learning rate: {learning_rate}")

Learning rate: 0.001
