In [22]:
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torch.utils.tensorboard import SummaryWriter
import torchvision


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

In [18]:
learning_rate = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epochs = 50

In [19]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize((0.5,), (0.5,))
])

In [20]:
dataset = datasets.MNIST(root='MNIST/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)
opt_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 0

In [21]:
for data, label in loader:
    data = data.view(-1, 784)
print(data.shape)

torch.Size([32, 784])


In [23]:
from tqdm.auto import tqdm
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(tqdm(loader, desc=f'Epoch {epoch+1}', unit='batch')):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        #Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()

        lossD.backward(retain_graph=True)

        opt_disc.step()

        #Train Generator min log(1 - D(G(z))) == max log(D(G(z))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1
        

Epoch 1:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [0/50] Batch 0/1875                       Loss D: 0.5955, loss G: 0.7351


Epoch 2:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [1/50] Batch 0/1875                       Loss D: 0.3288, loss G: 1.4770


Epoch 3:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [2/50] Batch 0/1875                       Loss D: 0.6986, loss G: 0.8741


Epoch 4:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [3/50] Batch 0/1875                       Loss D: 0.6877, loss G: 1.0533


Epoch 5:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [4/50] Batch 0/1875                       Loss D: 0.8482, loss G: 0.6100


Epoch 6:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [5/50] Batch 0/1875                       Loss D: 0.4213, loss G: 1.3988


Epoch 7:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [6/50] Batch 0/1875                       Loss D: 0.4534, loss G: 1.1871


Epoch 8:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [7/50] Batch 0/1875                       Loss D: 0.5296, loss G: 1.2509


Epoch 9:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [8/50] Batch 0/1875                       Loss D: 0.4934, loss G: 1.4695


Epoch 10:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [9/50] Batch 0/1875                       Loss D: 0.7765, loss G: 0.7153


Epoch 11:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [10/50] Batch 0/1875                       Loss D: 0.6377, loss G: 1.1401


Epoch 12:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [11/50] Batch 0/1875                       Loss D: 0.4885, loss G: 1.3906


Epoch 13:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [12/50] Batch 0/1875                       Loss D: 0.6646, loss G: 0.8649


Epoch 14:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [13/50] Batch 0/1875                       Loss D: 0.6158, loss G: 1.0826


Epoch 15:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [14/50] Batch 0/1875                       Loss D: 0.6138, loss G: 1.3286


Epoch 16:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [15/50] Batch 0/1875                       Loss D: 0.7725, loss G: 1.1235


Epoch 17:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [16/50] Batch 0/1875                       Loss D: 0.7060, loss G: 1.1740


Epoch 18:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [17/50] Batch 0/1875                       Loss D: 0.6880, loss G: 1.3535


Epoch 19:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [18/50] Batch 0/1875                       Loss D: 0.7918, loss G: 1.1859


Epoch 20:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [19/50] Batch 0/1875                       Loss D: 0.5930, loss G: 1.1028


Epoch 21:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [20/50] Batch 0/1875                       Loss D: 0.4546, loss G: 1.4387


Epoch 22:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [21/50] Batch 0/1875                       Loss D: 0.6339, loss G: 0.9857


Epoch 23:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [22/50] Batch 0/1875                       Loss D: 0.7466, loss G: 1.2208


Epoch 24:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [23/50] Batch 0/1875                       Loss D: 0.7662, loss G: 1.0894


Epoch 25:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [24/50] Batch 0/1875                       Loss D: 0.6649, loss G: 0.9264


Epoch 26:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [25/50] Batch 0/1875                       Loss D: 0.6219, loss G: 1.0248


Epoch 27:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [26/50] Batch 0/1875                       Loss D: 0.6099, loss G: 1.0726


Epoch 28:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [27/50] Batch 0/1875                       Loss D: 0.7168, loss G: 0.7754


Epoch 29:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [28/50] Batch 0/1875                       Loss D: 0.6412, loss G: 0.8345


Epoch 30:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [29/50] Batch 0/1875                       Loss D: 0.6987, loss G: 1.0683


Epoch 31:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [30/50] Batch 0/1875                       Loss D: 0.6526, loss G: 0.8966


Epoch 32:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [31/50] Batch 0/1875                       Loss D: 0.6718, loss G: 0.9837


Epoch 33:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [32/50] Batch 0/1875                       Loss D: 0.6198, loss G: 0.9365


Epoch 34:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [33/50] Batch 0/1875                       Loss D: 0.6648, loss G: 0.9870


Epoch 35:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [34/50] Batch 0/1875                       Loss D: 0.6974, loss G: 0.7321


Epoch 36:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [35/50] Batch 0/1875                       Loss D: 0.6526, loss G: 1.1610


Epoch 37:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [36/50] Batch 0/1875                       Loss D: 0.5477, loss G: 1.2032


Epoch 38:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [37/50] Batch 0/1875                       Loss D: 0.6661, loss G: 0.8911


Epoch 39:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [38/50] Batch 0/1875                       Loss D: 0.6703, loss G: 0.7897


Epoch 40:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [39/50] Batch 0/1875                       Loss D: 0.5657, loss G: 1.0464


Epoch 41:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [40/50] Batch 0/1875                       Loss D: 0.6510, loss G: 1.1087


Epoch 42:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [41/50] Batch 0/1875                       Loss D: 0.6407, loss G: 0.9554


Epoch 43:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [42/50] Batch 0/1875                       Loss D: 0.6498, loss G: 0.9204


Epoch 44:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [43/50] Batch 0/1875                       Loss D: 0.6772, loss G: 0.9573


Epoch 45:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [44/50] Batch 0/1875                       Loss D: 0.5269, loss G: 1.1116


Epoch 46:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [45/50] Batch 0/1875                       Loss D: 0.6481, loss G: 0.9933


Epoch 47:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [46/50] Batch 0/1875                       Loss D: 0.7199, loss G: 0.8566


Epoch 48:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [47/50] Batch 0/1875                       Loss D: 0.7859, loss G: 0.7953


Epoch 49:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [48/50] Batch 0/1875                       Loss D: 0.6344, loss G: 0.9836


Epoch 50:   0%|          | 0/1875 [00:00<?, ?batch/s]

Epoch [49/50] Batch 0/1875                       Loss D: 0.6646, loss G: 1.0051
