In [204]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

In [186]:
with open("processed_datasets.pkl", "rb") as file: # save file as pickle to use in other notebooks
    prcessed_datasets = pickle.load(file)

In [187]:
prcessed_datasets.keys()

dict_keys(['cat', 'rabbit', 'bus'])

In [None]:
# Function to create a base diffusion model
def diffusion_model(input_dim=3, embedding_dim=256, hidden_dim=256, num_categories=3, num_layers=2, dropout=0.1):
    # Stroke embedder: [x, y, pen_state] -> embedding
    stroke_embedder = nn.Linear(input_dim, embedding_dim)
    
    # Category embedder: category_id -> embedding  
    category_embedder = nn.Embedding(num_categories, embedding_dim)
    
    # Temporal encoder: LSTM for sequence modeling
    temporal_encoder = nn.LSTM(
        input_size=embedding_dim,
        hidden_size=hidden_dim,
        num_layers=num_layers,
        batch_first=True,
        dropout=dropout if num_layers > 1 else 0
    )
    
    # Noise predictor: hidden_states -> predicted_noise
    noise_predictor = nn.Linear(hidden_dim, input_dim)
    
    model_components = {
        'stroke_embedder': stroke_embedder,
        'category_embedder': category_embedder, 
        'temporal_encoder': temporal_encoder,
        'noise_predictor': noise_predictor
    }
    
    return model_components

In [199]:
prcessed_datasets['bus']['train_data'].shape

(10000, 451, 3)

In [200]:
def model_forward_training(sequences, model_components, categories, timesteps):
    # Get components
    stroke_embedder = model_components['stroke_embedder']
    category_embedder = model_components['category_embedder']
    temporal_encoder = model_components['temporal_encoder']
    noise_predictor = model_components['noise_predictor']
    
    batch_size, seq_len, _ = sequences.shape
    
    # Embded stroke sequences
    stroke_embeddings = stroke_embedder(sequences) # [batch, seq_len, embedding_dim]
    
    # Embed categories and inject into sequence
    category_embeddings = category_embedder(categories)  # [batch, embedding_dim]
    category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
    
    # Combine stroke and category embeddings
    conditioned_embeddings = stroke_embeddings + category_expanded
    
    # Process through LSTM
    lstm_output, _ = temporal_encoder(conditioned_embeddings)  # [batch, seq_len, hidden_dim]
    
    # Predict noise
    predicted_noise = noise_predictor(lstm_output)
    
    return predicted_noise

In [201]:
# Create progressive noise schedule for diffusion
def create_noise_schedule(timesteps=400, beta_start=0.0001, beta_end=0.02):
    # Linear schedule
    betas = torch.linspace(beta_start, beta_end, timesteps)
    
    # Calculate alphas and cumulative alphas
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    
    return {
        'betas': betas,
        'alphas': alphas,
        'alphas_cumprod': alphas_cumprod,
        'timesteps': timesteps
    }

In [202]:
def noise_addition(sequences, noise, timesteps, noise_schedule):
    alphas_cumprod = noise_schedule['alphas_cumprod']
    
    # Get alpha values for each timestep in batch
    alpha_t = alphas_cumprod[timesteps].view(-1, 1, 1)  # [batch, 1, 1]
    
    # Only add noise to coordinates [x, y], keep pen_state [2] unchanged
    noisy_sequences = sequences.clone()
    
    # Add noise to x and y coordinates
    noisy_sequences[:, :, :2] = torch.sqrt(alpha_t) * sequences[:, :, :2] + torch.sqrt(1 - alpha_t) * noise[:, :, :2] # Keep pen_state unchanged (no noise on pen_state) noisy_sequences[:, :, 2] remains the same
    
    return noisy_sequences

In [203]:
# Compute weighted loss for diffusion traning
def compute_weigthed_loss(predicted_noise, target_noise, sequences):
    # Basic MSE loss
    mse_loss = F.mse_loss(predicted_noise, target_noise, reduction='none')
    
    # Create weights based on stroke changes
    pen_states = sequences[:, :, 2]  # [batch, seq_len]
    
    # Higher weight for stroke boundaries
    pen_change_weight = torch.ones_like(pen_states)
    
    # Find stroke changes
    pen_changes = torch.abs(pen_states[:, 1:] - pen_states[:, :-1]) > 0.5
    pen_change_weight[:, 1:][pen_changes] = 2.0 
    
    # Apply weights to coordinates only
    weights = torch.ones_like(mse_loss)
    weights[:, :, :2] = pen_change_weight.unsqueeze(-1).expand(-1, -1, 2)
    
    # Weighted loss calculation
    weighted_loss = (mse_loss * weights).mean()
    
    return weighted_loss

In [207]:
def create_dataloader(dataset, batch_size=8):
    category_mapping = {
    'cat': 0,
    'bus': 1, 
    'rabbit': 2
    }
    
    category_name = dataset['category']
    category_id = category_mapping[category_name]
    
    # Get data
    train_data = dataset['train_data']
    test_data = dataset['test_data']
    
    # Convert train data to torch
    train_sequences = torch.FloatTensor(train_data)
    train_categories = torch.full((len(train_data),), category_id, dtype=torch.long)
    
    # Convert test data to torch
    test_sequences = torch.FloatTensor(test_data)
    test_categories = torch.full((len(test_data),), category_id, dtype=torch.long)
    
    # Create datasets
    train_dataset = TensorDataset(train_sequences, train_categories)
    test_dataset = TensorDataset(test_sequences, test_categories)
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    print(f"Train: {len(train_dataset)} samples, {len(train_loader)} batches")
    print(f"Test: {len(test_dataset)} samples, {len(test_loader)} batches")
    
    return train_loader, test_loader

In [208]:
train_loader, test_loader = create_dataloader(prcessed_datasets['bus'])

Train: 10000 samples, 1250 batches
Test: 2000 samples, 250 batches


In [None]:
# Setup device
def setup_device():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return device

In [218]:
# Setup adam optimizer
def initialize_optimizer(model_components, learning_rate=1e-4):
    # Collect all parameters from all components
    all_params = []
    total_params = 0
    
    for _, component in model_components.items():
        component_params = list(component.parameters())
        all_params.extend(component_params)
        num_params = sum(p.numel() for p in component_params)
        total_params += num_params
    
    # Create optimizer
    optimizer = torch.optim.Adam(all_params, lr=learning_rate)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Learning rate: {learning_rate}")
    
    return optimizer

In [None]:
# Helper function to train the model one epochs
def train_one_epoch(model_components, dataloader, noise_schedule, optimizer, device):
    # Set parameters
    total_loss = 0
    num_batches = 0
    
    # Set models to training mode
    for component in model_components.values():
        component.train()
    
    for batch_idx, (sequences, categories) in enumerate(tqdm(dataloader, desc="Training")):
        
        # Move data to device
        sequences = sequences.to(device)
        categories = categories.to(device)
        batch_size = sequences.shape[0]
        
        # Sample random timesteps for diffusion
        timesteps = torch.randint(0, noise_schedule['timesteps'], (batch_size,), device=device)
        
        # Generate noise (only for coordinates, not pen_state)
        noise = torch.randn_like(sequences, device=device)
        noise[:, :, 2] = 0  # No noise on pen_state
        
        # Add progressive noise to sequences
        alphas_cumprod = noise_schedule['alphas_cumprod']
        alpha_t = alphas_cumprod[timesteps].view(-1, 1, 1)
        noisy_sequences = sequences.clone()
        noisy_sequences[:, :, :2] = torch.sqrt(alpha_t) * sequences[:, :, :2] + torch.sqrt(1 - alpha_t) * noise[:, :, :2]
        
        # Forward pass through model
        stroke_embedder = model_components['stroke_embedder']
        category_embedder = model_components['category_embedder']
        temporal_encoder = model_components['temporal_encoder']
        noise_predictor = model_components['noise_predictor']
        
        seq_len = sequences.shape[1]
        stroke_embeddings = stroke_embedder(noisy_sequences)
        category_embeddings = category_embedder(categories)
        category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
        conditioned_embeddings = stroke_embeddings + category_expanded
        lstm_output, _ = temporal_encoder(conditioned_embeddings)
        predicted_noise = noise_predictor(lstm_output)
        
        # Compute weighted loss
        mse_loss = F.mse_loss(predicted_noise, noise, reduction='none')
        pen_states = sequences[:, :, 2]
        pen_change_weight = torch.ones_like(pen_states)
        pen_changes = torch.abs(pen_states[:, 1:] - pen_states[:, :-1]) > 0.5
        pen_change_weight[:, 1:][pen_changes] = 2.0
        weights = torch.ones_like(mse_loss)
        weights[:, :, :2] = pen_change_weight.unsqueeze(-1).expand(-1, -1, 2)
        loss = (mse_loss * weights).mean()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        all_params = [p for component in model_components.values() for p in component.parameters()]
        torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
        
        optimizer.step()
        
        # Track loss
        total_loss += loss.item()
        num_batches += 1
        
        # Print progress every 20 batches
        if batch_idx % 20 == 0:
            print(f"   Batch {batch_idx:3d}: Loss = {loss.item():.4f}")
    
    # Calculate average loss
    avg_loss = total_loss / num_batches if num_batches > 0 else 0