# Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import torchvision
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
from PIL import Image
import shutil
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.multiprocessing as mp
from tqdm import tqdm

mp.set_start_method('spawn', force=True)

# Helper Functions

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, fname))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # Since there are no labels, we can return a dummy label (0)

def create_subset_dataset(original_dir, subset_dir, subset_size=100):
    if not os.path.exists(subset_dir):
        os.makedirs(subset_dir)
    
    image_paths = [os.path.join(original_dir, fname) for fname in os.listdir(original_dir) if os.path.isfile(os.path.join(original_dir, fname))]
    
    np.random.seed(1000)
    np.random.shuffle(image_paths)
    subset_paths = image_paths[:subset_size]
    
    for img_path in subset_paths:
        shutil.copy(img_path, subset_dir)

    print(f"Subset created with {len(subset_paths)} images in {subset_dir}")

def get_data_loader(data_dir, batch_size, image_size=(128, 128), subset_size=None):

    transform = transforms.Compose([
        transforms.Resize(image_size),
        #transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    assert os.path.exists(data_dir), f"Directory not found: {data_dir}"

    dataset = CustomImageDataset(root_dir=data_dir, transform=transform)

    if subset_size:
        num_images = len(dataset)
        indices = list(range(num_images))
        np.random.seed(1000)
        np.random.shuffle(indices)
        subset_indices = indices[:subset_size]
        subset_sampler = SubsetRandomSampler(subset_indices)
        loader = DataLoader(dataset, batch_size=batch_size, sampler=subset_sampler)
        return loader

    def split_indices(dataset):
        num_images = len(dataset)
        indices = list(range(num_images))
        np.random.seed(1000)
        np.random.shuffle(indices)

        train_split = int(0.7 * num_images)
        val_split = int(0.15 * num_images)

        train_indices = indices[:train_split]
        val_indices = indices[train_split:train_split + val_split]
        test_indices = indices[train_split + val_split:]

        return train_indices, val_indices, test_indices

    train_indices, val_indices, test_indices = split_indices(dataset)

    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)
    test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

    return train_loader, val_loader, test_loader

# Data Loading

### 1.1 Load Data

In [None]:
BATCH_SIZE = 32

# Load data from the full datasets
train_loader_A, val_loader_A, test_loader_A = get_data_loader('/root/aps360-project/data/dataSetA_10k', BATCH_SIZE, image_size=(128, 128))
train_loader_B, val_loader_B, test_loader_B = get_data_loader('/root/aps360-project/data/dataSetB_10k', BATCH_SIZE, image_size=(128, 128))

# Landmarking

In [None]:
import cv2
import dlib
from matplotlib import pyplot as plt

In [None]:
frontalface_detector = dlib.get_frontal_face_detector()
landmark_predictor = dlib.shape_predictor('./shape_predictor_68_face_landmarks.dat')

In [None]:
picture = '/home/alumkalryan/aps360-project/data/dataSetA_10k/000010.jpg'
image = cv2.imread(picture)

In [None]:
def get_human_landmarks(image):
    img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    img_black = image * 0

    rects = frontalface_detector(img_gray, 1)

    eye_points = [36,39,42,45]
    nose_points = [30]
    mouth_points = [48,54]

    for (i, rect) in enumerate(rects):
        shape = landmark_predictor(img_gray, rect)
        shape = [(shape.part(i).x, shape.part(i).y) for i in range(68)]

        all_points = eye_points + nose_points + mouth_points
        for i in all_points:
            x, y = shape[i]
            cv2.circle(img_black, (x, y), 1, (255, 255,255), -1)
    
    return img_black

In [None]:
plt.imshow(get_human_landmarks(image))

# Landmarking Discriminators 

## Landmark Consistency Loss

In [None]:
class LandmarkConsistencyLoss(nn.Module):
    def __init__(self):
        super(LandmarkConsistencyLoss, self).__init__()
        self.landmark_regressor = get_human_landmarks
        self.l2_loss = nn.MSELoss() #? correct?
    
    def forward(self, generated_image, target_image):
        generated_landmarks = self.landmark_regressor(generated_image)
        target_landmarks = self.landmark_regressor(target_image) #does not use previously saved landmarks
        loss = self.l2_loss(generated_landmarks, target_landmarks)
        return loss

## Landmark Matched Global Discriminator

In [None]:
# Define the reusable convolutional block
class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, is_downsampling: bool = True, add_activation: bool = True, **kwargs):
        super().__init__()
        if is_downsampling:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
                nn.Dropout(0.5)
            )
        else:
            self.conv = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True) if add_activation else nn.Identity(),
                nn.Dropout(0.5)
            )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.block = nn.Sequential(
            ConvolutionalBlock(channels, channels, add_activation=True, kernel_size=3, padding=1, is_downsampling=False),
            ConvolutionalBlock(channels, channels, add_activation=False, kernel_size=3, padding=1, is_downsampling=False),
        )

    def forward(self, x):
        return x + self.block(x)

# Define the reusable block
class ConvInstanceNormLeakyReLUBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=False,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

### Global Discriminator

Just change previous discriminator to this

In [None]:
#Two global discriminators, one unconditional and another conditional

# Define the discriminator using the reusable block
class GlobalDiscriminator(nn.Module):
    def __init__(self, in_channels=3, pm=64, conditional = False):
        super().__init__()
        self.conditional = conditional
        features = [pm * 1, pm * 2, pm * 4, pm * 8]
        
        self.initial_layer = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
                bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                ConvInstanceNormLeakyReLUBlock(
                    in_channels,
                    feature,
                    stride=2
                )
            )
            in_channels = feature

        # Adjust the last block to match the original structure
        layers.append(
            ConvInstanceNormLeakyReLUBlock(
                in_channels,
                features[-1],
                stride=1
            )
        )

        # Final layer
        layers.append(
            nn.Conv2d(
                in_channels=features[-1],
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect"
            )
        )
        
        self.model = nn.Sequential(*layers)

    def forward(self, x, landmarks = None):
        if self.conditional and landmarks is not None:
            x = torch.cat([x, landmarks], dim=1)
        x = self.initial_layer(x)
        return self.model(x)


## Landmark Guided Local Discriminator

In [None]:
class LocalDiscriminator(nn.Module):
    def __init__(self):
        super(LocalDiscriminator, self).__init__()
        self.main = nn.Sequential(
            
        )
    
    def forward(self, image):
        return self.main(image)

# Anime Face Detection

In [None]:
import sys
import os.path

picture = '/home/alumkalryan/aps360-project/data/dataSetB_20k/141_2000.jpg'
image = cv2.imread(picture)

def detect(image, cascade_file = "./lbpcascade_animeface.xml"):
    if not os.path.isfile(cascade_file):
        raise RuntimeError("%s: not found" % cascade_file)

    cascade = cv2.CascadeClassifier(cascade_file)

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = cascade.detectMultiScale(gray,
                                     # detector options
                                     scaleFactor = 1.1,
                                     minNeighbors = 5,
                                     minSize = (24, 24))

    for (x, y, w, h) in faces:
        cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 2)

    plt.imshow(image)

detect(image)