# Face GAN:
## 1. Generate realistic images of faces of people that don't exist.  
## 2. Enhance images of faces which have been qualitatively downgraded 

In [None]:
import torchvision
import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import math
import pandas as pd
device = torch.device("cuda")

In [None]:
from torchvision.transforms.transforms import GaussianBlur
img_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256,256)),
    torchvision.transforms.GaussianBlur(7, (2,3)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.48,0.4,0.4), std=(0.225,0.225,0.225))
])

target_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256,256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.48,0.4,0.4), std=(0.225,0.225,0.225))
])

In [None]:
import os
from os.path import abspath, expanduser
from typing import Any, Callable, List, Dict, Optional, Tuple, Union

import torch
from PIL import Image

from torchvision.datasets.utils import (
    download_file_from_google_drive,
    download_and_extract_archive,
    extract_archive,
    verify_str_arg,
)
from torchvision.datasets import VisionDataset


class WIDERFace2(torchvision.datasets.VisionDataset):
    """`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.

    Args:
        root (string): Root directory where images and annotations are downloaded to.
            Expects the following folder structure if download=False:

            .. code::

                <root>
                    └── widerface
                        ├── wider_face_split ('wider_face_split.zip' if compressed)
                        ├── WIDER_train ('WIDER_train.zip' if compressed)
                        ├── WIDER_val ('WIDER_val.zip' if compressed)
                        └── WIDER_test ('WIDER_test.zip' if compressed)
        split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
            Defaults to ``train``.
        transform (callable, optional): A function/transform that  takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    BASE_FOLDER = "widerface"
    FILE_LIST = [
        # File ID                             MD5 Hash                            Filename
        ("15hGDLhsx8bLgLcIRD5DhYt5iBxnjNF1M", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
        ("1GUCogbp16PMGa39thoMMeWxp7Rp5oM8Q", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
        ("1HIfDbVEWKmsYKJZm4lchTBDLW5N7dY5T", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
    ]
    ANNOTATIONS_FILE = (
        "http://shuoyang1213.me/WIDERFACE/support/bbx_annotation/wider_face_split.zip",
        "0e3767bcf0e326556d407bf5bff5d27c",
        "wider_face_split.zip",
    )

    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super().__init__(
            root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
        )
        # check arguments
        self.split = verify_str_arg(split, "split", ("train", "val", "test"))

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")

        self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
        if self.split in ("train", "val"):
            self.parse_train_val_annotations_file()
        else:
            self.parse_test_annotations_file()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is a dict of annotations for all faces in the image.
            target=None for the test split.
        """

        # stay consistent with other datasets and return a PIL Image
        img = Image.open(self.img_info[index]["img_path"])

        target = None if self.split == "test" else self.img_info[index]["annotations"]
        if self.target_transform is not None:
            target = self.target_transform(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target


    def __len__(self) -> int:
        return len(self.img_info)

    def extra_repr(self) -> str:
        lines = ["Split: {split}"]
        return "\n".join(lines).format(**self.__dict__)

    def parse_train_val_annotations_file(self) -> None:
        filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
        filepath = os.path.join(self.root, "wider_face_split", filename)

        with open(filepath) as f:
            lines = f.readlines()
            file_name_line, num_boxes_line, box_annotation_line = True, False, False
            num_boxes, box_counter = 0, 0
            labels = []
            for line in lines:
                line = line.rstrip()
                if file_name_line:
                    img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
                    img_path = abspath(expanduser(img_path))
                    file_name_line = False
                    num_boxes_line = True
                elif num_boxes_line:
                    num_boxes = int(line)
                    num_boxes_line = False
                    box_annotation_line = True
                elif box_annotation_line:
                    box_counter += 1
                    line_split = line.split(" ")
                    line_values = [int(x) for x in line_split]
                    labels.append(line_values)
                    if box_counter >= num_boxes:
                        box_annotation_line = False
                        file_name_line = True
                        labels_tensor = torch.tensor(labels)
                        self.img_info.append(
                            {
                                "img_path": img_path,
                                "annotations": {
                                    "bbox": labels_tensor[:, 0:4],  # x, y, width, height
                                    "blur": labels_tensor[:, 4],
                                    "expression": labels_tensor[:, 5],
                                    "illumination": labels_tensor[:, 6],
                                    "occlusion": labels_tensor[:, 7],
                                    "pose": labels_tensor[:, 8],
                                    "invalid": labels_tensor[:, 9],
                                },
                            }
                        )
                        box_counter = 0
                        labels.clear()
                else:
                    raise RuntimeError(f"Error parsing annotation file {filepath}")

    def parse_test_annotations_file(self) -> None:
        filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
        filepath = abspath(expanduser(filepath))
        with open(filepath) as f:
            lines = f.readlines()
            for line in lines:
                line = line.rstrip()
                img_path = os.path.join(self.root, "WIDER_test", "images", line)
                img_path = abspath(expanduser(img_path))
                self.img_info.append({"img_path": img_path})

    def _check_integrity(self) -> bool:
        # Allow original archive to be deleted (zip). Only need the extracted images
        all_files = self.FILE_LIST.copy()
        all_files.append(self.ANNOTATIONS_FILE)
        for (_, md5, filename) in all_files:
            file, ext = os.path.splitext(filename)
            extracted_dir = os.path.join(self.root, file)
            if not os.path.exists(extracted_dir):
                return False
        return True

    def download(self) -> None:
        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        # download and extract image data
        for (file_id, md5, filename) in self.FILE_LIST:
            download_file_from_google_drive(file_id, self.root, filename, md5)
            filepath = os.path.join(self.root, filename)
            extract_archive(filepath)

        # download and extract annotation files
        download_and_extract_archive(
            url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
        )


In [None]:
train_set = WIDERFace2(root="./", split="train", transform=img_transforms,target_transform=target_transforms, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=128, shuffle=True)

val_set = WIDERFace2(root="./", split="val", transform=img_transforms,target_transform=target_transforms, download=True)
val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=128, shuffle=True)

test_set = WIDERFace2(root="./", split="test", transform=img_transforms,target_transform=target_transforms, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=128, shuffle=True)

RuntimeError: ignored

In [None]:
train_set.__getitem__(9)[0].shape

torch.Size([3, 256, 256])

In [None]:
 class Generator(torch.nn.Module):
    def __init__(self, Q1, Q2):# B):
        super(Generator, self).__init__()
        self.Q1 = Q1
        self.Q2 = Q2
        #self.B = B
        self.norm = torch.nn.BatchNorm2d(3)
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=Q1*3, kernel_size=5, stride=1, padding=2)
        self.conv2 = torch.nn.Conv2d(in_channels=Q1*3, out_channels=Q2*3, kernel_size=5, stride=1, padding=2)
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout2d(p=0.5)

        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(in_features=int(3*Q2*256*256),out_features=int(3*256*256), bias=True)#stride happens in 2 dimensions so for every conv layer with a stride of 2 the image shrinks by a factor of 4
        self.unflatten = torch.nn.Unflatten(1, (3,256,256))
        self.prelu = torch.nn.PReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.convblock = torch.nn.Sequential(
            torch.nn.BatchNorm2d(3),
            torch.nn.Conv2d(in_channels=3, out_channels=Q1*3, kernel_size=5, stride=1, padding=2),
            torch.nn.Dropout2d(p=0.5),
            torch.nn.PReLU(),
            torch.nn.Conv2d(in_channels=Q1*3, out_channels=Q2*3, kernel_size=5, stride=1, padding=2),
            torch.nn.Dropout2d(p=0.5),
            torch.nn.PReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=int(3*Q2*256*28),out_features=int(3*256*256), bias=True),
            torch.nn.Unflatten(1, (3,256,256)),
            torch.nn.Sigmoid()#PSigmoid()
        )

    def forward(self, x):
        x = self.convblock(x)
        x = self.convblock(x)
        deep_feature = self.convblock(x)
        output = self.convblock(deep_feature)
        #possibly add batchnorm at the end

        '''
        x = self.norm(x)
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.act(x)
        deep_feature = self.conv2(x)
        x = deep_feature
        x = self.dropout(x)
        x = self.act(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.unflatten(x)
        output = self.sigmoid(x)
        '''
        # stack multiple conv layers
        return output, deep_feature


In [None]:
class GateSigmoid(torch.nn.Module):
    def __init__(self):
        super(GateSigmoid, self).__init__()
        self.p = torch.nn.Parameter(torch.ones((1)))
    
    def forward(self, x):
        return 1/(1+torch.exp(-x*self.p))

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, Q1, Q2):#, B):
        super(Discriminator, self).__init__()
        self.Q1 = Q1
        self.Q2 = Q2
        #self.B = B
        self.norm = torch.nn.BatchNorm2d(3)
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=Q1*3, kernel_size=5, stride=1, padding=2)
        self.conv2 = torch.nn.Conv2d(in_channels=Q1*3, out_channels=Q2*3, kernel_size=5, stride=1, padding=2)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = torch.nn.Dropout2d()
        #add dropout layer
        self.act = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.fc1 = torch.nn.Linear(in_features=int(3*Q2*256*256*0.5**2*0.5**2),out_features=2, bias=True)#stride happens in 2 dimensions so for every conv layer with a stride of 2 the image shrinks by a factor of 4
        self.prelu = torch.nn.PReLU()
        self.fc2 = torch.nn.Linear(2,1,True)
        self.sigmoid = torch.nn.Sigmoid()
        self.gate = GateSigmoid()

    def forward(self, x):
        x = self.norm(x)
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.pool(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.pool(x)
        x = self.act(x)
        x = self.flatten(x)
        deep_feature = self.fc1(x)
        J_gan = self.prelu(deep_feature)
        J_gan = self.fc2(J_gan)
        J_gan = self.gate(J_gan)
        return J_gan, deep_feature



In [None]:
def GANLoss(fake_disc, real_disc):
    loss = -torch.sum(torch.log(real_disc))-torch.sum(torch.log(1-fake_disc))
    return loss

def DISCLoss(fake_disc, real_disc):
    loss = -torch.sum(torch.log(real_disc))-torch.sum(torch.log(1-fake_disc))
    return loss

def GENLoss(fake_disc, real_disc):
    loss = (-torch.sum(torch.log(real_disc))-torch.sum(torch.log(fake_disc)))
    return loss


In [None]:
batchlen = train_loader.batch_size
val_batch = val_loader.batch_size
generator = Generator(Q1=16, Q2=16,).to(device)
discriminator = Discriminator(Q1=16, Q2=16).to(device)

gen_optim = torch.optim.RMSprop(params=generator.parameters(), lr=0.0001)
disc_optim = torch.optim.RMSprop(params=discriminator.parameters(), lr=0.0001)

disc_plot = []
gen_plot = []
comb_plot = []
real_plot = []
fake_plot = []
false_positive = []
true_positive = []


epochs = 10
batch_number = len(train_set)//batchlen
for epoch in tqdm(range(epochs)):
    disc_loss = 0
    gen_loss = 0
    for x, t in tqdm(train_loader):
        comb = 0
        #x = x.float()#
        #t = t.float()#
        #for x, t in tqdm(data):
        disc_optim.zero_grad()
        y,_ = generator(x.to(device))
        real_disc,_ = discriminator(t.to(device)) 
        fake_disc,_ = discriminator(y.to(device))
        J = GANLoss(fake_disc, real_disc)
        J.backward()
        disc_optim.step()
        disc_loss += J.item()
        disc_plot.append(J.item()/batchlen)
        comb += (J.item()/batchlen)
        real_plot.append(sum(np.squeeze(real_disc.cpu().detach(),1))/batchlen)
        fake_plot.append(sum(np.squeeze(fake_disc.cpu().detach(),1))/batchlen)
        

        gen_optim.zero_grad()
        y,_ = generator(x.to(device))
        real_disc,_ = discriminator(t.to(device))
        fake_disc,_ = discriminator(y.to(device))
        J = GENLoss(fake_disc, real_disc)
        J.backward()
        gen_optim.step()
        gen_loss += J.item()
        gen_plot.append(J.item()/batchlen)
        comb += (J.item()/batchlen)
        comb_plot.append(comb)

    accuracy = 0
    fake_accuracy = 0
    real_accuracy = 0
    for x, t in  val_loader:
        with torch.no_grad():
            #x = x.float()
            #t = t.float()
            fake,_ = generator(x.to(device))
            real = t.to(device)
            
            zero,_ = discriminator(fake.to(device))
            one,_ = discriminator(real)

            accuracy += len(zero)-torch.sum(torch.round(zero)).item()
            accuracy += torch.sum(torch.round(one)).item()
            fake_accuracy += len(zero)-torch.sum(torch.round(zero)).item()
            real_accuracy += torch.sum(torch.round(one)).item()
            false_positive.append((torch.sum(torch.round(zero)).item()/len(zero)))
            true_positive.append(((torch.sum(torch.round(one)).item())/len(one)))
            
    print(f"epoch: {epoch}\t Disc. Loss: {disc_loss/(batchlen*batch_number):1.4f}\t Gen. Loss: {gen_loss/(batchlen*batch_number):1.4f}\t Combined Loss: {((disc_loss/(batchlen*batch_number))+ (gen_loss/(batchlen*batch_number)))/2:1.4f}")
    print(f"accuracy: {accuracy/(2*len(val_set)):1.4f}\t fake accuracy: {fake_accuracy/len(val_set):1.4f}\t real accuracy: {real_accuracy/len(val_set):1.4f}")

    if math.isnan(disc_loss) == True:
        break
    elif math.isnan(gen_loss) == True:
        break
    else:
        torch.save(generator.state_dict(), "/content/drive/My Drive/Colab Notebooks/Quickdraw_gan_generator.pth")
        torch.save(discriminator.state_dict(), "/content/drive/My Drive/Colab Notebooks/Quickdraw_gan_discriminator.pth")


plt.figure(figsize=(10,7))
plt.plot(disc_plot,label="discriminator")
plt.plot(gen_plot, label="generator")
plt.plot(comb_plot, label="combined")
plt.yscale("log")
plt.legend()

In [None]:
print(128*3145728)
print(37632*2352)
print((128*3145728)/(37632*2352))

402653184
88510464
4.5492156046091905
