# Step 1 : Diffusion Models that creates the Dataset

Main

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torchvision.utils import save_image
from PIL import Image
import os
import time
import datetime
import sys
import json
import numpy as np
from sklearn.cluster import KMeans
from scipy.linalg import sqrtm

def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved()/1e9:.2f}GB")

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads=8, reduction_factor=2):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        
        # Downsampling
        self.down = nn.Conv2d(channels, channels, kernel_size=3, 
                             stride=reduction_factor, padding=1)
        
        # Upsampling
        self.up = nn.ConvTranspose2d(channels, channels, kernel_size=4,
                                    stride=reduction_factor, padding=1)
        
        # Transformer components
        self.norm1 = nn.GroupNorm(8, channels)
        self.norm2 = nn.GroupNorm(8, channels)
        self.self_attention = nn.MultiheadAttention(channels, num_heads)
        
        # MLP block
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels * 4),
            nn.GELU(),
            nn.Linear(channels * 4, channels)
        )
        
    def forward(self, x):
        b, c, h, w = x.shape
        identity = x
        
        # Downsampling
        x = self.down(x)
        
        # Normalization
        x = self.norm1(x)
        
        # Reshape per attention
        _, _, h_down, w_down = x.shape
        x_flat = x.flatten(2).permute(2, 0, 1)  # (h*w, batch, channels)
        
        # Self attention
        attn_output, _ = self.self_attention(x_flat, x_flat, x_flat)
        
        # Reshape back e prima skip connection
        x = attn_output.permute(1, 2, 0).view(b, c, h_down, w_down) + x
        
        # MLP
        x = self.norm2(x)
        x_mlp = x.view(b, c, -1).permute(0, 2, 1)
        x_mlp = self.mlp(x_mlp)
        x = x + x_mlp.permute(0, 2, 1).view(b, c, h_down, w_down)
        
        # Upsampling e seconda skip connection
        x = self.up(x)
        return x + identity

class E2GANGenerator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super().__init__()
        
        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, 64, 7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks
        self.res1 = ResidualBlock(256)
        self.res2 = ResidualBlock(256)
        
        # Transformer block
        self.transformer = TransformerBlock(256)
        
        # Third residual block
        self.res3 = ResidualBlock(256)
        
        # Upsampling
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Output convolution
        self.output = nn.Sequential(
            nn.Conv2d(64, output_channels, 7, padding=3),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.initial(x)
        x = self.down1(x)
        x = self.down2(x)
        
        x = self.res1(x)
        x = self.res2(x)
        x = self.transformer(x)
        x = self.res3(x)
        
        x = self.up1(x)
        x = self.up2(x)
        x = self.output(x)
        
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

class ImagePairsDataset(Dataset):
    def __init__(self, originals_dir, filtered_dir, transform=None):
        self.originals_dir = originals_dir
        self.filtered_dir = filtered_dir
        self.transform = transform
        self.image_names = sorted(os.listdir(originals_dir))  # Sorted per riproducibilità

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        original_path = os.path.join(self.originals_dir, image_name)
        filtered_path = os.path.join(self.filtered_dir, image_name)

        original_image = Image.open(original_path).convert("RGB")
        filtered_image = Image.open(filtered_path).convert("RGB")

        if self.transform:
            original_image = self.transform(original_image)
            filtered_image = self.transform(filtered_image)
        
        return original_image, filtered_image

class DatasetManager:
    def __init__(self, n_clusters=400):
        self.feature_extractor = resnet50(pretrained=True)
        self.feature_extractor.eval()
        self.n_clusters = n_clusters
        
    def extract_features(self, dataset, batch_size=32, device='cuda'):
        self.feature_extractor = self.feature_extractor.to(device)
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
        features = []
        
        print("Extracting features...")
        with torch.no_grad():
            for i, (images, _) in enumerate(dataloader):
                images = images.to(device)
                feat = self.feature_extractor(images)
                features.append(feat.cpu().numpy())
                if i % 10 == 0:
                    print(f"Processed {i * batch_size}/{len(dataset)} images")
                
        return np.concatenate(features)
    
    def create_clustered_dataset(self, original_dataset):
        features = self.extract_features(original_dataset)
        features = features.reshape(features.shape[0], -1)
        
        print(f"Running K-means clustering with {self.n_clusters} clusters...")
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=42, n_init=10)
        clusters = kmeans.fit_predict(features)
        
        selected_indices = []
        for i in range(self.n_clusters):
            cluster_points = features[clusters == i]
            cluster_indices = np.where(clusters == i)[0]
            
            if len(cluster_points) > 0:
                centroid = kmeans.cluster_centers_[i]
                distances = np.linalg.norm(cluster_points - centroid, axis=1)
                closest_idx = cluster_indices[np.argmin(distances)]
                selected_indices.append(closest_idx)
            
        print(f"Selected {len(selected_indices)} representative samples")
        return torch.utils.data.Subset(original_dataset, selected_indices)

def measure_inference_speed(generator, device, n_samples=100):
    generator.eval()
    dummy_input = torch.randn(1, 3, 256, 256).to(device)
    
    # Warmup
    print("Warming up...")
    with torch.no_grad():
        for _ in range(10):
            _ = generator(dummy_input)
    
    # Misura reale
    print(f"Measuring inference time over {n_samples} samples...")
    times = []
    with torch.no_grad():
        for _ in range(n_samples):
            torch.cuda.synchronize()
            start = time.time()
            _ = generator(dummy_input)
            torch.cuda.synchronize()
            end = time.time()
            times.append((end - start) * 1000)  # ms
            
    avg_time = np.mean(times)
    std_time = np.std(times)
    print(f"Average inference time: {avg_time:.2f}ms ± {std_time:.2f}ms")
    return avg_time

def save_sample_images(generator, val_loader, epoch, save_dir, device):
    generator.eval()
    with torch.no_grad():
        real_A, real_B = next(iter(val_loader))
        real_A, real_B = real_A.to(device), real_B.to(device)
        fake_B = generator(real_A)
        
        # Denormalize
        def denorm(x):
            return (x + 1) / 2
            
        # Salva griglia
        img_sample = torch.cat([
            denorm(real_A), 
            denorm(fake_B), 
            denorm(real_B)
        ], -2)
        
        save_image(img_sample, 
                  f"{save_dir}/epoch_{epoch}.png", 
                  nrow=min(8, real_A.size(0)), 
                  normalize=False)

def calculate_fid(real_features, fake_features):
    try:
        if np.isnan(real_features).any() or np.isnan(fake_features).any():
            return float('inf')
            
        eps = 1e-6
        
        mu1 = np.mean(real_features, axis=0)
        sigma1 = np.cov(real_features, rowvar=False) + np.eye(real_features.shape[1]) * eps
        
        mu2 = np.mean(fake_features, axis=0)
        sigma2 = np.cov(fake_features, rowvar=False) + np.eye(fake_features.shape[1]) * eps
        
        diff = mu1 - mu2
        
        covmean = sqrtm(sigma1.dot(sigma2))
        
        if np.iscomplexobj(covmean):
            covmean = covmean.real
            
        fid = np.sum(diff**2) + np.trace(sigma1 + sigma2 - 2*covmean)
        return float(fid)
    except Exception as e:
        print(f"Error in FID calculation: {e}")
        return float('inf')

def evaluate_model(generator, dataloader, device):
    generator.eval()
    feature_extractor = resnet50(pretrained=True).to(device)
    feature_extractor.eval()
    
    real_features = []
    fake_features = []
    
    print("Evaluating model...")
    with torch.no_grad():
        for i, (real_A, real_B) in enumerate(dataloader):
            real_A = real_A.to(device)
            real_B = real_B.to(device)
            
            fake_B = generator(real_A)
            
            real_feat = feature_extractor(real_B)
            fake_feat = feature_extractor(fake_B)
            
            real_features.append(real_feat.cpu().numpy())
            fake_features.append(fake_feat.cpu().numpy())
            
            if i % 10 == 0:
                print(f"Processed {i * dataloader.batch_size}/{len(dataloader.dataset)} images")
    
    if len(real_features) == 0 or len(fake_features) == 0:
        return float('inf')
        
    real_features = np.concatenate(real_features)
    fake_features = np.concatenate(fake_features)
    
    return calculate_fid(real_features, fake_features)


def train_e2gan(generator, discriminator, train_loader, val_loader, num_epochs, device, save_dir):
    # Parameters come nel paper
    criterion_GAN = torch.nn.MSELoss()
    criterion_pixel = torch.nn.L1Loss()
    lambda_pixel = 100

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


    # Metrics tracking
    best_fid = float('inf')
    metrics = {
        'g_losses': [],
        'd_losses': [],
        'pixel_losses': [],
        'fid_scores': [],
        'epoch_times': [],
        'running_times': []
    }

    total_start_time = time.time()

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        generator.train()
        discriminator.train()
        
        epoch_g_losses = []
        epoch_d_losses = []
        epoch_pixel_losses = []
        
        for i, (real_A, real_B) in enumerate(train_loader):
            batch_size = real_A.size(0)
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Ground truths
            valid = torch.ones((batch_size, 1, 16, 16), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1, 16, 16), requires_grad=False).to(device)

            # Train Generator
            optimizer_G.zero_grad()
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_pixel = criterion_pixel(fake_B, real_B)
            loss_G = loss_GAN + lambda_pixel * loss_pixel
            loss_G.backward()
            optimizer_G.step()

            # Train Discriminator
            optimizer_D.zero_grad()
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            # Save losses
            epoch_g_losses.append(loss_G.item())
            epoch_d_losses.append(loss_D.item())
            epoch_pixel_losses.append(loss_pixel.item())

            # Print progress
            if i % 5 == 0:
                print(f"\rEpoch [{epoch}/{num_epochs}] Batch [{i}/{len(train_loader)}] "
                      f"d_loss: {loss_D.item():.4f}, g_loss: {loss_G.item():.4f}, "
                      f"pixel: {loss_pixel.item():.4f}", end="")

        # End of epoch
        epoch_time = time.time() - epoch_start_time
        running_time = time.time() - total_start_time
        
        metrics['epoch_times'].append(epoch_time)
        metrics['running_times'].append(running_time)
        
        # Calculate average losses
        avg_g_loss = np.mean(epoch_g_losses)
        avg_d_loss = np.mean(epoch_d_losses)
        avg_pixel_loss = np.mean(epoch_pixel_losses)
        
        metrics['g_losses'].append(avg_g_loss)
        metrics['d_losses'].append(avg_d_loss)
        metrics['pixel_losses'].append(avg_pixel_loss)

        # Validation and FID calculation
        print("\nRunning validation...")
        generator.eval()
        val_fid = evaluate_model(generator, val_loader, device)
        metrics['fid_scores'].append(val_fid)
        
        # Save sample images
        if epoch % 5 == 0:
            save_sample_images(generator, val_loader, epoch, save_dir, device)

        # Save best model
        if val_fid < best_fid:
            best_fid = val_fid
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'metrics': metrics,
            }, f'{save_dir}/best_model.pt')
            print(f"Saved best model with FID: {val_fid:.4f}")

        # Save checkpoint ogni 10 epoche
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'metrics': metrics,
            }, f'{save_dir}/checkpoint_epoch_{epoch+1}.pt')


        # Print epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}, Pixel Loss: {avg_pixel_loss:.4f}")
        print(f"FID Score: {val_fid:.4f}")
        print(f"Epoch Time: {epoch_time:.2f}s, Total Time: {running_time/60:.2f}m")
        print_gpu_memory()

    return generator, discriminator, metrics

def main():
    # Hyperparameters
    num_epochs = 100
    batch_size = 32  # Ottimizzato per T4
    image_size = 256
    n_clusters = 400  # Come nel paper
    
    # Setup CUDA e reproducibility
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create run directory
    run_name = f"e2gan_single_concept_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}"
    save_dir = f"results/{run_name}"
    os.makedirs(save_dir, exist_ok=True)
    
    # Dataset transforms
    transforms_ = transforms.Compose([
        transforms.Resize((image_size, image_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load dataset
    print("Loading dataset...")
    full_dataset = ImagePairsDataset(
        originals_dir='/kaggle/input/images/original_images',  # Modifica questo
        filtered_dir='/kaggle/input/images/modified_images',   # Modifica questo
        transform=transforms_
    )
    
    # Split dataset
    train_size = int(0.8 * len(full_dataset))
    val_size = int(0.1 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, 
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"Dataset sizes - Train: {train_size}, Val: {val_size}, Test: {test_size}")
    
    # Create clustered dataset
    print("Creating clustered dataset...")
    dataset_manager = DatasetManager(n_clusters=n_clusters)
    train_clustered = dataset_manager.create_clustered_dataset(train_dataset)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_clustered,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize models
    print("Initializing models...")
    generator = E2GANGenerator().to(device)
    discriminator = Discriminator().to(device)
    
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)
    
    # Print initial GPU memory usage
    print("Initial GPU memory usage:")
    print_gpu_memory()
    
    # Measure initial inference time
    print("\nMeasuring initial inference time...")
    initial_inf_time = measure_inference_speed(generator, device)
    
    # Training
    print("\nStarting training...")
    training_start_time = time.time()
    
    generator, discriminator, metrics = train_e2gan(
        generator,
        discriminator,
        train_loader,
        val_loader,
        num_epochs,
        device,
        save_dir
    )
    
    total_training_time = time.time() - training_start_time
    
    # Final evaluation
    print("\nFinal evaluation:")
    test_fid = evaluate_model(generator, test_loader, device)
    final_inf_time = measure_inference_speed(generator, device)
    
    # Save final results
    final_results = {
        'test_fid': float(test_fid),
        'initial_inference_time_ms': float(initial_inf_time),
        'final_inference_time_ms': float(final_inf_time),
        'total_training_time_minutes': float(total_training_time/60),
        'gpu_memory_gb': float(torch.cuda.max_memory_allocated()/1e9),
        'metrics': metrics
    }
    
    with open(f'{save_dir}/final_results.json', 'w') as f:
        json.dump(final_results, f, indent=4)
    
    print("\nTraining completed!")
    print(f"Total training time: {total_training_time/60:.2f} minutes")
    print(f"Final test FID: {test_fid:.4f}")
    print(f"Final inference time: {final_inf_time:.2f}ms")
    print(f"Results saved in: {save_dir}")

if __name__ == "__main__":
    main()