In [6]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from collections import Counter

#from helper_funcs import generate_sequences


In [2]:
import wandb
run = wandb.init()
artifact = run.use_artifact('matteopeluso1922/cdcd-hmp-param-search-orion_truewarp/best_model_aiicxkad:v0', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmatteopeluso1922[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [8]:
# Load data
loaded_df = pd.read_hdf('../data/sample_otu_arrays.h5', key='df')

# Set random seed
np.random.seed(42)

# Split indices into train/test
train_idx, test_idx = train_test_split(loaded_df.index, test_size=0.2, random_state=42)

# Create train and test dataframes
train_df = loaded_df.loc[train_idx]
test_df = loaded_df.loc[test_idx]

print(f"Train size: {len(train_df)}")
print(f"Test size: {len(test_df)}")
print("\nFirst few training samples:")
print(train_df.head())

# Let's also look at array lengths
array_lengths = [len(x) for x in loaded_df['otu_arrays']]
print(f"\nMin array length: {min(array_lengths)}")
print(f"Max array length: {max(array_lengths)}")
print(f"Mean array length: {np.mean(array_lengths):.2f}")

Train size: 6486
Test size: 1622

First few training samples:
                                                            otu_arrays
Unnamed: 0                                                            
SRR044975.SRS011167  [30, 58, 82, 89, 93, 98, 99, 104, 117, 120, 12...
SRR049604.SRS049164  [9, 10, 11, 14, 15, 16, 17, 20, 28, 30, 31, 32...
SRR331714.SRS076947  [19, 30, 43, 58, 65, 70, 71, 74, 80, 90, 92, 9...
SRR089999.SRS077685  [12, 14, 18, 20, 22, 38, 45, 67, 68, 76, 88, 1...
SRR048091.SRS021563  [19, 30, 45, 52, 58, 60, 65, 70, 74, 80, 90, 9...

Min array length: 3
Max array length: 277
Mean array length: 69.10


In [9]:
import torch
from torch.utils.data import Dataset, DataLoader

class OTUDataset(Dataset):
   def __init__(self, df):
       self.df = df
       
       # Find max sequence length for padding
       self.max_len = max(len(x) for x in df['otu_arrays'])
       
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       # Get array for this sample
       array = self.df.iloc[idx]['otu_arrays']
       
       # Create padded tensor
       padded = torch.zeros(self.max_len, dtype=torch.long)
       padded[:len(array)] = torch.tensor(array)
       
       # Create mask (False where we have real tokens, True for padding)
       mask = torch.zeros(self.max_len, dtype=torch.bool)
       mask[len(array):] = True
       
       return padded, mask

# Create datasets
train_dataset = OTUDataset(train_df)
test_dataset = OTUDataset(test_df)

# Create dataloaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Verify shapes
for tokens, mask in train_loader:
   print(f"Batch tokens shape: {tokens.shape}")
   print(f"Batch mask shape: {mask.shape}")

   break

# Get vocab size (maximum token ID + 1 for padding)
vocab_size = max(max(x) for x in loaded_df['otu_arrays']) + 1
print(f"\nVocabulary size: {vocab_size}")

Batch tokens shape: torch.Size([8, 277])
Batch mask shape: torch.Size([8, 277])

Vocabulary size: 519


In [13]:
import model_arch
#import helper_funcs
import importlib
from model_arch import CategoricalScoreDiffusion
#from helper_funcs import generate_sequences
importlib.reload(model_arch)
#importlib.reload(helper_funcs)

<module 'model_arch' from '/mnt/mnemo9/mpelus/matlas/cdcd_multi_train/cdcd_hmp/simplifiedV1/model_arch.py'>

In [14]:

class TrainingMetrics:
    def __init__(self):
        self.best_val_loss = float('inf')

        
    def update_best_metrics(self, val_loss):
        improved = False
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            improved = True
        return improved

def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)

    # Get clean embeddings
    x0 = model.embedding(tokens)
  
    
    # Add noise
    noise = model.get_noise(x0, t)

    xt = x0 + noise

    
    # Get model predictions
    logits = model(xt, mask, t)

    
    # Compute loss
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        model.update_time_warping(t, loss.detach())
        loss.backward()
        optimizer.step()
    
    return loss.item()

def validation_step(model, tokens, mask, device):
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)
    
    # Get clean embeddings
    x0 = model.embedding(tokens)
    
    # Add noise according to N(0, σt²)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    
    # Get model predictions
    logits = model(xt, mask, t)
    
    # Compute cross-entropy loss with padding handling
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0  # Assuming 0 is padding token
    )
    
    return loss.item()

def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        
    }
    torch.save(checkpoint, 'best_model.pt')

def log_metrics(metrics_dict, step_type='batch'):
    wandb.log(metrics_dict)

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    return train_loss / len(train_loader)

def validate_epoch(model, test_loader, device, epoch):
    model.eval()
    val_loss = 0
    val_bar = tqdm(test_loader, desc=f'Validation Epoch {epoch}')
    
    # Collect real sequences
    real_sequences = []
    with torch.no_grad():
        for tokens, mask in val_bar:
            tokens = tokens.to(device)
            mask = mask.to(device)
            
            loss = validation_step(model, tokens, mask, device)
            val_loss += loss
            val_bar.set_postfix({'loss': f'{loss:.4f}'})
            
            real_sequences.extend([seq[seq != 0].cpu().numpy() for seq in tokens])

    
    return val_loss / len(test_loader)



def train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device, use_lr_scheduling=True):
    metrics = TrainingMetrics()
    
    scheduler = None
    if use_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=3, factor=0.5, verbose=True
        )
    
    for epoch in range(num_epochs):
        # Training phase
        avg_train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        log_metrics({'train/epoch_loss': avg_train_loss, 'epoch': epoch})
         
        # Validation phase (every 5 epochs)
        if epoch % 1 == 0:
            avg_val_loss = validate_epoch(model, test_loader, device, epoch)
            
            log_metrics({
                'val/epoch_loss': avg_val_loss,
                'epoch': epoch
            })
            
            print(f'\nEpoch {epoch}:')
            print(f'Average Train Loss: {avg_train_loss:.4f}')
            print(f'Average Val Loss: {avg_val_loss:.4f}')
         
            
            if scheduler:
                scheduler.step(avg_val_loss)
            
            if metrics.update_best_metrics(avg_val_loss):
                save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss)
                log_metrics({
                    'best_model/val_loss': avg_val_loss,
                    'best_model/train_loss': avg_train_loss,
                    'best_model/epoch': epoch
                })
        else:
            print(f'\nEpoch {epoch}: Average Train Loss: {avg_train_loss:.4f}\n')


def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    t = model.sample_time(tokens.shape[0], tokens.device)
    
    # Get bin assignments and importance weights
    bin_idx = model.time_warping.get_bin_assignment(t)
    importance_weights = model.time_warping.get_importance_weights(bin_idx)
    
    x0 = model.embedding(tokens)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    logits = model(xt, mask, t)
    
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        # Collect statistics for time warping
        model.time_warping.collect_statistics(t, loss.detach().expand(tokens.shape[0]))
        
        # Apply importance weights to loss
        weighted_loss = loss * importance_weights.mean()
        weighted_loss.backward()
        optimizer.step()
    
    return loss.item()

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    # Update time warping at end of epoch using accumulated statistics
    model.time_warping.update_warping()
    
    return train_loss / len(train_loader)


In [7]:
checkpoint = torch.load(f"/mnt/mnemo9/mpelus/matlas/cdcd_multi_train/cdcd_hmp/artifacts/best_model_aiicxkad:v0/tmp21qdlph1.pt")
# Initialize model with the same parameters you showed
embed_dim = 32
num_layers = 3
num_heads = 8
dim_feedforward = 32
num_fourier_features = 16

model = CategoricalScoreDiffusion(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dim_feedforward=dim_feedforward,
    num_fourier_features=num_fourier_features
)


model.load_state_dict(checkpoint)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Move model to device
model = model.to(device)

  checkpoint = torch.load(f"/mnt/mnemo9/mpelus/matlas/cdcd_multi_train/cdcd_hmp/artifacts/best_model_aiicxkad:v0/tmp21qdlph1.pt")


RuntimeError: Error(s) in loading state_dict for CategoricalScoreDiffusion:
	Unexpected key(s) in state_dict: "transformer.layers.1.self_attn.in_proj_weight", "transformer.layers.1.self_attn.in_proj_bias", "transformer.layers.1.self_attn.out_proj.weight", "transformer.layers.1.self_attn.out_proj.bias", "transformer.layers.1.linear1.weight", "transformer.layers.1.linear1.bias", "transformer.layers.1.linear2.weight", "transformer.layers.1.linear2.bias", "transformer.layers.1.norm1.weight", "transformer.layers.1.norm1.bias", "transformer.layers.1.norm2.weight", "transformer.layers.1.norm2.bias", "transformer.layers.2.self_attn.in_proj_weight", "transformer.layers.2.self_attn.in_proj_bias", "transformer.layers.2.self_attn.out_proj.weight", "transformer.layers.2.self_attn.out_proj.bias", "transformer.layers.2.linear1.weight", "transformer.layers.2.linear1.bias", "transformer.layers.2.linear2.weight", "transformer.layers.2.linear2.bias", "transformer.layers.2.norm1.weight", "transformer.layers.2.norm1.bias", "transformer.layers.2.norm2.weight", "transformer.layers.2.norm2.bias", "transformer.layers.3.self_attn.in_proj_weight", "transformer.layers.3.self_attn.in_proj_bias", "transformer.layers.3.self_attn.out_proj.weight", "transformer.layers.3.self_attn.out_proj.bias", "transformer.layers.3.linear1.weight", "transformer.layers.3.linear1.bias", "transformer.layers.3.linear2.weight", "transformer.layers.3.linear2.bias", "transformer.layers.3.norm1.weight", "transformer.layers.3.norm1.bias", "transformer.layers.3.norm2.weight", "transformer.layers.3.norm2.bias". 
	size mismatch for random_matrix: copying a param with shape torch.Size([1, 4]) from checkpoint, the shape in current model is torch.Size([1, 16]).
	size mismatch for embedding.embedding.weight: copying a param with shape torch.Size([519, 96]) from checkpoint, the shape in current model is torch.Size([519, 160]).
	size mismatch for transformer.layers.0.self_attn.in_proj_weight: copying a param with shape torch.Size([288, 96]) from checkpoint, the shape in current model is torch.Size([480, 160]).
	size mismatch for transformer.layers.0.self_attn.in_proj_bias: copying a param with shape torch.Size([288]) from checkpoint, the shape in current model is torch.Size([480]).
	size mismatch for transformer.layers.0.self_attn.out_proj.weight: copying a param with shape torch.Size([96, 96]) from checkpoint, the shape in current model is torch.Size([160, 160]).
	size mismatch for transformer.layers.0.self_attn.out_proj.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for transformer.layers.0.linear1.weight: copying a param with shape torch.Size([28, 96]) from checkpoint, the shape in current model is torch.Size([32, 160]).
	size mismatch for transformer.layers.0.linear1.bias: copying a param with shape torch.Size([28]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for transformer.layers.0.linear2.weight: copying a param with shape torch.Size([96, 28]) from checkpoint, the shape in current model is torch.Size([160, 32]).
	size mismatch for transformer.layers.0.linear2.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for transformer.layers.0.norm1.weight: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for transformer.layers.0.norm1.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for transformer.layers.0.norm2.weight: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for transformer.layers.0.norm2.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for to_logits.weight: copying a param with shape torch.Size([519, 96]) from checkpoint, the shape in current model is torch.Size([519, 160]).
	size mismatch for time_mlp.0.weight: copying a param with shape torch.Size([96, 8]) from checkpoint, the shape in current model is torch.Size([160, 32]).
	size mismatch for time_mlp.0.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).
	size mismatch for time_mlp.2.weight: copying a param with shape torch.Size([96, 96]) from checkpoint, the shape in current model is torch.Size([160, 160]).
	size mismatch for time_mlp.2.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([160]).

In [15]:
# Initialize model
embed_dim = 32
num_layers = 3
num_heads = 8
dim_feedforward = 32
num_fourier_features = 8# going from 4 to 8 destabilised the batch loss but seems o have resulted in a faster convergence and lower
model = CategoricalScoreDiffusion(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dim_feedforward=dim_feedforward,
    num_fourier_features=num_fourier_features
    
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Move model to device
model = model.to(device)


In [18]:
import wandb
num_epochs = 30
learning_rate = 1e-3

wandb.finish()
wandb.init(
    project="diffusion-hmp",
    config={
        "learning_rate": learning_rate,
        "architecture": "restart",
        "dataset": "hmp",
        "epochs": num_epochs,
        "embed_dim": embed_dim,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "dim_feedforward": dim_feedforward,
        "vocab_size": vocab_size,
        "num_fourier_features":num_fourier_features
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmatteopeluso1922[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
# Training parameters
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Start training
train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device)

Training Epoch 0: 100%|██████████| 811/811 [00:09<00:00, 88.05it/s, loss=4.1705]
  output = torch._nested_tensor_from_mask(
Validation Epoch 0: 100%|██████████| 203/203 [00:01<00:00, 185.51it/s, loss=4.7967]



Epoch 0:
Average Train Loss: 4.6073
Average Val Loss: 4.3017


Training Epoch 1: 100%|██████████| 811/811 [00:09<00:00, 87.57it/s, loss=4.7687]
Validation Epoch 1: 100%|██████████| 203/203 [00:01<00:00, 194.91it/s, loss=3.8095]



Epoch 1:
Average Train Loss: 4.2825
Average Val Loss: 4.2097


Training Epoch 2: 100%|██████████| 811/811 [00:09<00:00, 87.64it/s, loss=3.9237]
Validation Epoch 2: 100%|██████████| 203/203 [00:01<00:00, 180.83it/s, loss=4.5829]



Epoch 2:
Average Train Loss: 4.2167
Average Val Loss: 4.1721


Training Epoch 3: 100%|██████████| 811/811 [00:09<00:00, 87.15it/s, loss=2.6374]
Validation Epoch 3: 100%|██████████| 203/203 [00:01<00:00, 196.11it/s, loss=2.5846]



Epoch 3:
Average Train Loss: 4.1111
Average Val Loss: 4.0186


Training Epoch 4: 100%|██████████| 811/811 [00:09<00:00, 88.21it/s, loss=4.3385]
Validation Epoch 4: 100%|██████████| 203/203 [00:01<00:00, 199.00it/s, loss=3.7467]



Epoch 4:
Average Train Loss: 4.0465
Average Val Loss: 4.1151


Training Epoch 5: 100%|██████████| 811/811 [00:08<00:00, 95.18it/s, loss=3.3159]
Validation Epoch 5: 100%|██████████| 203/203 [00:01<00:00, 199.49it/s, loss=4.5987]



Epoch 5:
Average Train Loss: 4.0286
Average Val Loss: 3.9960


Training Epoch 6: 100%|██████████| 811/811 [00:09<00:00, 88.62it/s, loss=5.0100]
Validation Epoch 6: 100%|██████████| 203/203 [00:01<00:00, 196.51it/s, loss=4.8181]



Epoch 6:
Average Train Loss: 4.0229
Average Val Loss: 3.9813


Training Epoch 7: 100%|██████████| 811/811 [00:09<00:00, 88.26it/s, loss=4.3627]
Validation Epoch 7: 100%|██████████| 203/203 [00:01<00:00, 198.60it/s, loss=3.6841]



Epoch 7:
Average Train Loss: 4.0320
Average Val Loss: 3.9816


Training Epoch 8: 100%|██████████| 811/811 [00:09<00:00, 87.62it/s, loss=4.3788]
Validation Epoch 8: 100%|██████████| 203/203 [00:01<00:00, 197.89it/s, loss=4.7033]



Epoch 8:
Average Train Loss: 3.9766
Average Val Loss: 3.9510


Training Epoch 9: 100%|██████████| 811/811 [00:08<00:00, 91.87it/s, loss=3.4580]
Validation Epoch 9: 100%|██████████| 203/203 [00:01<00:00, 197.62it/s, loss=4.4698]



Epoch 9:
Average Train Loss: 3.9958
Average Val Loss: 4.0197


Training Epoch 10: 100%|██████████| 811/811 [00:11<00:00, 70.75it/s, loss=4.4632]
Validation Epoch 10: 100%|██████████| 203/203 [00:01<00:00, 195.16it/s, loss=4.5070]



Epoch 10:
Average Train Loss: 3.9715
Average Val Loss: 4.0329


Training Epoch 11: 100%|██████████| 811/811 [00:09<00:00, 85.69it/s, loss=3.2277]
Validation Epoch 11: 100%|██████████| 203/203 [00:01<00:00, 193.86it/s, loss=4.5556]



Epoch 11:
Average Train Loss: 3.9787
Average Val Loss: 4.0616


Training Epoch 12: 100%|██████████| 811/811 [00:09<00:00, 87.07it/s, loss=3.1893]
Validation Epoch 12: 100%|██████████| 203/203 [00:01<00:00, 198.22it/s, loss=4.6473]



Epoch 12:
Average Train Loss: 3.9602
Average Val Loss: 3.9182


Training Epoch 13: 100%|██████████| 811/811 [00:09<00:00, 86.60it/s, loss=4.4600]
Validation Epoch 13: 100%|██████████| 203/203 [00:01<00:00, 199.66it/s, loss=3.5674]



Epoch 13:
Average Train Loss: 3.9895
Average Val Loss: 4.0094


Training Epoch 14: 100%|██████████| 811/811 [00:09<00:00, 86.50it/s, loss=4.6182]
Validation Epoch 14: 100%|██████████| 203/203 [00:01<00:00, 199.65it/s, loss=3.8143]



Epoch 14:
Average Train Loss: 3.9360
Average Val Loss: 3.9989


Training Epoch 15: 100%|██████████| 811/811 [00:09<00:00, 87.02it/s, loss=4.1996]
Validation Epoch 15: 100%|██████████| 203/203 [00:01<00:00, 196.93it/s, loss=4.0350]



Epoch 15:
Average Train Loss: 3.9571
Average Val Loss: 3.9026


Training Epoch 16: 100%|██████████| 811/811 [00:09<00:00, 88.05it/s, loss=4.3092]
Validation Epoch 16: 100%|██████████| 203/203 [00:01<00:00, 199.53it/s, loss=2.1407]



Epoch 16:
Average Train Loss: 3.9521
Average Val Loss: 3.9565


Training Epoch 17: 100%|██████████| 811/811 [00:09<00:00, 87.35it/s, loss=4.2984]
Validation Epoch 17: 100%|██████████| 203/203 [00:01<00:00, 199.60it/s, loss=4.1038]



Epoch 17:
Average Train Loss: 3.9493
Average Val Loss: 3.9223


Training Epoch 18: 100%|██████████| 811/811 [00:09<00:00, 87.37it/s, loss=2.9275]
Validation Epoch 18: 100%|██████████| 203/203 [00:01<00:00, 197.89it/s, loss=3.0798]



Epoch 18:
Average Train Loss: 3.9172
Average Val Loss: 3.9295


Training Epoch 19: 100%|██████████| 811/811 [00:09<00:00, 85.70it/s, loss=4.3674]
Validation Epoch 19: 100%|██████████| 203/203 [00:01<00:00, 199.62it/s, loss=4.2238]



Epoch 19:
Average Train Loss: 3.9313
Average Val Loss: 3.8920


Training Epoch 20: 100%|██████████| 811/811 [00:11<00:00, 69.97it/s, loss=4.4422]
Validation Epoch 20: 100%|██████████| 203/203 [00:01<00:00, 198.46it/s, loss=4.0653]



Epoch 20:
Average Train Loss: 3.9243
Average Val Loss: 3.8496


Training Epoch 21: 100%|██████████| 811/811 [00:09<00:00, 85.03it/s, loss=2.6731]
Validation Epoch 21: 100%|██████████| 203/203 [00:01<00:00, 197.35it/s, loss=4.2050]



Epoch 21:
Average Train Loss: 3.9648
Average Val Loss: 3.9731


Training Epoch 22: 100%|██████████| 811/811 [00:08<00:00, 95.13it/s, loss=3.9071]
Validation Epoch 22: 100%|██████████| 203/203 [00:01<00:00, 198.19it/s, loss=2.9628]



Epoch 22:
Average Train Loss: 3.9296
Average Val Loss: 3.9945


Training Epoch 23: 100%|██████████| 811/811 [00:08<00:00, 93.82it/s, loss=4.2795]
Validation Epoch 23: 100%|██████████| 203/203 [00:01<00:00, 198.17it/s, loss=4.5445]



Epoch 23:
Average Train Loss: 3.9308
Average Val Loss: 3.8925


Training Epoch 24: 100%|██████████| 811/811 [00:09<00:00, 88.64it/s, loss=3.0517]
Validation Epoch 24: 100%|██████████| 203/203 [00:01<00:00, 199.13it/s, loss=4.4651]



Epoch 24:
Average Train Loss: 3.9394
Average Val Loss: 3.9933


Training Epoch 25: 100%|██████████| 811/811 [00:09<00:00, 86.29it/s, loss=3.7963]
Validation Epoch 25: 100%|██████████| 203/203 [00:01<00:00, 191.52it/s, loss=4.4827]



Epoch 25:
Average Train Loss: 3.9088
Average Val Loss: 3.8955


Training Epoch 26: 100%|██████████| 811/811 [00:09<00:00, 87.95it/s, loss=2.9070]
Validation Epoch 26: 100%|██████████| 203/203 [00:01<00:00, 190.01it/s, loss=4.0245]



Epoch 26:
Average Train Loss: 3.9122
Average Val Loss: 3.9501


Training Epoch 27: 100%|██████████| 811/811 [00:09<00:00, 85.93it/s, loss=2.8811]
Validation Epoch 27: 100%|██████████| 203/203 [00:01<00:00, 191.54it/s, loss=4.3125]



Epoch 27:
Average Train Loss: 3.9192
Average Val Loss: 3.8919


Training Epoch 28: 100%|██████████| 811/811 [00:09<00:00, 88.30it/s, loss=3.5204]
Validation Epoch 28: 100%|██████████| 203/203 [00:01<00:00, 200.00it/s, loss=2.9363]



Epoch 28:
Average Train Loss: 3.9132
Average Val Loss: 3.9255


Training Epoch 29: 100%|██████████| 811/811 [00:10<00:00, 74.97it/s, loss=4.2529]
Validation Epoch 29: 100%|██████████| 203/203 [00:01<00:00, 189.15it/s, loss=3.0139]



Epoch 29:
Average Train Loss: 3.9085
Average Val Loss: 3.8278


In [20]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'model_args': {
        'vocab_size': vocab_size,
        'embed_dim': embed_dim,
        'num_layers': num_layers,
        'num_heads': num_heads,
        'dim_feedforward': dim_feedforward,
        'num_fourier_features': num_fourier_features
    }
}

torch.save(checkpoint, 'model_checkpoint_3.82.pt')

In [None]:
from model_arch import CategoricalScoreDiffusion

checkpoint = torch.load('model_checkpoint_2.65.pt')
model = CategoricalScoreDiffusion(**checkpoint['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Access the learning rate
# Get the optimizer state dict
optimizer_state = checkpoint['optimizer_state_dict']
learning_rate = optimizer_state['param_groups'][0]['lr']
print(f"Learning rate: {learning_rate}")

Learning rate: 0.001


In [None]:
import time
from contextlib import contextmanager

@contextmanager
def timer(name):
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    print(f"{name}: {(end - start)*1000:.2f} ms")