<a href="https://www.kaggle.com/code/yashikajain/eye-glass-removal?scriptVersionId=95157787" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
# Imports

import os
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
import time

train = pd.read_csv('../input/glasses-or-no-glasses/train.csv')
test = pd.read_csv('../input/glasses-or-no-glasses/test.csv')
train.set_index('id', inplace=True)

In [2]:
# configs 
BATCH_SIZE = 1
LAMBDA_IDENTITY = 0.0
LEARNING_RATE = 2e-4
NUM_EPOCHS = 2
LAMBDA_CYCLE = 10 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(42)

In [4]:
# Output Directories
os.mkdir('./NoGlasses')
os.mkdir('./Glasses')

results_NoGlasses = './NoGlasses'
results_Glasses = './Glasses'

In [5]:
# segregating images with and without glasses
directory = '../input/glasses-or-no-glasses/faces-spring-2020/faces-spring-2020'
train_imgs_glasses = []
train_imgs_no_glasses = []
test_imgs = []
for img in os.listdir(directory):
    img_id = int(img.replace('face-', ''  ).replace('.png', ''))
    for i in range(len(train)+1):
        if (i) == (img_id):
            if train.loc[img_id]['glasses'] == 0:
                train_imgs_no_glasses.append(os.path.join(directory, str(img)))
                break
                
            elif train.loc[img_id]['glasses'] == 1:
                train_imgs_glasses.append(os.path.join(directory, str(img)))
                break
                
        
    for i in range(len(test)+1): 
        if int(i) == int(img_id):
            test_imgs.append(os.path.join(directory, str(img)))
            break         

In [6]:
print(len(train_imgs_glasses))
print(len(train_imgs_no_glasses))
print(len(test_imgs))

2856
1644
500


In [7]:
test_imgs_glasses = train_imgs_glasses[:571]
test_imgs_no_glasses = train_imgs_no_glasses[:328]
train_imgs_glasses = train_imgs_glasses[571:]
train_imgs_no_glasses = train_imgs_no_glasses[328:]

In [8]:
class Dataset(Dataset):
    def __init__(self, imgs_glasses, imgs_no_glasses, transform = None ):
        super().__init__()
        self.root_dir = directory
        self.glasses_list = imgs_glasses
        self.no_glasses_list = imgs_no_glasses
        self.transform = transform
        
    def __len__(self):
        self.glasses_len = len(self.glasses_list)
        self.no_glasses_len = len(self.no_glasses_list)
        self.dataset_len = max(self.glasses_len, self.no_glasses_len)
        return self.dataset_len
    
    def __getitem__(self, idx):
        glasses_path = self.glasses_list[idx % self.glasses_len]
        no_glasses_path = self.no_glasses_list[idx % self.no_glasses_len]
        img_glasses = np.array(ImageOps.grayscale(Image.open(glasses_path).convert('RGB')))
        img_no_glasses = np.array(ImageOps.grayscale(Image.open(no_glasses_path).convert('RGB')))
        
        if self.transform:
            augmentation = self.transform(image=img_glasses, image0=img_no_glasses)
            img_glasses = augmentation["image0"]
            img_no_glasses = augmentation["image"]
            
            
        return img_glasses, img_no_glasses

In [9]:
transforms_train = A.Compose(
    [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
    additional_targets = {"image0":"image"}
)

transforms_val = A.Compose(
    [
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
        ToTensorV2()
    ],
    additional_targets = {"image0":"image"}
)

In [10]:
# Generator

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)
    
class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))    

In [11]:
# Discriminator

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=True,
                padding_mode="reflect"
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)       
        )
        
    def forward(self, x):
        return self.conv(x)       
        

class Discriminator(nn.Module):
    def __init__(self, in_channels, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
        nn.Conv2d(
            in_channels,
            features[0],
            kernel_size=4,
            stride=2,
            padding=1,
            padding_mode="reflect"
        ),
        nn.LeakyReLU(0.2))
    
        layers = []
    
        in_channels = features[0]
    
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature==features[-1] else 2)
            )
            in_channels = feature
        
        layers.append(
            nn.Conv2d(
                in_channels,
                1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect"
            ))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [17]:
# utility

def save_some_examples_g(gen, val_loader, epoch, idx, folder):
    img_glasses, img_no_glasses = next(iter(val_loader))
    img_glasses, img_no_glasses = img_glasses.to(DEVICE), img_no_glasses.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake_glasses = gen(img_no_glasses)
        y_fake_glasses = y_fake_glasses*0.5 + 0.5
        save_image(y_fake_glasses, os.path.join(folder, f"{epoch}_{idx}_fake_glasses.png"))

def save_some_examples_n(gen, val_loader, epoch, idx, folder):
    img_glasses, img_no_glasses = next(iter(val_loader))
    img_glasses, img_no_glasses = img_glasses.to(DEVICE), img_no_glasses.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake_no_glasses = gen(img_glasses)
        y_fake_no_glasses = y_fake_no_glasses*0.5 + 0.5
        save_image(y_fake_no_glasses, os.path.join(folder, f"{epoch}_{idx}_fake_no_glasses.png"))

def save_checkpoint(model, optimizer, epoch, filename):
    filename = str(epoch) + filename + "_cpt.pth.tar"
    #print("=> Saving checkpoint")
    checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    #print("=> Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state(checkpoint["state_dict"])
    optimizer.load_state(checkpoint["state_dict"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


In [18]:
# train
def trainFn(disc_G, disc_N, gen_G, gen_N, loader, optim_g, optim_d, l1, mse, val_loader, epoch, epoch_generator_loss, epoch_discriminator_loss):
    loop = tqdm(loader,leave=True, position=0)
    for idx, (img_glasses, img_no_glasses) in enumerate(loop):
      
        img_glasses = img_glasses.to(DEVICE)
        img_no_glasses = img_no_glasses.to(DEVICE)
        
        fake_no_glasses = gen_N(img_glasses)
        disc_ng_real = disc_N(img_no_glasses)
        disc_ng_fake = disc_N(fake_no_glasses.detach())
        disc_ng_real_loss = mse(disc_ng_real, torch.ones_like(disc_ng_real))
        disc_ng_fake_loss = mse(disc_ng_fake, torch.zeros_like(disc_ng_fake))
        disc_ng_loss = disc_ng_fake_loss + disc_ng_real_loss
            
        
        fake_glasses = gen_G(img_no_glasses)
        disc_g_real = disc_G(img_glasses)
        disc_g_fake = disc_G(fake_glasses.detach())
        disc_g_real_loss = mse(disc_g_real, torch.ones_like(disc_g_real))
        disc_g_fake_loss = mse(disc_g_fake, torch.zeros_like(disc_g_fake))
        disc_g_loss = disc_g_real_loss + disc_g_fake_loss
        
        #total disc loss
        D_loss = (disc_ng_loss + disc_g_loss)/2
        
        # backprop and update the weights of discriminator
        optim_d.zero_grad()
        D_loss.backward()
        optim_d.step()
        
        # Training Generators
        
        # Adversarial loss
        disc_g_fake = disc_G(fake_glasses)
        disc_ng_fake = disc_N(fake_no_glasses)
        loss_g_glasses = mse(disc_g_fake, torch.ones_like(disc_g_fake))
        loss_g_no_glasses = mse(disc_ng_fake, torch.ones_like(disc_ng_fake))
        
        # Cycle loss
        cycle_glasses = gen_G(fake_no_glasses)
        cycle_no_glasses = gen_N(fake_glasses)
        glasses_cycle_loss = l1(img_glasses, cycle_glasses)
        no_glasses_cycle_loss = l1(img_no_glasses, cycle_no_glasses)
    
        
        # Total Generator Loss
        G_loss = (loss_g_glasses + loss_g_no_glasses) + LAMBDA_CYCLE*(glasses_cycle_loss + no_glasses_cycle_loss) #+ loss_pccl 
        
        # backprop and updating weights of Generator
        optim_g.zero_grad()
        G_loss.backward()
        optim_g.step()
        
        epoch_generator_loss = epoch_generator_loss + G_loss
        epoch_discriminator_loss = epoch_discriminator_loss + D_loss
        
        #print(f"Generator loss: {G_loss}, Discriminator loss: {D_loss}")
        save_some_examples_n(gen_N, val_loader, epoch, idx, folder=results_NoGlasses)
        save_some_examples_g(gen_G, val_loader, epoch, idx, folder=results_Glasses)
    return epoch_generator_loss, epoch_discriminator_loss
        
        
    

In [19]:

# Instantiating Generators and Discriminators
disc_G = Discriminator(in_channels=1).to(DEVICE)
disc_N = Discriminator(in_channels=1).to(DEVICE)
gen_G = Generator(img_channels=1).to(DEVICE)
gen_N = Generator(img_channels=1).to(DEVICE)

# Initializing optimizers
opt_disc = optim.Adam(list(disc_G.parameters()) + list(disc_N.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_gen = optim.Adam(list(gen_G.parameters()) + list(gen_N.parameters()), lr = LEARNING_RATE, betas=(0.5, 0.999))

# Loss Function
L1 = nn.L1Loss()
mse = nn.MSELoss()

# Dataset Instantiation
train_data = Dataset(imgs_glasses=train_imgs_glasses, imgs_no_glasses=train_imgs_no_glasses, transform=transforms_train)
val_data = Dataset(imgs_glasses=test_imgs_glasses, imgs_no_glasses=test_imgs_no_glasses, transform = transforms_val )

# Dataloaders
train_loader = DataLoader(train_data, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, BATCH_SIZE, shuffle=False)

# Training Loop
for epoch in range(NUM_EPOCHS):
    print(f"Epoch: {epoch}")
    epoch_generator_loss = 0
    epoch_discriminator_loss = 0
    gen_loss, disc_loss = trainFn(disc_G, disc_N, gen_G, gen_N, train_loader, opt_gen, opt_disc, L1, mse, val_loader, epoch, epoch_generator_loss, epoch_discriminator_loss)
    print(f'Generator_loss {gen_loss}')
    print(f'Discriminator_loss {disc_loss}')

    save_checkpoint(disc_G, opt_disc, epoch, "disc_G")
    save_checkpoint(disc_N, opt_disc, epoch, "disc_N")
    save_checkpoint(gen_G, opt_gen, epoch, "gen_G")
    save_checkpoint(gen_N, opt_gen, epoch, "gen_N")

Epoch: 0


100%|██████████| 2285/2285 [31:48<00:00,  1.20it/s]


Generator_loss 9724.5947265625
Discriminator_loss 591.7130737304688
Epoch: 1


100%|██████████| 2285/2285 [30:30<00:00,  1.25it/s]


Generator_loss 9355.0146484375
Discriminator_loss 297.63946533203125
