In [1]:
import os, json, time
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from glob import glob
import pandas as pd
from sklearn.metrics import confusion_matrix

In [6]:

# Load config
with open("../experiments/exp4/config.json") as f:
    config = json.load(f)

original_dir = config["original_dir"]
cartoon_dir = config["cartoon_dir"]
val_original_dir = config["val_original_dir"]
val_cartoon_dir = config["val_cartoon_dir"]
batch_size = config["batch_size"]
total_epochs = config["total_epochs"]
lambda_l1 = config["lambda_l1"]
learning_rate = config["learning_rate"]
limit = config["limit"]
beta1, beta2 = config["beta1"], config["beta2"]

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [9]:
# Dataset
class CartoonDataset(Dataset):
    def __init__(self, original_dir, cartoon_dir, transform=None, limit=None):
        self.original_paths = sorted(glob(os.path.join(original_dir, "*.png")))
        self.cartoon_paths = sorted(glob(os.path.join(cartoon_dir, "*.png")))
        if limit:
            self.original_paths = self.original_paths[:limit]
            self.cartoon_paths = self.cartoon_paths[:limit]
        self.transform = transform

    def __len__(self):
        return len(self.original_paths)

    def __getitem__(self, idx):
        real = Image.open(self.original_paths[idx]).convert("RGB")
        cartoon = Image.open(self.cartoon_paths[idx]).convert("RGB")
        if self.transform:
            real = self.transform(real)
            cartoon = self.transform(cartoon)
        return real, cartoon

In [10]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [11]:
train_loader = DataLoader(CartoonDataset(original_dir, cartoon_dir, transform, limit), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(CartoonDataset(val_original_dir, val_cartoon_dir, transform, limit), batch_size=batch_size, shuffle=False)

In [12]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(self.middle(self.encoder(x)))

In [13]:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, 1, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

In [14]:
# Setup
generator = Generator().to(device)
discriminator = Discriminator().to(device)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, beta2))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, beta2))
adversarial_loss = nn.BCELoss()
content_loss = nn.L1Loss()

In [20]:

g_losses, d_losses, d_accuracies = [], [], []
val_g_losses, val_d_losses, val_accuracies = [], [], []
start_epoch = 0

In [27]:
# Training
for epoch in range(start_epoch, start_epoch + 10):
    start_time = time.time() 
    generator.train()
    discriminator.train()
    g_epoch_loss = 0
    d_epoch_loss = 0
    accuracy_accum = 0
    batch_count = 0

    for real, cartoon in train_loader:
        real, cartoon = real.to(device), cartoon.to(device)

        # Discriminator
        fake = generator(real).detach()
        d_real = discriminator(cartoon)
        d_fake = discriminator(fake)
        real_labels = torch.ones_like(d_real)
        fake_labels = torch.zeros_like(d_fake)
        d_loss = 0.5 * (adversarial_loss(d_real, real_labels) + adversarial_loss(d_fake, fake_labels))
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Accuracy
        pred_real = (d_real.mean(dim=[1, 2, 3]) > 0.5).float()
        pred_fake = (d_fake.mean(dim=[1, 2, 3]) < 0.5).float()
        correct = pred_real.sum().item() + pred_fake.sum().item()
        accuracy_accum += correct / (pred_real.numel() + pred_fake.numel())
        batch_count += 1

        # Generator
        fake = generator(real)
        g_adv = adversarial_loss(discriminator(fake), real_labels)
        g_l1 = content_loss(fake, cartoon)
        g_loss = g_adv + lambda_l1 * g_l1
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        g_epoch_loss += g_loss.item()
        d_epoch_loss += d_loss.item()

    # Logging
    g_losses.append(g_epoch_loss / len(train_loader))
    d_losses.append(d_epoch_loss / len(train_loader))
    d_accuracies.append(accuracy_accum / batch_count)

    # ✅ Validation Accuracy
    generator.eval()
    discriminator.eval()
    val_correct = 0
    val_total = 0
    val_g_loss_total = 0
    val_d_loss_total = 0
    with torch.no_grad():
        for real, cartoon in val_loader:
            real, cartoon = real.to(device), cartoon.to(device)

            fake = generator(real)
            d_real = discriminator(cartoon)
            d_fake = discriminator(fake)
            
            pred_real = (d_real.mean(dim=[1, 2, 3]) > 0.5).float()
            pred_fake = (d_fake.mean(dim=[1, 2, 3]) < 0.5).float()
            val_correct += pred_real.sum().item() + pred_fake.sum().item()
            val_total += pred_real.numel() + pred_fake.numel()

            # print("fake:", fake.shape)
            # print("disc(fake):", discriminator(fake).shape)
            # print("real_labels:", real_labels.shape)
             # Validation losses
            real_labels = torch.ones_like(d_real)
            fake_labels = torch.zeros_like(d_fake)
            d_loss_real = adversarial_loss(d_real, real_labels)
            d_loss_fake = adversarial_loss(d_fake, fake_labels)
            d_loss = (d_loss_real + d_loss_fake) * 0.5
            val_d_loss_total += d_loss.item()

            g_adv = adversarial_loss(discriminator(fake), real_labels)
            g_l1 = content_loss(fake, cartoon)
            g_loss = g_adv + lambda_l1 * g_l1
            val_g_loss_total += g_loss.item()
            

    val_acc = val_correct / val_total
    val_accuracies.append(val_acc)
    val_g_losses.append(val_g_loss_total / len(val_loader))
    val_d_losses.append(val_d_loss_total / len(val_loader))

    # ✅ Confusion Matrix
    if (epoch + 1) % 10 == 0:
        all_labels, all_preds = [], []
        with torch.no_grad():
            for real, cartoon in val_loader:
                real, cartoon = real.to(device), cartoon.to(device)
                d_real = discriminator(cartoon)
                d_fake = discriminator(generator(real))
                preds_real = (d_real.mean(dim=[1, 2, 3]) > 0.5).int().cpu().numpy()
                preds_fake = (d_fake.mean(dim=[1, 2, 3]) > 0.5).int().cpu().numpy()
                labels_real = torch.ones_like(d_real.mean(dim=[1, 2, 3])).int().cpu().numpy()
                labels_fake = torch.zeros_like(d_fake.mean(dim=[1, 2, 3])).int().cpu().numpy()
                all_preds.extend(preds_real)
                all_preds.extend(preds_fake)
                all_labels.extend(labels_real)
                all_labels.extend(labels_fake)
        cm = confusion_matrix(all_labels, all_preds)
        cm_df = pd.DataFrame(cm, index=["Fake", "Real"], columns=["Pred Fake", "Pred Real"])
        cm_df.to_csv(f"../experiments/exp4/logs/confusion_epoch_{epoch+1}.csv")

    end_time = time.time()  # ⏱️ End timer
    epoch_time = end_time - start_time 
    
    print(
    f"Epoch {epoch+1} | "
    f"Train G Loss: {g_losses[-1]:.4f} | "
    f"Train D Loss: {d_losses[-1]:.4f} | "
    f"Train Acc: {d_accuracies[-1]:.4f} | "
    f"Val G Loss: {val_g_losses[-1]:.4f} | "
    f"Val D Loss: {val_d_losses[-1]:.4f} | "
    f"Val Acc: {val_accuracies[-1]:.4f} | "
    f"Time: {epoch_time:.2f}s"
)
    # Save losses
    with open("../experiments/exp4/losses/losses.json", "w") as f:
        json.dump({
            "g_losses": g_losses,
            "d_losses": d_losses,
            "d_accuracies": d_accuracies,
            "val_g_losses" : val_g_losses,
            "val_d_losses" : val_d_losses,
            "val_accuracies": val_accuracies
        }, f)

    if (epoch + 1) % 10 == 0:
        torch.save({
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'g_optimizer': g_optimizer.state_dict(),
            'd_optimizer': d_optimizer.state_dict(),
            'epoch': epoch + 1
        }, f"../experiments/exp4/checkpoints/cartoongan_epoch{epoch+1}.pth")

fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([4, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: torch.Size([8, 1, 31, 31])
fake: torch.Size([8, 3, 256, 256])
disc(fake): torch.Size([8, 1, 31, 31])
real_labels: t

KeyboardInterrupt: 