In [1]:
!pip install torch_fidelity

Collecting torch_fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->torch_fidelity)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->torch_fidelity)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->torch_fidelity)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->torch_fidelity)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->torch_fidelity)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->torch_fidel

In [2]:
# Imports
import torch
import torch.nn.functional as F
import pickle
import numpy as np
from torch_fidelity import calculate_metrics
import os
import tempfile
from PIL import Image
import math

In [3]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Load preprocessed dataset
with open('/kaggle/input/quuick-draw-dataset/processed_datasets.pkl', 'rb') as f:
    processed_datasets = pickle.load(f)

In [5]:
# Load trained model
def load_trained_model(model_path, device):
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    
    model_components = checkpoint['model_components']
    noise_schedule = checkpoint['noise_schedule']
    training_info = checkpoint['training_info']
    
    for name, component in model_components.items():
        component.to(device)
        component.eval()
    
    print(f"Category: {training_info['category']}")
    print(f"Training epochs: {training_info['num_epochs']}")
    print(f"Final loss: {training_info['final_loss']:.4f}")
    print(f"Model components: {list(model_components.keys())}")
    
    return model_components, noise_schedule, training_info

In [6]:
# Generate timestep embeddings
def create_timestep_embedding(timesteps, embedding_dim, device):
    """Create sinusoidal timestep embeddings"""
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    
    # Pad if odd dimension
    if embedding_dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    
    return emb

In [7]:
# Helper function to generate sketches
def model_forward_training(sequences, model_components, categories, timesteps):
    # Get components
    stroke_embedder = model_components['stroke_embedder']
    category_embedder = model_components['category_embedder']
    temporal_encoder = model_components['temporal_encoder']
    noise_predictor = model_components['noise_predictor']

    batch_size, seq_len, _ = sequences.shape

    # Embded stroke sequences
    stroke_embeddings = stroke_embedder(sequences)

    # Embed categories and inject into sequence
    category_embeddings = category_embedder(categories)
    category_expanded = category_embeddings.unsqueeze(1).expand(-1, seq_len, -1)

    # Add timestep conditioning
    device = sequences.device
    t_emb = create_timestep_embedding(timesteps, stroke_embeddings.shape[-1], device)
    t_expanded = t_emb.unsqueeze(1).expand(-1, seq_len, -1)

    # Combine stroke and category embeddings
    conditioned_embeddings = stroke_embeddings + 0.7 * category_expanded + 0.3 * t_expanded
    conditioned_embeddings = F.layer_norm(conditioned_embeddings, conditioned_embeddings.shape[-1:])

    # Process through LSTM
    lstm_output, _ = temporal_encoder(conditioned_embeddings)

    # Predict noise
    predicted_noise = noise_predictor(lstm_output)
    
    return predicted_noise

In [8]:
# Generate samples
def generate_samples(model_components, noise_schedule, num_samples, seq_length, device, num_steps=500, category_id=1):
    generated_samples = []
    
    for component in model_components.values():
        component.eval()
    
    with torch.no_grad():
        for i in range(num_samples):
            # Start with pure noise for ALL dimensions (coordinates AND pen states)
            sketch = torch.randn(1, seq_length, 3, device=device)
            
            category = torch.tensor([category_id], device=device)
            
            # Denoising loop - let the model generate everything including pen states
            for step in range(num_steps):
                t = torch.tensor([num_steps - step - 1], device=device)
                
                if t >= len(noise_schedule['alphas']):
                    t = torch.tensor([len(noise_schedule['alphas']) - 1], device=device)
                
                predicted_noise = model_forward_training(sketch, model_components, category, t)
                
                if step < num_steps - 1:
                    alpha_t = noise_schedule['alphas'][t]
                    beta_t = noise_schedule['betas'][t]
                    alpha_cumprod_t = noise_schedule['alphas_cumprod'][t]
                    
                    # Apply noise to ALL dimensions including pen states
                    noise = torch.randn_like(sketch)
                    coeff1 = 1 / torch.sqrt(alpha_t)
                    coeff2 = beta_t / torch.sqrt(1 - alpha_cumprod_t)
                    
                    sketch[:, :, :2] = coeff1 * (sketch[:, :, :2] - coeff2 * predicted_noise[:, :, :2]) + torch.sqrt(beta_t) * noise[:, :, :2]
                    pen_noise_scale = torch.sqrt(beta_t) * 0.1
                    sketch[:, :, 2] = coeff1 * (sketch[:, :, 2] - coeff2 * predicted_noise[:, :, 2]) + pen_noise_scale * noise[:, :, 2]
                
                # Round and clamp pen states after each denoising step
                pen_states_raw = sketch[:, :, 2]
                pen_states_rounded = torch.round(pen_states_raw)
                pen_states_clamped = torch.clamp(pen_states_rounded, -1, 3)
                sketch[:, :, 2] = pen_states_clamped
            
            # Store final sketch
            final_sketch = sketch[0].cpu().numpy()
            final_sketch[:, :2] = np.clip(final_sketch[:, :2], -1.0, 1.0)
            final_sketch[:, 2] = np.clip(np.round(final_sketch[:, 2]), -1, 3)
            
            generated_samples.append(final_sketch)
            
            if (i + 1) % 25 == 0:
                print(f"Generated {i + 1}/{num_samples} samples")
    
    return np.array(generated_samples)

In [9]:
def stroke_to_image_and_save(stroke_sequence, save_path, image_size=64):
    img = np.zeros((image_size, image_size))
    
    for i in range(len(stroke_sequence)):
        x, y, pen_state = stroke_sequence[i]
        
        # Skip padding and end markers
        if pen_state == -1 or pen_state == 3:
            continue
            
        # Only draw when pen is down (pen_state == 2)
        if pen_state == 2:
            # Convert from [-1, 1] to [0, image_size-1]
            x_pixel = int((x + 1) * (image_size - 1) / 2)
            y_pixel = int((y + 1) * (image_size - 1) / 2)
            
            x_pixel = np.clip(x_pixel, 0, image_size-1)
            y_pixel = np.clip(y_pixel, 0, image_size-1)
            
            # Draw point with small brush
            for dx in [-1, 0, 1]:
                for dy in [-1, 0, 1]:
                    nx, ny = x_pixel + dx, y_pixel + dy
                    if 0 <= nx < image_size and 0 <= ny < image_size:
                        img[ny, nx] = 1.0
    
    # Convert to RGB and save
    img_rgb = np.stack([img, img, img], axis=2)
    img_uint8 = (img_rgb * 255).astype(np.uint8)
    pil_img = Image.fromarray(img_uint8, 'RGB')
    pil_img.save(save_path)

def create_image_directories(real_data, generated_data, temp_dir):
    real_dir = os.path.join(temp_dir, "real")
    gen_dir = os.path.join(temp_dir, "generated")
    
    os.makedirs(real_dir, exist_ok=True)
    os.makedirs(gen_dir, exist_ok=True)
    
    # Save real images
    for i, stroke_seq in enumerate(real_data):
        save_path = os.path.join(real_dir, f"real_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
    
    # Save generated images
    for i, stroke_seq in enumerate(generated_data):
        save_path = os.path.join(gen_dir, f"gen_{i:06d}.png")
        stroke_to_image_and_save(stroke_seq, save_path)
    
    return real_dir, gen_dir

In [10]:
def calculate_kid_fid_metrics(real_dir, generated_dir, num_samples):
    metrics = calculate_metrics(
            input1=generated_dir,
            input2=real_dir,
            fid=True,
            kid=True,
            kid_subset_size = num_samples-1,
            cuda=torch.cuda.is_available(),
            verbose=True
        )
    
    fid_score = metrics['frechet_inception_distance']
    kid_score = metrics['kernel_inception_distance_mean']
    
    return fid_score, kid_score

In [20]:
# Model evaluation
def run_evaluation(category, device, processed_datasets, model_path, num_samples=150):
    # Load trained model
    model_components, noise_schedule, _ = load_trained_model(model_path, device)
    
    # Load test data
    real_data = processed_datasets[category]['test_data']
    seq_length = real_data.shape[1]
    
    print(f"Using {num_samples} samples")
    
    num_samples = min(num_samples, len(real_data))
    real_data_subset = real_data[:num_samples]
    
    generated_data = generate_samples(model_components, noise_schedule, num_samples, seq_length, device, category_id=0, num_steps=1000)
    
    # Create tempfiles to save images
    with tempfile.TemporaryDirectory() as temp_dir:
        real_dir, gen_dir = create_image_directories(real_data_subset, generated_data, temp_dir)
        
        try:
            # Calculate metrics
            fid_score, kid_score = calculate_kid_fid_metrics(real_dir, gen_dir, num_samples)
            
            print("\nRESULTS:")
            print(f"FID Score: {fid_score:.4f}")
            print(f"KID Score: {kid_score:.4f}")
            
            result_dict = {
                'FID': fid_score,
                'KID': kid_score,
                'num_samples': num_samples,
                'seq_length': seq_length
            }
            
            return result_dict
        except Exception as e:
            print(f"Metric calculation failed: {e}")
            return None

In [12]:
score_dict = {}

In [None]:
score_dict['cat'] = run_evaluation('cat', device, processed_datasets, '/kaggle/input/sketch_model_last/pytorch/cat_200epochs/1/sketch_diffusion_cat_20250806_201127.pth')

In [None]:
score_dict['bus'] = run_evaluation('bus', device, processed_datasets, '/kaggle/input/sketch_model_last/pytorch/175_epochs/1/sketch_diffusion_bus_20250804_221426.pth')

In [43]:
score_dict['rabbit'] = run_evaluation('rabbit', device, processed_datasets, '/kaggle/input/sketch_model_last/pytorch/rabbit_200epochs/1/sketch_diffusion_rabbit_20250807_084937.pth')

Category: rabbit
Training epochs: 200
Final loss: 0.0361
Model components: ['stroke_embedder', 'category_embedder', 'temporal_encoder', 'noise_predictor']
Using 250 samples
Generated 25/250 samples
Generated 50/250 samples
Generated 75/250 samples
Generated 100/250 samples
Generated 125/250 samples
Generated 150/250 samples
Generated 175/250 samples
Generated 200/250 samples
Generated 225/250 samples
Generated 250/250 samples


Creating feature extractor "inception-v3-compat" with features ['2048']
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 141MB/s]
Extracting features from input1
Looking for samples non-recursivelty in "/tmp/tmp4plr07gl/generated" with extensions png,jpg,jpeg
Found 250 samples
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                         
Extracting features from input2
Looking for samples non-recursivelty in "/tmp/tmp4plr07gl/real" with extensions png,jpg,jpeg
Found 250 samples
Processing samples                                                         
Frechet Inception Distance: 48.31993427307994
                                                                                 


RESULTS:
FID Score: 48.3199
KID Score: 0.0459


Kernel Inception Distance: 0.04588724626409189 ± 0.00013062694246060124


In [None]:
print(score_dict)

In [None]:
import json
with open("eval_results_v1.json","w") as file:
    json.dump(score_dict, file, indent=2)