# Project 24 - Image generation with GANs

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms

In [None]:
transform = transforms.ToTensor()

train = datasets.MNIST(root='Datasets', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train, batch_size=256, shuffle=True)

In [None]:
class generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.dense0 = nn.Linear(100, 32)
        self.dense1 = nn.Linear(32, 64)
        self.dense2 = nn.Linear(64, 128)
        self.dense3 = nn.Linear(128, 784)

        self.dropout = nn.Dropout(0.3)

    def forward(self, X):
        X = self.dropout(F.leaky_relu(self.dense0(X), 0.2))
        X = self.dropout(F.leaky_relu(self.dense1(X), 0.2))
        X = self.dropout(F.leaky_relu(self.dense2(X), 0.2))
        X = torch.tanh(self.dense3(X))
        X = X.view(X.shape[0], 28, 28)

        return X

In [None]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.dense0 = nn.Linear(784, 128)
        self.dense1 = nn.Linear(128, 64)
        self.dense2 = nn.Linear(64, 32)
        self.dense3 = nn.Linear(32, 1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, X):
        X = X.view(X.shape[0], 28 * 28)
        X = self.dropout(F.leaky_relu(self.dense0(X), 0.2))
        X = self.dropout(F.leaky_relu(self.dense1(X), 0.2))
        X = self.dropout(F.leaky_relu(self.dense2(X), 0.2))
        X = self.dense3(X)

        return X

In [None]:
G = generator()
D = discriminator()

In [None]:
G_opt = optim.Adam(G.parameters(), lr=0.002)
D_opt = optim.Adam(D.parameters(), lr=0.002)

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
G.to(device)
D.to(device)

In [None]:
for epoch in range(100):
    D_running_loss = 0
    G_running_loss = 0

    for i, (real_images, _) in enumerate(train_loader):

        batch_size = real_images.size(0)
        real_images = real_images * 2 -1
        real_images = real_images.to(device)

        G_opt.zero_grad()
    
        noise = np.random.uniform(low=-1, high=1., size=(batch_size, 100))
        noise = torch.from_numpy(noise).float().to(device)

        fake_images = G.forward(noise)
        fake_outputs = D.forward(fake_images)

        fake_labels = torch.ones(batch_size).to(device)
        G_loss = criterion(fake_outputs.view(*fake_labels.shape), fake_labels)
        G_loss.backward()
        G_opt.step()


        D_opt.zero_grad()
        real_outputs = D.forward(real_images)
        real_labels = (torch.ones(batch_size)*0.9).to(device)
        D_real_loss = criterion(real_outputs.view(*real_labels.shape), real_labels)

        noise = np.random.uniform(low=-1, high=1., size=(batch_size, 100))
        noise = torch.from_numpy(noise).float().to(device)
        fake_images = G.forward(noise)
        fake_outputs = D.forward(fake_images)
        fake_labels = torch.zeros(batch_size).to(device)
        D_fake_loss = criterion(fake_outputs.view(*fake_labels.shape), fake_labels)

        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_opt.step()

        D_running_loss += D_loss.item()
        G_running_loss += G_loss.item()

        print(f'Epoch: {epoch} - Loop{i} Discriminator cost {D_running_loss} Generator cost {G_running_loss}')

    D_running_loss /= len(train_loader)
    G_running_loss /= len(train_loader)

    print(f'Epoch: {epoch} - Discriminator cost {D_running_loss} Generator cost {G_running_loss}')

    fig, ax = plt.subplots(1,5, figsize=(10,5))
    for i in range(5):
        ax[i].imshow(fake_images.cpu().detach().numpy()[i].reshape(28, 28), cmap='gray')
        ax[i].xaxis.set_visible(False)
        ax[i].yaxis.set_visible(False)
    plt.show()

In [None]:
torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')

In [None]:
G = generator()
G.load_state_dict(torch.load('Weights/generator.pth'))


In [None]:
D = discriminator()
D.load_state_dict(torch.load('Weights/discriminator.pth'))

In [None]:
noise = np.random.uniform(-1., 1., size=(20, 100))

In [None]:
noise

In [None]:
noise = torch.from_numpy(noise).float().to(device)

In [None]:
G.to(device)

G.eval()


In [None]:
forecast = G(noise)

In [None]:
forecast = forecast.cpu().detach().numpy()

In [None]:
forecast

In [None]:
for i in range(forecast.shape[0]):
    plt.imshow(forecast[i, :].squeeze(), cmap='gray')
    plt.show()