In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
from torch.cuda.amp import autocast, GradScaler

In [3]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

# === Image Transform ===
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# === Helper function to check if a file is an image ===
def is_image_file(filename):
    return any(filename.lower().endswith(ext) for ext in ['.png', '.jpg', '.jpeg', '.bmp'])

# === Dataset Class for Real and Cartoon Images ===
class CartoonDataset(Dataset):
    def __init__(self, real_dir, cartoon_dir, transform=None, shuffle=True):
        self.real_images = [os.path.join(real_dir, img) for img in os.listdir(real_dir) if is_image_file(img)]
        self.cartoon_images = [os.path.join(cartoon_dir, img) for img in os.listdir(cartoon_dir) if is_image_file(img)]
        self.transform = transform

        # Shuffle the lists to ensure random pairing
        if shuffle:
            random.shuffle(self.real_images)
            random.shuffle(self.cartoon_images)
        
        # Make sure both lists are equal in size (minimum length)
        self.length = min(len(self.real_images), len(self.cartoon_images))

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        real_image = Image.open(self.real_images[idx]).convert('RGB')
        cartoon_image = Image.open(self.cartoon_images[idx]).convert('RGB')
        
        if self.transform:
            real_image = self.transform(real_image)
            cartoon_image = self.transform(cartoon_image)
        
        return real_image, cartoon_image


In [4]:
import torch
from torch.utils.data import DataLoader, random_split

# Define dataset
real_dir = "archive (16)/PytorchCtoonGAN/dataset/train_photo"
cartoon_dir = "archive (16)/PytorchCtoonGAN/dataset/Hayao/style"
dataset = CartoonDataset(real_dir, cartoon_dir, transform=transform)

# Total size of dataset
total_size = len(dataset)

# Define split proportions
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size  # Remaining goes to test

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision import transforms

# === Hinge Loss for Discriminator ===
def discriminator_hinge_loss(real_pred, fake_pred):
    real_loss = torch.mean(F.relu(1.0 - real_pred))
    fake_loss = torch.mean(F.relu(1.0 + fake_pred))
    return real_loss + fake_loss

# === Hinge Loss for Generator ===
def generator_hinge_loss(fake_pred):
    return -torch.mean(fake_pred)

# === Perceptual Content Loss using VGG19 ===
class VGGContentLoss(nn.Module):
    def __init__(self, device):
        super(VGGContentLoss, self).__init__()
        vgg = vgg19(pretrained=True).features[:21].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    def forward(self, input, target):
        input = self.norm(input)
        target = self.norm(target)
        return F.l1_loss(self.vgg(input), self.vgg(target))

# === Optional Style Loss ===
def gram_matrix(features):
    (b, c, h, w) = features.size()
    features = features.view(b, c, h * w)
    G = torch.bmm(features, features.transpose(1, 2))
    return G / (c * h * w)

class StyleLoss(nn.Module):
    def __init__(self, device):
        super(StyleLoss, self).__init__()
        vgg = vgg19(pretrained=True).features[:21].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg.to(device)
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    def forward(self, input, target):
        input = self.norm(input)
        target = self.norm(target)
        input_features = self.vgg(input)
        target_features = self.vgg(target)
        return F.l1_loss(gram_matrix(input_features), gram_matrix(target_features))


In [6]:
import torch.nn as nn
import torch.nn.utils.spectral_norm as spectral_norm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1))
        )

    def forward(self, x):
        return self.model(x)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Initial convolution block
        self.initial = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Downsampling
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )

        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(6)])

        # Upsampling
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        )

        # Final output block
        self.final = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, 3, kernel_size=7),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.res_blocks(x)
        x = self.up1(x)
        x = self.up2(x)
        return self.final(x)


In [12]:
import torch

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 50
n_critic = 2
best_val_g_loss = float('inf')  # For saving best model based on generator validation loss
patience = 5                    # Number of epochs to wait for improvement
epochs_without_improvement = 0 # Early stopping counter

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss functions
content_loss_fn = VGGContentLoss(device)
style_loss_fn = StyleLoss(device)
use_style_loss = True  # toggle this to False if not needed

# Optimizers
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=5e-5, betas=(0.5, 0.999))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))

# Learning rate schedulers
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.5)
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.5)

for epoch in range(num_epochs):
    generator.train()
    discriminator.train()

    correct = 0
    total = 0

    for i, (real, cartoon) in enumerate(train_loader):
        real, cartoon = real.to(device), cartoon.to(device)

        # === Train Discriminator ===
        for _ in range(n_critic):
            optimizer_D.zero_grad()
            real_pred = discriminator(cartoon)
            fake_cartoon = generator(real).detach()
            fake_pred = discriminator(fake_cartoon)
            d_loss = discriminator_hinge_loss(real_pred, fake_pred)
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
            optimizer_D.step()

        # === Approximate D Accuracy ===
        real_correct = (real_pred > 0.6).sum().item()
        fake_correct = (fake_pred < -0.6).sum().item()
        correct += real_correct + fake_correct
        total += real_pred.numel() + fake_pred.numel()

        # === Train Generator ===
        optimizer_G.zero_grad()
        fake_cartoon = generator(real)
        fake_pred = discriminator(fake_cartoon)

        g_loss = generator_hinge_loss(fake_pred)
        g_loss += content_loss_fn(fake_cartoon, cartoon)
        if use_style_loss:
            g_loss += 50.0 * style_loss_fn(fake_cartoon, cartoon)

        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
        optimizer_G.step()

    train_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}, D Accuracy: {train_accuracy:.2f}%")

    # === Validation ===
    generator.eval()
    val_g_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for real, cartoon in val_loader:
            real, cartoon = real.to(device), cartoon.to(device)
            fake_cartoon = generator(real)
            fake_pred = discriminator(fake_cartoon)
            val_real_pred = discriminator(cartoon)

            loss = generator_hinge_loss(fake_pred) + content_loss_fn(fake_cartoon, cartoon)
            if use_style_loss:
                loss += 50.0 * style_loss_fn(fake_cartoon, cartoon)

            val_g_loss += loss

            val_real_correct = (val_real_pred > 0.6).sum().item()
            val_fake_correct = (fake_pred < -0.6).sum().item()
            val_correct += val_real_correct + val_fake_correct
            val_total += val_real_pred.numel() + fake_pred.numel()

    val_g_loss /= len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    print(f"Validation G Loss: {val_g_loss.item():.4f}, D Accuracy: {val_accuracy:.2f}%")

    # ✅ Save best model based on lowest generator validation loss
    if val_g_loss.item() < best_val_g_loss:
        best_val_g_loss = val_g_loss.item()
        epochs_without_improvement = 0
        torch.save(generator.state_dict(), "best_Generator2.pth")
        torch.save(discriminator.state_dict(), "best_discriminator2.pth")
        print(f"✅ Saved Best Model at Epoch {epoch+1} with Val G Loss: {val_g_loss.item():.4f}")
    else:
        epochs_without_improvement += 1
        print(f"⏳ No improvement for {epochs_without_improvement} epoch(s).")

    # 🛑 Early Stopping
    if epochs_without_improvement >= patience:
        print(f"🛑 Early stopping triggered after {epoch+1} epochs.")
        break

    scheduler_D.step()
    scheduler_G.step()

print("🎉 Training Complete!")


Epoch [1/50], D Loss: 0.9600, G Loss: 6.3195, D Accuracy: 37.33%
Validation G Loss: 5.0150, D Accuracy: 33.35%
✅ Saved Best Model at Epoch 1 with Val G Loss: 5.0150
Epoch [2/50], D Loss: 0.5217, G Loss: 6.0691, D Accuracy: 52.92%
Validation G Loss: 4.9940, D Accuracy: 55.72%
✅ Saved Best Model at Epoch 2 with Val G Loss: 4.9940
Epoch [3/50], D Loss: 0.7475, G Loss: 5.7772, D Accuracy: 67.17%
Validation G Loss: 5.6219, D Accuracy: 69.50%
⏳ No improvement for 1 epoch(s).
Epoch [4/50], D Loss: 0.7175, G Loss: 6.4150, D Accuracy: 65.07%
Validation G Loss: 4.5260, D Accuracy: 25.81%
✅ Saved Best Model at Epoch 4 with Val G Loss: 4.5260
Epoch [5/50], D Loss: 0.3328, G Loss: 6.7222, D Accuracy: 68.25%
Validation G Loss: 5.5628, D Accuracy: 60.32%
⏳ No improvement for 1 epoch(s).
Epoch [6/50], D Loss: 0.2809, G Loss: 5.7846, D Accuracy: 60.14%
Validation G Loss: 5.7557, D Accuracy: 66.88%
⏳ No improvement for 2 epoch(s).
Epoch [7/50], D Loss: 0.5260, G Loss: 5.9098, D Accuracy: 64.11%
Validati

In [17]:
import torch
import os
from torchvision.utils import save_image

# Load the generator model
generator = Generator().to(device)
generator.load_state_dict(torch.load("best_Generator2.pth"))
generator.eval()

# Directory to save outputs
output_dir = "test_outputs1"
os.makedirs(output_dir, exist_ok=True)

# Testing loop
with torch.no_grad():
    for idx, (real, _) in enumerate(test_loader):
        real = real.to(device)
        fake_cartoon = generator(real)

        # Denormalize output before saving: [-1, 1] -> [0, 1]
        fake_cartoon = (fake_cartoon + 1) / 2
        real = (real + 1) / 2

        # Save the real and cartoonized images side by side
        for i in range(real.size(0)):
            save_image(torch.cat([real[i], fake_cartoon[i]], dim=2),
                       os.path.join(output_dir, f"sample_{idx * test_loader.batch_size + i}.png"))

print("✅ Cartoonized test images saved in 'test_outputs1' folder.")


✅ Cartoonized test images saved in 'test_outputs1' folder.
