In [42]:
# Imports
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
from torch_fidelity import calculate_metrics
import os
from pathlib import Path
import shutil
import tempfile
from PIL import Image

In [28]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [29]:
with open('processed_datasets.pkl', 'rb') as f:
    processed_datasets = pickle.load(f)

In [30]:
# Load trained model
def load_trained_model(model_path, device):
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    model_components = checkpoint['model_components']
    noise_schedule = checkpoint['noise_schedule']
    training_info = checkpoint['training_info']
    
    for name, component in model_components.items():
        component.to(device)
        component.eval()
    
    print(f"Category: {training_info['category']}")
    print(f"Training epochs: {training_info['num_epochs']}")
    print(f"Final loss: {training_info['final_loss']:.4f}")
    print(f"Model components: {list(model_components.keys())}")
    
    return model_components, noise_schedule, training_info

In [37]:
# Helper function to generate sketches
def model_forward_training(sequences, model_components, categories, timesteps):
    # Get components
    stroke_embedder = model_components['stroke_embedder']
    category_embedder = model_components['category_embedder']
    temporal_encoder = model_components['temporal_encoder']
    noise_predictor = model_components['noise_predictor']
    #attention_layer = model_components['attention_layer']
    
    batch_size, seq_len, _ = sequences.shape
    
    # Embded stroke sequences
    stroke_embeddings = stroke_embedder(sequences) # [batch, seq_len, embedding_dim]
    
    # Embed categories and inject into sequence
    category_embeddings = category_embedder(categories)  # [batch, embedding_dim]
    category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
    
    # Combine stroke and category embeddings
    conditioned_embeddings = stroke_embeddings + category_expanded
    
    # Process through LSTM
    lstm_output, _ = temporal_encoder(conditioned_embeddings)  # [batch, seq_len, hidden_dim]
    
    # Self attention
    #final_output, _ = attention_layer(lstm_output, lstm_output, lstm_output)
    
    # Predict noise
    predicted_noise = noise_predictor(lstm_output)
    
    return predicted_noise

In [32]:
# Generate samples from the model
def generate_samples(model_components, noise_schedule, num_samples, seq_length, device, num_steps=500, category_id=1):
    generated_samples = []
    
    # Set all model components to eval mode
    for component in model_components.values():
        component.eval()
    
    with torch.no_grad():
        for i in range(num_samples):
            # Start with random noise
            sketch = torch.randn(1, seq_length, 3, device=device)
            sketch[0, :, 2] = 2.0  # Start with pen down
            
            # Category for conditioning
            category = torch.tensor([category_id], device=device)
            
            # Simple denoising loop (same as your generate_simple_sketch)
            for step in range(num_steps):
                # Current timestep (going from high noise to low noise)
                t = torch.tensor([num_steps - step - 1], device=device)
                
                # Predict noise using your model (using your model_forward_training function)
                predicted_noise = model_forward_training(sketch, model_components, category, t)
                
                # PROPER DDPM denoising step (exactly like your function)
                if step < num_steps - 1:
                    # Get noise schedule values
                    alpha_t = noise_schedule['alphas'][t]
                    beta_t = noise_schedule['betas'][t]
                    alpha_cumprod_t = noise_schedule['alphas_cumprod'][t]
                    
                    # DDPM formula
                    noise = torch.randn_like(sketch) if step < num_steps - 1 else torch.zeros_like(sketch)
                    coeff1 = 1 / torch.sqrt(alpha_t)
                    coeff2 = beta_t / torch.sqrt(1 - alpha_cumprod_t)
                    
                    # Apply to coordinates only
                    sketch[:, :, :2] = coeff1 * (sketch[:, :, :2] - coeff2 * predicted_noise[:, :, :2]) + torch.sqrt(beta_t) * noise[:, :, :2]
                    
                    # For pen states, use gentler denoising
                    pen_noise_scale = 0.01 * (num_steps - step) / num_steps
                    sketch[:, :, 2] = sketch[:, :, 2] - pen_noise_scale * predicted_noise[:, :, 2]
                
                # Keep pen states reasonable
                sketch[:, :, 2] = torch.clamp(sketch[:, :, 2], 0, 3)
            
            # Add the generated sketch to samples
            generated_samples.append(sketch[0].cpu().numpy())
            
            # Progress tracking
            if (i + 1) % 50 == 0:
                print(f"Generated {i + 1}/{num_samples} samples")
    return np.array(generated_samples)

In [43]:
def stroke_to_image_and_save(stroke_sequence, save_path, image_size=64):
    img = np.zeros((image_size, image_size))
    
    for i in range(len(stroke_sequence)):
        x, y, pen_state = stroke_sequence[i]
        
        # Skip padding and end markers
        if pen_state == -1 or pen_state == 3:
            continue
            
        # Only draw when pen is down (pen_state == 2)
        if pen_state == 2:
            # Convert from [-1, 1] to [0, image_size-1]
            x_pixel = int((x + 1) * (image_size - 1) / 2)
            y_pixel = int((y + 1) * (image_size - 1) / 2)
            
            x_pixel = np.clip(x_pixel, 0, image_size-1)
            y_pixel = np.clip(y_pixel, 0, image_size-1)
            
            # Draw point with small brush
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    nx, ny = x_pixel + dx, y_pixel + dy
                    if 0 <= nx < image_size and 0 <= ny < image_size:
                        img[ny, nx] = 1.0
    
    # Convert to RGB and save
    img_rgb = np.stack([img, img, img], axis=2)
    img_uint8 = (img_rgb * 255).astype(np.uint8)
    pil_img = Image.fromarray(img_uint8, 'RGB')
    pil_img.save(save_path)

def create_image_directories(real_data, generated_data, temp_dir):
    real_dir = os.path.join(temp_dir, "real")
    gen_dir = os.path.join(temp_dir, "generated")
    
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(gen_dir, exist_ok=True)
    
    # Save real images
    print("Saving real images...")
    for i, stroke_seq in enumerate(real_data):
        save_path = os.path.join(real_dir, f"real_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
        if (i + 1) % 100 == 0:
            print(f"  Saved {i + 1}/{len(real_data)} real images")
    
    # Save generated images
    print("Saving generated images...")
    for i, stroke_seq in enumerate(generated_data):
        save_path = os.path.join(gen_dir, f"gen_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
        if (i + 1) % 100 == 0:
            print(f"  Saved {i + 1}/{len(generated_data)} generated images")
    
    return real_dir, gen_dir

In [44]:
def calculate_kid_fid_metrics(real_dir, generated_dir):
    metrics = calculate_metrics(
            input1=generated_dir,
            input2=real_dir,
            fid=True,
            kid=True,
            cuda=torch.cuda.is_available(),
            verbose=True
        )
    
    fid_score = metrics['frechet_inception_distance']
    kid_score = metrics['kernel_inception_distance_mean']
    
    return fid_score, kid_score

In [48]:
# Model evaluation
def run_evaluation(category, device, processed_datasets, model_path, num_samples=100):
    # Load trained model
    model_components, noise_schedule, _ = load_trained_model(model_path, device)
    
    # Load test data
    real_data = processed_datasets[category]['test_data']
    seq_length = real_data.shape[1]
    
    print(f"Using {num_samples} samples")
    # Use subset for faster evaluation
    
    num_samples = min(num_samples, len(real_data))
    real_data_subset = real_data[:num_samples]
    
    generated_data = generate_samples(model_components, noise_schedule, num_samples, seq_length, device, category_id=0, num_steps=500)
    
    # Calculate metrics
    with tempfile.TemporaryDirectory() as temp_dir:
        print("Creating image directories...")
        real_dir, gen_dir = create_image_directories(real_data_subset, generated_data, temp_dir)
        
        print("Calculating FID/KID metrics...")
        try:
            fid_score, kid_score = calculate_kid_fid_metrics(real_dir, gen_dir)
            
            print("\nRESULTS:")
            print(f"FID Score: {fid_score:.4f}")
            print(f"KID Score: {kid_score:.4f}")
            
            return {
                'category': category,
                'FID': fid_score,
                'KID': kid_score,
                'num_samples': num_samples,
                'seq_length': seq_length
            }
            
        except Exception as e:
            print(f"Metric calculation failed: {e}")
            return None

In [49]:
score_dict = run_evaluation('bus', device, processed_datasets, 'models/sketch_diffusion_bus_20250803_084052.pth')

Category: bus
Training epochs: 100
Final loss: 0.0190
Model components: ['stroke_embedder', 'category_embedder', 'temporal_encoder', 'noise_predictor']
Using 100 samples
Generated 50/100 samples
Generated 100/100 samples
Creating image directories...
Saving real images...
  Saved 100/100 real images
Saving generated images...


  pil_img = Image.fromarray(img_uint8, 'RGB')


  Saved 100/100 generated images
Calculating FID/KID metrics...


Creating feature extractor "inception-v3-compat" with features ['2048']
Extracting features from input1
Looking for samples non-recursivelty in "/tmp/tmpaz41h91n/generated" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                        
Extracting features from input2
Looking for samples non-recursivelty in "/tmp/tmpaz41h91n/real" with extensions png,jpg,jpeg
Found 100 samples
Processing samples                                                        


Metric calculation failed: KID subset size 1000 cannot be smaller than the number of samples (input_1: 100, input_2: 100). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to proceed.


Frechet Inception Distance: 435.27740579160036


In [51]:
print(score_dict)

None
