In [181]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [182]:
# Function to create a base diffusion model
def diffusion_model(input_dim=3, embedding_dim=128, hidden_dim=256, num_categories=3, num_layers=2, dropout=0.1):
    # Stroke embedder: [x, y, pen_state] -> embedding
    stroke_embedder = nn.Linear(input_dim, embedding_dim)
    
    # Category embedder: category_id -> embedding  
    category_embedder = nn.Embedding(num_categories, embedding_dim)
    
    # Temporal encoder: LSTM for sequence modeling
    temporal_encoder = nn.LSTM(
        input_size=embedding_dim,
        hidden_size=hidden_dim,
        num_layers=num_layers,
        batch_first=True,
        dropout=dropout if num_layers > 1 else 0
    )
    
    # Noise predictor: hidden_states -> predicted_noise
    noise_predictor = nn.Linear(hidden_dim, input_dim)
    
    model_components = {
        'stroke_embedder': stroke_embedder,
        'category_embedder': category_embedder, 
        'temporal_encoder': temporal_encoder,
        'noise_predictor': noise_predictor
    }
    
    return model_components