In [97]:
# Imports
import torch
import torch.nn.functional as F
import pickle
import numpy as np
from torch_fidelity import calculate_metrics
import os
import tempfile
from PIL import Image
import math

In [98]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [99]:
with open('processed_datasets.pkl', 'rb') as f:
    processed_datasets = pickle.load(f)

In [88]:
# Load trained model
def load_trained_model(model_path, device):
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    model_components = checkpoint['model_components']
    noise_schedule = checkpoint['noise_schedule']
    training_info = checkpoint['training_info']
    
    for name, component in model_components.items():
        component.to(device)
        component.eval()
    
    print(f"Category: {training_info['category']}")
    print(f"Training epochs: {training_info['num_epochs']}")
    print(f"Final loss: {training_info['final_loss']:.4f}")
    print(f"Model components: {list(model_components.keys())}")
    
    return model_components, noise_schedule, training_info

In [89]:
def create_timestep_embedding(timesteps, embedding_dim, device):
    """Create sinusoidal timestep embeddings"""
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    
    # Pad if odd dimension
    if embedding_dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    
    return emb

In [101]:
# Helper function to generate sketches
def model_forward_training(sequences, model_components, categories, timesteps):
    # Get components
    stroke_embedder = model_components['stroke_embedder']
    category_embedder = model_components['category_embedder']
    temporal_encoder = model_components['temporal_encoder']
    noise_predictor = model_components['noise_predictor']
    
    batch_size, seq_len, _ = sequences.shape
    
    # Embed stroke sequences first
    stroke_embeddings = stroke_embedder(sequences)
    embedding_dim = stroke_embeddings.shape[-1]
    
    # Create timestep embedding with correct dimension
    t_emb = create_timestep_embedding(timesteps, embedding_dim, device)
    t_expanded = t_emb.unsqueeze(1).expand(-1, seq_len, -1)
    
    # Embed categories and inject into sequence
    category_embeddings = category_embedder(categories)
    category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
    
    # Add timestep conditioning
    conditioned_embeddings = stroke_embeddings + category_expanded + t_expanded
    
    # Add layer normalization
    conditioned_embeddings = F.layer_norm(conditioned_embeddings, conditioned_embeddings.shape[-1:])
    
    # Process through LSTM
    lstm_output, _ = temporal_encoder(conditioned_embeddings)
    
    # Predict noise
    predicted_noise = noise_predictor(lstm_output)
    
    return predicted_noise

In [111]:
# Generate samples from the model
def generate_samples(model_components, noise_schedule, num_samples, seq_length, device, num_steps=750, category_id=1):
    generated_samples = []
    
    for component in model_components.values():
        component.eval()
    
    with torch.no_grad():
        for i in range(num_samples):
            # Start with random noise for coordinates
            sketch = torch.randn(1, seq_length, 3, device=device)
            
            # Create realistic pen state pattern
            pen_pattern = []
            remaining_length = seq_length
            while remaining_length > 0:
                if len(pen_pattern) == 0:
                    pen_pattern.append(1)  # Start stroke
                    remaining_length -= 1
                elif remaining_length > 10:
                    draw_length = min(torch.randint(3, 8, (1,)).item(), remaining_length - 3)
                    pen_pattern.extend([2] * draw_length)
                    remaining_length -= draw_length
                    pen_pattern.append(0)  # End stroke
                    remaining_length -= 1
                else:
                    pen_pattern.extend([2] * (remaining_length - 1))
                    pen_pattern.append(3)  # End of drawing
                    remaining_length = 0
            
            pen_states = torch.tensor(pen_pattern[:seq_length], device=device, dtype=torch.float32)
            sketch[0, :len(pen_states), 2] = pen_states
            if len(pen_states) < seq_length:
                sketch[0, len(pen_states):, 2] = 2.0
            
            category = torch.tensor([category_id], device=device)
            
            # Denoising loop
            for step in range(num_steps):
                t = torch.tensor([num_steps - step - 1], device=device)
                
                if t >= len(noise_schedule['alphas']):
                    t = torch.tensor([len(noise_schedule['alphas']) - 1], device=device)
                
                predicted_noise = model_forward_training(sketch, model_components, category, t)
                
                if step < num_steps - 1:
                    alpha_t = noise_schedule['alphas'][t]
                    beta_t = noise_schedule['betas'][t]
                    alpha_cumprod_t = noise_schedule['alphas_cumprod'][t]
                    
                    noise = torch.randn_like(sketch)
                    noise[:, :, 2] = 0
                    
                    coeff1 = 1 / torch.sqrt(alpha_t)
                    coeff2 = beta_t / torch.sqrt(1 - alpha_cumprod_t)
                    
                    sketch[:, :, :2] = coeff1 * (sketch[:, :, :2] - coeff2 * predicted_noise[:, :, :2]) + torch.sqrt(beta_t) * noise[:, :, :2]
                    
                    pen_noise_scale = 0.005 * (num_steps - step) / num_steps
                    sketch[:, :, 2] = sketch[:, :, 2] - pen_noise_scale * predicted_noise[:, :, 2]
                
                pen_states_raw = sketch[:, :, 2]
                pen_states_rounded = torch.round(pen_states_raw)
                pen_states_clamped = torch.clamp(pen_states_rounded, -1, 3)
                sketch[:, :, 2] = pen_states_clamped
            
            # Store sketch without individual normalization
            final_sketch = sketch[0].cpu().numpy()
            final_sketch[:, :2] = np.clip(final_sketch[:, :2], -1.0, 1.0)
            final_sketch[:, 2] = np.clip(np.round(final_sketch[:, 2]), -1, 3)
            
            generated_samples.append(final_sketch)
            
            if (i + 1) % 25 == 0:
                print(f"Generated {i + 1}/{num_samples} samples")
    
    # BATCH-LEVEL STATISTICAL NORMALIZATION
    generated_samples = np.array(generated_samples)
    
    # Extract all coordinates from all samples
    all_coords = generated_samples[:, :, :2].reshape(-1, 2)  # Flatten to (N*seq_len, 2)
    
    # Calculate batch statistics
    current_mean = all_coords.mean(axis=0)  # Mean for x and y separately
    current_std = all_coords.std(axis=0)    # Std for x and y separately
    
    # Target statistics (from real data)
    target_mean = np.array([-0.2245, -0.2245])  # Assuming same for x and y
    target_std = np.array([0.6561, 0.6561])     # Assuming same for x and y
    
    # Normalize coordinates across all samples
    for i in range(len(generated_samples)):
        coords = generated_samples[i, :, :2]
        
        # Apply normalization per dimension
        for dim in range(2):
            if current_std[dim] > 0:
                coords[:, dim] = (coords[:, dim] - current_mean[dim]) / current_std[dim] * target_std[dim] + target_mean[dim]
        
        # Clamp to valid range
        coords = np.clip(coords, -1.0, 1.0)
        generated_samples[i, :, :2] = coords
    
    return generated_samples

In [92]:
def stroke_to_image_and_save(stroke_sequence, save_path, image_size=64):
    img = np.zeros((image_size, image_size))
    
    for i in range(len(stroke_sequence)):
        x, y, pen_state = stroke_sequence[i]
        
        # Skip padding and end markers
        if pen_state == -1 or pen_state == 3:
            continue
            
        # Only draw when pen is down (pen_state == 2)
        if pen_state == 2:
            # Convert from [-1, 1] to [0, image_size-1]
            x_pixel = int((x + 1) * (image_size - 1) / 2)
            y_pixel = int((y + 1) * (image_size - 1) / 2)
            
            x_pixel = np.clip(x_pixel, 0, image_size-1)
            y_pixel = np.clip(y_pixel, 0, image_size-1)
            
            # Draw point with small brush
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    nx, ny = x_pixel + dx, y_pixel + dy
                    if 0 <= nx < image_size and 0 <= ny < image_size:
                        img[ny, nx] = 1.0
    
    # Convert to RGB and save
    img_rgb = np.stack([img, img, img], axis=2)
    img_uint8 = (img_rgb * 255).astype(np.uint8)
    pil_img = Image.fromarray(img_uint8, 'RGB')
    pil_img.save(save_path)

def create_image_directories(real_data, generated_data, temp_dir):
    real_dir = os.path.join(temp_dir, "real")
    gen_dir = os.path.join(temp_dir, "generated")
    
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(gen_dir, exist_ok=True)
    
    # Save real images
    for i, stroke_seq in enumerate(real_data):
        save_path = os.path.join(real_dir, f"real_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
    
    # Save generated images
    for i, stroke_seq in enumerate(generated_data):
        save_path = os.path.join(gen_dir, f"gen_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
    
    return real_dir, gen_dir

In [93]:
def calculate_kid_fid_metrics(real_dir, generated_dir, num_samples):
    metrics = calculate_metrics(
            input1=generated_dir,
            input2=real_dir,
            fid=True,
            kid=True,
            kid_subset_size = num_samples-1,
            cuda=torch.cuda.is_available(),
            verbose=True
        )
    
    fid_score = metrics['frechet_inception_distance']
    kid_score = metrics['kernel_inception_distance_mean']
    
    return fid_score, kid_score

In [94]:
def diagnose_generation_quality(generated_data, real_data):
    """Add this to your evaluation notebook"""
    print("\nGENERATION DIAGNOSTICS:")
    
    # Statistical comparison
    gen_mean = np.mean(generated_data)
    real_mean = np.mean(real_data)
    gen_std = np.std(generated_data)
    real_std = np.std(real_data)
    
    print(f"Generated - Mean: {gen_mean:.4f}, Std: {gen_std:.4f}")
    print(f"Real data - Mean: {real_mean:.4f}, Std: {real_std:.4f}")
    
    # Check for issues
    if abs(gen_mean - real_mean) > 0.5:
        print("⚠ WARNING: Generated data mean differs significantly from real data")
    
    if abs(gen_std - real_std) > 0.3:
        print("⚠ WARNING: Generated data variance differs significantly from real data")
    
    # Check coordinate ranges
    gen_coords = generated_data[:, :, :2].reshape(-1, 2)
    real_coords = real_data[:, :, :2].reshape(-1, 2)
    
    print(f"Generated coord range: [{gen_coords.min():.3f}, {gen_coords.max():.3f}]")
    print(f"Real coord range: [{real_coords.min():.3f}, {real_coords.max():.3f}]")
    
    # Check pen states
    gen_pen = generated_data[:, :, 2]
    real_pen = real_data[:, :, 2]
    
    print(f"Generated pen states: {np.unique(gen_pen.astype(int))}")
    print(f"Real pen states: {np.unique(real_pen.astype(int))}")
    
    results = {
        'mean_diff': abs(gen_mean - real_mean),
        'std_diff': abs(gen_std - real_std),
        'coord_range_match': abs(gen_coords.max() - real_coords.max()) < 0.5,
        'pen_states_valid': len(np.unique(gen_pen.astype(int))) <= 4
    }
    
    return results

In [109]:
# Model evaluation
def run_evaluation(category, device, processed_datasets, model_path, num_samples=100):
    # Load trained model
    model_components, noise_schedule, _ = load_trained_model(model_path, device)
    
    # Load test data
    real_data = processed_datasets[category]['test_data']
    seq_length = real_data.shape[1]
    
    print(f"Using {num_samples} samples")
    # Use subset for faster evaluation
    
    num_samples = min(num_samples, len(real_data))
    real_data_subset = real_data[:num_samples]
    
    generated_data = generate_samples(model_components, noise_schedule, num_samples, seq_length, device, category_id=0, num_steps=500)
    
    diagnose_generation_quality(generated_data, real_data_subset)
    
    # Create tempfiles to save images
    with tempfile.TemporaryDirectory() as temp_dir:
        real_dir, gen_dir = create_image_directories(real_data_subset, generated_data, temp_dir)
        
        try:
            # Calculate metrics
            fid_score, kid_score = calculate_kid_fid_metrics(real_dir, gen_dir, num_samples)
            
            print("\nRESULTS:")
            print(f"FID Score: {fid_score:.4f}")
            print(f"KID Score: {kid_score:.4f}")
            
            return {
                'category': category,
                'FID': fid_score,
                'KID': kid_score,
                'num_samples': num_samples,
                'seq_length': seq_length
            }
            
        except Exception as e:
            print(f"Metric calculation failed: {e}")
            return None

In [110]:
score_dict = run_evaluation('bus', device, processed_datasets, 'models/sketch_diffusion_bus_20250804_221426.pth')

Category: bus
Training epochs: 25
Final loss: 0.0436
Model components: ['stroke_embedder', 'category_embedder', 'temporal_encoder', 'noise_predictor']
Using 100 samples
Generated 25/100 samples
Generated 50/100 samples
Generated 75/100 samples
Generated 100/100 samples

GENERATION DIAGNOSTICS:
Generated - Mean: 0.4202, Std: 1.0833
Real data - Mean: -0.2245, Std: 0.6561
Generated coord range: [-1.000, 1.000]
Real coord range: [-1.000, 1.000]
Generated pen states: [0 1 2 3]
Real pen states: [-1  0  1  2  3]


  pil_img = Image.fromarray(img_uint8, 'RGB')
Creating feature extractor "inception-v3-compat" with features ['2048']
Extracting features from input1
Looking for samples non-recursivelty in "/tmp/tmphjjk8xh3/generated" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                        
Extracting features from input2
Looking for samples non-recursivelty in "/tmp/tmphjjk8xh3/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                        
Frechet Inception Distance: 187.3189529534955
Kernel Inception Distance: 0.2597881031036377 ± 0.0008696856659023659           



RESULTS:
FID Score: 187.3190
KID Score: 0.2598


In [107]:
print(score_dict)

{'category': 'bus', 'FID': 212.28985279593488, 'KID': 0.3002964091300964, 'num_samples': 100, 'seq_length': 451}
