# Step 1 : Diffusion Models that creates the Dataset

Main

In [None]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torch import nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import itertools
import torch.cuda.amp as amp
import gc

# Ottimizzazioni per CUDA
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

class Config:
    # Dataset parameters
    MAX_IMAGES = 100          # Usiamo solo 10 immagini per il test
    IMAGE_SIZE = 128         # Riduciamo la dimensione delle immagini
    
    # Training parameters
    BATCH_SIZE = 2          # Batch size molto piccolo
    NUM_EPOCHS = 3          # Poche epoche per il test
    BASE_CHANNELS = 32      # Riduciamo i canali base (era 64)
    
    # Model parameters
    USE_AMP = True         # Usiamo mixed precision
    PIN_MEMORY = True      
    NUM_WORKERS = 0        # Disabilitiamo il multiprocessing per il test
    
    # Paths
    ORIGINALS_DIR = '/kaggle/input/images/original_images'
    FILTERED_DIR = '/kaggle/input/images/modified_images'

class ImagePairsDataset(Dataset):
    def __init__(self, originals_dir, filtered_dir, transform=None):
        self.originals_dir = originals_dir
        self.filtered_dir = filtered_dir
        self.transform = transform
        # Prendiamo solo le prime MAX_IMAGES immagini
        self.image_names = os.listdir(originals_dir)[:Config.MAX_IMAGES]
        print(f"Loading {len(self.image_names)} images")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        original_path = os.path.join(self.originals_dir, image_name)
        filtered_path = os.path.join(self.filtered_dir, image_name)

        original_image = Image.open(original_path).convert('RGB')
        filtered_image = Image.open(filtered_path).convert('RGB')

        if self.transform:
            original_image = self.transform(original_image)
            filtered_image = self.transform(filtered_image)
        
        return original_image, filtered_image

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.norm1 = nn.InstanceNorm2d(channels)
        self.norm2 = nn.InstanceNorm2d(channels)
        
    def forward(self, x):
        residual = x
        out = F.relu(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out = out + residual
        return F.relu(out)

class TransformerBlock(nn.Module):
    def __init__(self, channels, num_heads=8):
        super(TransformerBlock, self).__init__()
        self.channels = channels
        self.num_heads = num_heads
        
        # Self-attention
        self.self_attention = nn.MultiheadAttention(channels, num_heads)
        
        # Cross-attention
        self.cross_attention = nn.MultiheadAttention(channels, num_heads)
        
        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels * 4),
            nn.ReLU(),
            nn.Linear(channels * 4, channels)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(channels)
        self.norm2 = nn.LayerNorm(channels)
        self.norm3 = nn.LayerNorm(channels)
        
    def forward(self, x, text_features=None):
        # Reshape for attention: [B, C, H, W] -> [H*W, B, C]
        B, C, H, W = x.shape
        x_flat = x.flatten(2).permute(2, 0, 1)
        
        # Self-attention
        x_norm = self.norm1(x_flat)
        sa_out, _ = self.self_attention(x_norm, x_norm, x_norm)
        x_flat = x_flat + sa_out
        
        # Cross-attention with text features if provided
        if text_features is not None:
            x_norm = self.norm2(x_flat)
            ca_out, _ = self.cross_attention(x_norm, text_features, text_features)
            x_flat = x_flat + ca_out
        
        # MLP
        x_norm = self.norm3(x_flat)
        mlp_out = self.mlp(x_norm)
        x_flat = x_flat + mlp_out
        
        # Reshape back: [H*W, B, C] -> [B, C, H, W]
        x = x_flat.permute(1, 2, 0).view(B, C, H, W)
        return x

class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, base_channels=64):
        super(Generator, self).__init__()
        
        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, base_channels, kernel_size=7, padding=3),
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(True)
        )
        
        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels*2),
            nn.ReLU(True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels*4),
            nn.ReLU(True)
        )
        
        # Residual blocks
        self.resblocks = nn.ModuleList([
            ResidualBlock(base_channels*4) for _ in range(3)
        ])
        
        # Transformer block after second ResBlock
        self.transformer = TransformerBlock(base_channels*4)
        
        # Upsampling
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(base_channels*2),
            nn.ReLU(True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(base_channels),
            nn.ReLU(True)
        )
        
        # Output convolution
        self.output = nn.Sequential(
            nn.Conv2d(base_channels, output_channels, kernel_size=7, padding=3),
            nn.Tanh()
        )
        
    def forward(self, x, text_features=None):
        x = self.initial(x)
        x = self.down1(x)
        x = self.down2(x)
        
        # First ResBlock
        x = self.resblocks[0](x)
        
        # Second ResBlock + Transformer
        x = self.resblocks[1](x)
        x = self.transformer(x, text_features)
        
        # Third ResBlock
        x = self.resblocks[2](x)
        
        x = self.up1(x)
        x = self.up2(x)
        x = self.output(x)
        
        return x

class Discriminator(nn.Module):
    def __init__(self, input_channels=3, base_channels=64):
        super(Discriminator, self).__init__()
        
        # Input layer
        self.input = nn.Sequential(
            nn.Conv2d(input_channels*2, base_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        )
        
        # Downsampling layers
        self.down1 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels*2),
            nn.LeakyReLU(0.2, True)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(base_channels*4),
            nn.LeakyReLU(0.2, True)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(base_channels*4, base_channels*8, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(base_channels*8),
            nn.LeakyReLU(0.2, True)
        )
        
        # Output layer
        self.output = nn.Conv2d(base_channels*8, 1, kernel_size=4, stride=1, padding=1)
        
    def forward(self, x, y):
        # Concatenate input image and target image
        x = torch.cat([x, y], dim=1)
        x = self.input(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.output(x)
        return x

class GANTrainer:
    def __init__(self, 
                 generator,
                 discriminator,
                 train_loader,
                 val_loader,
                 device,
                 lr=2e-4,
                 n_epochs=Config.NUM_EPOCHS,
                 lambda_l1=100.0):
        
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader 
        self.device = device
        self.n_epochs = n_epochs
        self.lambda_l1 = lambda_l1
        
        # Loss functions
        self.criterion_gan = nn.MSELoss()
        self.criterion_l1 = nn.L1Loss()
        
        # Optimizers
        self.optimizer_g = optim.Adam(self.generator.parameters(), 
                                    lr=lr, betas=(0.5, 0.999))
        self.optimizer_d = optim.Adam(self.discriminator.parameters(), 
                                    lr=lr, betas=(0.5, 0.999))
        
        # Gradient scaler for mixed precision
        self.scaler_g = amp.GradScaler()
        self.scaler_d = amp.GradScaler()
        
        # Learning rate schedulers 
        def lambda_rule(epoch):
            decay_start_epoch = self.n_epochs // 2
            return 1.0 - max(0, epoch - decay_start_epoch) / float(self.n_epochs - decay_start_epoch + 1)
        
        self.scheduler_g = LambdaLR(self.optimizer_g, lr_lambda=lambda_rule)
        self.scheduler_d = LambdaLR(self.optimizer_d, lr_lambda=lambda_rule)

    def train_epoch(self):
        # Pulizia memoria
        torch.cuda.empty_cache()
        gc.collect()
        
        self.generator.train()
        self.discriminator.train()
        
        total_g_loss = 0
        total_d_loss = 0
        n_batches = len(self.train_loader)
        
        for real_img, mod_img in self.train_loader:
            batch_size = real_img.size(0)
            real_img = real_img.to(self.device, non_blocking=True)
            mod_img = mod_img.to(self.device, non_blocking=True)
            
            # Train Discriminator
            self.optimizer_d.zero_grad(set_to_none=True)
            
            with amp.autocast():
                fake_img = self.generator(real_img)
                pred_fake = self.discriminator(fake_img.detach(), real_img)
                loss_d_fake = self.criterion_gan(pred_fake, torch.zeros_like(pred_fake))
                
                pred_real = self.discriminator(mod_img, real_img)
                loss_d_real = self.criterion_gan(pred_real, torch.ones_like(pred_real))
                
                loss_d = (loss_d_fake + loss_d_real) * 0.5
            
            self.scaler_d.scale(loss_d).backward()
            self.scaler_d.step(self.optimizer_d)
            self.scaler_d.update()
            
            # Train Generator
            self.optimizer_g.zero_grad(set_to_none=True)
            
            with amp.autocast():
                pred_fake = self.discriminator(fake_img, real_img)
                loss_g_gan = self.criterion_gan(pred_fake, torch.ones_like(pred_fake))
                
                loss_g_l1 = self.criterion_l1(fake_img, mod_img) * self.lambda_l1
                
                loss_g = loss_g_gan + loss_g_l1
            
            self.scaler_g.scale(loss_g).backward()
            self.scaler_g.step(self.optimizer_g)
            self.scaler_g.update()
            
            total_g_loss += loss_g.item()
            total_d_loss += loss_d.item()
            
        avg_g_loss = total_g_loss / n_batches
        avg_d_loss = total_d_loss / n_batches
        
        return avg_g_loss, avg_d_loss
    
    def validate(self):
        self.generator.eval()
        self.discriminator.eval()
        
        total_val_loss = 0
        n_batches = len(self.val_loader)
        
        with torch.no_grad(), amp.autocast():
            for real_img, mod_img in self.val_loader:
                real_img = real_img.to(self.device, non_blocking=True)
                mod_img = mod_img.to(self.device, non_blocking=True)
                
                fake_img = self.generator(real_img)
                val_loss = self.criterion_l1(fake_img, mod_img)
                total_val_loss += val_loss.item()
        
        avg_val_loss = total_val_loss / n_batches
        return avg_val_loss
    
    def train(self):
        best_val_loss = float('inf')
        training_history = {
            'g_losses': [],
            'd_losses': [],
            'val_losses': []
        }
        
        for epoch in range(self.n_epochs):
            g_loss, d_loss = self.train_epoch()
            val_loss = self.validate()
            
            # Update learning rates
            self.scheduler_g.step()
            self.scheduler_d.step()
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'generator_state_dict': self.generator.state_dict(),
                    'discriminator_state_dict': self.discriminator.state_dict(),
                    'g_optimizer_state_dict': self.optimizer_g.state_dict(),
                    'd_optimizer_state_dict': self.optimizer_d.state_dict(),
                    'epoch': epoch,
                }, 'best_model.pth')
            
            # Store losses
            training_history['g_losses'].append(g_loss)
            training_history['d_losses'].append(d_loss)
            training_history['val_losses'].append(val_loss)
            
            print(f"Epoch [{epoch+1}/{self.n_epochs}] - "
                  f"G_loss: {g_loss:.4f}, D_loss: {d_loss:.4f}, Val_loss: {val_loss:.4f}")
            
            # Monitor GPU memory
            if torch.cuda.is_available():
                print(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**2:.1f}MB / "
                      f"{torch.cuda.memory_reserved() / 1024**2:.1f}MB")
        
        return training_history

# Inizializzazione e training
def train_gan(train_loader, val_loader, device):
    print("Initializing models...")
    generator = Generator()
    discriminator = Discriminator()
    
    # Stampa il numero di parametri
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Generator parameters: {count_parameters(generator):,}")
    print(f"Discriminator parameters: {count_parameters(discriminator):,}")
    
    trainer = GANTrainer(
        generator=generator,
        discriminator=discriminator,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        lr=2e-4,
        n_epochs=100
    )
    
    print("Starting training...")
    history = trainer.train()
    return trainer, history

# Main execution
if __name__ == "__main__":
    # Set random seed
    torch.manual_seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Paths
    originals_dir = '/kaggle/input/images/original_images'
    filtered_dir = '/kaggle/input/images/modified_images'
    
    # Create dataset
    print("Loading dataset...")
    dataset = ImagePairsDataset(Config.ORIGINALS_DIR, Config.FILTERED_DIR, transform=transform)
    
    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_set, val_set, test_set = random_split(
        dataset, [train_size, val_size, test_size]
    )
    # DataLoader con batch size ridotto
    train_loader = DataLoader(
        train_set,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=Config.PIN_MEMORY
    )
    
    val_loader = DataLoader(
        val_set,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=Config.PIN_MEMORY
    )
   
    test_loader = DataLoader(
        test_set,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=Config.PIN_MEMORY
    )
    
    print(f"Dataset sizes: Train={len(train_set)}, Val={len(val_set)}, Test={len(test_set)}")
     
    # Pulisci la memoria GPU prima di iniziare
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    # Start training
    trainer, history = train_gan(train_loader, val_loader, device)

    print("Training completed!")

    plt.figure(figsize=(10, 5))
    plt.plot(history['g_losses'], label='Generator Loss')
    plt.plot(history['d_losses'], label='Discriminator Loss')
    plt.plot(history['val_losses'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training History')
    plt.savefig('training_history.png')
    plt.close()