In [3]:
# Imports

import os
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
import time

train = pd.read_csv('../input/glasses-or-no-glasses/train.csv')
test = pd.read_csv('../input/glasses-or-no-glasses/test.csv')
train.set_index('id', inplace=True)

In [36]:
# configs 
BATCH_SIZE = 1
LAMBDA_IDENTITY = 0.0
LEARNING_RATE = 2e-4
NUM_EPOCHS = 2
LAMBDA_CYCLE = 10 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [37]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

In [38]:
# Output Directories
results_NoGlasses = './NoGlasses'
results_Glasses = './Glasses'

In [39]:
# segregating images with and without glasses
directory = '../input/glasses-or-no-glasses/faces-spring-2020/faces-spring-2020'
train_imgs_glasses = []
train_imgs_no_glasses = []
test_imgs = []
for img in os.listdir(directory):
    img_id = int(img.replace('face-', ''  ).replace('.png', ''))
    for i in range(len(train)+1):
        if (i) == (img_id):
            if train.loc[img_id]['glasses'] == 0:
                train_imgs_no_glasses.append(os.path.join(directory, ('face-' + str(img) + '.png')))
                break
                
            elif train.loc[img_id]['glasses'] == 1:
                train_imgs_glasses.append(os.path.join(directory, ('face-' + str(img) + '.png')))
                break
                
        
    for i in range(len(test)+1): 
        if int(i) == int(img_id):
            test_imgs.append(os.path.join(directory, ('face-' + str(img) + '.png')))
            break         

In [40]:
print(len(train_imgs_glasses))
print(len(train_imgs_no_glasses))
print(len(test_imgs))

2856
1644
500


In [41]:
class Dataset(Dataset):
    def __init__(self,imgs_glasses, img_no_glasses, transform = None ):
        super().__init__()
        self.root_dir = directory
        self.glasses_list = imgs_glasses
        self.no_glasses_list = img_no_glasses
        self.transform = transform
        
    def __len__(self):
        self.glasses_len = len(self.glasses_list)
        self.no_glasses_len = len(self.no_glasses_list)
        self.dataset_len = max(self.glasses_len, self.no_glasses_len)
        return self.dataset_len
    
    def __getitem__(self, idx):
        start_time = time.time()
        glasses_path = self.glasses_list[idx % self.dataset_len]
        no_glasses_paths = self.no_glasses_list[idx % self.dataset_len]
        img_glasses = np.array(ImageOps.grayscale(Image.open(glasses_path).convert('RGB')))
        img_no_glasses = np.array(ImageOps.grayscale(Image.open(no_glasses_path).convert('RGB')))
        
        if self.transform:
            augmentation = self.transform(image=img_glasses, imgage0=img_no_glasses)
            img_glasses = augmentation["image0"]
            img_no_glasses = augmentation["image"]
            
            
        end_time = time.time()    
        print(f"{end_time - start_time} data loading time")
        return img_glasses, img_no_glasses

In [42]:
transforms_train = A.Compose(
    [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
    additional_targets = {"image0":"image"}
)

transforms_val = A.Compose(
    [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
    additional_targets = {"image0":"image"}
)

In [46]:
# Generator

class SelfAttentionBlock(nn.Module):
    
    def __init__(self, in_dim):
        super().__init__()
        self.in_channel = in_dim
        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//8, kernel_size=1
        )
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim//8, kernel_size=1
        )
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, kernel_size=1
        )
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
    
    def forward():
        """
        inputs:
            x : input feature maps (B X C X W X H)
        returns:
            out : self attention value + input features
            attention: B x N X N (N is Width*Height)
        """
        batch_size, c, h, w = x.shape
        proj_query = self.query_conv(x).view(batch_size, -1, w*h),permute(0, 2, 1)
        # proj_query without permute has a shape of (batch_size, c//8, w*h)
        # proj_query after permute has a shape of (batch_size, w*h, c//8)
        # this is done in order to make proj_query compatible for matrix mult with proj_key
        proj_key = self.key_conv(x).view(batch_size, -1, w*h) 
        
        energy = torch.bmm(proj_query, proj_key) #(batch_size, w*h, w*h)
        attention = self.softmax(energy) #(batch_size, w*h, w*h)
        
        proj_value = self.value_conv(x).view(batch_size, -1, w*h)  # (batch_size, c, w*h)
        
        out = torch.bmm(proj_value, attention)
        out = out.view(batch_size, c, h, w)
        
        out = self.gamma*out + x
        return out, attention

    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=channels,
                out_channels=channels,
                kernel_size=1,
                stride=1,
                padding=0))

    def forward(self, x):
        return self.block(x) + x
    
    
class Generator(nn.Module):
    def __init__(self, img_channels):
        super().__init__()
        
        self.downConvBlock = nn.Sequential((
            nn.Conv2d(in_channels=img_channels, out_channels=1, kernel_size=7, stride=1, padding=3),
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=6, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU()))
        self.convtrans1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)
        self.active1 = nn.Sequential(nn.InstancNorm2d(128), nn.ReLU())
        self.convtrans2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.convtrans3 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=6, stride=2, padding=2)
        self.selfAttentionayer = SelfAttentionBlock(in_dim=256)
        self.resBlock1 = ResidualBlock(channels=256)
        self.resBlock2 = ResidualBlock(channel=256)
        self.lastlayer = nn.Sequential(nn.Conv2d(in_channls=1, out_channels=1, kernel_size=7, stride=1, padding=3), nn.Tanh())
            
    def forward(self, x):
        batch_size, ch, h, w = x.shape
        out = self.downConvBlock(x)
        out = self.resBlock1(out)
        out, _ = self.AttentionLayer(out)
        out = self.resBlock2(out)
        out = self.convtrans1(out, output_size=(batch_size, 128, 32, 32))
        out = self.active1(out)
        out = self.convtrans2(out, output_size=(batch_size, 64, 64, 64))
        out = self.convtrans3(out, output_size=(batch_size, 1, 128, 128))
        out = self.lastlayer(out)

In [44]:
# Discriminator

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=True,
                padding_mode="reflect"
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLu(0.2)       
        )
        
    def forward(self, x):
        return self.conv(x)       
        

class Discriminator(nn.Module):
    def __init__(self, in_channels, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
        nn.Conv2d(
            in_channels,
            features[0],
            kernel_size=4,
            stride=2,
            padding=1,
            padding_mode="reflect"
        ),
        nn.LeakyReLU(0.2))
    
        layers = []
    
        in_channels = features[0]
    
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature==features[-1] else 2)
            )
            in_channels = feature
        
        layers.append(
            nn.Conv2d(
                in_channels,
                1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect"
            ))
        self.model = nn.Sequential(*layers)
    
    def forward():
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [45]:
# utility

def save_some_examples_g(gen, val_loader, epoch, idx, folder):
    img_glasses, img_no_glasses = next(iter(val_loader))
    img_glasses, img_no_glasses = img_glasses.to(DEVICE), img_no_glasses.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake_glasses = gen(img_no_glasses)
        y_fake_glasses = y_fake_glasses*0.5 + 0.5
        save_image(y_fake_glasses, os.path.join(folder, f"{epoch}_{idx}_fake_glasses.png"))

def save_some_examples_n(gen, val_loader, epoch, idx, folder):
    img_glasses, img_no_glasses = next(iter(val_loader))
    img_glasses, img_no_glasses = img_gasses.to(DEVICE), img_no_glasses.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake_no_glasses = gen(img_glasses)
        y_fake_no_glasses = y_fake_no_glasses*0.5 + 0.5
        save_image(y_fake_no_glasses, os.path.join(folder, f"{epoch}_{idx}_fake_no_glasses.png"))

def save_checkpoint(model, optimizer, epoch, filename):
    filename = str(epoch) + filename + "_cpt.pth.tar"
    print("=> Saving checkpoint")
    checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state(checkpoint["state_dict"])
    optimizer.load_state(checkpoint["state_dict"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


In [None]:
# train

def getVggFeatures():
    pass

def pcclT1():
    pass

def pcclT2():
    pass

def trainFn():
    pass

In [45]:

# Instantiating Generators and Discriminators
disc_G = Discriminator(inp_channels=1).to(DEVICE)
disc_N = Discriminator(inp_channels=1).to(DEVICE)
gen_G = Generator(img_channels=1, num_residuals=4).to(DEVICE)
gen_N = Generator(img_channels=1, num_residuals=4).to(DEVICE)

# Initializing optimizers
opt_disc = optim.Adam(list(disc_G.parameters()) + list(disc_N.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_gen = optim.Adam(list(gen_G.parameters()) + list(gen_N.parameters()), lr = LEARNING_RATE, betas=(0.5, 0.999))

# Loss Function
L1 = nn.L1Loss()
mse = nn.MSELoss()

# Dataset Instantiation
train_data = Dataset( train_imgs_glasses, train_imgs_no_glasses, transform = transforms_train)
val_data = Dataset(test_imgs_glasses, test_imgs_no_glasses, transform = transforms_val )

# Dataloaders
train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, BATCH_SIZE, shuffle=False)

# Training Loop
for epoch in range(NUM_EPOCH):
    print(f"Epoch: {epoch}")
    trainFn(disc_G, disc_N, gen_G, gen_N, train_loader, opt_gen, opt_disc, L1, mse, val_loader, epoch)
    
    save_checkpoint(disc_G, opt_disc, epoch, "disc_G")
    save_checkpoint(disc_N, opt_disc, epoch, "disc_N")
    save_checkpoint(gen_G, opt_gen, epoch, "gen_G")
    save_checkpoint(gen_N, opt_gen, epoch, "gen_N")