In [2]:

import os
from PIL import Image
from glob import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn


In [16]:
class CartoonDataset(Dataset):
    def __init__(self, original_dir, cartoon_dir, transform=None, limit=1000):
        self.original_paths = sorted(glob(os.path.join(original_dir, "*.png")))[:limit]
        self.cartoon_paths = sorted(glob(os.path.join(cartoon_dir, "*.png")))[:limit]
        self.transform = transform

    def __len__(self):
        return len(self.original_paths)

    def __getitem__(self, idx):
        original = Image.open(self.original_paths[idx]).convert("RGB")
        cartoon = Image.open(self.cartoon_paths[idx]).convert("RGB")

        if self.transform:
            original = self.transform(original)
            cartoon = self.transform(cartoon)

        return original, cartoon

In [4]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

In [5]:

class UNetGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

In [6]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [7]:
original_dir = "./train/original"
cartoon_dir = "./train/cartoonized"

In [20]:

dataset = CartoonDataset(original_dir, cartoon_dir, transform=transform, limit=1000)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [21]:
generator = UNetGenerator().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4)

In [22]:

for epoch in range(50):
    generator.train()
    total_loss = 0
    for real, cartoon in loader:
        real, cartoon = real.to(device), cartoon.to(device)
        output = generator(real)
        loss = criterion(output, cartoon)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")


Epoch 1, Loss: 0.2226
Epoch 2, Loss: 0.1703
Epoch 3, Loss: 0.1521
Epoch 4, Loss: 0.1446
Epoch 5, Loss: 0.1378
Epoch 6, Loss: 0.1338
Epoch 7, Loss: 0.1293
Epoch 8, Loss: 0.1276
Epoch 9, Loss: 0.1237
Epoch 10, Loss: 0.1211
Epoch 11, Loss: 0.1194
Epoch 12, Loss: 0.1177
Epoch 13, Loss: 0.1156
Epoch 14, Loss: 0.1159
Epoch 15, Loss: 0.1136
Epoch 16, Loss: 0.1133
Epoch 17, Loss: 0.1118
Epoch 18, Loss: 0.1111
Epoch 19, Loss: 0.1088
Epoch 20, Loss: 0.1074
Epoch 21, Loss: 0.1073
Epoch 22, Loss: 0.1058
Epoch 23, Loss: 0.1061
Epoch 24, Loss: 0.1037
Epoch 25, Loss: 0.1041
Epoch 26, Loss: 0.1052
Epoch 27, Loss: 0.1027
Epoch 28, Loss: 0.1024
Epoch 29, Loss: 0.1032
Epoch 30, Loss: 0.1016
Epoch 31, Loss: 0.1011
Epoch 32, Loss: 0.1000
Epoch 33, Loss: 0.1005
Epoch 34, Loss: 0.0993
Epoch 35, Loss: 0.0973
Epoch 36, Loss: 0.0977
Epoch 37, Loss: 0.0981
Epoch 38, Loss: 0.0973
Epoch 39, Loss: 0.0976
Epoch 40, Loss: 0.0962
Epoch 41, Loss: 0.0965
Epoch 42, Loss: 0.0960
Epoch 43, Loss: 0.0957
Epoch 44, Loss: 0.09

In [23]:
torch.save(generator.state_dict(), "generator50.pth")

In [24]:
# 1. Load the model
generator = UNetGenerator().to(device)
generator.load_state_dict(torch.load("generator50.pth", map_location=device))

<All keys matched successfully>

In [25]:
generator.train()

UNetGenerator(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
  )
  (middle): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), pad

In [26]:
# 3. Continue training
for epoch in range(50, 100):  # Start from epoch 51
    total_loss = 0
    for real, cartoon in loader:
        real, cartoon = real.to(device), cartoon.to(device)
        output = generator(real)
        loss = criterion(output, cartoon)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")

Epoch 51, Loss: 0.0998
Epoch 52, Loss: 0.0999
Epoch 53, Loss: 0.1002
Epoch 54, Loss: 0.1001
Epoch 55, Loss: 0.1002
Epoch 56, Loss: 0.0995
Epoch 57, Loss: 0.1000
Epoch 58, Loss: 0.0998
Epoch 59, Loss: 0.1002
Epoch 60, Loss: 0.0999
Epoch 61, Loss: 0.0998
Epoch 62, Loss: 0.0998
Epoch 63, Loss: 0.0998
Epoch 64, Loss: 0.1001
Epoch 65, Loss: 0.0995
Epoch 66, Loss: 0.1000
Epoch 67, Loss: 0.0995
Epoch 68, Loss: 0.0997
Epoch 69, Loss: 0.0997
Epoch 70, Loss: 0.0999
Epoch 71, Loss: 0.1003
Epoch 72, Loss: 0.0999
Epoch 73, Loss: 0.0996
Epoch 74, Loss: 0.1001
Epoch 75, Loss: 0.0998
Epoch 76, Loss: 0.0997
Epoch 77, Loss: 0.0994
Epoch 78, Loss: 0.0994
Epoch 79, Loss: 0.1000
Epoch 80, Loss: 0.0998
Epoch 81, Loss: 0.0998
Epoch 82, Loss: 0.0997
Epoch 83, Loss: 0.0998
Epoch 84, Loss: 0.0998
Epoch 85, Loss: 0.0996
Epoch 86, Loss: 0.0997
Epoch 87, Loss: 0.1001
Epoch 88, Loss: 0.0997
Epoch 89, Loss: 0.1000
Epoch 90, Loss: 0.0996
Epoch 91, Loss: 0.0998
Epoch 92, Loss: 0.0996
Epoch 93, Loss: 0.0995
Epoch 94, L

In [27]:
torch.save(generator.state_dict(), "generatorfinal100.pth")