In [1]:
import torch
import torch.nn as nn

import torchvision
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter



In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super().__init__()

        # Try adding batch normalization later
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.LeakyReLU(0.2)
        self.conv1 = self._block(1, 256, 3, 2, 1, self.relu)
        self.conv2 = self._block(256, 256, 3, 2, 1, self.relu)

        # self.conv1 = nn.Conv2d(1, 256, kernel_size=3, stride =2, padding=1)
        # self.bn1 = nn.BatchNorm2d(256)
        # self.conv2 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        # self.bn2 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(256 * 7 * 7, 1)
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding, activation):

        return nn.Sequential(*[
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            activation
        ])
    
    def forward(self, X):
        out = self.conv1(X)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)
        out = self.sigmoid(self.fc1(out))

        return out
    
class Generator(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        self.fc1 = nn.Linear(z_dim, 256 * 7 * 7)
        self.bn1 = nn.BatchNorm1d(256*7*7)
        self.relu = nn.LeakyReLU(0.2)
        self.tanh = nn.Tanh()

        self.conv1 = self._block(256, 128, 4, 2, 1, self.relu)
        self.conv2 = self._block(128, 64, 4, 2, 1, self.relu)
        self.conv3 = self._block(64, 1, 3, 1, 1, self.tanh)


        # self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride = 2, padding=1)
        # self.bn2 = nn.BatchNorm2d(128)
        # self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding =1)
        # self.bn3 = nn.BatchNorm2d(64)
        # self.conv3 = nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding =1)

        

    def _block(self, in_channels, out_channels, kernel_size, stride, padding, activation):
        return nn.Sequential(*[
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            activation,
        ])

    def forward(self, X):
        out = self.relu(self.bn1(self.fc1(X)))
        out = out.view(out.size(0), 256, 7, 7)
        out = self.conv1(out)
        out = self.conv2(out)
        out = self.conv3(out)

        return out


In [7]:
# Hyper params
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
epochs = 50

disc = Discriminator()
gene = Generator(z_dim)
fixed_noise = torch.randn((batch_size, z_dim))



In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [9]:
dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

optim_disc = torch.optim.Adam(params=disc.parameters(), lr=lr)
optim_gene = torch.optim.Adam(params=gene.parameters(), lr=lr)
criterion = nn.BCELoss()

writer_fake = SummaryWriter('runs/fake')
writer_real = SummaryWriter('runs/real')
step = 1


for epoch in range(epochs):
    for idx, (real, _) in enumerate(loader):
        batch_size = real.shape[0]
        noise = torch.randn(batch_size, z_dim)
        fake = gene(noise)

        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake)/2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optim_disc.step()

        # Generator
        out = disc(fake).view(-1)
        lossG = criterion(out, torch.ones_like(out))
        gene.zero_grad()
        lossG.backward()
        optim_gene.step()

        if idx == 0:
            print(f"Epoch [{epoch}/{epochs}]; lossD: {lossD:.4f}; lossG: {lossG:.4f}")

            with torch.no_grad():
                fake = gene(fixed_noise).view(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(real, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )

                writer_fake.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step+=1

    

Epoch [0/50]; lossD: 0.0004; lossG: 13.6932
Epoch [1/50]; lossD: 0.0000; lossG: 12.0279
Epoch [2/50]; lossD: 0.0000; lossG: 13.7280
Epoch [3/50]; lossD: 0.0000; lossG: 15.2406
Epoch [4/50]; lossD: 0.0000; lossG: 16.1687
Epoch [5/50]; lossD: 0.0214; lossG: 5.2874
Epoch [6/50]; lossD: 0.0463; lossG: 2.6583
Epoch [7/50]; lossD: 0.1735; lossG: 2.7121
Epoch [8/50]; lossD: 0.2397; lossG: 2.4559
Epoch [9/50]; lossD: 0.1702; lossG: 3.2550
Epoch [10/50]; lossD: 0.1584; lossG: 2.2367
Epoch [11/50]; lossD: 0.2307; lossG: 3.2700
Epoch [12/50]; lossD: 0.2097; lossG: 1.9792
Epoch [13/50]; lossD: 0.2402; lossG: 2.7962
Epoch [14/50]; lossD: 0.2690; lossG: 1.9417
Epoch [15/50]; lossD: 0.1579; lossG: 3.1928
Epoch [16/50]; lossD: 0.2352; lossG: 2.3110
Epoch [17/50]; lossD: 0.3720; lossG: 2.5588


KeyboardInterrupt: 