# Polygon Color Generation with Conditional UNet

This notebook demonstrates inference and testing of the trained conditional UNet model that generates colored polygons based on input polygon images and color names.

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import json
import os
from unet_model import ConditionalUNet

## 1. Load the Trained Model

In [None]:
# Define color mapping (same as used in training)
COLOR_MAP = {
    'red': 0,
    'blue': 1,
    'green': 2,
    'yellow': 3,
    'purple': 4,
    'orange': 5,
    'cyan': 6,
    'magenta': 7
}

# Reverse mapping for display
IDX_TO_COLOR = {v: k for k, v in COLOR_MAP.items()}

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = ConditionalUNet(
    n_channels=1,
    n_classes=3,
    num_colors=len(COLOR_MAP),
    embed_dim=64
).to(device)

# Load trained weights
checkpoint = torch.load('best_polygon_unet.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")

## 2. Define Inference Functions

In [None]:
def preprocess_image(image_path):
    """Preprocess input polygon image"""
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    image = Image.open(image_path).convert('RGB')
    tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return tensor

def generate_colored_polygon(model, polygon_path, color_name, device):
    """Generate colored polygon using the trained model"""
    # Preprocess input
    input_tensor = preprocess_image(polygon_path).to(device)
    
    # Get color index
    color_idx = torch.tensor([COLOR_MAP[color_name]], dtype=torch.long).to(device)
    
    # Generate output
    with torch.no_grad():
        output = model(input_tensor, color_idx)
    
    # Convert to PIL image
    output_np = output.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    output_np = np.clip(output_np, 0, 1)
    output_image = Image.fromarray((output_np * 255).astype(np.uint8))
    
    return output_image

def visualize_results(polygon_path, color_name, generated_image, actual_image_path=None):
    """Visualize input, generated output, and actual output"""
    fig, axes = plt.subplots(1, 3 if actual_image_path else 2, figsize=(15, 5))
    
    # Input polygon
    input_img = Image.open(polygon_path).convert('RGB')
    axes[0].imshow(input_img)
    axes[0].set_title(f"Input: {os.path.basename(polygon_path)}")
    axes[0].axis('off')
    
    # Generated output
    axes[1].imshow(generated_image)
    axes[1].set_title(f"Generated: {color_name} polygon")
    axes[1].axis('off')
    
    if actual_image_path:
        actual_img = Image.open(actual_image_path)
        axes[2].imshow(actual_img)
        axes[2].set_title(f"Actual: {color_name} polygon")
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return fig

## 3. Test with Validation Data

In [None]:
# Load validation data
with open('validation/data.json', 'r') as f:
    val_data = json.load(f)

print(f"Testing with {len(val_data)} validation samples...")

# Test a few samples
for i, item in enumerate(val_data[:3]):
    polygon_path = f"validation/inputs/{item['input_polygon']}"
    color_name = item['colour']
    actual_path = f"validation/outputs/{item['output_image']}"
    
    print(f"\nTest {i+1}: {item['input_polygon']} -> {color_name}")
    
    # Generate colored polygon
    generated = generate_colored_polygon(model, polygon_path, color_name, device)
    
    # Visualize results
    visualize_results(polygon_path, color_name, generated, actual_path)

## 4. Interactive Testing

In [None]:
# Interactive testing function
def interactive_test(polygon_name, color_name):
    """Test with specific polygon and color"""
    polygon_path = f"training/inputs/{polygon_name}"
    
    if not os.path.exists(polygon_path):
        print(f"Polygon {polygon_name} not found!")
        return
    
    if color_name not in COLOR_MAP:
        print(f"Color {color_name} not available. Available colors: {list(COLOR_MAP.keys())}")
        return
    
    # Generate colored polygon
    generated = generate_colored_polygon(model, polygon_path, color_name, device)
    
    # Visualize results
    visualize_results(polygon_path, color_name, generated)
    
    return generated

# Test with different combinations
print("Available polygons:", [f for f in os.listdir('training/inputs') if f.endswith('.png')])
print("Available colors:", list(COLOR_MAP.keys()))

In [None]:
# Test specific combinations
interactive_test('triangle.png', 'blue')
interactive_test('square.png', 'green')
interactive_test('circle.png', 'red')

## 5. Performance Analysis

In [None]:
import torch.nn.functional as F

def calculate_metrics(generated, actual):
    """Calculate MSE and SSIM between generated and actual images"""
    # Convert PIL images to tensors
    transform = transforms.ToTensor()
    gen_tensor = transform(generated).unsqueeze(0)
    actual_tensor = transform(actual).unsqueeze(0)
    
    # Calculate MSE
    mse = F.mse_loss(gen_tensor, actual_tensor).item()
    
    return mse

# Evaluate on validation set
mse_scores = []

print("Validation Set Performance:")
print("-" * 50)

for item in val_data:
    polygon_path = f"validation/inputs/{item['input_polygon']}"
    color_name = item['colour']
    actual_path = f"validation/outputs/{item['output_image']}"
    
    # Generate
    generated = generate_colored_polygon(model, polygon_path, color_name, device)
    actual = Image.open(actual_path)
    
    # Calculate metrics
    mse = calculate_metrics(generated, actual)
    mse_scores.append(mse)
    
    print(f"{item['input_polygon']} -> {color_name}: MSE = {mse:.4f}")

print(f"\nAverage MSE: {np.mean(mse_scores):.4f}")
print(f"Min MSE: {np.min(mse_scores):.4f}")
print(f"Max MSE: {np.max(mse_scores):.4f}")

## 6. Save Generated Samples

In [None]:
# Create output directory
os.makedirs('generated_samples', exist_ok=True)

# Generate and save samples for all validation data
for item in val_data:
    polygon_path = f"validation/inputs/{item['input_polygon']}"
    color_name = item['colour']
    
    # Generate
    generated = generate_colored_polygon(model, polygon_path, color_name, device)
    
    # Save
    filename = f"generated_{item['output_image']}"
    generated.save(f"generated_samples/{filename}")
    
print("Generated samples saved to 'generated_samples/' directory")

## 7. Model Summary

In [None]:
# Print model summary
print("Model Architecture:")
print("-" * 50)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / (1024*1024):.2f} MB")