In [None]:
import torch
from torch import nn
import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

torch.manual_seed(45)

In [None]:
device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Device: ", device)

1. transforms.ToTensor() = Convert the image to PyTorch tensors

2. transforms.Normalize((0.5,), (0.5,)) = Normalize the tensors so that all the values are between -1 and 1

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

In [None]:
train_set = torchvision.datasets.MNIST(
    root=".", train=True, download=True, transform=transform
)

In [None]:
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)

In [None]:
data, mnist_labels = next(iter(train_loader))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(data[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])

In [None]:
#Discriminator model

class Discriminator(nn.Module):

  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
    )

  def forward(self, x):
    x = x.view(x.size(0), 784)
    output = self.model(x)
    return output

In [None]:
#Generator Model

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(100, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 1024),
        nn.ReLU(),
        nn.Linear(1024, 784),
        nn.Tanh()
    )

  def forward(self, x):
    output = self.model(x)
    output = output.view(x.size(0), 1, 28, 28)
    return output


In [None]:
def train(generator, discriminator, train_loader, n_epochs, batch_size, loss_function, dis_optim, gen_optim):
  for epoch in range(n_epochs):
    for n, (real_data, mnist_labels) in enumerate(train_loader):
      # MNIST Data
      real_data = real_data.to(device)
      real_data_labels = torch.ones((batch_size, 1)).to(device) # All real data labeled 1

      # Generator generated Data
      noise = torch.randn((batch_size, 100)).to(device)
      gen_data = generator(noise)
      gen_data_labels = torch.zeros((batch_size, 1)).to(device) # All generated data labeled 0

      # Concatinating all the data and labels
      combined_data = torch.cat((real_data, gen_data))
      combined_labels = torch.cat((real_data_labels, gen_data_labels))

      # Discriminator Training
      discriminator.zero_grad() # clear previous gradients
      dis_out = discriminator(combined_data) # discriminator output
      dis_loss = loss_function(dis_out, combined_labels) # Calculate the loss
      dis_loss.backward() # Backpropogation
      dis_optim.step() # change lr accordingly

      # Generator Training
      new_gen_data = torch.randn((batch_size, 100)).to(device) # Noise to create new generator data

      generator.zero_grad() # clear previous gradients
      gen_out = generator(new_gen_data) # newly generated data
      dis_out = discriminator(gen_out) # discriminator output
      gen_loss = loss_function(dis_out, real_data_labels) #
      gen_loss.backward() # Backpropogation
      gen_optim.step() # change lr accordingly

    if epoch%20 == 0:
      print(f"Epoch: {epoch} Loss D.: {dis_loss}, Loss G.: {gen_loss}")

In [None]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

In [None]:
lr = 0.0001
n_epochs =201
loss_function = nn.BCELoss()

dis_optim = torch.optim.Adam(discriminator.parameters(), lr=lr)
gen_optim = torch.optim.Adam(generator.parameters(), lr=lr)

In [None]:
train(generator, discriminator, train_loader, n_epochs, batch_size, loss_function, dis_optim, gen_optim)

In [None]:
noise = torch.randn(batch_size, 100).to(device=device)
gen_data = generator(noise)

gen_data = gen_data.cpu().detach()
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    plt.imshow(gen_data[i].reshape(28, 28), cmap="gray_r")
    plt.xticks([])
    plt.yticks([])