In [None]:
# Install required packages
#!pip install torch torchvision transformers pandas pillow requests matplotlib tqdm ipywidgets gradio

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import requests
import os
import json
from tqdm import tqdm
import math
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import urllib.request
import zipfile
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Dataset download utility
def reporthook(block_num, block_size, total_size):
    if block_num % 16384 == 0:
        print(f"Downloading... {block_num * block_size / (1024 * 1024):.2f} MB")

def download_dataset_if_not_exists():
    dataset_dir = "dataset"
    pokedex_main_dir = os.path.join(dataset_dir, "pokedex-main")
    zip_url = "https://github.com/cristobalmitchell/pokedex/archive/refs/heads/main.zip"
    zip_path = "pokedex_main.zip"

    # Check if dataset/pokedex-main already exists
    if os.path.exists(pokedex_main_dir):
        print(f"{pokedex_main_dir} already exists. Skipping download.")
        return

    # Create dataset directory if it doesn't exist
    os.makedirs(dataset_dir, exist_ok=True)

    # Download the zip file
    print("Downloading dataset...")
    urllib.request.urlretrieve(zip_url, zip_path, reporthook)
    print("Download complete.")

    # Extract the zip file into the dataset directory
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)
    print("Extraction complete.")

    # Optionally, remove the zip file after extraction
    os.remove(zip_path)

# Download the dataset
download_dataset_if_not_exists()
print("Dataset ready!")


In [None]:
import torch
import torchvision.transforms as T

import torch
import torchvision.transforms as T

class AugmentationPipeline:
    def __init__(self, p=0.8):
        self.p = p
        self.transforms = T.RandomApply([
            T.RandomHorizontalFlip(p=0.5),

            # Applica trasformazioni affini (rotazione/scala) solo il 50% delle volte.
            # Ho ridotto leggermente l'intensità (degrees=10).
            T.RandomApply([
                T.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05), fill=1)
            ], p=0.5),

            # Applica ColorJitter solo il 50% delle volte.
            # I parametri sono già abbastanza bassi, quindi li manteniamo.
            T.RandomApply([
                T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
            ], p=0.5),

            # --- Passo 4: RandomErasing (su Tensore) ---
            # Ridotto la probabilità di applicazione.
            # È una tecnica forte, meglio usarla con parsimonia per iniziare.
            T.RandomErasing(p=0.15, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random'),
        ], p=self.p)

    def apply(self, images):
        return self.transforms(images)

In [None]:
# Enhanced Pokemon Dataset Class with modular augmentation support
class PokemonDataset(Dataset):
    def __init__(self, tokenizer, csv_path="dataset/pokedex-main/data/pokemon.csv",
                 image_dir="dataset/pokedex-main/images/small_images",
                 max_length=128, augmentation_pipeline=None):
        """
        Dataset per Pokemon: testo (descrizione) -> immagine (sprite)
        Enhanced with modular augmentation pipeline support
        """
        self.df = pd.read_csv(csv_path, encoding='utf-16 LE', delimiter='\t')
        self.image_dir = Path(image_dir)
        print(f"Dataset caricato: {len(self.df)} Pokemon con descrizioni e immagini")

        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmentation_pipeline = augmentation_pipeline

        if self.augmentation_pipeline is not None:
            self.final_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256, 256), antialias=True),
                self.augmentation_pipeline,
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalizza a [-1, 1]
            ])
        else:
            self.final_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256, 256), antialias=True),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalizza a [-1, 1]
            ])

    def __len__(self):
        """Restituisce il numero totale di campioni"""
        return len(self.df)

    def __getitem__(self, idx):
        """
        Restituisce un singolo campione: (testo_tokenizzato, immagine_tensor)
        Full implementation matching pokemon_dataset.py
        """
        # Ottieni la riga corrispondente
        row = self.df.iloc[idx]

        # === PREPROCESSING DEL TESTO ===
        description = str(row['description'])

        # Tokenizza il testo
        encoded = self.tokenizer(
            description,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Estrai token_ids e attention_mask
        text_ids = encoded['input_ids'].squeeze(0)  # Rimuovi la dimensione batch
        attention_mask = encoded['attention_mask'].squeeze(0)

        # === CARICAMENTO E PREPROCESSING DELL'IMMAGINE ===
        # Costruisce il percorso dell'immagine
        image_filename = f"{row['national_number']:03d}.png"
        image_path = self.image_dir/image_filename

        # Carica l'immagine
        image_rgba = Image.open(image_path).convert('RGBA')

        # Gestisce la trasparenza: ricombina l'immagine con uno sfondo bianco
        background = Image.new('RGB', image_rgba.size, (255, 255, 255))
        background.paste(image_rgba, mask=image_rgba.split()[-1])

        # Applica le trasformazioni finali (ToTensor, Resize, Normalize)
        image_tensor = self.final_transform(background)

        # Costruisce il risultato (matches pokemon_dataset.py structure)
        sample = {
            'text': text_ids,
            'image': image_tensor,
            'description': description,  # Per debug o visualizzazione
            'pokemon_name': row['english_name'],
            'idx': idx,
            'attention_mask': attention_mask,
        }

        return sample

def create_training_setup(tokenizer, train_val_split, batch_size, num_workers=0,
                         num_viz_samples=4, random_seed=42, train_augmentation_pipeline=None):
    """
    Crea un setup completo per il training con dataset, dataloader e batch fissi per visualizzazione.
    Enhanced with modular augmentation pipeline support
    """
    from torch.utils.data import random_split, TensorDataset, Subset

    # --- Creazione dei Dataset ---
    # Crea un'istanza per il training (con augmentazione) e la validazione (senza augmentazione)
    train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_pipeline=train_augmentation_pipeline)
    val_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_pipeline=None)  # No augmentation for validation

    # --- Divisione deterministica degli indici ---
    assert len(train_full_dataset) == len(val_full_dataset)
    dataset_size = len(train_full_dataset)
    train_size = int(train_val_split * dataset_size)
    val_size = dataset_size - train_size

    train_indices_subset, val_indices_subset = random_split(
        TensorDataset(torch.arange(dataset_size)),
        [train_size, val_size],
        generator=torch.Generator().manual_seed(random_seed),
    )

    train_dataset = Subset(train_full_dataset, train_indices_subset.indices)
    val_dataset = Subset(val_full_dataset, val_indices_subset.indices)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    # --- Creazione deterministica dei batch per la visualizzazione ---
    vis_generator = torch.Generator().manual_seed(random_seed)
    fixed_train_batch = next(
        iter(
            DataLoader(
                train_dataset,
                batch_size=num_viz_samples,
                shuffle=True,
                generator=vis_generator,
            )
        )
    )
    fixed_val_batch = next(
        iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False))
    )  # la validazione non ha shuffle

    vis_generator.manual_seed(random_seed)  # Reset per coerenza
    fixed_train_attention_batch = next(
        iter(
            DataLoader(
                train_dataset, batch_size=1, shuffle=True, generator=vis_generator
            )
        )
    )
    fixed_val_attention_batch = next(
        iter(DataLoader(val_dataset, batch_size=1, shuffle=False))
    )

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'fixed_train_batch': fixed_train_batch,
        'fixed_val_batch': fixed_val_batch,
        'fixed_train_attention_batch': fixed_train_attention_batch,
        'fixed_val_attention_batch': fixed_val_attention_batch,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
    }

# Initialize BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-mini')

# train_augmentation_pipeline = AugmentationPipeline()
# Create the complete training setup using the function from pokemon_dataset.py
print("Creating training setup with train/val split and fixed batches...")
training_setup = create_training_setup(
    tokenizer=tokenizer,
    train_val_split=0.9,
    batch_size=16,
    num_workers=0,
    num_viz_samples=4,
    random_seed=42,
    train_augmentation_pipeline=None
)

# Extract components
train_loader = training_setup['train_loader']
val_loader = training_setup['val_loader']
fixed_train_batch = training_setup['fixed_train_batch']
fixed_val_batch = training_setup['fixed_val_batch']
fixed_train_attention_batch = training_setup['fixed_train_attention_batch']
fixed_val_attention_batch = training_setup['fixed_val_attention_batch']
train_dataset = training_setup['train_dataset']
val_dataset = training_setup['val_dataset']

print(f"Training setup complete!")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")

# Test the training setup with fixed batches
print(f"\nFixed batch shapes:")
print(f"  Train batch - Images: {fixed_train_batch['image'].shape}")
print(f"  Train batch - Text: {fixed_train_batch['text'].shape}")
print(f"  Train batch - Attention: {fixed_train_batch['attention_mask'].shape}")
print(f"  Val batch - Images: {fixed_val_batch['image'].shape}")

# Display sample images from fixed batches
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
for i in range(4):
    # Fixed train batch images
    img = (fixed_train_batch['image'][i] + 1) / 2.0  # Denormalize
    axes[0, i].imshow(img.permute(1, 2, 0).clamp(0, 1))
    axes[0, i].set_title(f"Train: {fixed_train_batch['pokemon_name'][i]}")
    axes[0, i].axis('off')

    # Fixed val batch images
    img = (fixed_val_batch['image'][i] + 1) / 2.0  # Denormalize
    axes[1, i].imshow(img.permute(1, 2, 0).clamp(0, 1))
    axes[1, i].set_title(f"Val: {fixed_val_batch['pokemon_name'][i]}")
    axes[1, i].axis('off')

plt.suptitle("Fixed Batches for Training Visualization", fontsize=16)
plt.tight_layout()
plt.show()

# Set the main dataloader to use train_loader for consistency
dataloader = train_loader
sample_batch = fixed_train_batch

print(f"\n✅ Dataset and batches loaded successfully from pokemon_dataset.py functionality!")
print(f"Ready for training with proper train/val split and fixed visualization batches.")


In [None]:

# Demonstrate that augmentations are working and different for the same Pokemon
print(f"\n🔄 AUGMENTATION VERIFICATION:")
print("Testing that augmentations produce different results for the same Pokemon...")

train_augmentation_pipeline = AugmentationPipeline()

# Pick the first Pokemon from the training dataset and show it with different augmentations
test_pokemon_idx = 0
original_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_pipeline=None)  # No augmentation
augmented_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_pipeline=train_augmentation_pipeline.transforms)  # With augmentation

# Get the same Pokemon multiple times to see different augmentations
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
fig.suptitle("Augmentation Verification: Same Pokemon with Different Augmentations", fontsize=16)

# Row 0: Original (no augmentation) - same image repeated
original_sample = original_dataset[test_pokemon_idx]
for i in range(5):
    img = (original_sample['image'] + 1) / 2.0  # Denormalize
    axes[0, i].imshow(img.permute(1, 2, 0).clamp(0, 1))
    axes[0, i].set_title(f"Original {i+1}")
    axes[0, i].axis('off')

# Row 1: Augmented (should be different each time)
for i in range(5):
    augmented_sample = augmented_dataset[test_pokemon_idx]  # Same index, different augmentation
    img = (augmented_sample['image'] + 1) / 2.0  # Denormalize
    axes[1, i].imshow(img.permute(1, 2, 0).clamp(0, 1))
    axes[1, i].set_title(f"Augmented {i+1}")
    axes[1, i].axis('off')

# Add row labels
axes[0, 0].text(-0.1, 0.5, 'No Augmentation\n(Should be identical)',
                ha='center', va='center', rotation='vertical',
                fontsize=12, transform=axes[0, 0].transAxes)
axes[1, 0].text(-0.1, 0.5, 'With Augmentation\n(Should be different)',
                ha='center', va='center', rotation='vertical',
                fontsize=12, transform=axes[1, 0].transAxes)

plt.tight_layout()
plt.show()

print(f"Pokemon tested: {original_sample['pokemon_name']} (#{original_sample['idx']})")
print(f"Description: {original_sample['description'][:60]}...")
print(f"✅ If augmentations are working correctly, the bottom row should show different variations!")
print(f"   - Look for differences in: rotation, translation, color, brightness, horizontal flip")
print(f"   - The top row should be identical (no augmentation)")
print(f"   - This proves augmentations will be different at each epoch for the same Pokemon!")

In [None]:
class TextEncoder(nn.Module):
    """
    Encoder per processare il testo.
    Usa gli embedding di bert-mini e li fa passare in un Transformer.
    """
    def __init__(self, model_name="prajjwal1/bert-mini", fine_tune_embeddings=True):
        super().__init__()
        # Carica il modello bert-mini pre-addestrato per estrarre gli embedding
        bert_mini_model = AutoModel.from_pretrained(model_name)

        # Estrae lo strato di embedding
        self.embedding = bert_mini_model.embeddings

        # Imposta se fare il fine-tuning degli embedding durante il training
        for param in self.embedding.parameters():
            param.requires_grad = fine_tune_embeddings

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=256, nhead=4, dim_feedforward=1024, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)

    def forward(self, token_ids, attention_mask=None):
        # 1. Ottieni gli embedding dai token ID
        # Shape: (batch_size, seq_len) -> (batch_size, seq_len, embedding_dim)
        embedded_text = self.embedding(token_ids)

        # 2. Prepara la maschera di padding per il TransformerEncoder
        # La maschera di HuggingFace è 1 per i token reali, 0 per il padding.
        # TransformerEncoder si aspetta True per le posizioni da ignorare (padding).
        src_key_padding_mask = None
        if attention_mask is not None:
            src_key_padding_mask = (attention_mask == 0)

        # 3. Passa gli embedding attraverso il Transformer Encoder con la maschera
        # Shape: (batch_size, seq_len, embedding_dim) -> (batch_size, seq_len, embedding_dim)
        encoder_output = self.transformer_encoder(
            src=embedded_text,
            src_key_padding_mask=src_key_padding_mask
        )
        return encoder_output


class ImageCrossAttention(nn.Module):
    """
    Modulo di Cross-Attention.
    Permette a una sequenza di query (dall'immagine) di "prestare attenzione"
    a una sequenza di key/value (dal testo), gestendo internamente
    il reshaping dei tensori e la maschera di padding.
    """
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim, num_heads=num_heads, batch_first=True
        )
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, image_features, text_features, key_padding_mask=None):
        # query: (B, C, H, W) - Feature dell'immagine (spaziale)
        # key/value: (B, seq_len, embed_dim) - Output dell'encoder di testo
        # key_padding_mask: (B, seq_len) - Maschera dal tokenizer

        B, C, H, W = image_features.shape

        # 1. Prepara la query (feature dell'immagine)
        # Reshape da spaziale a sequenza: (B, C, H, W) -> (B, H*W, C)
        query_seq = image_features.view(B, C, H * W).permute(0, 2, 1)
        query_norm = self.layer_norm(query_seq)

        # 2. Prepara la maschera di padding per l'attenzione
        # La maschera di HuggingFace è 1 per i token reali, 0 per il padding.
        # MultiheadAttention si aspetta True per le posizioni da ignorare.
        if key_padding_mask is not None:
            mask = (key_padding_mask == 0)
        else:
            mask = None

        # 3. Applica l'attenzione
        attn_output, attn_weights = self.attention(
            query=query_norm,
            key=text_features,
            value=text_features,
            key_padding_mask=mask,
            need_weights=True
        )
        # attn_output: (B, H*W, C)

        # 4. Riconverti l'output nella forma spaziale originale
        # (B, H*W, C) -> (B, C, H*W) -> (B, C, H, W)
        attn_output_spatial = attn_output.permute(0, 2, 1).view(B, C, H, W)

        return attn_output_spatial, attn_weights


class DecoderBlock(nn.Module):
    """
    Blocco del Generatore come da istruzioni:
    Attenzione (opzionale) -> Fusione -> Upsampling (ConvTranspose) -> Normalizzazione -> Attivazione.
    """
    def __init__(self, in_channels, out_channels, use_attention=True, text_embed_dim=256, nhead=4):
        super().__init__()
        self.use_attention = use_attention

        if self.use_attention:
            # Se in_channels è diverso da text_embed_dim, aggiungi una conv 1x1 per adattare le dimensioni
            if in_channels != text_embed_dim:
                self.channel_adapter = nn.Conv2d(in_channels, text_embed_dim, kernel_size=1, bias=False)
                attention_dim = text_embed_dim
            else:
                self.channel_adapter = None
                attention_dim = in_channels

            self.cross_attention = ImageCrossAttention(embed_dim=attention_dim, num_heads=nhead)
            # Nuova convolution per fondere le feature dell'immagine con il contesto del testo
            self.fusion_conv = nn.Conv2d(attention_dim * 2, in_channels, kernel_size=1, bias=False)

        # Blocco di upsampling come da istruzioni
        self.upsample_block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.GroupNorm(1, out_channels), # Equivalente a LayerNorm per feature map (N, C, H, W)
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x, encoder_output=None, attention_mask=None):
        attn_weights = None
        if self.use_attention:
            if encoder_output is None or attention_mask is None:
                raise ValueError("encoder_output and attention_mask must be provided for attention.")

            # Adatta le dimensioni se necessario
            if self.channel_adapter is not None:
                x_adapted = self.channel_adapter(x)
            else:
                x_adapted = x

            attn_output, attn_weights = self.cross_attention(
                image_features=x_adapted,
                text_features=encoder_output,
                key_padding_mask=attention_mask
            )

            # Concatena le feature adattate (x_adapted) con il contesto (attn_output)
            # e le fonde con una convoluzione 1x1.
            fused_features = torch.cat([x_adapted, attn_output], dim=1) # Shape: (B, 2*attention_dim, H, W)
            skip = self.fusion_conv(fused_features) # Shape: (B, in_channels, H, W)
            x = x + skip  # Shape: (B, in_channels, H, W)


        # Apply the U-Net style sequence
        x = self.upsample_block(x)
        return x, attn_weights


class ImageDecoder(nn.Module):
    """
    Decoder CNN (Generatore) che sintetizza l'immagine.
    Questa versione usa l'attenzione per-step fin dall'inizio.
    """
    def __init__(self, noise_dim, text_embed_dim, final_image_channels=3):
        super().__init__()

        # Meccanismo per calcolare i punteggi di attenzione per il contesto iniziale.
        self.initial_context_scorer = nn.Sequential(
            nn.Linear(in_features=text_embed_dim, out_features=512),
            nn.Tanh(),
            nn.Linear(in_features=512, out_features=1)
            # Il Softmax viene applicato nel forward pass per poter usare la maschera
        )

        # Proiezione lineare iniziale a una feature map 4x4.
        self.initial_projection = nn.Sequential(
            nn.Linear(noise_dim + text_embed_dim, 256 * 4 * 4),
            nn.GroupNorm(1, 256 * 4 * 4),
            nn.LeakyReLU(inplace=True)
        )

        # Blocchi condivisi per entrambe le risoluzioni (fino a 64x64)
        self.blocks_64 = nn.ModuleList([
            # Input: (B, 256, 4, 4)   -> Output: (B, 256, 8, 8)
            DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
            # Input: (B, 256, 8, 8)   -> Output: (B, 256, 16, 16)
            DecoderBlock(in_channels=256, out_channels=256, use_attention=True),
            # Input: (B, 256, 16, 16)  -> Output: (B, 128, 32, 32)
            DecoderBlock(in_channels=256, out_channels=128, use_attention=True),
            # Input: (B, 128, 32, 32)  -> Output: (B, 64, 64, 64)
            DecoderBlock(in_channels=128, out_channels=64, use_attention=False),
        ])

        # Blocchi aggiuntivi solo per 256x256 (da 64x64 a 256x256)
        self.blocks_256 = nn.ModuleList([
            # Input: (B, 64, 64, 64)  -> Output: (B, 32, 128, 128)
            DecoderBlock(in_channels=64, out_channels=32, use_attention=False),
            # Input: (B, 32, 128, 128) -> Output: (B, 16, 256, 256)
            DecoderBlock(in_channels=32, out_channels=16, use_attention=False),
        ])

        # Layer finale per portare ai canali RGB - 256x256
        # Input: (B, 16, 256, 256) -> Output: (B, 3, 256, 256)
        self.final_conv_256 = nn.Conv2d(16, final_image_channels, kernel_size=3, padding=1)
        self.final_activation_256 = nn.Tanh()

        # Layer finale per portare ai canali RGB - 64x64
        # Input: (B, 64, 64, 64) -> Output: (B, 3, 64, 64)
        self.final_conv_64 = nn.Conv2d(64, final_image_channels, kernel_size=3, padding=1)
        self.final_activation_64 = nn.Tanh()

    def forward(self, noise, encoder_output_full, attention_mask):
        # noise.shape: (B, noise_dim)
        # encoder_output_full.shape: (B, seq_len, text_embed_dim)
        # attention_mask.shape: (B, seq_len)

        # 1. Calcola il vettore di contesto iniziale con una media pesata (ATTENZIONE #1)
        # Calcola i punteggi (logits) per ogni token del testo
        attn_scores = self.initial_context_scorer(encoder_output_full)

        # Applica la maschera di padding prima del softmax.
        # Imposta i punteggi dei token di padding a -infinito.
        if attention_mask is not None:
            # La maschera è (B, seq_len), i punteggi (B, seq_len, 1)
            # Il broadcast si occupa di allineare le dimensioni.
            attn_scores.masked_fill_(attention_mask.unsqueeze(-1) == 0, -1e9)

        # Ora applica il softmax per ottenere i pesi.
        # attention_weights.shape: (B, seq_len, 1)
        attention_weights = torch.softmax(attn_scores, dim=1)

        # Calcola il contesto come media pesata degli output dell'encoder.
        # context_vector.shape: (B, text_embed_dim)
        context_vector = torch.sum(attention_weights * encoder_output_full, dim=1)

        # 2. Prepara il vettore di input iniziale per la proiezione
        #    Si usa direttamente il rumore 'noise' invece del vettore di stile 'w'
        # initial_input.shape: (B, noise_dim + text_embed_dim)
        initial_input = torch.cat([noise, context_vector], dim=1)

        # 3. Proietta e rimodella
        # x.shape: (B, 256 * 4 * 4)
        x = self.initial_projection(initial_input)
        # x.shape: (B, 256, 4, 4)
        x = x.view(x.size(0), 256, 4, 4)

        # 5. Passa attraverso i blocchi del decoder
        attention_maps = []

        # Percorso condiviso per entrambe le risoluzioni (fino a 64x64)
        for block in self.blocks_64:
            encoder_ctx = encoder_output_full if block.use_attention else None
            mask_ctx = attention_mask if block.use_attention else None
            x, attn_weights = block(x, encoder_ctx, mask_ctx)
            if attn_weights is not None:
                attention_maps.append(attn_weights)

        # A questo punto x ha dimensione (B, 64, 64, 64)

        # Percorso per 64x64 (usa direttamente x_shared)
        image_64 = self.final_conv_64(x)
        image_64 = self.final_activation_64(image_64)

        # Percorso per 256x256 (continua con blocchi aggiuntivi)
        for block in self.blocks_256:
            encoder_ctx = encoder_output_full if block.use_attention else None
            mask_ctx = attention_mask if block.use_attention else None
            x, attn_weights = block(x, encoder_ctx, mask_ctx)
            if attn_weights is not None:
                attention_maps.append(attn_weights)

        # 6. Layer finale per 256x256
        # x_256.shape: (B, 16, 256, 256) -> (B, 3, 256, 256)
        image_256 = self.final_conv_256(x)
        image_256 = self.final_activation_256(image_256)

        # 7. Layer finale per 64x64
        # x_64.shape: (B, 64, 64, 64) -> (B, 3, 64, 64)

        return image_256, image_64, attention_maps, attention_weights


class Generator(nn.Module):
    """
    Modello completo che unisce Encoder e Decoder.
    """
    def __init__(self, text_encoder_model_name="prajjwal1/bert-mini", noise_dim=100):
        super().__init__()
        self.text_encoder = TextEncoder(
            model_name=text_encoder_model_name,
        )

        text_embed_dim = 256

        self.image_decoder = ImageDecoder(
            noise_dim=noise_dim,
            text_embed_dim=text_embed_dim
        )

        self.noise_dim = noise_dim

    def forward(self, token_ids, attention_mask, return_attentions=False):
        # token_ids.shape: (batch_size, seq_len)
        # attention_mask.shape: (batch_size, seq_len)
        # Genera rumore casuale per il batch
        batch_size = token_ids.size(0)
        # noise.shape: (batch_size, noise_dim)
        noise = torch.randn(batch_size, self.noise_dim, device=token_ids.device)

        # 1. Codifica il testo per ottenere i vettori di ogni parola
        # encoder_output.shape: (batch_size, seq_len, text_embed_dim)
        encoder_output = self.text_encoder(token_ids, attention_mask=attention_mask)

        # 2. Genera l'immagine usando l'output completo dell'encoder
        #    Il decoder calcolerà internamente sia il contesto iniziale (ATTENZIONE #1)
        #    sia l'attenzione per-step (ATTENZIONE #2)
        # generated_image_256.shape: (batch_size, 3, 256, 256)
        # generated_image_64.shape: (batch_size, 3, 64, 64)
        generated_image_256, generated_image_64, attention_maps, initial_attention_weights = self.image_decoder(noise, encoder_output, attention_mask)

        if return_attentions:
            return generated_image_256, generated_image_64, attention_maps, initial_attention_weights
        return generated_image_256, generated_image_64



# Test the generator
generator = Generator().to(device)
with torch.no_grad():
    generated_images_256, generated_images_64 = generator(
        sample_batch['text'][:2].to(device),
        sample_batch['attention_mask'][:2].to(device)
    )
print(f"Generator output shape 256x256: {generated_images_256.shape}")
print(f"Generator output shape 64x64: {generated_images_64.shape}")

# Show a sample generated image to verify it works
plt.figure(figsize=(12, 8))
for i in range(2):
    # 256x256 images
    plt.subplot(2, 2, i+1)
    img = (generated_images_256[i].cpu() + 1) / 2.0  # Denormalize
    plt.imshow(img.permute(1, 2, 0).clamp(0, 1))
    plt.title(f"Generated 256x256 Sample {i+1}")
    plt.axis('off')

    # 64x64 images
    plt.subplot(2, 2, i+3)
    img = (generated_images_64[i].cpu() + 1) / 2.0  # Denormalize
    plt.imshow(img.permute(1, 2, 0).clamp(0, 1))
    plt.title(f"Generated 64x64 Sample {i+1}")
    plt.axis('off')
plt.tight_layout()
plt.show()
print("✅ Generator test successful!")


In [None]:
# Enhanced Discriminator for dynamic image sizes (64x64 or 256x256) with AttnGAN-style dual outputs
class Discriminator256(nn.Module):
    def __init__(self, text_dim=256, img_channels=3, img_size=256):
        super(Discriminator256, self).__init__()

        self.text_encoder = TextEncoder()
        self.img_size = img_size

        self.img_path = nn.Sequential(
            # 256x256 -> 128x128
            nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 128x128 -> 64x64
            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            # 64x64 -> 32x32
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # 32x32 -> 16x16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 16x16 -> 8x8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # 8x8 -> 4x4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # Text encoder for discriminator
        self.text_path = nn.Sequential(
            nn.Linear(text_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512)
        )

        # AttnGAN-style dual outputs
        # 1. Unconditional classifier (real/fake without text conditioning)
        self.unconditional_classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

        # 2. Conditional classifier (text-conditioned real/fake)
        self.conditional_classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4 + 512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, images, text_features=None, text_mask=None, return_both=True):
        # Encode image
        img_features = self.img_path(images)
        img_features_flat = img_features.view(img_features.size(0), -1)  # Flatten

        # Unconditional output (real/fake without text)
        unconditional_output = self.unconditional_classifier(img_features_flat)

        if not return_both:
            return unconditional_output

        # Conditional output (text-conditioned real/fake)
        if text_features is not None and text_mask is not None:
            # Encode text (mean pooling)
            global_full_text = self.text_encoder(text_features, text_mask)
            global_text = global_full_text.mean(dim=1)
            text_features_encoded = self.text_path(global_text)

            # Combine features
            combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
            conditional_output = self.conditional_classifier(combined)
        else:
            # If no text provided, return zeros for conditional output
            conditional_output = torch.zeros_like(unconditional_output)

        return unconditional_output, conditional_output


class Discriminator64(nn.Module):
    def __init__(self, text_dim=256, img_channels=3):
        super(Discriminator64, self).__init__()

        self.text_encoder = TextEncoder()

        self.img_path = nn.Sequential(
            # 64x64 -> 32x32
            nn.Conv2d(img_channels, 16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 32x32 -> 16x16
            nn.Conv2d(16, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),

            # 16x16 -> 8x8
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # 8x8 -> 4x4
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Text encoder for discriminator
        self.text_path = nn.Sequential(
            nn.Linear(text_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512)
        )

        # AttnGAN-style dual outputs
        # 1. Unconditional classifier (real/fake without text conditioning)
        self.unconditional_classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

        # 2. Conditional classifier (text-conditioned real/fake)
        self.conditional_classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4 + 512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, images, text_features=None, text_mask=None, return_both=True):
        # Encode image
        img_features = self.img_path(images)
        img_features_flat = img_features.view(img_features.size(0), -1)  # Flatten

        # Unconditional output (real/fake without text)
        unconditional_output = self.unconditional_classifier(img_features_flat)

        if not return_both:
            return unconditional_output

        # Conditional output (text-conditioned real/fake)
        if text_features is not None and text_mask is not None:
            # Encode text (mean pooling)
            global_full_text = self.text_encoder(text_features, text_mask)
            global_text = global_full_text.mean(dim=1)
            text_features_encoded = self.text_path(global_text)

            # Combine features
            combined = torch.cat([img_features_flat, text_features_encoded], dim=1)
            conditional_output = self.conditional_classifier(combined)
        else:
            # If no text provided, return zeros for conditional output
            conditional_output = torch.zeros_like(unconditional_output)

        return unconditional_output, conditional_output


# Test the discriminator with AttnGAN-style dual outputs
discriminator_256 = Discriminator256().to(device)
with torch.no_grad():
    # Generate test images first
    test_generated_images_256, test_generated_images_64 = generator(
        sample_batch['text'][:2].to(device),
        sample_batch['attention_mask'][:2].to(device)
    )

    # Test with 256x256 images - both outputs
    disc_unconditional_256, disc_conditional_256 = discriminator_256(
        test_generated_images_256,
        sample_batch['text'][:2].to(device),
        sample_batch['attention_mask'][:2].to(device),
        return_both=True
    )

    # Test unconditional only
    disc_unconditional_only = discriminator_256(
        test_generated_images_256,
        return_both=False
    )

print(f"Discriminator unconditional output shape (256x256): {disc_unconditional_256.shape}")
print(f"Discriminator conditional output shape (256x256): {disc_conditional_256.shape}")
print(f"Discriminator unconditional-only output shape: {disc_unconditional_only.shape}")
print(f"Test generated images shape (256x256): {test_generated_images_256.shape}")

# Test with 64x64 discriminator
discriminator_64 = Discriminator64().to(device)
with torch.no_grad():
    # Test with 64x64 images
    disc_unconditional_64, disc_conditional_64 = discriminator_64(
        test_generated_images_64,
        sample_batch['text'][:2].to(device),
        sample_batch['attention_mask'][:2].to(device),
        return_both=True
    )

print(f"\nDiscriminator unconditional output shape (64x64): {disc_unconditional_64.shape}")
print(f"Discriminator conditional output shape (64x64): {disc_conditional_64.shape}")
print(f"Test generated images shape (64x64): {test_generated_images_64.shape}")

# Show the architecture for both sizes
print(f"\nDiscriminator architecture for 256x256:")
print(f"Number of conv layers: {len([m for m in discriminator_256.img_path if isinstance(m, nn.Conv2d)])}")
print(f"Final feature map size: 512 x 4 x 4")
print(f"Outputs: Unconditional (real/fake) + Conditional (text-conditioned real/fake)")

print(f"\nDiscriminator architecture for 64x64:")
print(f"Number of conv layers: {len([m for m in discriminator_64.img_path if isinstance(m, nn.Conv2d)])}")
print(f"Final feature map size: 512 x 4 x 4")
print(f"Outputs: Unconditional (real/fake) + Conditional (text-conditioned real/fake)")

print("✅ AttnGAN-style discriminator with dual outputs now supports both 64x64 and 256x256 images!")

In [None]:
from torchvision import models
from torchvision.models import VGG19_Weights


class VGGPerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG19 pretrained on ImageNet.
    We extract features at:
      - relu1_2  (index: 3)
      - relu2_2  (index: 8)
      - relu3_2  (index: 17)
      - relu4_2  (index: 26)
    Then compute L1 distance between those feature maps.
    Input images are in [-1,1]. We convert to [0,1], then normalize with ImageNet stats.
    """
    def __init__(self, device):
        super(VGGPerceptualLoss, self).__init__()
        vgg19_features = models.vgg19(weights=VGG19_Weights.DEFAULT).features.to(device).eval()
        # We only need layers up to 26 (relu4_2)
        self.slices = nn.ModuleDict({
            "relu1_2": nn.Sequential(*list(vgg19_features.children())[:4]),     # conv1_1, relu1_1, conv1_2, relu1_2
            "relu2_2": nn.Sequential(*list(vgg19_features.children())[4:9]),    # pool1, conv2_1, relu2_1, conv2_2, relu2_2
            "relu3_2": nn.Sequential(*list(vgg19_features.children())[9:18]),   # pool2, conv3_1, relu3_1, conv3_2, relu3_2, ...
            "relu4_2": nn.Sequential(*list(vgg19_features.children())[18:27])   # pool3, conv4_1, relu4_1, conv4_2, relu4_2
        })
        for param in self.parameters():
            param.requires_grad = False

        self.l1 = nn.L1Loss()
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1))

    def forward(self, img_gen, img_ref):
        """
        img_gen, img_ref: [B,3,H,W] in range [-1,1].
        Return: sum of L1 distances between VGG feature maps at chosen layers.
        """
        # Convert to [0,1]
        gen = (img_gen + 1.0) / 2.0
        ref = (img_ref + 1.0) / 2.0
        # Normalize
        gen_norm = (gen - self.mean) / self.std
        ref_norm = (ref - self.mean) / self.std

        loss = 0.0
        x_gen = gen_norm
        x_ref = ref_norm
        for slice_mod in self.slices.values():
            x_gen = slice_mod(x_gen)
            x_ref = slice_mod(x_ref)
            loss += self.l1(x_gen, x_ref)
        return loss


class SobelLoss(nn.Module):
    """
    Computes the Sobel loss between two images, which encourages edge similarity.
    This loss operates on the grayscale versions of the input images.
    """
    def __init__(self):
        super(SobelLoss, self).__init__()
        # Sobel kernels for edge detection
        kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        self.register_buffer("kernel_x", kernel_x)
        self.register_buffer("kernel_y", kernel_y)
        self.l1 = nn.L1Loss()

        # Grayscale conversion weights (ITU-R BT.601)
        self.register_buffer("rgb_to_gray_weights", torch.tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1))

    def _get_edges(self, img):
        """
        Converts an RGB image to grayscale and applies Sobel filters.
        Args:
            img: [B, 3, H, W] image tensor in range [-1, 1].
        Returns:
            Gradient magnitude map [B, 1, H, W].
        """
        # Ensure input is 4D
        if img.dim() != 4:
            raise ValueError(f"Expected 4D input (got {img.dim()}D)")

        # Convert from [-1, 1] to [0, 1]
        img = (img + 1.0) / 2.0

        # Convert to grayscale
        # The weights need to be on the same device as the image.
        grayscale_img = F.conv2d(img, self.rgb_to_gray_weights.to(img.device)) # type: ignore

        # Apply Sobel filters. Kernels also need to be on the correct device.
        grad_x = F.conv2d(grayscale_img, self.kernel_x.to(img.device), padding=1) # type: ignore
        grad_y = F.conv2d(grayscale_img, self.kernel_y.to(img.device), padding=1) # type: ignore

        # Compute gradient magnitude
        edges = torch.sqrt(grad_x**2 + grad_y**2 + 1e-6) # add epsilon for stability
        return edges

    def forward(self, img_gen, img_ref):
        """
        img_gen, img_ref: [B, 3, H, W] in range [-1, 1].
        Returns: L1 loss between the edge maps of the two images.
        """
        edges_gen = self._get_edges(img_gen)
        edges_ref = self._get_edges(img_ref)
        return self.l1(edges_gen, edges_ref)


In [None]:
# Enhanced Training utilities and visualization functions from utils.py
def weights_init(m):
    """Initialize model weights"""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def denormalize_image(tensor):
    """
    Denormalizza un tensore immagine dall'intervallo [-1, 1] a [0, 1] per la visualizzazione.

    Args:
        tensor (torch.Tensor): Il tensore dell'immagine, con valori in [-1, 1].

    Returns:
        torch.Tensor: Il tensore denormalizzato con valori in [0, 1].
    """
    tensor = (tensor + 1) / 2
    return tensor.clamp(0, 1)

def save_plot_losses(losses_g, losses_d, losses_recon=None, output_dir="training_output", show_inline=True):
    """
    Genera e salva un plot delle loss del generatore e del discriminatore.
    """
    os.makedirs(output_dir, exist_ok=True)

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(losses_g, label="Generator Loss", color="blue")
    ax.plot(losses_d, label="Discriminator Loss", color="red")
    if losses_recon is not None:
        ax.plot(losses_recon, label="Reconstruction Loss", color="green")
    ax.set_title("Training Losses")
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")
    ax.legend()
    ax.grid(True)

    save_path = os.path.join(output_dir, "training_losses.png")
    plt.savefig(save_path)
    print(f"Grafico delle loss salvato in: {save_path}")

    if show_inline:
        plt.show()
    else:
        plt.close(fig)

def save_plot_non_gan_losses(train_losses_history, val_losses_history, output_dir="training_output", show_inline=True, filter_losses=None):
    """
    Generates and saves plots of losses for non-GAN models with multiple loss components.

    Args:
        train_losses_history: List of dicts containing training losses per epoch
                             e.g., [{'l1': 0.5, 'sobel': 0.3, 'ssim': 0.2}, ...]
        val_losses_history: List of dicts containing validation losses per epoch
        output_dir: Directory to save the plot
        show_inline: Whether to display the plot inline
        filter_losses: Optional list of loss names to plot. If None, plots all losses.
                      e.g., ['l1', 'sobel'] to only plot those specific losses
    """
    os.makedirs(output_dir, exist_ok=True)

    if not train_losses_history or not val_losses_history:
        print("No loss history to plot")
        return

    # Extract all unique loss keys from both training and validation
    all_keys = set()
    for losses_dict in train_losses_history + val_losses_history:
        all_keys.update(losses_dict.keys())

    # Filter out non-numeric keys if any
    loss_keys = [key for key in all_keys if key not in ['epoch']]

    # Apply filter if specified
    if filter_losses is not None:
        loss_keys = [key for key in loss_keys if key in filter_losses]

    loss_keys = sorted(loss_keys)  # Sort for consistent ordering

    if not loss_keys:
        print("No valid loss keys found")
        return

    # Create subplots
    n_losses = len(loss_keys)
    cols = min(3, n_losses)  # Max 3 columns
    rows = (n_losses + cols - 1) // cols  # Ceiling division

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
    if n_losses == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes.reshape(1, -1)

    fig.suptitle("Training and Validation Losses", fontsize=16, y=0.98)

    for i, loss_key in enumerate(loss_keys):
        row = i // cols
        col = i % cols
        ax = axes[row, col] if rows > 1 else axes[col]

        # Extract train and validation losses for this key
        train_values = [losses.get(loss_key, 0) for losses in train_losses_history]
        val_values = [losses.get(loss_key, 0) for losses in val_losses_history]

        epochs_train = range(1, len(train_values) + 1)
        epochs_val = range(1, len(val_values) + 1)

        # Plot training and validation curves
        if train_values:
            ax.plot(epochs_train, train_values, label=f"Train {loss_key}", color="blue", linewidth=1.5)
        if val_values:
            ax.plot(epochs_val, val_values, label=f"Val {loss_key}", color="red", linewidth=1.5, linestyle='--')

        ax.set_title(f"{loss_key.capitalize()} Loss", fontsize=12)
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.legend()
        ax.grid(True, alpha=0.3)

        # Set y-axis to start from 0 for better visualization
        ax.set_ylim(bottom=0)

    # Hide unused subplots
    for i in range(n_losses, rows * cols):
        row = i // cols
        col = i % cols
        if rows > 1:
            axes[row, col].set_visible(False)
        else:
            axes[col].set_visible(False)

    plt.tight_layout()

    # Save the plot
    save_path = os.path.join(output_dir, "non_gan_training_losses.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Non-GAN training losses plot saved to: {save_path}")

    if show_inline:
        plt.show()
    else:
        plt.close(fig)


def save_comparison_grid(epoch, model, batch, set_name, device, output_dir="training_output", show_inline=True):
    """
    Genera e salva/mostra una griglia di confronto orizzontale (reale vs. generato).
    Enhanced version from utils.py - automatically handles 256x256 or 64x64 based on set_name
    """
    os.makedirs(output_dir, exist_ok=True)

    model.eval()
    token_ids = batch["text"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    real_images = batch["image"]
    pokemon_ids = batch["idx"]
    descriptions = batch["description"]
    num_images = real_images.size(0)

    with torch.no_grad():
        generated_images = model(token_ids, attention_mask)
        # Handle the case where generator returns both 256x256 and 64x64 images
        if isinstance(generated_images, tuple):
            # Check if we want 64x64 or 256x256 based on set_name
            if "64" in set_name:
                generated_images = generated_images[1]  # Use 64x64 output
                # Resize real images to 64x64 for comparison
                real_images = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)
            else:
                generated_images = generated_images[0]  # Use 256x256 output

    fig, axs = plt.subplots(2, num_images, figsize=(4 * num_images, 8.5))
    resolution = "64x64" if "64" in set_name else "256x256"
    fig.suptitle(
        f"Epoch {epoch} - {set_name.capitalize()} Comparison ({resolution})", fontsize=16, y=0.98
    )

    for i in range(num_images):
        # Riga 0: Immagini Reali
        ax_real = axs[0, i]
        ax_real.imshow(denormalize_image(real_images[i].cpu()).permute(1, 2, 0))
        ax_real.set_title(f"#{pokemon_ids[i]}: {descriptions[i][:35]}...", fontsize=10)
        ax_real.axis("off")

        # Riga 1: Immagini Generate
        ax_gen = axs[1, i]
        ax_gen.imshow(denormalize_image(generated_images[i].cpu()).permute(1, 2, 0))
        ax_gen.axis("off")

    axs[0, 0].text(
        -0.1,
        0.5,
        "Real",
        ha="center",
        va="center",
        rotation="vertical",
        fontsize=14,
        transform=axs[0, 0].transAxes,
    )
    axs[1, 0].text(
        -0.1,
        0.5,
        "Generated",
        ha="center",
        va="center",
        rotation="vertical",
        fontsize=14,
        transform=axs[1, 0].transAxes,
    )

    plt.tight_layout(rect=(0, 0, 1, 0.95))

    # Salva sempre l'immagine
    save_path = os.path.join(output_dir, f"{epoch:03d}_{set_name}_comparison.png")
    plt.savefig(save_path)

    if show_inline:
        plt.show()
    else:
        plt.close(fig)

def save_attention_visualization(epoch, model, tokenizer, batch, device, set_name, output_dir="training_output", show_inline=True):
    """
    Genera e salva una visualizzazione dell'attenzione multi-livello in stile griglia.
    Enhanced version from utils.py

    L'immagine mostra:
    1. In alto, l'immagine generata e il prompt.
    2. Sotto, un bar chart dell'attenzione iniziale (contesto globale).
    3. Di seguito, una serie di griglie, una per ogni strato di attenzione del decoder.
       Ciascuna griglia mostra le mappe di calore pure per ogni token rilevante.
    """
    os.makedirs(output_dir, exist_ok=True)

    model.eval()

    with torch.no_grad():
        token_ids = batch["text"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        if token_ids.dim() > 1:  # Assicura un batch di 1
            token_ids = token_ids[0].unsqueeze(0)
            attention_mask = attention_mask[0].unsqueeze(0)

        pokemon_id = batch["idx"][0]
        description = batch["description"][0]

        model_to_use = model.module if isinstance(model, nn.DataParallel) else model
        generated_image, _, attention_maps, initial_context_weights = model_to_use(
            token_ids, attention_mask, return_attentions=True
        )

    decoder_attention_maps = [m for m in attention_maps if m is not None]

    if not decoder_attention_maps or initial_context_weights is None:
        print(
            f"Epoch {epoch}: Mappe di attenzione non disponibili. Salto la visualizzazione."
        )
        return

    tokens_all = tokenizer.convert_ids_to_tokens(token_ids.squeeze(0))
    display_tokens = []
    for i, token in enumerate(tokens_all):
        if (
            token not in [tokenizer.sep_token, tokenizer.pad_token]
            and attention_mask[0, i] == 1
        ):
            display_tokens.append({"token": token, "index": i})

    if not display_tokens:
        print(
            f"Epoch {epoch}: Nessun token valido da visualizzare per '{description}'. Salto."
        )
        return

    token_indices_to_display = [t["index"] for t in display_tokens]
    img_tensor_cpu = denormalize_image(generated_image.squeeze(0).cpu()).permute(
        1, 2, 0
    )
    num_decoder_layers = len(decoder_attention_maps)
    num_tokens = len(display_tokens)

    # --- Creazione del Plot ---
    # Calcola dinamicamente layout e dimensioni
    cols = min(num_tokens, 8)
    rows_per_layer = (num_tokens + cols - 1) // cols
    num_main_rows = (
        2 + num_decoder_layers
    )  # Immagine, Bar chart, e N layer di attenzione
    # Altezza per immagine, bar chart, e poi per ogni riga di ogni layer
    height_ratios = [3, 2] + [2 * rows_per_layer] * num_decoder_layers
    fig_height = sum(height_ratios)
    fig_width = max(20, 2.5 * cols)

    fig = plt.figure(figsize=(fig_width, fig_height))
    gs_main = fig.add_gridspec(
        num_main_rows, 1, height_ratios=height_ratios, hspace=1.2
    )
    fig.suptitle(
        f"Epoch {epoch}: Attention Visualization for Pokémon #{pokemon_id} ({set_name.capitalize()})",
        fontsize=24,
    )

    # --- 1. Immagine Generata e Prompt ---
    ax_main_img = fig.add_subplot(gs_main[0])
    ax_main_img.imshow(img_tensor_cpu)
    ax_main_img.set_title("Generated Image", fontsize=18)
    ax_main_img.text(
        0.5,
        -0.1,
        f"Prompt: {description}",
        ha="center",
        va="top",
        transform=ax_main_img.transAxes,
        fontsize=14,
        wrap=True,
    )
    ax_main_img.axis("off")

    # --- 2. Attenzione Iniziale per il Contesto (bar chart) ---
    ax_initial_attn = fig.add_subplot(gs_main[1])
    initial_weights_squeezed = initial_context_weights.squeeze().cpu().numpy()
    token_strings = [t["token"] for t in display_tokens]
    token_indices = [t["index"] for t in display_tokens]
    relevant_weights = initial_weights_squeezed[token_indices]
    ax_initial_attn.bar(
        np.arange(len(token_strings)), relevant_weights, color="skyblue"
    )
    ax_initial_attn.set_xticks(np.arange(len(token_strings)))
    ax_initial_attn.set_xticklabels(token_strings, rotation=45, ha="right", fontsize=10)
    ax_initial_attn.set_title("Initial Context Attention (Global)", fontsize=16)
    ax_initial_attn.set_ylabel("Weight", fontsize=12)
    ax_initial_attn.grid(axis="y", linestyle="--", alpha=0.7)

    # --- 3. Attenzione per Strato del Decoder (griglie di heatmap) ---
    for i, layer_attn_map in enumerate(decoder_attention_maps):
        map_size_flat = layer_attn_map.shape[1]
        map_side = int(np.sqrt(map_size_flat))
        layer_title = (
            f"Decoder Cross-Attention Layer {i + 1} (Size: {map_side}x{map_side})"
        )
        layer_attn_map_squeezed = layer_attn_map.squeeze(0).cpu()

        # Seleziona solo le mappe di attenzione per i token che visualizziamo
        relevant_attn_maps = layer_attn_map_squeezed[:, token_indices_to_display]

        # Trova i valori min/max per questo strato per la colorbar
        vmin = relevant_attn_maps.min()
        vmax = relevant_attn_maps.max()

        # Crea una subgrid per questo strato (con una colonna in più per la colorbar)
        gs_layer = gs_main[2 + i].subgridspec(
            rows_per_layer,
            cols + 1,
            wspace=0.2,
            hspace=0.4,
            width_ratios=[*([1] * cols), 0.1],
        )

        # Crea tutti gli assi per la griglia
        axes_in_layer = [
            fig.add_subplot(gs_layer[r, c])
            for r in range(rows_per_layer)
            for c in range(cols)
        ]

        # Usa la posizione del primo asse per il titolo
        if axes_in_layer:
            y_pos = axes_in_layer[0].get_position().y1
            fig.text(
                0.5,
                y_pos + 0.01,
                layer_title,
                ha="center",
                va="bottom",
                fontsize=16,
                weight="bold",
            )

        for j, token_info in enumerate(display_tokens):
            ax = axes_in_layer[j]
            attn_for_token = layer_attn_map_squeezed[:, token_info["index"]]
            heatmap = attn_for_token.reshape(map_side, map_side)
            im = ax.imshow(
                heatmap, cmap="jet", interpolation="nearest", vmin=vmin, vmax=vmax
            )
            ax.set_title(f"'{token_info['token']}'", fontsize=12)
            ax.axis("off")

        # Aggiungi la colorbar
        cax = fig.add_subplot(gs_layer[:, -1])
        cbar = fig.colorbar(im, cax=cax)
        cbar.ax.tick_params(labelsize=10)
        cbar.set_label("Attention Weight", rotation=270, labelpad=15, fontsize=12)

        # Pulisce gli assi non usati nella griglia
        for j in range(num_tokens, len(axes_in_layer)):
            axes_in_layer[j].axis("off")

    plt.tight_layout(rect=(0, 0.03, 1, 0.96))
    save_path = os.path.join(
        output_dir, f"{epoch:03d}_{set_name}_attention_visualization.png"
    )
    plt.savefig(save_path, bbox_inches="tight")

    if show_inline:
        plt.show()
    else:
        plt.close(fig)

def save_checkpoint(generator, discriminator, g_optimizer, d_optimizer, epoch, losses, path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'g_optimizer_state_dict': g_optimizer.state_dict(),
        'd_optimizer_state_dict': d_optimizer.state_dict(),
        'losses': losses
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved to {path}")



# Initialize models with enhanced discriminator
generator = Generator().to(device)
discriminator_256 = Discriminator256(img_size=256).to(device)  # 256x256 discriminator
discriminator_64 = Discriminator64().to(device)    # 64x64 discriminator

# Apply weight initialization
generator.apply(weights_init)
discriminator_256.apply(weights_init)
discriminator_64.apply(weights_init)

print("✅ Enhanced discriminators initialized with dynamic layer calculation!")

# Setup optimizers for AttnGAN-style training
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_256 = optim.Adam(discriminator_256.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_64 = optim.Adam(discriminator_64.parameters(), lr=lr, betas=(beta1, beta2))

# Loss function
adv_criterion = nn.BCELoss().to(device)
l1_criterion = nn.L1Loss().to(device)
perc_criterion = VGGPerceptualLoss(device)
sobel_criterion = SobelLoss().to(device)

print("Models and optimizers initialized successfully!")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters() if p.requires_grad):,}")
print(f"Discriminator 256 parameters: {sum(p.numel() for p in discriminator_256.parameters() if p.requires_grad):,}")
print(f"Discriminator 64 parameters: {sum(p.numel() for p in discriminator_64.parameters() if p.requires_grad):,}")
print(f"Total parameters: {sum(p.numel() for p in generator.parameters() if p.requires_grad) + sum(p.numel() for p in discriminator_256.parameters() if p.requires_grad) + sum(p.numel() for p in discriminator_64.parameters() if p.requires_grad):,}")


In [None]:
# Create output directory
os.makedirs('models', exist_ok=True)

# Training history
losses = {
    'generator': [],
    'discriminator': [],
    'l1': [],
    'perceptual': [],
    'sobel': [],
}

# Validation history (separate tracking)
val_losses = {
    'l1': [],
    'perceptual': [],
    'sobel': [],
    'total': [],
}

def validate_model(generator, val_loader, device, l1_criterion, perc_criterion, sobel_criterion):
    """
    Validate the model on the validation set
    Returns validation losses (dict)
    """
    generator.eval()

    val_l1_loss = 0.0
    val_perc_loss = 0.0
    val_sobel_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            # Move data to device
            real_images = batch['image'].to(device)
            text_ids = batch['text'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # Generate images
            generated_images, _ = generator(text_ids, attention_mask)

            # Calculate validation losses (no adversarial loss)
            batch_l1_loss = l1_criterion(generated_images, real_images)
            batch_perc_loss = perc_criterion(generated_images, real_images)
            batch_sobel_loss = sobel_criterion(generated_images, real_images)

            val_l1_loss += batch_l1_loss.item()
            val_perc_loss += batch_perc_loss.item()
            val_sobel_loss += batch_sobel_loss.item()
            num_batches += 1

    # Calculate averages
    avg_val_l1 = val_l1_loss / num_batches
    avg_val_perc = val_perc_loss / num_batches
    avg_val_sobel = val_sobel_loss / num_batches
    avg_val_total = avg_val_l1 + avg_val_perc + avg_val_sobel

    # Set models back to training mode
    generator.train()

    return {
        'l1': avg_val_l1,
        'perceptual': avg_val_perc,
        'sobel': avg_val_sobel,
        'total': avg_val_total
    }

epoch = 0
noise_dim = 100


In [None]:
from IPython.display import clear_output

# Training parameters
total_epochs = 150 # Reduced for faster training in demo
display_interval = 1
save_interval = 15
clear_interval = 22

# AttnGAN-style training - no generator update intervals
lambda_l1 = 1.0
lambda_adv = 1.0
lambda_perceptual = 0.0
lambda_sobel = 0.0

# Labels for real and fake data
real_label = 1.0
fake_label = 0.0

def create_mismatched_text_batch(text_ids, attention_mask):
    """Create a batch with mismatched text for wrong text conditioning"""
    batch_size = text_ids.size(0)
    indices = torch.randperm(batch_size)
    return text_ids[indices], attention_mask[indices]

def compute_discriminator_loss(discriminator, real_images, fake_images,
                              text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,
                              real_labels, fake_labels, adv_criterion):
    """Compute AttnGAN-style discriminator loss with 5 components"""
    # Real images with correct text
    real_uncond, real_cond = discriminator(real_images, text_ids, attention_mask, return_both=True)
    real_uncond_loss = adv_criterion(real_uncond, real_labels)
    real_cond_loss = adv_criterion(real_cond, real_labels)

    # Real images with wrong text
    _, real_cond_wrong = discriminator(real_images, wrong_text_ids, wrong_attention_mask, return_both=True)
    real_cond_wrong_loss = adv_criterion(real_cond_wrong, fake_labels)

    # Fake images with correct text
    fake_uncond, fake_cond = discriminator(fake_images.detach(), text_ids, attention_mask, return_both=True)
    fake_uncond_loss = adv_criterion(fake_uncond, fake_labels)
    fake_cond_loss = adv_criterion(fake_cond, fake_labels)

    # Average all 5 losses
    total_loss = (real_uncond_loss + real_cond_loss + real_cond_wrong_loss +
                  fake_uncond_loss + fake_cond_loss) / 5

    # Return both total loss and components for tracking
    components = {
        'real_uncond': real_uncond_loss.item(),
        'real_cond': real_cond_loss.item(),
        'real_wrong': real_cond_wrong_loss.item(),
        'fake_uncond': fake_uncond_loss.item(),
        'fake_cond': fake_cond_loss.item()
    }

    return total_loss, components

def compute_generator_adversarial_loss(discriminator, fake_images, text_ids, attention_mask,
                                     real_labels, adv_criterion):
    """Compute generator adversarial loss for one discriminator"""
    fake_uncond, fake_cond = discriminator(fake_images, text_ids, attention_mask, return_both=True)
    uncond_loss = adv_criterion(fake_uncond, real_labels)
    cond_loss = adv_criterion(fake_cond, real_labels)
    return (uncond_loss + cond_loss) / 2

def compute_reconstruction_losses(fake_images_256, real_images,
                                l1_criterion, perc_criterion, sobel_criterion,
                                lambda_l1, lambda_perceptual, lambda_sobel, device):
    """Compute all reconstruction losses"""
    l1_loss = l1_criterion(fake_images_256, real_images) if lambda_l1 > 0 else torch.tensor(0.0, device=device)
    perc_loss = perc_criterion(fake_images_256, real_images) if lambda_perceptual > 0 else torch.tensor(0.0, device=device)
    sobel_loss = sobel_criterion(fake_images_256, real_images) if lambda_sobel > 0 else torch.tensor(0.0, device=device)
    return l1_loss, perc_loss, sobel_loss

print("Starting AttnGAN-style training with dual discriminators...")
print(f"Device: {device}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Batch size: {dataloader.batch_size}")
print(f"Total epochs: {total_epochs}")
print(f"Using dual discriminators: 64x64 and 256x256")
print("-" * 50)

for epoch in range(epoch, total_epochs):
    epoch_g_loss = 0.0
    epoch_d_loss_64 = 0.0
    epoch_d_loss_256 = 0.0
    epoch_l1_loss = 0.0
    epoch_perc_loss = 0.0
    epoch_sobel_loss = 0.0

    # Track discriminator loss components
    epoch_d256_components = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_wrong': 0.0, 'fake_uncond': 0.0, 'fake_cond': 0.0}
    epoch_d64_components = {'real_uncond': 0.0, 'real_cond': 0.0, 'real_wrong': 0.0, 'fake_uncond': 0.0, 'fake_cond': 0.0}

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{total_epochs}")

    for i, batch in enumerate(progress_bar):
        batch_size = batch['image'].size(0)

        # Move data to device
        real_images = batch['image'].to(device)
        text_ids = batch['text'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        # Create mismatched text and labels
        wrong_text_ids, wrong_attention_mask = create_mismatched_text_batch(text_ids, attention_mask)
        real_labels = torch.full((batch_size, 1), real_label, device=device, dtype=torch.float)
        fake_labels = torch.full((batch_size, 1), fake_label, device=device, dtype=torch.float)

        # Generate fake images
        fake_images_256, fake_images_64 = generator(text_ids, attention_mask)
        real_images_64 = F.interpolate(real_images, size=(64, 64), mode='bilinear', align_corners=False)

        # ==========================================
        # Train Both Discriminators
        # ==========================================
        optimizer_D_256.zero_grad()

        # 256x256 discriminator loss
        d_loss_256, d256_components = compute_discriminator_loss(
            discriminator_256, real_images, fake_images_256,
            text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,
            real_labels, fake_labels, adv_criterion
        )
        d_loss_256.backward()

        # 64x64 discriminator loss
        d_loss_64, d64_components = compute_discriminator_loss(
            discriminator_64, real_images_64, fake_images_64,
            text_ids, attention_mask, wrong_text_ids, wrong_attention_mask,
            real_labels, fake_labels, adv_criterion
        )
        d_loss_64.backward()

        # Update discriminators
        optimizer_D_256.step()
        optimizer_D_64.step()

        # ==========================================
        # Train Generator
        # ==========================================
        optimizer_G.zero_grad()

        # Adversarial losses for both discriminators
        g_adv_loss_256 = compute_generator_adversarial_loss(
            discriminator_256, fake_images_256, text_ids, attention_mask, real_labels, adv_criterion
        )
        g_adv_loss_64 = compute_generator_adversarial_loss(
            discriminator_64, fake_images_64, text_ids, attention_mask, real_labels, adv_criterion
        )
        adversarial_loss = (g_adv_loss_256 + g_adv_loss_64) / 2

        # Reconstruction losses
        l1_loss, perc_loss, sobel_loss = compute_reconstruction_losses(
            fake_images_256, real_images, l1_criterion, perc_criterion, sobel_criterion,
            lambda_l1, lambda_perceptual, lambda_sobel, device
        )

        # Total generator loss
        g_loss = (lambda_adv * adversarial_loss + lambda_l1 * l1_loss +
                 lambda_perceptual * perc_loss + lambda_sobel * sobel_loss)
        g_loss.backward()
        optimizer_G.step()

        # Update loss tracking
        epoch_g_loss += g_loss.item()
        epoch_d_loss_256 += d_loss_256.item()
        epoch_d_loss_64 += d_loss_64.item()
        epoch_l1_loss += l1_loss.item()
        epoch_perc_loss += perc_loss.item()
        epoch_sobel_loss += sobel_loss.item()

        # Update discriminator component tracking
        for key in epoch_d256_components:
            epoch_d256_components[key] += d256_components[key]
            epoch_d64_components[key] += d64_components[key]

        # Update progress bar with detailed discriminator components
        progress_bar.set_postfix({
            'D256': f'{d_loss_256.item():.3f}',
            'D256_r_u': f'{d256_components["real_uncond"]:.3f}',
            'D256_r_c': f'{d256_components["real_cond"]:.3f}',
            'D256_rw': f'{d256_components["real_wrong"]:.3f}',
            'D256_f_u': f'{d256_components["fake_uncond"]:.3f}',
            'D256_f_c': f'{d256_components["fake_cond"]:.3f}',
            'D64': f'{d_loss_64.item():.3f}',
            'D64_r_u': f'{d64_components["real_uncond"]:.3f}',
            'D64_r_c': f'{d64_components["real_cond"]:.3f}',
            'D64_rw': f'{d64_components["real_wrong"]:.3f}',
            'D64_f_u': f'{d64_components["fake_uncond"]:.3f}',
            'D64_f_c': f'{d64_components["fake_cond"]:.3f}',
            'G': f'{g_loss.item():.3f}',
            'L1': f'{l1_loss.item():.3f}',
            'Adv': f'{adversarial_loss.item():.3f}'
        })

    # Calculate average losses for the epoch
    avg_g_loss = epoch_g_loss / len(dataloader)
    avg_d_loss_256 = epoch_d_loss_256 / len(dataloader)
    avg_d_loss_64 = epoch_d_loss_64 / len(dataloader)
    avg_l1_loss = epoch_l1_loss / len(dataloader)
    avg_perc_loss = epoch_perc_loss / len(dataloader)
    avg_sobel_loss = epoch_sobel_loss / len(dataloader)

    # Calculate average discriminator components
    avg_d256_components = {key: val / len(dataloader) for key, val in epoch_d256_components.items()}
    avg_d64_components = {key: val / len(dataloader) for key, val in epoch_d64_components.items()}

    # Store losses (combine discriminator losses)
    losses['generator'].append(avg_g_loss)
    losses['discriminator'].append((avg_d_loss_256 + avg_d_loss_64) / 2)
    losses['l1'].append(avg_l1_loss)
    losses['perceptual'].append(avg_perc_loss)
    losses['sobel'].append(avg_sobel_loss)

    # Run validation
    print(f"Running validation for epoch {epoch+1}...")
    validation_results = validate_model(generator, val_loader, device,
                                      l1_criterion, perc_criterion, sobel_criterion)

    # Store validation losses
    val_losses['l1'].append(validation_results['l1'])
    val_losses['perceptual'].append(validation_results['perceptual'])
    val_losses['sobel'].append(validation_results['sobel'])
    val_losses['total'].append(validation_results['total'])

    if (epoch + 1) % clear_interval == 0:
        clear_output(wait=True)

    print(f"Epoch [{epoch+1}/{total_epochs}]")
    print(f"  Train - D_256: {avg_d_loss_256:.4f}, D_64: {avg_d_loss_64:.4f}, G_loss: {avg_g_loss:.4f}")
    print(f"  D_256 Components - RU: {avg_d256_components['real_uncond']:.4f}, RC: {avg_d256_components['real_cond']:.4f}, RW: {avg_d256_components['real_wrong']:.4f}, FU: {avg_d256_components['fake_uncond']:.4f}, FC: {avg_d256_components['fake_cond']:.4f}")
    print(f"  D_64 Components  - RU: {avg_d64_components['real_uncond']:.4f}, RC: {avg_d64_components['real_cond']:.4f}, RW: {avg_d64_components['real_wrong']:.4f}, FU: {avg_d64_components['fake_uncond']:.4f}, FC: {avg_d64_components['fake_cond']:.4f}")
    print(f"  Train - L1: {avg_l1_loss:.4f}, Perceptual: {avg_perc_loss:.4f}, Sobel: {avg_sobel_loss:.4f}")
    print(f"  Val   - L1: {validation_results['l1']:.4f}, Perceptual: {validation_results['perceptual']:.4f}, Sobel: {validation_results['sobel']:.4f}, Total: {validation_results['total']:.4f}")
    print(f"  Legend: RU=RealUncond, RC=RealCond, RW=RealWrong, FU=FakeUncond, FC=FakeCond")

    # Display generated images
    if (epoch + 1) % display_interval == 0:
        print(f"\nGenerating sample images at epoch {epoch+1}:")
        print("256x256 Training Images:")
        save_comparison_grid(epoch+1, generator, fixed_train_batch, "train_256", device, show_inline=True)
        print("64x64 Training Images:")
        save_comparison_grid(epoch+1, generator, fixed_train_batch, "train_64", device, show_inline=True)

    # Save checkpoint
    if (epoch + 1) % save_interval == 0:
        checkpoint_path = f'models/checkpoint_epoch_{epoch+1}.pth'
        all_losses = {'train': losses, 'val': val_losses}
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_256_state_dict': discriminator_256.state_dict(),
            'discriminator_64_state_dict': discriminator_64.state_dict(),
            'g_optimizer_state_dict': optimizer_G.state_dict(),
            'd_optimizer_state_dict': optimizer_D_256.state_dict(),
            'd_64_optimizer_state_dict': optimizer_D_64.state_dict(),
            'losses': all_losses
        }
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

        print("256x256 Validation Images:")
        save_comparison_grid(epoch+1, generator, fixed_val_batch, "val_256", device, show_inline=True)
        print("64x64 Validation Images:")
        save_comparison_grid(epoch+1, generator, fixed_val_batch, "val_64", device, show_inline=True)
        save_attention_visualization(epoch+1, generator, tokenizer, fixed_train_batch, device, "train", show_inline=True)
        save_attention_visualization(epoch+1, generator, tokenizer, fixed_val_batch, device, "val", show_inline=True)

print("\nAttnGAN-style training completed!")

In [None]:
# Enhanced plot training losses using utils.py function
save_plot_losses(
    losses_g=losses['generator'],
    losses_d=losses['discriminator'],
    output_dir="training_output",
    show_inline=True
)

# Plot training vs validation losses for non-adversarial components
# Convert to list of dicts format expected by save_plot_non_gan_losses
train_losses_history = []
val_losses_history = []

for i in range(len(losses['l1'])):
    train_losses_history.append({
        'l1': losses['l1'][i],
        'perceptual': losses['perceptual'][i],
        'sobel': losses['sobel'][i],
        'total': losses['l1'][i] + losses['perceptual'][i] + losses['sobel'][i]
    })

for i in range(len(val_losses['l1'])):
    val_losses_history.append({
        'l1': val_losses['l1'][i],
        'perceptual': val_losses['perceptual'][i],
        'sobel': val_losses['sobel'][i],
        'total': val_losses['total'][i]
    })

save_plot_non_gan_losses(
    train_losses_history=train_losses_history,
    val_losses_history=val_losses_history,
    output_dir="training_output",
    show_inline=True
)

# Print final statistics
if losses['generator']:
    print(f"Final Train - Generator Loss: {losses['generator'][-1]:.4f}")
    print(f"Final Train - Discriminator Loss: {losses['discriminator'][-1]:.4f}")
    print(f"Final Train - L1 Loss: {losses['l1'][-1]:.4f}")
    print(f"Final Train - Perceptual Loss: {losses['perceptual'][-1]:.4f}")
    print(f"Final Train - Sobel Loss: {losses['sobel'][-1]:.4f}")

    if val_losses['l1']:
        print(f"Final Val   - L1 Loss: {val_losses['l1'][-1]:.4f}")
        print(f"Final Val   - Perceptual Loss: {val_losses['perceptual'][-1]:.4f}")
        print(f"Final Val   - Sobel Loss: {val_losses['sobel'][-1]:.4f}")
        print(f"Final Val   - Total Loss: {val_losses['total'][-1]:.4f}")
else:
    print("No training losses recorded yet.")


In [None]:
# Generate a grid of final results
print("Final Results - Generated Pokemon Sprites (256x256):")
batch = next(iter(dataloader))
save_comparison_grid(0, generator, batch, "final_256", device, show_inline=True)

print("Final Results - Generated Pokemon Sprites (64x64):")
save_comparison_grid(0, generator, batch, "final_64", device, show_inline=True)


In [None]:
# Enhanced interactive generation function with attention visualization
def generate_pokemon_from_text(description, num_samples=4, show_attention=False):
    """Generate Pokemon sprites from custom text description with enhanced visualization"""
    generator.eval()

    with torch.no_grad():
        # Tokenize the description
        tokens = tokenizer(
            description,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Repeat for multiple samples
        text_ids = tokens['input_ids'].repeat(num_samples, 1).to(device)
        attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(device)

        # Generate images (generator handles text encoding internally)
        if show_attention:
            generated_images_256, generated_images_64, attention_maps, initial_weights = generator(
                text_ids, attention_mask, return_attentions=True
            )
        else:
            generated_images_256, generated_images_64 = generator(text_ids, attention_mask)

        # Create batch format for visualization functions - 256x256
        fake_batch_256 = {
            'text': text_ids,
            'attention_mask': attention_mask,
            'image': generated_images_256,  # Use generated as "real" for display
            'description': [description] * num_samples,
            'pokemon_name': [f"Generated_{i+1}" for i in range(num_samples)],
            'idx': list(range(num_samples))
        }

        # Create batch format for visualization functions - 64x64
        fake_batch_64 = {
            'text': text_ids,
            'attention_mask': attention_mask,
            'image': generated_images_64,  # Use generated as "real" for display
            'description': [description] * num_samples,
            'pokemon_name': [f"Generated_{i+1}" for i in range(num_samples)],
            'idx': list(range(num_samples))
        }

        # Use enhanced comparison grid for both resolutions
        print("256x256 Generated Pokemon:")
        save_comparison_grid(
            epoch=0,
            model=generator,
            batch=fake_batch_256,
            set_name="custom_256",
            device=device,
            output_dir="custom_generation",
            show_inline=True
        )

        print("64x64 Generated Pokemon:")
        save_comparison_grid(
            epoch=0,
            model=generator,
            batch=fake_batch_64,
            set_name="custom_64",
            device=device,
            output_dir="custom_generation",
            show_inline=True
        )

        # Show attention visualization if requested
        if show_attention and attention_maps is not None:
            print("\nGenerating attention visualization...")
            # Create single-sample batch for attention visualization
            single_batch = {
                'text': text_ids[:1],
                'attention_mask': attention_mask[:1],
                'description': [description],
                'idx': [0]
            }
            save_attention_visualization(
                epoch=0,
                model=generator,
                tokenizer=tokenizer,
                batch=single_batch,
                device=device,
                set_name="custom",
                output_dir="custom_generation",
                show_inline=True
            )

    generator.train()

# Simple visualization function for basic usage
def simple_generate_pokemon(description, num_samples=4):
    """Simple generation without attention - for quick testing"""
    generator.eval()

    with torch.no_grad():
        tokens = tokenizer(
            description,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        text_ids = tokens['input_ids'].repeat(num_samples, 1).to(device)
        attention_mask = tokens['attention_mask'].repeat(num_samples, 1).to(device)
        generated_images_256, generated_images_64 = generator(text_ids, attention_mask)

        # Simple matplotlib visualization for both resolutions
        fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 3, 6))
        if num_samples == 1:
            axes = axes.reshape(2, 1)

        for i in range(num_samples):
            # 256x256 images
            img_256 = denormalize_image(generated_images_256[i].cpu()).permute(1, 2, 0)
            axes[0, i].imshow(img_256)
            axes[0, i].set_title(f"256x256 - Sample {i+1}")
            axes[0, i].axis('off')

            # 64x64 images
            img_64 = denormalize_image(generated_images_64[i].cpu()).permute(1, 2, 0)
            axes[1, i].imshow(img_64)
            axes[1, i].set_title(f"64x64 - Sample {i+1}")
            axes[1, i].axis('off')

        plt.suptitle(f'Generated Pokemon: "{description}"', fontsize=14)
        plt.tight_layout()
        plt.show()

    generator.train()

# Test with custom descriptions using enhanced functions
test_descriptions = [
    "A fire type pokemon with orange fur and a flame on its tail",
    "A blue water type pokemon with bubbles",
    "A grass type pokemon with green leaves and vines",
    "An electric type pokemon with yellow fur and lightning bolts",
    "A psychic type pokemon with purple coloring and mystical powers"
]

print("🔥 ENHANCED POKEMON GENERATION DEMO")
print("=" * 60)
print("Testing both simple and enhanced generation functions...\n")

# Test first description with enhanced visualization
print(f"✨ ENHANCED Generation: {test_descriptions[0]}")
generate_pokemon_from_text(test_descriptions[0], num_samples=3, show_attention=False)
print("\n" + "-"*80 + "\n")

# Test second description with simple visualization for comparison
print(f"⚡ SIMPLE Generation: {test_descriptions[1]}")
simple_generate_pokemon(test_descriptions[1], num_samples=3)
print("\n" + "-"*80 + "\n")

# Quick test of remaining descriptions with simple function
print("🎮 Quick tests with simple generation:")
for desc in test_descriptions[2:]:
    print(f"\nDescription: {desc}")
    simple_generate_pokemon(desc, num_samples=2)

print("\n" + "="*60)
print("💡 PRO TIP: Use show_attention=True for detailed attention analysis!")
print("Example: generate_pokemon_from_text('legendary dragon', show_attention=True)")
print("="*60)


In [None]:
## 6.5. Enhanced Visualization Demo with Utils.py Functions

# Test the enhanced visualization functions from utils.py
print("🎨 ENHANCED VISUALIZATION DEMO")
print("=" * 50)

# 1. Test enhanced comparison grid with training data
print("\n1. Enhanced Comparison Grid:")
save_comparison_grid(
    epoch=0,
    model=generator,
    batch=fixed_train_batch,
    set_name="demo",
    device=device,
    output_dir="demo_output",
    show_inline=True
)

# 2. Test attention visualization with a single sample
print("\n2. Attention Visualization (if attention is available):")
try:
    save_attention_visualization(
        epoch=0,
        model=generator,
        tokenizer=tokenizer,
        batch=fixed_train_attention_batch,
        device=device,
        set_name="demo",
        output_dir="demo_output",
        show_inline=True
    )
except Exception as e:
    print(f"Attention visualization not available: {e}")
    print("This is normal if the model doesn't have attention mechanisms enabled.")

# 3. Test enhanced loss plotting
print("\n3. Enhanced Loss Plotting:")
# Create some dummy loss data for demonstration
demo_losses_g = [3.2, 2.8, 2.5, 2.2, 2.0, 1.8, 1.6, 1.5, 1.4, 1.3]
demo_losses_d = [0.8, 0.7, 0.6, 0.65, 0.7, 0.68, 0.66, 0.64, 0.63, 0.62]
demo_losses_recon = [0.4, 0.35, 0.3, 0.28, 0.25, 0.23, 0.22, 0.21, 0.20, 0.19]

save_plot_losses(
    losses_g=demo_losses_g,
    losses_d=demo_losses_d,
    losses_recon=demo_losses_recon,
    output_dir="demo_output",
    show_inline=True
)

print("\n✅ Enhanced visualization functions successfully integrated!")
print("📁 All visualizations are saved to 'demo_output' and 'custom_generation' directories")
print("\n🎯 Available enhanced functions:")
print("  • save_comparison_grid() - Enhanced real vs generated comparison")
print("  • save_attention_visualization() - Detailed attention heatmaps")
print("  • save_plot_losses() - Professional loss plotting")
print("  • denormalize_image() - Proper image denormalization")
print("  • generate_pokemon_from_text() - Now with attention visualization!")
print("\n💡 Usage examples:")
print("  generate_pokemon_from_text('fire dragon', num_samples=4, show_attention=True)")
print("  simple_generate_pokemon('electric mouse', num_samples=3)  # For quick testing")


In [None]:
# Final model summary and analysis
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("=" * 60)
print("PIKAPIKAGEN: FINAL MODEL ANALYSIS")
print("=" * 60)

print(f"\\n📊 MODEL STATISTICS:")
print(f"Generator parameters: {count_parameters(generator):,}")
print(f"Discriminator parameters: {count_parameters(discriminator_256):,}")
print(f"Total parameters: {count_parameters(generator) + count_parameters(discriminator_256):,}")

print(f"\\n📈 TRAINING STATISTICS:")
print(f"Total epochs trained: {len(losses['generator'])}")
print(f"Final Generator Loss: {losses['generator'][-1]:.4f}")
print(f"Final Discriminator Loss: {losses['discriminator'][-1]:.4f}")
print(f"Final Reconstruction Loss: {losses['reconstruction'][-1]:.4f}")

print(f"\\n🎯 MODEL CAPABILITIES:")
print("✅ Text-to-Image Generation with Attention")
print("✅ BERT-mini Text Encoding (Fine-tuned)")
print("✅ Adversarial Training with Reconstruction Loss")
print("✅ Interactive Custom Text Generation")
print("✅ Real-time Training Visualization")

print(f"\\n📝 ARCHITECTURE SUMMARY:")
print("• Text Encoder: Transformer-based with pre-trained BERT-mini embeddings")
print("• Generator: CNN decoder with multi-layer attention mechanism")
print("• Discriminator: CNN discriminator with text conditioning")
print("• Attention: Allows selective focus on text features during generation")
print("• Loss: Adversarial + Reconstruction (MSE) loss combination")

print(f"\\n🔥 SUCCESS METRICS:")
print("• Successfully generates Pokemon sprites from text descriptions")
print("• Attention mechanism enables fine-grained text-image alignment")
print("• BERT-mini fine-tuning improves domain-specific understanding")
print("• Combined loss function balances realism and text fidelity")
print("• Real-time visualization shows training progress")

print("\\n✨ The PikaPikaGen model is now ready for Pokemon sprite generation!")
print("🎮 Try generating your own Pokemon with custom descriptions!")
print("=" * 60)

# Show final generation with interactive input
print("\\n🎯 INTERACTIVE DEMO:")
print("Try this: generate_pokemon_from_text('Your custom Pokemon description here!')")
print("\\nExample: generate_pokemon_from_text('A dragon type pokemon with silver wings and red eyes', num_samples=4)")

# Quick demonstration
generate_pokemon_from_text("A legendary fire dragon pokemon with golden scales", num_samples=4)
