In [6]:
import os
import itertools
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Custom Dataset
class FaceSketchDataset(Dataset):
    def __init__(self, root_dir, transforms_=None, mode='train'):
        self.transform = transforms_
        self.mode = mode
        photos_dir = os.path.join(root_dir, mode, 'photos')
        sketches_dir = os.path.join(root_dir, mode, 'sketches')
        
        # Ensure directories exist
        if not os.path.isdir(photos_dir):
            raise FileNotFoundError(f"Directory not found: {photos_dir}")
        if not os.path.isdir(sketches_dir):
            raise FileNotFoundError(f"Directory not found: {sketches_dir}")
        
        self.img_A_paths = sorted(os.listdir(photos_dir))
        self.img_B_paths = sorted(os.listdir(sketches_dir))
        self.root_dir = root_dir

    def __len__(self):
        return max(len(self.img_A_paths), len(self.img_B_paths))

    def __getitem__(self, idx):
        A_path = os.path.join(self.root_dir, self.mode, 'photos', self.img_A_paths[idx % len(self.img_A_paths)])
        B_path = os.path.join(self.root_dir, self.mode, 'sketches', self.img_B_paths[idx % len(self.img_B_paths)])
        A = Image.open(A_path).convert('RGB')
        B = Image.open(B_path).convert('RGB')
        if self.transform:
            A = self.transform(A)
            B = self.transform(B)
        return {'A': A, 'B': B}

# Define transformations
transform = transforms.Compose([
    transforms.Resize(int(256 * 1.12), Image.BICUBIC),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create DataLoaders
def get_dataloaders(root_dir, batch_size=4):
    train_dataset = FaceSketchDataset(root_dir, transforms_=transform, mode='train')
    val_dataset = FaceSketchDataset(root_dir, transforms_=transform, mode='val')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader

# Example usage:
root_dir = '/kaggle/input/person-face-sketches'
train_loader, val_loader = get_dataloaders(root_dir, batch_size=4)  # Set batch_size to 4

import torch.nn as nn
import torch.nn.functional as F

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(features, features, kernel_size=3, stride=1, padding=0),
            nn.InstanceNorm2d(features)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2,
                                   padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_features, output_nc, kernel_size=7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A series of convolutional layers
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        # Output layer
        model += [nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# Initialize generators and discriminators
G_A2B = Generator(input_nc=3, output_nc=3).to(device)
G_B2A = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator(input_nc=3).to(device)
D_B = Discriminator(input_nc=3).to(device)

# Initialize weights
def weights_init_normal(m):
    classname = m.__class__.__name__
    if hasattr(m, 'weight') and m.weight is not None:
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            nn.init.normal_(m.weight, 1.0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)

G_A2B.apply(weights_init_normal)
G_B2A.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

import torch.optim as optim

# Loss functions
criterion_GAN = nn.MSELoss().to(device)
criterion_cycle = nn.L1Loss().to(device)
criterion_identity = nn.L1Loss().to(device)

# Optimizers
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=lr, betas=(beta1, beta2))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, beta2))

# Learning rate schedulers
from torch.optim import lr_scheduler

lr_lambda = lambda epoch: 1.0 - max(0, epoch - 100) / float(100)
scheduler_G = lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lr_lambda)
scheduler_D_A = lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lr_lambda)
scheduler_D_B = lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lr_lambda)

# Buffers to store previously generated samples
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, 'Empty buffer or trying to create a buffer with negative size.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if torch.rand(1).item() > 0.5:
                    idx = torch.randint(0, self.max_size, (1,)).item()
                    to_return.append(self.data[idx].clone())
                    self.data[idx] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Function to load checkpoints
def load_checkpoint(epoch, checkpoint_dir):
    G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_A2B_epoch_{epoch}.pth')))
    G_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_B2A_epoch_{epoch}.pth')))
    D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_A_epoch_{epoch}.pth')))
    D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_B_epoch_{epoch}.pth')))
    print(f"Loaded models for epoch {epoch}")

# Define checkpoint directory
checkpoint_dir = '/kaggle/input/question-4/checkpoints'  # Update this path if different

# Example: Load models from epoch 1 and start training from epoch 2
starting_epoch = 1
try:
    load_checkpoint(starting_epoch, checkpoint_dir)
    starting_epoch += 1  # Start from the next epoch
except FileNotFoundError:
    print(f"No checkpoint found at epoch {starting_epoch}. Starting from scratch.")

# Training Loop
num_epochs = 50
train_loader, val_loader = get_dataloaders(root_dir, batch_size=4)  # Ensure batch_size is set to 4

for epoch in range(starting_epoch, num_epochs + 1):
    G_A2B.train()
    G_B2A.train()
    D_A.train()
    D_B.train()
    
    for i, batch in enumerate(train_loader):
        real_A = batch['A'].to(device)
        real_B = batch['B'].to(device)

        # Adversarial ground truths
        valid = torch.ones((real_A.size(0), 1, 30, 30), requires_grad=False).to(device)
        fake = torch.zeros((real_A.size(0), 1, 30, 30), requires_grad=False).to(device)

        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        same_B = G_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B) * 5.0
        same_A = G_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A) * 5.0

        # GAN loss
        fake_B = G_A2B(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, valid)

        fake_A = G_B2A(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, valid)

        # Cycle loss
        recovered_A = G_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * 10.0

        recovered_B = G_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * 10.0

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optimizer_G.step()

        ###### Discriminator A ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, valid)

        # Fake loss
        fake_A_buffered = fake_A_buffer.push_and_pop(fake_A)
        pred_fake = D_A(fake_A_buffered.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D_A_total = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A_total.backward()
        optimizer_D_A.step()

        ###### Discriminator B ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, valid)

        # Fake loss
        fake_B_buffered = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = D_B(fake_B_buffered.detach())
        loss_D_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D_B_total = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B_total.backward()
        optimizer_D_B.step()

        if i % 500 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(train_loader)}] "
                  f"Loss_G: {loss_G.item():.4f} Loss_D_A: {loss_D_A_total.item():.4f} "
                  f"Loss_D_B: {loss_D_B_total.item():.4f}")

    # Update learning rates
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()

    # Save model checkpoints
    save_checkpoint_dir = '/kaggle/working/checkpoints'  # Update this path as needed
    os.makedirs(save_checkpoint_dir, exist_ok=True)
    torch.save(G_A2B.state_dict(), os.path.join(save_checkpoint_dir, f'G_A2B_epoch_{epoch}.pth'))
    torch.save(G_B2A.state_dict(), os.path.join(save_checkpoint_dir, f'G_B2A_epoch_{epoch}.pth'))
    torch.save(D_A.state_dict(), os.path.join(save_checkpoint_dir, f'D_A_epoch_{epoch}.pth'))
    torch.save(D_B.state_dict(), os.path.join(save_checkpoint_dir, f'D_B_epoch_{epoch}.pth'))
    print(f"Saved models for epoch {epoch}")

    print(f"Completed epoch {epoch}/{num_epochs}")

# Optional: Function to load checkpoints (Already defined above)
def load_checkpoint(epoch, checkpoint_dir, G_A2B, G_B2A, D_A, D_B):
    G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_A2B_epoch_{epoch}.pth')))
    G_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_B2A_epoch_{epoch}.pth')))
    D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_A_epoch_{epoch}.pth')))
    D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_B_epoch_{epoch}.pth')))
    print(f"Loaded models for epoch {epoch}")


Loaded models for epoch 1


  G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_A2B_epoch_{epoch}.pth')))
  G_B2A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'G_B2A_epoch_{epoch}.pth')))
  D_A.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_A_epoch_{epoch}.pth')))
  D_B.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'D_B_epoch_{epoch}.pth')))


Epoch [2/50] Batch [0/5164] Loss_G: 3.3103 Loss_D_A: 0.0284 Loss_D_B: 0.0721
Epoch [2/50] Batch [500/5164] Loss_G: 3.1871 Loss_D_A: 0.0634 Loss_D_B: 0.1808
Epoch [2/50] Batch [1000/5164] Loss_G: 2.6192 Loss_D_A: 0.1825 Loss_D_B: 0.0882
Epoch [2/50] Batch [1500/5164] Loss_G: 3.2295 Loss_D_A: 0.0714 Loss_D_B: 0.0020
Epoch [2/50] Batch [2000/5164] Loss_G: 2.3332 Loss_D_A: 0.2099 Loss_D_B: 0.4262
Epoch [2/50] Batch [2500/5164] Loss_G: 4.0434 Loss_D_A: 0.2244 Loss_D_B: 0.0061
Epoch [2/50] Batch [3000/5164] Loss_G: 3.7528 Loss_D_A: 0.2686 Loss_D_B: 0.0031
Epoch [2/50] Batch [3500/5164] Loss_G: 2.7278 Loss_D_A: 0.2856 Loss_D_B: 0.0201
Epoch [2/50] Batch [4000/5164] Loss_G: 3.3284 Loss_D_A: 0.2205 Loss_D_B: 0.0098
Epoch [2/50] Batch [4500/5164] Loss_G: 3.4815 Loss_D_A: 0.1992 Loss_D_B: 0.0028
Epoch [2/50] Batch [5000/5164] Loss_G: 4.3149 Loss_D_A: 0.1719 Loss_D_B: 0.0034
Saved models for epoch 2
Completed epoch 2/50
Epoch [3/50] Batch [0/5164] Loss_G: 3.8351 Loss_D_A: 0.2576 Loss_D_B: 0.0028
E

KeyboardInterrupt: 

In [7]:
import os
import shutil

# Define the source and destination directories
checkpoint_dir = '/kaggle/working/checkpoints'
output_dir = '/kaggle/working'

# Get the latest model checkpoint files based on epoch number
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

# Copy the latest checkpoint to output directory for download
shutil.copy(os.path.join(checkpoint_dir, latest_checkpoint), output_dir)
print(f"Copied {latest_checkpoint} to output directory.")

# Print download link
from IPython.display import FileLink
print("Click the link below to download the checkpoint:")
FileLink(os.path.join(output_dir, latest_checkpoint))


Copied D_B_epoch_8.pth to output directory.
Click the link below to download the checkpoint:
