In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
print("package loaded")

package loaded


In [2]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0,1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0,1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )
    def forward(self, x):
        return self.gen(x)

# Hyperparametrs etc
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,)),]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

In [3]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        #Train Discriminator : max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_fake + lossD_real) / 2
        disc.zero_grad()
        lossD.backward(retain_graph = True)
        opt_disc.step()


        ## Train Generator min log(1 - D(G(z))) <--> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx ==0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, Loss G:{lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28 )
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step = step
                )

                writer_real.add_image(
                    "Mnist real Images", img_grid_real, global_step = step
                )

                step+=1 




Epoch [0/50] \ Loss D: 0.7662, Loss G:0.7181
Epoch [1/50] \ Loss D: 0.7218, Loss G:0.8296
Epoch [2/50] \ Loss D: 0.7030, Loss G:0.7666
Epoch [3/50] \ Loss D: 0.7448, Loss G:0.8513
Epoch [4/50] \ Loss D: 0.7750, Loss G:0.8272
Epoch [5/50] \ Loss D: 0.2623, Loss G:1.7289
Epoch [6/50] \ Loss D: 0.5109, Loss G:1.2280
Epoch [7/50] \ Loss D: 0.8954, Loss G:0.6832
Epoch [8/50] \ Loss D: 0.6194, Loss G:1.0272
Epoch [9/50] \ Loss D: 0.8469, Loss G:0.6488
Epoch [10/50] \ Loss D: 0.6203, Loss G:0.9111
Epoch [11/50] \ Loss D: 0.7395, Loss G:0.7766
Epoch [12/50] \ Loss D: 0.4679, Loss G:1.1203
Epoch [13/50] \ Loss D: 1.0092, Loss G:0.6748
Epoch [14/50] \ Loss D: 0.5820, Loss G:1.0073
Epoch [15/50] \ Loss D: 0.7740, Loss G:0.8569
Epoch [16/50] \ Loss D: 0.8137, Loss G:0.7001
Epoch [17/50] \ Loss D: 0.5507, Loss G:1.0649
Epoch [18/50] \ Loss D: 0.6710, Loss G:0.7440
Epoch [19/50] \ Loss D: 1.0640, Loss G:0.6139
Epoch [20/50] \ Loss D: 0.7195, Loss G:0.9140
Epoch [21/50] \ Loss D: 0.7967, Loss G:0.674