In [36]:
# Imports
import torch
import torch.nn as nn
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch_fidelity import calculate_metrics
import json
import os
from pathlib import Path
import shutil

In [37]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [38]:
with open('processed_datasets.pkl', 'rb') as f:
    processed_datasets = pickle.load(f)

In [None]:
# Load trained model
def load_trained_model(model_path, device):
    checkpoint = torch.load(model_path, map_location=device)
    
    model_components = checkpoint['model_components']
    noise_schedule = checkpoint['noise_schedule']
    training_info = checkpoint['training_info']
    
    for name, component in model_components.items():
        component.to(device)
        component.eval()
    
    print(f"Category: {training_info['category']}")
    print(f"Training epochs: {training_info['num_epochs']}")
    print(f"Final loss: {training_info['final_loss']:.4f}")
    print(f"Model components: {list(model_components.keys())}")
    
    return model_components, noise_schedule, training_info

In [50]:
# Helper function to generate sketches
def model_forward_training(sequences, model_components, categories, timesteps):
    # Get components
    stroke_embedder = model_components['stroke_embedder']
    category_embedder = model_components['category_embedder']
    temporal_encoder = model_components['temporal_encoder']
    noise_predictor = model_components['noise_predictor']
    attention_layer = model_components['attention_layer']
    
    batch_size, seq_len, _ = sequences.shape
    
    # Embded stroke sequences
    stroke_embeddings = stroke_embedder(sequences) # [batch, seq_len, embedding_dim]
    
    # Embed categories and inject into sequence
    category_embeddings = category_embedder(categories)  # [batch, embedding_dim]
    category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
    
    # Combine stroke and category embeddings
    conditioned_embeddings = stroke_embeddings + category_expanded
    
    # Process through LSTM
    lstm_output, _ = temporal_encoder(conditioned_embeddings)  # [batch, seq_len, hidden_dim]
    
    # Self attention
    final_output, _ = attention_layer(lstm_output, lstm_output, lstm_output)
    
    # Predict noise
    predicted_noise = noise_predictor(lstm_output)
    
    return predicted_noise

In [51]:
# Generate samples from the model
def generate_samples(model_components, num_samples, seq_length, device):
    generated_samples = []
    
    with torch.no_grad():
        for i in range(num_samples):
            # Start with random noise
            sketch = torch.randn(1, seq_length, 3, device=device)
            sketch[0, :, 2] = 1.0  # Start with pen down
            
            # Category for conditioning
            category = torch.tensor([0], device=device)
            
            # Denoising loop
            num_steps = 1000
            for step in range(num_steps):
                t = torch.tensor([num_steps - step - 1], device=device)
                predicted_noise = model_forward_training(sketch, model_components, category, t)
                
                # Denoising step
                alpha_t = 1.0 - 0.02 * step / num_steps
                sketch[:, :, :2] = sketch[:, :, :2] - alpha_t * predicted_noise[:, :, :2]
                sketch[:, :, 2] = torch.clamp(sketch[:, :, 2], 0, 3)
            
            generated_samples.append(sketch[0].cpu().numpy())
            
            if (i + 1) % 25 == 0:
                print(f"  Progress: {i + 1}/{num_samples}")
    return np.array(generated_samples)

In [42]:
def stroke_to_image_tensor(stroke_sequence, image_size=64):
    img = np.zeros((image_size, image_size))
    
    for i in range(len(stroke_sequence)):
        x, y, pen_state = stroke_sequence[i]
        
        # Skip padding and end markers
        if pen_state == -1 or pen_state == 3:
            continue
            
        # Only draw when pen is down (pen_state == 2)
        if pen_state == 2:
            # Convert from [-1, 1] to [0, image_size-1]
            x_pixel = int((x + 1) * (image_size - 1) / 2)
            y_pixel = int((y + 1) * (image_size - 1) / 2)
            
            x_pixel = np.clip(x_pixel, 0, image_size-1)
            y_pixel = np.clip(y_pixel, 0, image_size-1)
            
            # Draw point with small brush
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    nx, ny = x_pixel + dx, y_pixel + dy
                    if 0 <= nx < image_size and 0 <= ny < image_size:
                        img[ny, nx] = 1.0
    return img

def convert_to_tensors(stroke_data, image_size=64):
    images = []
    for i, stroke_seq in enumerate(stroke_data):
        img = stroke_to_image_tensor(stroke_seq, image_size)
        # Convert to RGB format: (3, H, W)
        img_rgb = np.stack([img, img, img], axis=0)
        images.append(img_rgb)
        
        if (i + 1) % 100 == 0:
            print(f"  Converted {i + 1}/{len(stroke_data)} images")
    
    # Convert to torch tensor: (N, 3, H, W)
    tensor = torch.from_numpy(np.array(images)).float()
    
    # Normalize to [0, 1] range
    tensor = tensor.clamp(0, 1)
    return tensor

In [43]:
def calculate_kid_fid_metrics(real_tensor, generated_tensor):
    # Ensure tensors are in correct format and range
    real_tensor = real_tensor.clamp(0, 1)
    generated_tensor = generated_tensor.clamp(0, 1)
    
    # Convert to uint8 format (0-255)
    real_tensor = (real_tensor * 255).byte()
    generated_tensor = (generated_tensor * 255).byte()
    
    print(f"Real tensor shape: {real_tensor.shape}")
    print(f"Generated tensor shape: {generated_tensor.shape}")
    
    # Calculate metrics directly from tensors
    metrics = calculate_metrics(
        input1=generated_tensor,
        input2=real_tensor,
        fid=True,
        kid=True,
        cuda=torch.cuda.is_available(),
        verbose=False
    )
    
    fid_score = metrics['frechet_inception_distance']
    kid_score = metrics['kernel_inception_distance_mean']
    
    return fid_score, kid_score

In [47]:
# Model evaluation
def run_evaluation(category, device, processed_datasets, model_path, num_samples=500):
    # Load trained model
    model_components, noise_schedule, training_info = load_trained_model(model_path, device)
    
    # Load test data
    real_data = processed_datasets[category]['test_data']
    seq_length = real_data.shape[1]
    
    print(f"Using {num_samples} samples")
    # Use subset for faster evaluation
    
    num_samples = min(num_samples, len(real_data))
    real_data_subset = real_data[:num_samples]
    
    generated_data = generate_samples(model_components, num_samples, seq_length, device)
    
    # Convert to tensors
    real_tensor = convert_to_tensors(real_data_subset)
    generated_tensor = convert_to_tensors(generated_data)
    
    # Calculate metrics
    fid_score, kid_score = calculate_kid_fid_metrics(real_tensor, generated_tensor)
    
    if fid_score is not None:
        print("\nRESULTS:")
        print(f"FID Score: {fid_score:.4f}")
        print(f"KID Score: {kid_score:.4f}")
        
        return {
            'category': category,
            'FID': fid_score,
            'KID': kid_score,
            'num_samples': num_samples,
            'seq_length': seq_length
        }
    else:
        return None

In [48]:
score_dict = run_evaluation('bus', device, processed_datasets, 'models/sketch_diffusion_bus_20250802_144910.pth')

Category: bus
Training epochs: 5
Final loss: 0.0319
Model components: ['stroke_embedder', 'category_embedder', 'temporal_encoder', 'noise_predictor']
Using 500 samples


TypeError: 'collections.OrderedDict' object is not callable