In [None]:
import sys
sys.path.append('/content/drive/My Drive/Codes/Deep_Learning/GAN/')

# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import models
import torch.nn.functional as Functional
from torchvision import transforms
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import pyplot as plt
import torch
import torch.nn as nn

In [None]:
# anime face data
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

In [None]:
train_ds = ImageFolder(
    root="/content/drive/MyDrive/CODES/Deep_Learning/data/anime-face",
    transform=transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    ),
)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=3, pin_memory=True)

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")


def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


class DeviceDataLoader:
    def __init__(self, data_loader, device):
        self.data_loader = data_loader
        self.device = device

    def __iter__(self):
        for b in self.data_loader:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.data_loader)

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
device

device(type='cpu')

In [None]:
discriminator = torch.nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 32 x 32
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 16 x 16
    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 8 x 8
    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 4 x 4
    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    nn.Flatten(),
    nn.Sigmoid(),
)

discriminator = to_device(discriminator, device)


In [None]:
latent_size = 128
generator = nn.Sequential(
    nn.ConvTranspose2d(
        latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False
    ),
    nn.BatchNorm2d(512),
    nn.ReLU(inplace=True),
    # out: 512 x 4 x 4
    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(inplace=True),
    # out: 256 x 8 x 8
    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(inplace=True),
    # out: 128 x 16 x 16
    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    # out: 64 x 32 x 32
    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh(),
)

generator = to_device(generator, device)

In [None]:
def train_descriminator(real_images, optimizer_descriminator):
    optimizer_descriminator.zero_grad()

    real_preds = discriminator(real_images)
    real_targets = torch.ones(real_images.size(0), 1, device=device)
    real_loss = Functional.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()

    # fake image
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)

    fake_preds = discriminator(fake_images)
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_loss = Functional.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    loss = real_loss + fake_loss
    loss.backward()
    optimizer_descriminator.step()
    return loss.item(), real_score, fake_score


def train_generator(optimizer_generator):
    optimizer_generator.zero_grad()

    # fake image
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = generator(latent)

    # fool the discriminator
    preds = discriminator(fake_images)
    targets = torch.ones(batch_size, 1, device=device)
    loss = Functional.binary_cross_entropy(preds, targets)

    loss.backward()
    optimizer_generator.step()

    # img
    img = fake_images[0].clone().detach().numpy().transpose(1,2,0)

    return loss.item(), img

In [None]:
def fit(epochs, lr, start_index=1):
    torch.cuda.empty_cache()

    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        batch_losses_g = []
        batch_losses_d = []
        count = 0
        for batch in train_dl:
            count += 1
            real_images = batch[0]
            loss_d, real_score, fake_score = train_descriminator(real_images, optimizer_discriminator)
            loss_g, img = train_generator(optimizer_generator)
            batch_losses_g.append(loss_g)
            batch_losses_d.append(loss_d)

            if count % 20 == 0:
                plt.imshow(img)
                plt.show()

        losses_d.append(sum(batch_losses_d) / len(batch_losses_d))
        losses_g.append(sum(batch_losses_g) / len(batch_losses_g))
        # real_scores.append(real_score)
        # fake_scores.append(fake_score)
        print(f"Epochs {epoch+1}/{epochs} loss_g: {loss_g:.2f} loss_d: {loss_d:.2f} real_score: {real_score:.2f} fake_score: {fake_score:.2f}")

        torch.save(generator.state_dict(), "/content/drive/MyDrive/CODES/Deep_Learning/data/Generator.pth")
        torch.save(discriminator.state_dict(), "/content/drive/MyDrive/CODES/Deep_Learning/data/Discriminator.pth")

    # return losses_g, losses_d, real_scores, fake_scores
    return losses_g, losses_d

lr = 0.0002
epochs = 25

losses_g, losses_d = fit(epochs, lr)
plt.plot(losses_d)
plt.show()
plt.plot(losses_g)
plt.show()