In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

  warn(


In [2]:
### Creating the Discriminator 
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid() # real or fake
        )
        
    def forward(self, x):
        return self.disc(x)
    
### Creating the Generator
class Generator(nn.Module):
    def __init__(self, img_dim, z_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # outputted a 28 x 28 x 1 -> 784
            nn.Tanh()
        )
        
    def forward(self,x):
        return self.gen(x)

In [3]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(image_dim, z_dim).to(device)
fixed_noise = torch.rand((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]
)

# Datasets
dataset = datasets.MNIST(root='./data',
                         transform=transforms,
                         download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optimizer
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:44<00:00, 220637.08it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 114774.87it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:06<00:00, 272652.62it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 931247.43it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw



In [25]:
for epoch in range(num_epochs):
    for i, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        
        ### Train a Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn((batch_size,z_dim)).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        ### Train a Generator: max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        
        if i % 600 == 0:
            print('epoch %.1f, loss gen %.2f, loss disc %.2f (step : %.1f )' % (epoch, lossG, lossD, i))

epoch 0.0, loss gen 0.67, loss disc 0.60 (step : 0.0 )
epoch 0.0, loss gen 1.70, loss disc 0.22 (step : 600.0 )
epoch 0.0, loss gen 1.00, loss disc 0.48 (step : 1200.0 )
epoch 0.0, loss gen 1.40, loss disc 0.36 (step : 1800.0 )
epoch 1.0, loss gen 0.82, loss disc 0.63 (step : 0.0 )
epoch 1.0, loss gen 1.25, loss disc 0.34 (step : 600.0 )
epoch 1.0, loss gen 0.99, loss disc 0.66 (step : 1200.0 )
epoch 1.0, loss gen 0.83, loss disc 0.54 (step : 1800.0 )
epoch 2.0, loss gen 0.82, loss disc 0.62 (step : 0.0 )
epoch 2.0, loss gen 1.18, loss disc 0.50 (step : 600.0 )
epoch 2.0, loss gen 0.85, loss disc 0.74 (step : 1200.0 )
epoch 2.0, loss gen 0.59, loss disc 1.00 (step : 1800.0 )
epoch 3.0, loss gen 1.04, loss disc 0.43 (step : 0.0 )
epoch 3.0, loss gen 0.85, loss disc 0.67 (step : 600.0 )
epoch 3.0, loss gen 1.35, loss disc 0.62 (step : 1200.0 )
epoch 3.0, loss gen 0.82, loss disc 0.58 (step : 1800.0 )
epoch 4.0, loss gen 1.02, loss disc 0.56 (step : 0.0 )
epoch 4.0, loss gen 1.28, loss di

In [23]:
torch.ones_like(torch.tensor(5))

tensor(1)