<a href="https://colab.research.google.com/github/s-mostafa-a/pytorch_learning/blob/master/simple_generative_adversarial_net/MNIST_GANs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
class DeviceDataLoader:
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield self.to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

    def to_device(self, data, device):
        if isinstance(data, (list, tuple)):
            return [self.to_device(x, device) for x in data]
        return data.to(device, non_blocking=True)

class MNIST_GANS:
    def __init__(self, dataset, image_size, device, num_epochs=50, loss_function=nn.BCELoss(), batch_size=100,
                 hidden_size=2561, latent_size=64):
        self.device = device
        bare_data_loader = DataLoader(dataset, batch_size, shuffle=True)
        self.data_loader = DeviceDataLoader(bare_data_loader, device)
        self.loss_function = loss_function
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.batch_size = batch_size
        self.D = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid())
        self.G = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, image_size),
            nn.Tanh())
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=0.0002)
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=0.0002)
        self.sample_dir = './../data/samples'
        if not os.path.exists(self.sample_dir):
            os.makedirs(self.sample_dir)
        if self.device:
          self.G.to(device)
          self.D.to(device)
        self.sample_vectors = torch.randn(self.batch_size, self.latent_size).to(self.device)
        self.num_epochs = num_epochs

    @staticmethod
    def denormalize(x):
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def reset_grad(self):
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    def train_discriminator(self, images):
        real_labels = torch.ones(self.batch_size, 1).to(self.device)
        fake_labels = torch.zeros(self.batch_size, 1).to(self.device)

        outputs = self.D(images)
        d_loss_real = self.loss_function(outputs, real_labels)
        real_score = outputs

        new_sample_vectors = torch.randn(self.batch_size, self.latent_size).to(self.device)
        fake_images = self.G(new_sample_vectors)
        outputs = self.D(fake_images)
        d_loss_fake = self.loss_function(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        self.reset_grad()
        d_loss.backward()
        self.d_optimizer.step()

        return d_loss, real_score, fake_score

    def train_generator(self):
        new_sample_vectors = torch.randn(self.batch_size, self.latent_size).to(self.device)
        fake_images = self.G(new_sample_vectors)
        labels = torch.ones(self.batch_size, 1).to(self.device)
        g_loss = self.loss_function(self.D(fake_images), labels)

        self.reset_grad()
        g_loss.backward()
        self.g_optimizer.step()
        return g_loss, fake_images

    def save_fake_images(self, index):
        fake_images = self.G(self.sample_vectors)
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        fake_fname = 'fake_images-{0:0=4d}.png'.format(index)
        print('Saving', fake_fname)
        save_image(self.denormalize(fake_images), os.path.join(self.sample_dir, fake_fname),
                   nrow=10)

    def run(self):
        total_step = len(self.data_loader)
        d_losses, g_losses, real_scores, fake_scores = [], [], [], []

        for epoch in range(self.num_epochs):
            for i, (images, _) in enumerate(self.data_loader):
                images = images.reshape(self.batch_size, -1)

                d_loss, real_score, fake_score = self.train_discriminator(images)
                g_loss, fake_images = self.train_generator()

                if (i + 1) % 600 == 0:
                    d_losses.append(d_loss.item())
                    g_losses.append(g_loss.item())
                    real_scores.append(real_score.mean().item())
                    fake_scores.append(fake_score.mean().item())
                    print(f'''Epoch [{epoch}/{self.num_epochs}], Step [{i + 1}/{
                    total_step}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {
                    real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}''')
            self.save_fake_images(epoch + 1)


image_size = 784
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist = MNIST(root='./../data', train=True, download=True, transform=Compose([ToTensor(), Normalize(mean=(0.5,), std=(0.5,))]))
gans = MNIST_GANS(dataset=mnist, image_size=image_size, device=device)
gans.run()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


0.0%0.1%0.2%0.2%0.3%0.4%0.5%0.6%0.7%0.7%0.8%0.9%1.0%1.1%1.2%1.2%1.3%1.4%1.5%1.6%1.7%1.7%1.8%1.9%2.0%2.1%2.1%2.2%2.3%2.4%2.5%2.6%2.6%2.7%2.8%2.9%3.0%3.1%3.1%3.2%3.3%3.4%3.5%3.6%3.6%3.7%3.8%3.9%4.0%4.0%4.1%4.2%4.3%4.4%4.5%4.5%4.6%4.7%4.8%4.9%5.0%5.0%5.1%5.2%5.3%5.4%5.5%5.5%5.6%5.7%5.8%5.9%6.0%6.0%6.1%6.2%6.3%6.4%6.4%6.5%6.6%6.7%6.8%6.9%6.9%7.0%7.1%7.2%7.3%7.4%7.4%7.5%7.6%7.7%7.8%7.9%7.9%8.0%8.1%8.2%8.3%8.3%8.4%8.5%8.6%8.7%8.8%8.8%8.9%9.0%9.1%9.2%9.3%9.3%9.4%9.5%9.6%9.7%9.8%9.8%9.9%10.0%10.1%10.2%10.2%10.3%10.4%10.5%10.6%10.7%10.7%10.8%10.9%11.0%11.1%11.2%11.2%11.3%11.4%11.5%11.6%11.7%11.7%11.8%11.9%12.0%12.1%12.1%12.2%12.3%12.4%12.5%12.6%12.6%12.7%12.8%12.9%13.0%13.1%13.1%13.2%13.3%13.4%13.5%13.6%13.6%13.7%13.8%13.9%14.0%14.0%14.1%14.2%14.3%14.4%14.5%14.5%14.6%14.7%14.8%14.9%15.0%15.0%15.1%15.2%15.3%15.4