In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from collections import Counter

from helper_funcs import generate_sequences


In [2]:
# Load data
loaded_df = pd.read_hdf('./data/sample_otu_arrays.h5', key='df')

# Set random seed
np.random.seed(42)

# Split indices into train/test
train_idx, test_idx = train_test_split(loaded_df.index, test_size=0.2, random_state=42)

# Create train and test dataframes
train_df = loaded_df.loc[train_idx]
test_df = loaded_df.loc[test_idx]

print(f"Train size: {len(train_df)}")
print(f"Test size: {len(test_df)}")
print("\nFirst few training samples:")
print(train_df.head())

# Let's also look at array lengths
array_lengths = [len(x) for x in loaded_df['otu_arrays']]
print(f"\nMin array length: {min(array_lengths)}")
print(f"Max array length: {max(array_lengths)}")
print(f"Mean array length: {np.mean(array_lengths):.2f}")

Train size: 6486
Test size: 1622

First few training samples:
                                                            otu_arrays
Unnamed: 0                                                            
SRR044975.SRS011167  [30, 58, 82, 89, 93, 98, 99, 104, 117, 120, 12...
SRR049604.SRS049164  [9, 10, 11, 14, 15, 16, 17, 20, 28, 30, 31, 32...
SRR331714.SRS076947  [19, 30, 43, 58, 65, 70, 71, 74, 80, 90, 92, 9...
SRR089999.SRS077685  [12, 14, 18, 20, 22, 38, 45, 67, 68, 76, 88, 1...
SRR048091.SRS021563  [19, 30, 45, 52, 58, 60, 65, 70, 74, 80, 90, 9...

Min array length: 3
Max array length: 277
Mean array length: 69.10


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader

class OTUDataset(Dataset):
   def __init__(self, df):
       self.df = df
       
       # Find max sequence length for padding
       self.max_len = max(len(x) for x in df['otu_arrays'])
       
   def __len__(self):
       return len(self.df)
   
   def __getitem__(self, idx):
       # Get array for this sample
       array = self.df.iloc[idx]['otu_arrays']
       
       # Create padded tensor
       padded = torch.zeros(self.max_len, dtype=torch.long)
       padded[:len(array)] = torch.tensor(array)
       
       # Create mask (False where we have real tokens, True for padding)
       mask = torch.zeros(self.max_len, dtype=torch.bool)
       mask[len(array):] = True
       
       return padded, mask

# Create datasets
train_dataset = OTUDataset(train_df)
test_dataset = OTUDataset(test_df)

# Create dataloaders
batch_size = 68
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Verify shapes
for tokens, mask in train_loader:
   print(f"Batch tokens shape: {tokens.shape}")
   print(f"Batch mask shape: {mask.shape}")

   break

# Get vocab size (maximum token ID + 1 for padding)
vocab_size = max(max(x) for x in loaded_df['otu_arrays']) + 1
print(f"\nVocabulary size: {vocab_size}")

Batch tokens shape: torch.Size([68, 277])
Batch mask shape: torch.Size([68, 277])

Vocabulary size: 519


In [4]:
import model_arch
import helper_funcs
import importlib
from model_arch import CategoricalScoreDiffusion
from helper_funcs import generate_sequences
importlib.reload(model_arch)
importlib.reload(helper_funcs)

<module 'helper_funcs' from '/mnt/mnemo9/mpelus/matlas/cdcd_multi_train/cdcd_hmp/helper_funcs.py'>

In [5]:

class TrainingMetrics:
    def __init__(self):
        self.best_val_loss = float('inf')

        
    def update_best_metrics(self, val_loss):
        improved = False
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            improved = True
        return improved

def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)

    # Get clean embeddings
    x0 = model.embedding(tokens)
  
    
    # Add noise
    noise = model.get_noise(x0, t)

    xt = x0 + noise

    
    # Get model predictions
    logits = model(xt, mask, t)

    
    # Compute loss
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        model.update_time_warping(t, loss.detach())
        loss.backward()
        optimizer.step()
    
    return loss.item()

def validation_step(model, tokens, mask, device):
    # Sample time using warping
    t = model.sample_time(tokens.shape[0], tokens.device)
    
    # Get clean embeddings
    x0 = model.embedding(tokens)
    
    # Add noise according to N(0, σt²)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    
    # Get model predictions
    logits = model(xt, mask, t)
    
    # Compute cross-entropy loss with padding handling
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0  # Assuming 0 is padding token
    )
    
    return loss.item()

def save_checkpoint(model, optimizer, scheduler, epoch, train_loss, val_loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_loss': train_loss,
        'val_loss': val_loss,
        
    }
    torch.save(checkpoint, 'best_model.pt')

def log_metrics(metrics_dict, step_type='batch'):
    wandb.log(metrics_dict)

def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    return train_loss / len(train_loader)

def validate_epoch(model, test_loader, device, epoch):
    model.eval()
    val_loss = 0
    val_bar = tqdm(test_loader, desc=f'Validation Epoch {epoch}')
    
    # Collect real sequences
    real_sequences = []
    with torch.no_grad():
        for tokens, mask in val_bar:
            tokens = tokens.to(device)
            mask = mask.to(device)
            
            loss = validation_step(model, tokens, mask, device)
            val_loss += loss
            val_bar.set_postfix({'loss': f'{loss:.4f}'})
            
            real_sequences.extend([seq[seq != 0].cpu().numpy() for seq in tokens])

    
    return val_loss / len(test_loader)



def train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device, use_lr_scheduling=True):
    metrics = TrainingMetrics()
    
    scheduler = None
    if use_lr_scheduling:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', patience=3, factor=0.5, verbose=True
        )
    
    for epoch in range(num_epochs):
        # Training phase
        avg_train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        log_metrics({'train/epoch_loss': avg_train_loss, 'epoch': epoch})
         
        # Validation phase (every 5 epochs)
        if epoch % 1 == 0:
            avg_val_loss = validate_epoch(model, test_loader, device, epoch)
            
            log_metrics({
                'val/epoch_loss': avg_val_loss,
                'epoch': epoch
            })
            
            print(f'\nEpoch {epoch}:')
            print(f'Average Train Loss: {avg_train_loss:.4f}')
            print(f'Average Val Loss: {avg_val_loss:.4f}')
         
            
            if scheduler:
                scheduler.step(avg_val_loss)
            
            if metrics.update_best_metrics(avg_val_loss):
                save_checkpoint(model, optimizer, scheduler, epoch, avg_train_loss, avg_val_loss)
                log_metrics({
                    'best_model/val_loss': avg_val_loss,
                    'best_model/train_loss': avg_train_loss,
                    'best_model/epoch': epoch
                })
        else:
            print(f'\nEpoch {epoch}: Average Train Loss: {avg_train_loss:.4f}\n')


def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    t = model.sample_time(tokens.shape[0], tokens.device)
    x0 = model.embedding(tokens)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    logits = model(xt, mask, t)
    
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        # Just collect statistics instead of updating
        model.collect_time_statistics(t, loss.detach())
        loss.backward()
        optimizer.step()
    
    return loss.item()


def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    num_batches = len(train_loader)
    
    # Reset statistics at start of epoch
    model.epoch_loss_history.zero_()
    model.epoch_count_history.zero_()
    
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    # Update time warping at end of epoch
    model.update_time_warping_epoch()
    
    return train_loss / num_batches

def train_step(model, tokens, mask, optimizer, device):
    optimizer.zero_grad()
    
    t = model.sample_time(tokens.shape[0], tokens.device)
    x0 = model.embedding(tokens)
    noise = model.get_noise(x0, t)
    xt = x0 + noise
    logits = model(xt, mask, t)
    
    loss = F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        tokens.view(-1),
        ignore_index=0
    )

    if not torch.isnan(loss):
        # Update time warping statistics and weights immediately
        model.collect_time_statistics(t, loss.detach())
        model.update_time_warping_batch()  # New method we'll add
        loss.backward()
        optimizer.step()
    
    return loss.item()


def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    train_loss = 0
    num_batches = len(train_loader)
    
    train_bar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    
    for batch_idx, (tokens, mask) in enumerate(train_bar):
        tokens = tokens.to(device)
        mask = mask.to(device)
        
        loss = train_step(model, tokens, mask, optimizer, device)
        train_loss += loss
        train_bar.set_postfix({'loss': f'{loss:.4f}'})
        
        log_metrics({
            'train/batch_loss': loss,
            'train/learning_rate': optimizer.param_groups[0]['lr'],
            'epoch': epoch,
            'batch': batch_idx
        })
    
    return train_loss / num_batches

In [6]:
# Initialize model
embed_dim =16 #8 
num_layers = 5 #5
num_heads = 4
dim_feedforward = 32 #32
num_fourier_features = 8# going from 4 to 8 destabilised the batch loss but seems o have resulted in a faster convergence and lower
model = CategoricalScoreDiffusion(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dim_feedforward=dim_feedforward,
    num_fourier_features=num_fourier_features
    
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Move model to device
model = model.to(device)


In [7]:
import wandb
num_epochs = 200
learning_rate = 1e-3

wandb.finish()
wandb.init(
    project="diffusion-hmp",
    config={
        "learning_rate": learning_rate,
        "architecture": "restart",
        "dataset": "hmp",
        "epochs": num_epochs,
        "embed_dim": embed_dim,
        "num_layers": num_layers,
        "num_heads": num_heads,
        "dim_feedforward": dim_feedforward,
        "vocab_size": vocab_size,
        "num_fourier_features":num_fourier_features
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmatteopeluso1922[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
# Training parameters
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# Start training
train_and_validate(model, train_loader, test_loader, optimizer, num_epochs, device)

Training Epoch 0: 100%|██████████| 96/96 [00:04<00:00, 23.65it/s, loss=5.0822]
  output = torch._nested_tensor_from_mask(
Validation Epoch 0: 100%|██████████| 24/24 [00:03<00:00,  7.34it/s, loss=5.0721]



Epoch 0:
Average Train Loss: 5.3628
Average Val Loss: 5.0480


Training Epoch 1: 100%|██████████| 96/96 [00:04<00:00, 22.74it/s, loss=4.7633]
Validation Epoch 1: 100%|██████████| 24/24 [00:02<00:00, 11.03it/s, loss=4.6936]



Epoch 1:
Average Train Loss: 4.8146
Average Val Loss: 4.7015


Training Epoch 2: 100%|██████████| 96/96 [00:03<00:00, 24.77it/s, loss=4.3543]
Validation Epoch 2: 100%|██████████| 24/24 [00:03<00:00,  6.98it/s, loss=4.5947]



Epoch 2:
Average Train Loss: 4.6219
Average Val Loss: 4.5783


Training Epoch 3: 100%|██████████| 96/96 [00:04<00:00, 23.85it/s, loss=4.6707]
Validation Epoch 3: 100%|██████████| 24/24 [00:01<00:00, 14.79it/s, loss=4.4326]



Epoch 3:
Average Train Loss: 4.5405
Average Val Loss: 4.4808


Training Epoch 4: 100%|██████████| 96/96 [00:04<00:00, 22.88it/s, loss=4.5727]
Validation Epoch 4: 100%|██████████| 24/24 [00:03<00:00,  7.20it/s, loss=4.6450]



Epoch 4:
Average Train Loss: 4.4549
Average Val Loss: 4.4685


Training Epoch 5: 100%|██████████| 96/96 [00:03<00:00, 29.45it/s, loss=4.3374]
Validation Epoch 5: 100%|██████████| 24/24 [00:03<00:00,  7.32it/s, loss=4.3760]



Epoch 5:
Average Train Loss: 4.4166
Average Val Loss: 4.3824


Training Epoch 6: 100%|██████████| 96/96 [00:04<00:00, 22.78it/s, loss=4.4902]
Validation Epoch 6: 100%|██████████| 24/24 [00:02<00:00, 11.42it/s, loss=4.2369]



Epoch 6:
Average Train Loss: 4.3903
Average Val Loss: 4.4092


Training Epoch 7: 100%|██████████| 96/96 [00:03<00:00, 25.45it/s, loss=4.7689]
Validation Epoch 7: 100%|██████████| 24/24 [00:03<00:00,  7.17it/s, loss=4.4086]



Epoch 7:
Average Train Loss: 4.3611
Average Val Loss: 4.3407


Training Epoch 8: 100%|██████████| 96/96 [00:04<00:00, 22.98it/s, loss=4.5321]
Validation Epoch 8: 100%|██████████| 24/24 [00:01<00:00, 20.57it/s, loss=4.3295]



Epoch 8:
Average Train Loss: 4.3306
Average Val Loss: 4.2751


Training Epoch 9: 100%|██████████| 96/96 [00:04<00:00, 22.83it/s, loss=4.0986]
Validation Epoch 9: 100%|██████████| 24/24 [00:03<00:00,  6.88it/s, loss=4.5237]



Epoch 9:
Average Train Loss: 4.2604
Average Val Loss: 4.2434


Training Epoch 10: 100%|██████████| 96/96 [00:03<00:00, 28.41it/s, loss=4.1995]
Validation Epoch 10: 100%|██████████| 24/24 [00:00<00:00, 55.88it/s, loss=4.0865]



Epoch 10:
Average Train Loss: 4.2637
Average Val Loss: 4.1997


Training Epoch 11: 100%|██████████| 96/96 [00:02<00:00, 42.27it/s, loss=4.0443]
Validation Epoch 11: 100%|██████████| 24/24 [00:00<00:00, 66.75it/s, loss=4.1812]



Epoch 11:
Average Train Loss: 4.2535
Average Val Loss: 4.1553


Training Epoch 12: 100%|██████████| 96/96 [00:02<00:00, 41.43it/s, loss=4.4283]
Validation Epoch 12: 100%|██████████| 24/24 [00:02<00:00,  9.02it/s, loss=4.1299]



Epoch 12:
Average Train Loss: 4.2013
Average Val Loss: 4.1791


Training Epoch 13: 100%|██████████| 96/96 [00:04<00:00, 23.55it/s, loss=3.7490]
Validation Epoch 13: 100%|██████████| 24/24 [00:03<00:00,  7.51it/s, loss=4.4620]



Epoch 13:
Average Train Loss: 4.2017
Average Val Loss: 4.1311


Training Epoch 14: 100%|██████████| 96/96 [00:03<00:00, 31.38it/s, loss=3.7119]
Validation Epoch 14: 100%|██████████| 24/24 [00:02<00:00,  8.36it/s, loss=3.7523]



Epoch 14:
Average Train Loss: 4.1292
Average Val Loss: 4.0650


Training Epoch 15: 100%|██████████| 96/96 [00:04<00:00, 22.89it/s, loss=3.9251]
Validation Epoch 15: 100%|██████████| 24/24 [00:03<00:00,  7.54it/s, loss=4.3253]



Epoch 15:
Average Train Loss: 4.1468
Average Val Loss: 4.1704


Training Epoch 16: 100%|██████████| 96/96 [00:02<00:00, 32.14it/s, loss=4.1871]
Validation Epoch 16: 100%|██████████| 24/24 [00:03<00:00,  7.84it/s, loss=4.1745]



Epoch 16:
Average Train Loss: 4.1092
Average Val Loss: 4.0847


Training Epoch 17: 100%|██████████| 96/96 [00:04<00:00, 22.95it/s, loss=4.2570]
Validation Epoch 17: 100%|██████████| 24/24 [00:03<00:00,  7.32it/s, loss=3.8277]



Epoch 17:
Average Train Loss: 4.0963
Average Val Loss: 4.0521


Training Epoch 18: 100%|██████████| 96/96 [00:02<00:00, 34.09it/s, loss=3.8582]
Validation Epoch 18: 100%|██████████| 24/24 [00:03<00:00,  7.33it/s, loss=4.1435]



Epoch 18:
Average Train Loss: 4.0459
Average Val Loss: 4.1158


Training Epoch 19: 100%|██████████| 96/96 [00:04<00:00, 23.16it/s, loss=4.0876]
Validation Epoch 19: 100%|██████████| 24/24 [00:03<00:00,  7.49it/s, loss=3.9804]



Epoch 19:
Average Train Loss: 4.0808
Average Val Loss: 4.0640


Training Epoch 20: 100%|██████████| 96/96 [00:02<00:00, 33.45it/s, loss=3.8995]
Validation Epoch 20: 100%|██████████| 24/24 [00:03<00:00,  7.74it/s, loss=4.1042]



Epoch 20:
Average Train Loss: 4.0947
Average Val Loss: 4.0679


Training Epoch 21: 100%|██████████| 96/96 [00:04<00:00, 22.93it/s, loss=3.9190]
Validation Epoch 21: 100%|██████████| 24/24 [00:03<00:00,  7.88it/s, loss=3.8125]



Epoch 21:
Average Train Loss: 4.0521
Average Val Loss: 3.9890


Training Epoch 22: 100%|██████████| 96/96 [00:03<00:00, 30.94it/s, loss=4.0073]
Validation Epoch 22: 100%|██████████| 24/24 [00:03<00:00,  7.22it/s, loss=4.0241]



Epoch 22:
Average Train Loss: 4.0082
Average Val Loss: 4.0092


Training Epoch 23: 100%|██████████| 96/96 [00:04<00:00, 22.70it/s, loss=4.1507]
Validation Epoch 23: 100%|██████████| 24/24 [00:02<00:00,  9.04it/s, loss=4.2321]



Epoch 23:
Average Train Loss: 4.0233
Average Val Loss: 4.0313


Training Epoch 24: 100%|██████████| 96/96 [00:03<00:00, 29.16it/s, loss=4.2742]
Validation Epoch 24: 100%|██████████| 24/24 [00:03<00:00,  7.27it/s, loss=3.6869]



Epoch 24:
Average Train Loss: 4.0581
Average Val Loss: 3.8888


Training Epoch 25: 100%|██████████| 96/96 [00:04<00:00, 22.68it/s, loss=3.7281]
Validation Epoch 25: 100%|██████████| 24/24 [00:02<00:00,  9.86it/s, loss=4.2957]



Epoch 25:
Average Train Loss: 3.8819
Average Val Loss: 3.9696


Training Epoch 26: 100%|██████████| 96/96 [00:03<00:00, 27.94it/s, loss=3.8483]
Validation Epoch 26: 100%|██████████| 24/24 [00:03<00:00,  7.17it/s, loss=4.0238]



Epoch 26:
Average Train Loss: 3.8934
Average Val Loss: 3.8700


Training Epoch 27: 100%|██████████| 96/96 [00:04<00:00, 22.67it/s, loss=4.1040]
Validation Epoch 27: 100%|██████████| 24/24 [00:02<00:00, 10.76it/s, loss=4.0548]



Epoch 27:
Average Train Loss: 3.8849
Average Val Loss: 3.8685


Training Epoch 28: 100%|██████████| 96/96 [00:03<00:00, 26.51it/s, loss=4.1792]
Validation Epoch 28: 100%|██████████| 24/24 [00:03<00:00,  7.27it/s, loss=3.9977]



Epoch 28:
Average Train Loss: 3.8928
Average Val Loss: 3.8521


Training Epoch 29: 100%|██████████| 96/96 [00:04<00:00, 23.03it/s, loss=4.1528]
Validation Epoch 29: 100%|██████████| 24/24 [00:02<00:00, 11.49it/s, loss=3.4880]



Epoch 29:
Average Train Loss: 3.8607
Average Val Loss: 3.8140


Training Epoch 30: 100%|██████████| 96/96 [00:03<00:00, 25.85it/s, loss=3.4933]
Validation Epoch 30: 100%|██████████| 24/24 [00:03<00:00,  7.25it/s, loss=3.7965]



Epoch 30:
Average Train Loss: 3.8828
Average Val Loss: 3.7666


Training Epoch 31: 100%|██████████| 96/96 [00:04<00:00, 22.94it/s, loss=3.7610]
Validation Epoch 31: 100%|██████████| 24/24 [00:01<00:00, 12.09it/s, loss=4.1890]



Epoch 31:
Average Train Loss: 3.8967
Average Val Loss: 3.9032


Training Epoch 32: 100%|██████████| 96/96 [00:02<00:00, 40.53it/s, loss=3.7149]
Validation Epoch 32: 100%|██████████| 24/24 [00:00<00:00, 66.94it/s, loss=4.1345]



Epoch 32:
Average Train Loss: 3.8704
Average Val Loss: 3.8540


Training Epoch 33: 100%|██████████| 96/96 [00:02<00:00, 42.39it/s, loss=3.9422]
Validation Epoch 33: 100%|██████████| 24/24 [00:00<00:00, 66.53it/s, loss=4.1170]



Epoch 33:
Average Train Loss: 3.8443
Average Val Loss: 3.8959


Training Epoch 34: 100%|██████████| 96/96 [00:03<00:00, 31.74it/s, loss=4.4045]
Validation Epoch 34: 100%|██████████| 24/24 [00:02<00:00, 11.78it/s, loss=3.6694]



Epoch 34:
Average Train Loss: 3.8608
Average Val Loss: 3.8672


Training Epoch 35: 100%|██████████| 96/96 [00:03<00:00, 30.16it/s, loss=3.8513]
Validation Epoch 35: 100%|██████████| 24/24 [00:02<00:00, 11.65it/s, loss=3.5754]



Epoch 35:
Average Train Loss: 3.8507
Average Val Loss: 3.8528


Training Epoch 36: 100%|██████████| 96/96 [00:03<00:00, 28.81it/s, loss=4.4768]
Validation Epoch 36: 100%|██████████| 24/24 [00:02<00:00, 11.76it/s, loss=3.6812]



Epoch 36:
Average Train Loss: 3.8710
Average Val Loss: 3.8082


Training Epoch 37: 100%|██████████| 96/96 [00:03<00:00, 26.72it/s, loss=3.9573]
Validation Epoch 37: 100%|██████████| 24/24 [00:01<00:00, 15.64it/s, loss=3.6087]



Epoch 37:
Average Train Loss: 3.8520
Average Val Loss: 3.8409


Training Epoch 38: 100%|██████████| 96/96 [00:03<00:00, 25.43it/s, loss=4.0224]
Validation Epoch 38: 100%|██████████| 24/24 [00:01<00:00, 20.92it/s, loss=4.0125]



Epoch 38:
Average Train Loss: 3.8409
Average Val Loss: 3.8706


Training Epoch 39: 100%|██████████| 96/96 [00:03<00:00, 25.40it/s, loss=3.9551]
Validation Epoch 39: 100%|██████████| 24/24 [00:01<00:00, 13.49it/s, loss=3.7907]



Epoch 39:
Average Train Loss: 3.8261
Average Val Loss: 3.8130


Training Epoch 40: 100%|██████████| 96/96 [00:03<00:00, 27.64it/s, loss=3.4866]
Validation Epoch 40: 100%|██████████| 24/24 [00:02<00:00, 10.87it/s, loss=3.6550]



Epoch 40:
Average Train Loss: 3.7569
Average Val Loss: 3.7256


Training Epoch 41: 100%|██████████| 96/96 [00:03<00:00, 28.94it/s, loss=3.6292]
Validation Epoch 41: 100%|██████████| 24/24 [00:02<00:00, 10.12it/s, loss=3.8004]



Epoch 41:
Average Train Loss: 3.6586
Average Val Loss: 3.6599


Training Epoch 42: 100%|██████████| 96/96 [00:03<00:00, 29.74it/s, loss=3.3096]
Validation Epoch 42: 100%|██████████| 24/24 [00:02<00:00,  9.90it/s, loss=3.5164]



Epoch 42:
Average Train Loss: 3.6408
Average Val Loss: 3.6872


Training Epoch 43: 100%|██████████| 96/96 [00:03<00:00, 29.72it/s, loss=3.1166]
Validation Epoch 43: 100%|██████████| 24/24 [00:02<00:00, 10.16it/s, loss=3.6759]



Epoch 43:
Average Train Loss: 3.6810
Average Val Loss: 3.6771


Training Epoch 44: 100%|██████████| 96/96 [00:03<00:00, 29.69it/s, loss=3.8997]
Validation Epoch 44: 100%|██████████| 24/24 [00:02<00:00, 10.27it/s, loss=3.8448]



Epoch 44:
Average Train Loss: 3.6713
Average Val Loss: 3.6470


Training Epoch 45: 100%|██████████| 96/96 [00:02<00:00, 33.67it/s, loss=3.9613]
Validation Epoch 45: 100%|██████████| 24/24 [00:00<00:00, 66.24it/s, loss=4.0208]



Epoch 45:
Average Train Loss: 3.6443
Average Val Loss: 3.7053


Training Epoch 46: 100%|██████████| 96/96 [00:02<00:00, 42.28it/s, loss=3.6722]
Validation Epoch 46: 100%|██████████| 24/24 [00:00<00:00, 66.02it/s, loss=3.7349]



Epoch 46:
Average Train Loss: 3.6547
Average Val Loss: 3.6355


Training Epoch 47: 100%|██████████| 96/96 [00:02<00:00, 41.16it/s, loss=3.5672]
Validation Epoch 47: 100%|██████████| 24/24 [00:03<00:00,  6.23it/s, loss=3.7636]



Epoch 47:
Average Train Loss: 3.6721
Average Val Loss: 3.6272


Training Epoch 48: 100%|██████████| 96/96 [00:04<00:00, 21.81it/s, loss=3.1180]
Validation Epoch 48: 100%|██████████| 24/24 [00:04<00:00,  5.96it/s, loss=3.6748]



Epoch 48:
Average Train Loss: 3.6655
Average Val Loss: 3.6268


Training Epoch 49: 100%|██████████| 96/96 [00:04<00:00, 23.14it/s, loss=3.6092]
Validation Epoch 49: 100%|██████████| 24/24 [00:02<00:00, 10.15it/s, loss=3.3725]



Epoch 49:
Average Train Loss: 3.6173
Average Val Loss: 3.6520


Training Epoch 50: 100%|██████████| 96/96 [00:04<00:00, 21.84it/s, loss=3.4467]
Validation Epoch 50: 100%|██████████| 24/24 [00:04<00:00,  5.47it/s, loss=3.7142]



Epoch 50:
Average Train Loss: 3.6491
Average Val Loss: 3.6714


Training Epoch 51: 100%|██████████| 96/96 [00:04<00:00, 21.84it/s, loss=3.3211]
Validation Epoch 51: 100%|██████████| 24/24 [00:02<00:00, 11.24it/s, loss=3.7119]



Epoch 51:
Average Train Loss: 3.6625
Average Val Loss: 3.6634


Training Epoch 52: 100%|██████████| 96/96 [00:04<00:00, 22.58it/s, loss=3.9297]
Validation Epoch 52: 100%|██████████| 24/24 [00:04<00:00,  5.53it/s, loss=3.6132]



Epoch 52:
Average Train Loss: 3.6562
Average Val Loss: 3.7064


Training Epoch 53: 100%|██████████| 96/96 [00:04<00:00, 21.57it/s, loss=3.8765]
Validation Epoch 53: 100%|██████████| 24/24 [00:03<00:00,  6.69it/s, loss=3.8135]



Epoch 53:
Average Train Loss: 3.6561
Average Val Loss: 3.6863


Training Epoch 54: 100%|██████████| 96/96 [00:03<00:00, 24.76it/s, loss=3.9796]
Validation Epoch 54: 100%|██████████| 24/24 [00:04<00:00,  5.38it/s, loss=3.7481]



Epoch 54:
Average Train Loss: 3.6930
Average Val Loss: 3.6819


Training Epoch 55: 100%|██████████| 96/96 [00:04<00:00, 21.65it/s, loss=3.4034]
Validation Epoch 55: 100%|██████████| 24/24 [00:04<00:00,  5.89it/s, loss=3.3982]



Epoch 55:
Average Train Loss: 3.6628
Average Val Loss: 3.6217


Training Epoch 56: 100%|██████████| 96/96 [00:03<00:00, 25.30it/s, loss=3.5816]
Validation Epoch 56: 100%|██████████| 24/24 [00:04<00:00,  5.39it/s, loss=3.6161]



Epoch 56:
Average Train Loss: 3.6477
Average Val Loss: 3.6471


Training Epoch 57: 100%|██████████| 96/96 [00:04<00:00, 21.77it/s, loss=3.8095]
Validation Epoch 57: 100%|██████████| 24/24 [00:04<00:00,  5.64it/s, loss=3.4115]



Epoch 57:
Average Train Loss: 3.6700
Average Val Loss: 3.5606


Training Epoch 58: 100%|██████████| 96/96 [00:03<00:00, 25.56it/s, loss=3.2341]
Validation Epoch 58: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s, loss=3.4315]



Epoch 58:
Average Train Loss: 3.6275
Average Val Loss: 3.6399


Training Epoch 59: 100%|██████████| 96/96 [00:04<00:00, 21.75it/s, loss=3.5585]
Validation Epoch 59: 100%|██████████| 24/24 [00:04<00:00,  5.54it/s, loss=3.7879]



Epoch 59:
Average Train Loss: 3.6402
Average Val Loss: 3.6656


Training Epoch 60: 100%|██████████| 96/96 [00:03<00:00, 26.07it/s, loss=4.1770]
Validation Epoch 60: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s, loss=3.6632]



Epoch 60:
Average Train Loss: 3.6505
Average Val Loss: 3.5820


Training Epoch 61: 100%|██████████| 96/96 [00:04<00:00, 21.82it/s, loss=3.8978]
Validation Epoch 61: 100%|██████████| 24/24 [00:04<00:00,  5.48it/s, loss=3.5783]



Epoch 61:
Average Train Loss: 3.6565
Average Val Loss: 3.6426


Training Epoch 62: 100%|██████████| 96/96 [00:03<00:00, 25.98it/s, loss=4.1863]
Validation Epoch 62: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s, loss=3.5958]



Epoch 62:
Average Train Loss: 3.6175
Average Val Loss: 3.6801


Training Epoch 63: 100%|██████████| 96/96 [00:04<00:00, 20.85it/s, loss=3.5726]
Validation Epoch 63: 100%|██████████| 24/24 [00:04<00:00,  5.55it/s, loss=3.3395]



Epoch 63:
Average Train Loss: 3.6423
Average Val Loss: 3.6508


Training Epoch 64: 100%|██████████| 96/96 [00:03<00:00, 25.93it/s, loss=3.7585]
Validation Epoch 64: 100%|██████████| 24/24 [00:04<00:00,  5.47it/s, loss=3.4816]



Epoch 64:
Average Train Loss: 3.6120
Average Val Loss: 3.6099


Training Epoch 65: 100%|██████████| 96/96 [00:04<00:00, 21.76it/s, loss=3.8225]
Validation Epoch 65: 100%|██████████| 24/24 [00:04<00:00,  5.41it/s, loss=3.5347]



Epoch 65:
Average Train Loss: 3.6645
Average Val Loss: 3.7330


Training Epoch 66: 100%|██████████| 96/96 [00:03<00:00, 26.06it/s, loss=3.8506]
Validation Epoch 66: 100%|██████████| 24/24 [00:04<00:00,  5.39it/s, loss=3.7869]



Epoch 66:
Average Train Loss: 3.6591
Average Val Loss: 3.6405


Training Epoch 67: 100%|██████████| 96/96 [00:04<00:00, 21.75it/s, loss=4.1466]
Validation Epoch 67: 100%|██████████| 24/24 [00:04<00:00,  5.42it/s, loss=3.7548]



Epoch 67:
Average Train Loss: 3.6655
Average Val Loss: 3.6742


Training Epoch 68: 100%|██████████| 96/96 [00:02<00:00, 34.79it/s, loss=3.3387]
Validation Epoch 68: 100%|██████████| 24/24 [00:00<00:00, 65.57it/s, loss=3.5110]



Epoch 68:
Average Train Loss: 3.6362
Average Val Loss: 3.6729


Training Epoch 69: 100%|██████████| 96/96 [00:02<00:00, 42.18it/s, loss=3.4158]
Validation Epoch 69: 100%|██████████| 24/24 [00:00<00:00, 67.03it/s, loss=3.7942]



Epoch 69:
Average Train Loss: 3.6625
Average Val Loss: 3.6322


Training Epoch 70: 100%|██████████| 96/96 [00:02<00:00, 42.27it/s, loss=3.3033]
Validation Epoch 70: 100%|██████████| 24/24 [00:00<00:00, 66.39it/s, loss=3.9910]



Epoch 70:
Average Train Loss: 3.6350
Average Val Loss: 3.6553


Training Epoch 71: 100%|██████████| 96/96 [00:02<00:00, 39.18it/s, loss=3.9266]
Validation Epoch 71: 100%|██████████| 24/24 [00:00<00:00, 29.92it/s, loss=3.3124]



Epoch 71:
Average Train Loss: 3.6417
Average Val Loss: 3.6575


Training Epoch 72: 100%|██████████| 96/96 [00:03<00:00, 31.95it/s, loss=3.7318]
Validation Epoch 72: 100%|██████████| 24/24 [00:00<00:00, 27.97it/s, loss=3.7582]



Epoch 72:
Average Train Loss: 3.6379
Average Val Loss: 3.6815


Training Epoch 73: 100%|██████████| 96/96 [00:02<00:00, 32.20it/s, loss=3.8657]
Validation Epoch 73: 100%|██████████| 24/24 [00:00<00:00, 27.86it/s, loss=3.7060]



Epoch 73:
Average Train Loss: 3.6625
Average Val Loss: 3.6187


Training Epoch 74: 100%|██████████| 96/96 [00:03<00:00, 31.97it/s, loss=3.5702]
Validation Epoch 74: 100%|██████████| 24/24 [00:00<00:00, 28.80it/s, loss=3.9392]



Epoch 74:
Average Train Loss: 3.6417
Average Val Loss: 3.6732


Training Epoch 75: 100%|██████████| 96/96 [00:02<00:00, 34.55it/s, loss=3.6868]
Validation Epoch 75: 100%|██████████| 24/24 [00:00<00:00, 27.39it/s, loss=3.6906]



Epoch 75:
Average Train Loss: 3.6156
Average Val Loss: 3.6161


Training Epoch 76: 100%|██████████| 96/96 [00:02<00:00, 32.28it/s, loss=3.6074]
Validation Epoch 76: 100%|██████████| 24/24 [00:00<00:00, 27.94it/s, loss=3.8133]



Epoch 76:
Average Train Loss: 3.6466
Average Val Loss: 3.6335


Training Epoch 77: 100%|██████████| 96/96 [00:02<00:00, 32.27it/s, loss=3.3136]
Validation Epoch 77: 100%|██████████| 24/24 [00:00<00:00, 28.52it/s, loss=3.9964]



Epoch 77:
Average Train Loss: 3.6306
Average Val Loss: 3.6760


Training Epoch 78: 100%|██████████| 96/96 [00:02<00:00, 32.56it/s, loss=3.4387]
Validation Epoch 78: 100%|██████████| 24/24 [00:00<00:00, 28.23it/s, loss=3.7163]



Epoch 78:
Average Train Loss: 3.6543
Average Val Loss: 3.6168


Training Epoch 79: 100%|██████████| 96/96 [00:02<00:00, 33.51it/s, loss=3.3433]
Validation Epoch 79: 100%|██████████| 24/24 [00:00<00:00, 27.59it/s, loss=3.9958]



Epoch 79:
Average Train Loss: 3.6865
Average Val Loss: 3.6733


Training Epoch 80: 100%|██████████| 96/96 [00:02<00:00, 32.21it/s, loss=3.2543]
Validation Epoch 80: 100%|██████████| 24/24 [00:00<00:00, 27.90it/s, loss=3.5401]



Epoch 80:
Average Train Loss: 3.6239
Average Val Loss: 3.5407


Training Epoch 81: 100%|██████████| 96/96 [00:02<00:00, 32.31it/s, loss=3.4039]
Validation Epoch 81: 100%|██████████| 24/24 [00:00<00:00, 27.48it/s, loss=3.4988]



Epoch 81:
Average Train Loss: 3.6290
Average Val Loss: 3.6341


Training Epoch 82: 100%|██████████| 96/96 [00:02<00:00, 33.37it/s, loss=3.7970]
Validation Epoch 82: 100%|██████████| 24/24 [00:00<00:00, 28.33it/s, loss=3.7539]



Epoch 82:
Average Train Loss: 3.6398
Average Val Loss: 3.6549


Training Epoch 83: 100%|██████████| 96/96 [00:03<00:00, 31.94it/s, loss=3.6913]
Validation Epoch 83: 100%|██████████| 24/24 [00:00<00:00, 28.38it/s, loss=3.5833]



Epoch 83:
Average Train Loss: 3.6407
Average Val Loss: 3.6013


Training Epoch 84: 100%|██████████| 96/96 [00:03<00:00, 31.81it/s, loss=3.3534]
Validation Epoch 84: 100%|██████████| 24/24 [00:00<00:00, 27.82it/s, loss=3.9553]



Epoch 84:
Average Train Loss: 3.6708
Average Val Loss: 3.6441


Training Epoch 85: 100%|██████████| 96/96 [00:03<00:00, 31.86it/s, loss=3.9950]
Validation Epoch 85: 100%|██████████| 24/24 [00:00<00:00, 27.95it/s, loss=3.8746]



Epoch 85:
Average Train Loss: 3.6461
Average Val Loss: 3.5827


Training Epoch 86: 100%|██████████| 96/96 [00:02<00:00, 36.92it/s, loss=3.5485]
Validation Epoch 86: 100%|██████████| 24/24 [00:00<00:00, 28.11it/s, loss=3.5011]



Epoch 86:
Average Train Loss: 3.6417
Average Val Loss: 3.6332


Training Epoch 87: 100%|██████████| 96/96 [00:03<00:00, 31.83it/s, loss=3.8463]
Validation Epoch 87: 100%|██████████| 24/24 [00:00<00:00, 28.07it/s, loss=3.6075]



Epoch 87:
Average Train Loss: 3.6568
Average Val Loss: 3.6433


Training Epoch 88: 100%|██████████| 96/96 [00:02<00:00, 32.10it/s, loss=3.9680]
Validation Epoch 88: 100%|██████████| 24/24 [00:00<00:00, 30.51it/s, loss=3.6105]



Epoch 88:
Average Train Loss: 3.6772
Average Val Loss: 3.6456


Training Epoch 89: 100%|██████████| 96/96 [00:02<00:00, 32.27it/s, loss=4.2282]
Validation Epoch 89: 100%|██████████| 24/24 [00:00<00:00, 27.46it/s, loss=3.7001]



Epoch 89:
Average Train Loss: 3.6514
Average Val Loss: 3.6314


Training Epoch 90: 100%|██████████| 96/96 [00:02<00:00, 36.73it/s, loss=4.0527]
Validation Epoch 90: 100%|██████████| 24/24 [00:00<00:00, 28.36it/s, loss=3.7107]



Epoch 90:
Average Train Loss: 3.6678
Average Val Loss: 3.6060


Training Epoch 91: 100%|██████████| 96/96 [00:03<00:00, 31.88it/s, loss=3.8543]
Validation Epoch 91: 100%|██████████| 24/24 [00:00<00:00, 27.64it/s, loss=3.5607]



Epoch 91:
Average Train Loss: 3.6308
Average Val Loss: 3.6176


Training Epoch 92: 100%|██████████| 96/96 [00:03<00:00, 30.85it/s, loss=3.6712]
Validation Epoch 92: 100%|██████████| 24/24 [00:00<00:00, 27.81it/s, loss=3.5405]



Epoch 92:
Average Train Loss: 3.6331
Average Val Loss: 3.6602


Training Epoch 93: 100%|██████████| 96/96 [00:03<00:00, 31.32it/s, loss=3.5905]
Validation Epoch 93: 100%|██████████| 24/24 [00:00<00:00, 28.06it/s, loss=3.6564]



Epoch 93:
Average Train Loss: 3.6378
Average Val Loss: 3.6559


Training Epoch 94: 100%|██████████| 96/96 [00:03<00:00, 31.83it/s, loss=3.8305]
Validation Epoch 94: 100%|██████████| 24/24 [00:00<00:00, 27.90it/s, loss=3.7487]



Epoch 94:
Average Train Loss: 3.6435
Average Val Loss: 3.6171


Training Epoch 95: 100%|██████████| 96/96 [00:03<00:00, 30.78it/s, loss=3.8776]
Validation Epoch 95: 100%|██████████| 24/24 [00:00<00:00, 27.50it/s, loss=3.7098]



Epoch 95:
Average Train Loss: 3.6551
Average Val Loss: 3.7013


Training Epoch 96: 100%|██████████| 96/96 [00:03<00:00, 30.92it/s, loss=3.3968]
Validation Epoch 96: 100%|██████████| 24/24 [00:00<00:00, 27.90it/s, loss=3.5005]



Epoch 96:
Average Train Loss: 3.6869
Average Val Loss: 3.5937


Training Epoch 97: 100%|██████████| 96/96 [00:02<00:00, 34.37it/s, loss=3.5994]
Validation Epoch 97: 100%|██████████| 24/24 [00:00<00:00, 27.17it/s, loss=3.7008]



Epoch 97:
Average Train Loss: 3.6209
Average Val Loss: 3.6063


Training Epoch 98: 100%|██████████| 96/96 [00:03<00:00, 31.00it/s, loss=3.8332]
Validation Epoch 98: 100%|██████████| 24/24 [00:00<00:00, 27.48it/s, loss=3.8600]



Epoch 98:
Average Train Loss: 3.6639
Average Val Loss: 3.6811


Training Epoch 99: 100%|██████████| 96/96 [00:03<00:00, 30.99it/s, loss=3.4038]
Validation Epoch 99: 100%|██████████| 24/24 [00:00<00:00, 27.75it/s, loss=3.5705]



Epoch 99:
Average Train Loss: 3.6720
Average Val Loss: 3.6369


Training Epoch 100: 100%|██████████| 96/96 [00:02<00:00, 35.39it/s, loss=3.5759]
Validation Epoch 100: 100%|██████████| 24/24 [00:00<00:00, 27.91it/s, loss=3.7881]



Epoch 100:
Average Train Loss: 3.6401
Average Val Loss: 3.6619


Training Epoch 101: 100%|██████████| 96/96 [00:03<00:00, 30.97it/s, loss=3.6163]
Validation Epoch 101: 100%|██████████| 24/24 [00:00<00:00, 27.89it/s, loss=3.5538]



Epoch 101:
Average Train Loss: 3.6481
Average Val Loss: 3.6048


Training Epoch 102: 100%|██████████| 96/96 [00:03<00:00, 30.96it/s, loss=3.4462]
Validation Epoch 102: 100%|██████████| 24/24 [00:00<00:00, 27.51it/s, loss=3.5728]



Epoch 102:
Average Train Loss: 3.6476
Average Val Loss: 3.6069


Training Epoch 103: 100%|██████████| 96/96 [00:03<00:00, 31.98it/s, loss=4.1694]
Validation Epoch 103: 100%|██████████| 24/24 [00:00<00:00, 27.91it/s, loss=3.3516]



Epoch 103:
Average Train Loss: 3.6588
Average Val Loss: 3.5395


Training Epoch 104: 100%|██████████| 96/96 [00:03<00:00, 31.16it/s, loss=3.2135]
Validation Epoch 104: 100%|██████████| 24/24 [00:00<00:00, 27.54it/s, loss=3.6734]



Epoch 104:
Average Train Loss: 3.6438
Average Val Loss: 3.6218


Training Epoch 105: 100%|██████████| 96/96 [00:03<00:00, 30.99it/s, loss=2.9612]
Validation Epoch 105: 100%|██████████| 24/24 [00:00<00:00, 27.90it/s, loss=3.5997]



Epoch 105:
Average Train Loss: 3.6572
Average Val Loss: 3.6361


Training Epoch 106: 100%|██████████| 96/96 [00:03<00:00, 30.94it/s, loss=3.4446]
Validation Epoch 106: 100%|██████████| 24/24 [00:00<00:00, 28.10it/s, loss=3.4187]



Epoch 106:
Average Train Loss: 3.6600
Average Val Loss: 3.6547


Training Epoch 107: 100%|██████████| 96/96 [00:02<00:00, 40.02it/s, loss=3.6379]
Validation Epoch 107: 100%|██████████| 24/24 [00:00<00:00, 65.99it/s, loss=3.7190]



Epoch 107:
Average Train Loss: 3.6433
Average Val Loss: 3.5981


Training Epoch 108: 100%|██████████| 96/96 [00:02<00:00, 42.20it/s, loss=3.3756]
Validation Epoch 108: 100%|██████████| 24/24 [00:00<00:00, 66.27it/s, loss=3.5472]



Epoch 108:
Average Train Loss: 3.6425
Average Val Loss: 3.6659


Training Epoch 109: 100%|██████████| 96/96 [00:02<00:00, 42.22it/s, loss=3.6659]
Validation Epoch 109: 100%|██████████| 24/24 [00:00<00:00, 65.88it/s, loss=3.7765]



Epoch 109:
Average Train Loss: 3.6563
Average Val Loss: 3.5871


Training Epoch 110: 100%|██████████| 96/96 [00:02<00:00, 34.68it/s, loss=3.7406]
Validation Epoch 110: 100%|██████████| 24/24 [00:00<00:00, 30.18it/s, loss=3.8967]



Epoch 110:
Average Train Loss: 3.6335
Average Val Loss: 3.6460


Training Epoch 111: 100%|██████████| 96/96 [00:02<00:00, 34.10it/s, loss=3.6252]
Validation Epoch 111: 100%|██████████| 24/24 [00:00<00:00, 29.42it/s, loss=3.4235]



Epoch 111:
Average Train Loss: 3.6532
Average Val Loss: 3.6318


Training Epoch 112:  98%|█████████▊| 94/96 [00:02<00:00, 34.57it/s, loss=3.5247]


KeyboardInterrupt: 

In [None]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'model_args': {
        'vocab_size': vocab_size,
        'embed_dim': embed_dim,
        'num_layers': num_layers,
        'num_heads': num_heads,
        'dim_feedforward': dim_feedforward,
        'num_fourier_features': num_fourier_features
    }
}

torch.save(checkpoint, 'model_checkpoint_3.58pt')

In [None]:
from model_arch import CategoricalScoreDiffusion

checkpoint = torch.load('model_checkpoint_2.65.pt')
model = CategoricalScoreDiffusion(**checkpoint['model_args'])
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
# Access the learning rate
# Get the optimizer state dict
optimizer_state = checkpoint['optimizer_state_dict']
learning_rate = optimizer_state['param_groups'][0]['lr']
print(f"Learning rate: {learning_rate}")

Learning rate: 0.001


In [None]:
import time
from contextlib import contextmanager

@contextmanager
def timer(name):
    start = time.perf_counter()
    yield
    end = time.perf_counter()
    print(f"{name}: {(end - start)*1000:.2f} ms")