In [None]:
import numpy as np 
import pandas as pd 
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from PIL import Image
import matplotlib.pyplot as plt



# 1. Config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ROOT      = '/kaggle/input/comic-faces-paired-synthetic-v2/face2comics_v2.0.0_by_Sxela/face2comics_v2.0.0_by_Sxela'
FACES_DIR = os.path.join(ROOT, 'faces')
COMICS_DIR= os.path.join(ROOT, 'comics')
BATCH_SIZE= 128
NUM_WORKERS= 0
IMAGE_SIZE = 256

# 2. Dataset
class Paired10Dataset(Dataset):
    def __init__(self, real_dir, comic_dir, transform):
        self.real_paths  = sorted(glob.glob(os.path.join(real_dir, '*.jpg')))[:8000]
        self.comic_paths = sorted(glob.glob(os.path.join(comic_dir, '*.jpg')))[:8000]
        self.transform = transform

    def __len__(self):
        return len(self.real_paths)

    def __getitem__(self, idx):
        real  = Image.open(self.real_paths[idx]).convert('RGB')
        comic = Image.open(self.comic_paths[idx]).convert('RGB')
        return self.transform(real), self.transform(comic)

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

# <<< FIXED HERE >>>
dataset = Paired10Dataset(
    real_dir  = FACES_DIR,
    comic_dir = COMICS_DIR,
    transform = transform
)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)
print(f"Loaded {len(dataset)} samples, {len(loader)} batches.")

In [None]:

def conv_block(in_c, out_c, norm=True):
    layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)]
    if norm: layers.append(nn.BatchNorm2d(out_c))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

def deconv_block(in_c, out_c, dropout=False):
    layers = [
        nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    ]
    if dropout: layers.append(nn.Dropout(0.5))
    return nn.Sequential(*layers)

class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        # Downsampling (8 blocks)
        self.d1 = conv_block(3, 64, norm=False)
        self.d2 = conv_block(64, 128)
        self.d3 = conv_block(128, 256)
        self.d4 = conv_block(256, 512)
        self.d5 = conv_block(512, 512)
        self.d6 = conv_block(512, 512)
        self.d7 = conv_block(512, 512)
        self.d8 = conv_block(512, 512)
        # Upsampling (8 blocks)
        self.u1 = deconv_block(512, 512, dropout=True)
        self.u2 = deconv_block(1024, 512, dropout=True)
        self.u3 = deconv_block(1024, 512, dropout=True)
        self.u4 = deconv_block(1024, 512)
        self.u5 = deconv_block(1024, 256)
        self.u6 = deconv_block(512, 128)
        self.u7 = deconv_block(256, 64)
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        # Encoder
        d1 = self.d1(x); d2 = self.d2(d1)
        d3 = self.d3(d2); d4 = self.d4(d3)
        d5 = self.d5(d4); d6 = self.d6(d5)
        d7 = self.d7(d6); d8 = self.d8(d7)
        # Decoder with skip connections
        u1 = self.u1(d8);         u2 = self.u2(torch.cat([u1, d7],1))
        u3 = self.u3(torch.cat([u2, d6],1))
        u4 = self.u4(torch.cat([u3, d5],1))
        u5 = self.u5(torch.cat([u4, d4],1))
        u6 = self.u6(torch.cat([u5, d3],1))
        u7 = self.u7(torch.cat([u6, d2],1))
        return self.final(torch.cat([u7, d1],1))


class PatchDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            conv_block(6, 64, norm=False),
            conv_block(64,128),
            conv_block(128,256),
            conv_block(256,512),
            nn.Conv2d(512,1,4,1,1)  # output 30×30 patch
        )
    def forward(self, x, y):
        return self.model(torch.cat([x,y], dim=1))








device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = UNetGenerator().to(device)  
D = PatchDiscriminator().to(device)
opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
adv_loss = nn.BCEWithLogitsLoss()  
l1_loss  = nn.L1Loss()
LAMBDA_L1 = 100
EPOCHS = 120
SAVE_EVERY = 40

history = {'G': [], 'D': []}
for epoch in range(1, EPOCHS+1):
    g_running, d_running = 0.0, 0.0

    for real_A, real_B in loader:
        real_A, real_B = real_A.to(device), real_B.to(device)

        # Dynamically get PatchGAN output for real pair
        pred_real = D(real_A, real_B)  

        # Create matching real/fake targets
        valid = torch.ones_like(pred_real)   # real=1 :contentReference[oaicite:6]{index=6}
        fake  = torch.zeros_like(pred_real)  # fake=0 :contentReference[oaicite:7]{index=7}

        # Generator step
        opt_G.zero_grad()
        fake_B    = G(real_A)
        pred_fake = D(real_A, fake_B)
        loss_G    = adv_loss(pred_fake, valid) + LAMBDA_L1 * l1_loss(fake_B, real_B)
        loss_G.backward()
        opt_G.step()

        # Discriminator step
        opt_D.zero_grad()
        loss_D_real = adv_loss(pred_real, valid)
        loss_D_fake = adv_loss(pred_fake.detach(), fake)
        loss_D      = 0.5 * (loss_D_real + loss_D_fake)
        loss_D.backward()
        opt_D.step()

        g_running += loss_G.item()
        d_running += loss_D.item()

    # Log epoch losses
    history['G'].append(g_running / len(loader))
    history['D'].append(d_running / len(loader))
    print(f"Epoch {epoch} | G: {history['G'][-1]:.4f} | D: {history['D'][-1]:.4f}")

    # Checkpoint
    if epoch % SAVE_EVERY == 0:
        torch.save(G.state_dict(), f"generator_{epoch}.pth")
        torch.save(D.state_dict(), f"discriminator_{epoch}.pth")



plt.figure(figsize=(8,4))
plt.plot(history['G'], label='Generator')  
plt.plot(history['D'], label='Discriminator')
plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.title('Training Loss'); plt.legend()
plt.show()




In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

# Set the generator to evaluation mode
G.eval()

# Retrieve a batch of real images and their corresponding comic images
real_A, real_B = next(iter(loader))
real_A, real_B = real_A.to(device), real_B.to(device)

# Generate fake comic images without computing gradients
with torch.no_grad():
    fake_B = G(real_A)

# Create a grid of images: real input, generated output, and ground truth
grid = make_grid(torch.cat([real_A, fake_B, real_B], 0), nrow=2, normalize=True)

# Move the grid to CPU and convert to NumPy for plotting
grid_np = grid.cpu().permute(1, 2, 0).numpy()

# Plot the images
plt.figure(figsize=(6, 6))
plt.imshow(grid_np)
plt.axis('off')
plt.title('Real → Generated → Comic GT')
plt.show()
