In [None]:
# Dataset Used https://www.kaggle.com/datasets/splcher/animefacedataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader 
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = '../input'
batch_size = 128
image_size = 64
stats = (0.5,0.5,0.5),(0.5,0.5,0.5)

In [None]:
train_ds = ImageFolder(DATA_DIR,transform=T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)
]))

train_dl = DataLoader(train_ds,batch_size,num_workers=4,pin_memory=True)

In [None]:
for i,l in train_dl:
    print(i.shape)
    plt.imshow(i[0].permute(1,2,0))
    break


In [None]:
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

In [None]:
from torchvision.utils import make_grid
def show_images(images,nmax=64):
    fig,ax = plt.subplots(figsize=(8,8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]),nrow=8).permute(1,2,0))

def show_batch(dl,nmax=64):
    for images,_ in dl:
        show_images(images,nmax)
        break;

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def to_device(data,device):
    if isinstance(data,(list,tuple)):
        return [to_device(x) for x in data]
    return data.to(device,non_blocking=True)

class DeviceDataLoader():
    def __init__(self,dl,device) -> None:
        self.dl = dl
        self.device = device 
    def __iter__(self):
        for b in self.dl:
            yield to_device(b,self.device)
    def __len__(self):
        return len(self.dl)

In [None]:
train_dl = DeviceDataLoader(train_dl,device)

In [36]:
discriminator = nn.Sequential(
    nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2,inplace=True),
    
    nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2,inplace=True),
    
    nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(512,1,kernel_size=4,stride=1,padding=0,bias=False),

    nn.Flatten(),
    nn.Sigmoid()
)

In [None]:
discriminator.to(device)

In [None]:
latent_size = 128

In [None]:
generator = nn.Sequential(
    nn.ConvTranspose2d(latent_size,512,kernel_size=4,stride=1,padding=0,bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),

    nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),

    nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),

    nn.ConvTranspose2d(64,3,kernel_size=4,stride=2,padding=1,bias=False),
    nn.Tanh()
)

In [None]:
generator.to(device)

In [None]:
def train_discriminator(real_images,opt_d):
    opt_d.zero_grad()
    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0),1,device=device)
    real_loss = F.binary_cross_entropy(real_preds,real_targets)
    real_score = torch.mean(real_preds).item()
    
    latent = torch.randn(batch_size,latent_size,1,1,device=device)
    fake_images = generator(latent)

    fake_targets = torch.zeros(fake_images.size(0),1,device=device)
    fake_preds = discriminator(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds,fake_targets)
    fake_score = torch.mean(fake_preds).item()

    loss = real_loss + fake_loss 
    loss.backward()
    opt_d.step()
    return loss.item(), real_score,fake_score

In [None]:
def train_generator(opt_g):
    opt_g.zero_grad()

    latent = torch.randn(batch_size,latent_size,1,1,device=device)
    fake_images = generator(latent)

    preds = discriminator(fake_images)
    targets = torch.ones(batch_size,1,device=device)
    loss = F.binary_cross_entropy(preds,targets)

    loss.backward()
    opt_g.step()
    return loss.item()

In [None]:
from torchvision.utils import save_image
import os
sample_dir = 'generated'
os.makedirs(sample_dir,exist_ok=True)

In [None]:
def save_samples(index,latent_tensors):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated_images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images),os.path.join(sample_dir,fake_fname),nrow=8)
    print('Saving Image',fake_fname)

In [None]:
fixed_latent = torch.randn(64,latent_size,1,1,device=device)

In [None]:
def fit(epochs,lr,start_idx = 1):
    torch.cuda.empty_cache()

    losses_g,losses_d,real_scores,fake_scores = [],[],[],[]

    opt_d = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(0.5,0.999))
    opt_g = torch.optim.Adam(generator.parameters(),lr=lr,betas=(0.5,0.999))

    for epoch in range(epochs):
        for real_images,_ in train_dl:
            loss_d,real_score,fake_score = train_discriminator(real_images,opt_d)
            loss_g = train_generator(opt_g)

        losses_d.append(loss_d)
        losses_g.append(loss_g)
        real_scores.append(real_score)
        fake_scores.append(fake_score)
        print(epoch)
        save_samples(epoch+start_idx,fixed_latent)

    return losses_g,losses_d,real_scores,fake_scores

In [None]:
lr = 0.0002
epochs = 10

In [None]:
history = fit(epochs,lr)

In [38]:
from torchsummary import summary
summary(generator,(128,1,1))

Layer (type:depth-idx)                   Output Shape              Param #
├─ConvTranspose2d: 1-1                   [-1, 512, 4, 4]           1,048,576
├─BatchNorm2d: 1-2                       [-1, 512, 4, 4]           1,024
├─ReLU: 1-3                              [-1, 512, 4, 4]           --
├─ConvTranspose2d: 1-4                   [-1, 256, 8, 8]           2,097,152
├─BatchNorm2d: 1-5                       [-1, 256, 8, 8]           512
├─ReLU: 1-6                              [-1, 256, 8, 8]           --
├─ConvTranspose2d: 1-7                   [-1, 128, 16, 16]         524,288
├─BatchNorm2d: 1-8                       [-1, 128, 16, 16]         256
├─ReLU: 1-9                              [-1, 128, 16, 16]         --
├─ConvTranspose2d: 1-10                  [-1, 64, 32, 32]          131,072
├─BatchNorm2d: 1-11                      [-1, 64, 32, 32]          128
├─ReLU: 1-12                             [-1, 64, 32, 32]          --
├─ConvTranspose2d: 1-13                  [-1, 3, 64, 64

Layer (type:depth-idx)                   Output Shape              Param #
├─ConvTranspose2d: 1-1                   [-1, 512, 4, 4]           1,048,576
├─BatchNorm2d: 1-2                       [-1, 512, 4, 4]           1,024
├─ReLU: 1-3                              [-1, 512, 4, 4]           --
├─ConvTranspose2d: 1-4                   [-1, 256, 8, 8]           2,097,152
├─BatchNorm2d: 1-5                       [-1, 256, 8, 8]           512
├─ReLU: 1-6                              [-1, 256, 8, 8]           --
├─ConvTranspose2d: 1-7                   [-1, 128, 16, 16]         524,288
├─BatchNorm2d: 1-8                       [-1, 128, 16, 16]         256
├─ReLU: 1-9                              [-1, 128, 16, 16]         --
├─ConvTranspose2d: 1-10                  [-1, 64, 32, 32]          131,072
├─BatchNorm2d: 1-11                      [-1, 64, 32, 32]          128
├─ReLU: 1-12                             [-1, 64, 32, 32]          --
├─ConvTranspose2d: 1-13                  [-1, 3, 64, 64