In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse

# Set random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Custom Dataset for paired and unpaired sketch-image data
class SketchImageDataset(Dataset):
    def __init__(self, sketch_dir, image_dir, transform=None):
        """
        Dataset for sketch to image translation
        
        Args:
            sketch_dir: Directory containing sketch images
            image_dir: Directory containing corresponding real images
            transform: Optional transforms to apply
        """
        self.sketch_dir = sketch_dir
        self.image_dir = image_dir
        self.transform = transform
        
        self.sketch_files = sorted(os.listdir(sketch_dir))
        self.image_files = sorted(os.listdir(image_dir))
        
    def __len__(self):
        return len(self.sketch_files)
    
    def __getitem__(self, idx):
        sketch_path = os.path.join(self.sketch_dir, self.sketch_files[idx])
        # For CycleGAN style training, get a random real image
        random_idx = random.randint(0, len(self.image_files) - 1)
        image_path = os.path.join(self.image_dir, self.image_files[random_idx])
        
        sketch = Image.open(sketch_path).convert('RGB')
        real_image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            sketch = self.transform(sketch)
            real_image = self.transform(real_image)
            
        return sketch, real_image

# DCGAN Generator with UNet-style skip connections
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, ngf=64):
        """
        Generator architecture combining DCGAN structure with UNet skip connections
        
        Args:
            in_channels: Number of input channels (3 for RGB)
            out_channels: Number of output channels (3 for RGB)
            ngf: Number of generator filters
        """
        super(Generator, self).__init__()
        
        # Encoder (downsampling)
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, ngf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )  # 128x128 -> 64x64
        
        self.down2 = nn.Sequential(
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )  # 64x64 -> 32x32
        
        self.down3 = nn.Sequential(
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )  # 32x32 -> 16x16
        
        self.down4 = nn.Sequential(
            nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )  # 16x16 -> 8x8
        
        # Bottleneck (residual blocks)
        self.res_blocks = nn.Sequential(
            ResidualBlock(ngf * 8),
            ResidualBlock(ngf * 8),
            ResidualBlock(ngf * 8),
            ResidualBlock(ngf * 8),
            ResidualBlock(ngf * 8)
        )
        
        # Decoder (upsampling)
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf * 4),
            nn.ReLU(inplace=True)
        )  # 8x8 -> 16x16
        
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf * 2),
            nn.ReLU(inplace=True)
        )  # 16x16 -> 32x32
        
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(inplace=True)
        )  # 32x32 -> 64x64
        
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )  # 64x64 -> 128x128
        
    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        
        # Bottleneck
        out = self.res_blocks(d4)
        
        # Decoder with skip connections (UNet style)
        u1 = self.up1(out)
        u1 = torch.cat([u1, d3], dim=1)  # Skip connection
        
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d2], dim=1)  # Skip connection
        
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d1], dim=1)  # Skip connection
        
        return self.up4(u3)

# Residual Block for Generator
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)  # Skip connection

# Discriminator based on DCGAN
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, ndf=64):
        """
        PatchGAN discriminator architecture from CycleGAN
        
        Args:
            in_channels: Number of input channels (3 for RGB)
            ndf: Number of discriminator filters
        """
        super(Discriminator, self).__init__()
        
        # Input: 128x128x3
        self.model = nn.Sequential(
            # Layer 1: 128x128 -> 64x64
            nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 2: 64x64 -> 32x32
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 3: 32x32 -> 16x16
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 4: 16x16 -> 8x8
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Output layer: 8x8 -> 7x7 (PatchGAN)
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

# Replay Buffer for CycleGAN's improved stability
class ReplayBuffer:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.buffer = []
    
    def push_and_pop(self, data):
        result = []
        for element in data:
            element = torch.unsqueeze(element, 0)
            if len(self.buffer) < self.max_size:
                self.buffer.append(element)
                result.append(element)
            else:
                if random.random() < 0.5:
                    i = random.randint(0, self.max_size - 1)
                    result.append(self.buffer[i].clone())
                    self.buffer[i] = element
                else:
                    result.append(element)
        return torch.cat(result)

# Training function for the hybrid GAN model
def train_hybrid_gan(sketch_dir, real_dir, epochs, batch_size, lr, beta1, beta2, lambda_cycle, 
                    lambda_identity, save_dir, sample_interval):
    """
    Train the hybrid GAN model
    
    Args:
        sketch_dir: Directory containing sketch images
        real_dir: Directory containing real images
        epochs: Number of training epochs
        batch_size: Batch size
        lr: Learning rate
        beta1, beta2: Adam optimizer parameters
        lambda_cycle: Weight for cycle consistency loss
        lambda_identity: Weight for identity loss
        save_dir: Directory to save models and samples
        sample_interval: Interval for saving samples
    """
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'samples'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True)
    
    # Transform for input images - resize to 128x128 and normalize
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load dataset
    dataset = SketchImageDataset(sketch_dir, real_dir, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Initialize models
    G_sketch_to_real = Generator().to(device)
    G_real_to_sketch = Generator().to(device)
    D_real = Discriminator().to(device)
    D_sketch = Discriminator().to(device)
    
    # Initialize replay buffers for CycleGAN stability
    fake_real_buffer = ReplayBuffer()
    fake_sketch_buffer = ReplayBuffer()
    
    # Loss functions
    criterion_GAN = nn.MSELoss()  # For adversarial loss
    criterion_cycle = nn.L1Loss()  # For cycle consistency loss
    criterion_identity = nn.L1Loss()  # For identity loss
    
    # Optimizers
    optimizer_G = optim.Adam(
        list(G_sketch_to_real.parameters()) + list(G_real_to_sketch.parameters()),
        lr=lr, betas=(beta1, beta2)
    )
    optimizer_D_real = optim.Adam(D_real.parameters(), lr=lr, betas=(beta1, beta2))
    optimizer_D_sketch = optim.Adam(D_sketch.parameters(), lr=lr, betas=(beta1, beta2))
    
    # Learning rate scheduler
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G, lr_lambda=lambda epoch: 1.0 - max(0, epoch - epochs * 0.5) / (epochs * 0.5)
    )
    lr_scheduler_D_real = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_real, lr_lambda=lambda epoch: 1.0 - max(0, epoch - epochs * 0.5) / (epochs * 0.5)
    )
    lr_scheduler_D_sketch = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_sketch, lr_lambda=lambda epoch: 1.0 - max(0, epoch - epochs * 0.5) / (epochs * 0.5)
    )
    
    # Training loop
    for epoch in range(epochs):
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for i, (sketches, real_images) in pbar:
            # Move data to device
            real_images = real_images.to(device)
            sketches = sketches.to(device)
            
            # Ground truths
            valid = torch.ones((real_images.size(0), 1, 7, 7), device=device)
            fake = torch.zeros((real_images.size(0), 1, 7, 7), device=device)
            
            # ------------------
            # Train Generators
            # ------------------
            optimizer_G.zero_grad()
            
            # Identity loss
            if lambda_identity > 0:
                # G_sketch_to_real should generate the same real image when fed a real image
                identity_real = G_sketch_to_real(real_images)
                loss_identity_real = criterion_identity(identity_real, real_images) * lambda_identity
                
                # G_real_to_sketch should generate the same sketch when fed a sketch
                identity_sketch = G_real_to_sketch(sketches)
                loss_identity_sketch = criterion_identity(identity_sketch, sketches) * lambda_identity
            else:
                loss_identity_real = 0
                loss_identity_sketch = 0
            
            # GAN loss for G_sketch_to_real
            fake_real = G_sketch_to_real(sketches)
            loss_GAN_sketch_to_real = criterion_GAN(D_real(fake_real), valid)
            
            # GAN loss for G_real_to_sketch
            fake_sketch = G_real_to_sketch(real_images)
            loss_GAN_real_to_sketch = criterion_GAN(D_sketch(fake_sketch), valid)
            
            # Cycle consistency loss
            recovered_sketch = G_real_to_sketch(fake_real)
            loss_cycle_sketch = criterion_cycle(recovered_sketch, sketches) * lambda_cycle
            
            recovered_real = G_sketch_to_real(fake_sketch)
            loss_cycle_real = criterion_cycle(recovered_real, real_images) * lambda_cycle
            
            # Total generator loss
            loss_G = (loss_identity_real + loss_identity_sketch + 
                     loss_GAN_sketch_to_real + loss_GAN_real_to_sketch + 
                     loss_cycle_sketch + loss_cycle_real)
            
            loss_G.backward()
            optimizer_G.step()
            
            # -----------------------
            # Train Discriminator Real
            # -----------------------
            optimizer_D_real.zero_grad()
            
            # Real loss
            loss_real_real = criterion_GAN(D_real(real_images), valid)
            
            # Fake loss (with buffer)
            fake_real_ = fake_real_buffer.push_and_pop(fake_real.detach())
            loss_fake_real = criterion_GAN(D_real(fake_real_), fake)
            
            # Total discriminator real loss
            loss_D_real = (loss_real_real + loss_fake_real) * 0.5
            loss_D_real.backward()
            optimizer_D_real.step()
            
            # -----------------------
            # Train Discriminator Sketch
            # -----------------------
            optimizer_D_sketch.zero_grad()
            
            # Real loss
            loss_real_sketch = criterion_GAN(D_sketch(sketches), valid)
            
            # Fake loss (with buffer)
            fake_sketch_ = fake_sketch_buffer.push_and_pop(fake_sketch.detach())
            loss_fake_sketch = criterion_GAN(D_sketch(fake_sketch_), fake)
            
            # Total discriminator sketch loss
            loss_D_sketch = (loss_real_sketch + loss_fake_sketch) * 0.5
            loss_D_sketch.backward()
            optimizer_D_sketch.step()
            
            # Update progress bar
            pbar.set_description(
                f"[Epoch {epoch+1}/{epochs}] "
                f"D_real: {loss_D_real.item():.4f}, D_sketch: {loss_D_sketch.item():.4f}, "
                f"G: {loss_G.item():.4f}, G_adv: {(loss_GAN_sketch_to_real + loss_GAN_real_to_sketch).item():.4f}, "
                f"G_cycle: {(loss_cycle_sketch + loss_cycle_real).item():.4f}"
            )
            
            # Save sample images
            if i % sample_interval == 0:
                batch_size = real_images.size(0)
                # Take up to 8 images for visualization
                n_samples = min(8, batch_size)
                
                # Generate fake images
                fake_real_samples = fake_real.detach()[:n_samples]
                fake_sketch_samples = fake_sketch.detach()[:n_samples]
                recovered_real_samples = recovered_real.detach()[:n_samples]
                recovered_sketch_samples = recovered_sketch.detach()[:n_samples]
                
                # Original inputs
                real_samples = real_images[:n_samples]
                sketch_samples = sketches[:n_samples]
                
                # Combine all images
                all_samples = torch.cat([
                    sketch_samples, fake_real_samples, recovered_sketch_samples,
                    real_samples, fake_sketch_samples, recovered_real_samples
                ], dim=0)
                
                # Save grid image
                grid = torchvision.utils.make_grid(all_samples, nrow=n_samples, normalize=True)
                save_path = os.path.join(save_dir, 'samples', f'epoch_{epoch+1}_batch_{i}.png')
                save_image(grid, save_path)
                
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_real.step()
        lr_scheduler_D_sketch.step()
        
        # Save models
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            torch.save({
                'G_sketch_to_real': G_sketch_to_real.state_dict(),
                'G_real_to_sketch': G_real_to_sketch.state_dict(),
                'D_real': D_real.state_dict(),
                'D_sketch': D_sketch.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D_real': optimizer_D_real.state_dict(),
                'optimizer_D_sketch': optimizer_D_sketch.state_dict(),
                'epoch': epoch
            }, os.path.join(save_dir, 'checkpoints', f'model_epoch_{epoch+1}.pth'))

# Function to generate images from sketches using a trained model
def generate_from_sketch(model_path, sketch_path, output_dir, num_samples=5):
    """
    Generate realistic images from sketches using a trained model
    
    Args:
        model_path: Path to trained model checkpoint
        sketch_path: Path to sketch image or directory of sketches
        output_dir: Directory to save generated images
        num_samples: Number of samples to generate (with small variations)
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Load model
    checkpoint = torch.load(model_path, map_location=device)
    G_sketch_to_real = Generator().to(device)
    G_sketch_to_real.load_state_dict(checkpoint['G_sketch_to_real'])
    G_sketch_to_real.eval()
    
    # Transform for input images
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Check if sketch_path is a directory or a single file
    if os.path.isdir(sketch_path):
        sketch_files = [os.path.join(sketch_path, f) for f in os.listdir(sketch_path) 
                       if f.endswith(('.png', '.jpg', '.jpeg'))]
    else:
        sketch_files = [sketch_path]
    
    for sketch_file in sketch_files:
        # Load and preprocess sketch
        sketch = Image.open(sketch_file).convert('RGB')
        sketch_tensor = transform(sketch).unsqueeze(0).to(device)
        
        # Generate multiple samples with small variations
        for i in range(num_samples):
            with torch.no_grad():
                # Add small noise for variation if generating multiple samples
                if i > 0:
                    noise_level = 0.02 * i
                    noise = torch.randn_like(sketch_tensor) * noise_level
                    input_tensor = sketch_tensor + noise
                else:
                    input_tensor = sketch_tensor
                
                # Generate image
                fake_image = G_sketch_to_real(input_tensor)
                
                # Convert to PIL image and save
                fake_image = (fake_image.squeeze().cpu().detach() * 0.5 + 0.5).clamp(0, 1)
                grid = torchvision.utils.make_grid(fake_image)
                ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
                im = Image.fromarray(ndarr)
                
                # Create output filename
                base_name = os.path.splitext(os.path.basename(sketch_file))[0]
                output_path = os.path.join(output_dir, f"{base_name}_generated_{i+1}.png")
                im.save(output_path)
                print(f"Generated image saved to {output_path}")

# Main function to parse arguments and run training
if __name__ == "__main__":
    sketch_dir = "/path/to/your/sketches"  # Replace with your actual path
    real_dir = "/kaggle/input/dataset/photos"  # Replace with your actual path
    save_dir = "./results"

    train_hybrid_gan(
        sketch_dir=sketch_dir,
        real_dir=real_dir,
        epochs=100,
        batch_size=8,
        lr=0.0002,
        beta1=0.5,
        beta2=0.999,
        lambda_cycle=10.0,
        lambda_identity=5.0,
        save_dir=save_dir,
        sample_interval=100
    )


Using device: cuda


[Epoch 1/100] D_real: 0.1696, D_sketch: 0.0902, G: 8.6813, G_adv: 0.6761, G_cycle: 5.4464: 100%|██████████| 24/24 [00:11<00:00,  2.15it/s] 
[Epoch 2/100] D_real: 0.2162, D_sketch: 0.0708, G: 5.3106, G_adv: 0.7249, G_cycle: 3.2034: 100%|██████████| 24/24 [00:11<00:00,  2.14it/s]
[Epoch 3/100] D_real: 0.1646, D_sketch: 0.1038, G: 5.0140, G_adv: 1.5511, G_cycle: 2.3711: 100%|██████████| 24/24 [00:11<00:00,  2.13it/s]
[Epoch 4/100] D_real: 0.1811, D_sketch: 0.1862, G: 3.8090, G_adv: 0.5850, G_cycle: 2.1946: 100%|██████████| 24/24 [00:11<00:00,  2.11it/s]
[Epoch 5/100] D_real: 0.2291, D_sketch: 0.0729, G: 4.1798, G_adv: 1.0746, G_cycle: 2.1504: 100%|██████████| 24/24 [00:11<00:00,  2.09it/s]
[Epoch 6/100] D_real: 0.3239, D_sketch: 0.0988, G: 5.2127, G_adv: 1.7600, G_cycle: 2.5380: 100%|██████████| 24/24 [00:11<00:00,  2.07it/s]
[Epoch 7/100] D_real: 0.1986, D_sketch: 0.1442, G: 4.2291, G_adv: 1.1091, G_cycle: 2.2481: 100%|██████████| 24/24 [00:11<00:00,  2.05it/s]
[Epoch 8/100] D_real: 0.20

In [9]:
!find . -type f | sort



./results/checkpoints/model_epoch_100.pth
./results/checkpoints/model_epoch_10.pth
./results/checkpoints/model_epoch_20.pth
./results/checkpoints/model_epoch_30.pth
./results/checkpoints/model_epoch_40.pth
./results/checkpoints/model_epoch_50.pth
./results/checkpoints/model_epoch_60.pth
./results/checkpoints/model_epoch_70.pth
./results/checkpoints/model_epoch_80.pth
./results/checkpoints/model_epoch_90.pth
./results/samples/epoch_100_batch_0.png
./results/samples/epoch_10_batch_0.png
./results/samples/epoch_11_batch_0.png
./results/samples/epoch_12_batch_0.png
./results/samples/epoch_13_batch_0.png
./results/samples/epoch_14_batch_0.png
./results/samples/epoch_15_batch_0.png
./results/samples/epoch_16_batch_0.png
./results/samples/epoch_17_batch_0.png
./results/samples/epoch_18_batch_0.png
./results/samples/epoch_19_batch_0.png
./results/samples/epoch_1_batch_0.png
./results/samples/epoch_20_batch_0.png
./results/samples/epoch_21_batch_0.png
./results/samples/epoch_22_batch_0.png
./re

In [11]:
!zip model_files.zip ./results/checkpoints/model_epoch_100.pth


  adding: results/checkpoints/model_epoch_100.pth (deflated 10%)


In [14]:
!ls -lh /kaggle/working/model_files.zip


-rw-r--r-- 1 root root 673M Mar 10 14:06 /kaggle/working/model_files.zip
