In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim 
import torchvision 
import torchvision.datasets as datasets
from torch.utils.data import DataLoader 
import torchvision.transforms as transforms 
from torch.utils.tensorboard import SummaryWriter 

In [2]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),               #in gans leakyReLU is good default choice
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        return self.disc(x)

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim,img_dim):       #z_dim is dim of latent noise that teh gen will be taking
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),                           #i/p normalized betwn 1 and -1 so do same for this too
        )
    def forward(self, x):
        return self.gen(x)


In [6]:
#hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu" 
lr = 1e-4 # andrej kar tweet this is best lr for adamopt
# lr = 3e-4 # andrej kar tweet this is best lr for adamopt
z_dim = 64 #128,256 #gans are sensitive to these hyperparameters
#newer papers have better wasys to stabalizing gans 
img_dim  = 28*28*1
batch_size = 32
num_epochs = 50

disc = Discriminator(img_dim).to(device)  #yeha ewta chuter tanab ??
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transform = transforms.Compose(
    [
        transforms.ToTensor(), 
        transforms.Normalize((0.1307,),(0.3081,))]
)
dataset = datasets.MNIST(root='dataset/', transform= transform, download= True)
loader = DataLoader(dataset, batch_size= batch_size, shuffle = True)
opt_disc = optim.Adam(disc.parameters(), lr= lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)
criterion = nn.BCELoss()  #similar form to previous intro gan equation 
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")        #confused about this separate video for tensorboard
step = 0 #step for tensorboard

for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):  #gans are unsupervised in a way that you dont need label for each image just real or fake
        real = real.view(-1, 784).to(device) #-1 to keep number of examples in a batch
        batch_size = real.shape[0]
        
        ## train discriminator : max log(D(real))+log(1-D(G(z)))
        noise = torch.randn(batch_size,z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)   #faltten everything 
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))  # derived using the bce los fxn pytorch documentation 
        ##here bce minimizes the -ve of the log(D(real)) which is ultimate way of formulating maximization of the +ve exp
        disc_fake = disc(fake).view(-1)   #fake.detach() can be done to detach fake 
        lossD_fake = criterion(disc_fake, torch.zeors_like(disc_fake))
        lossD = (lossD_fake+lossD_real)/2
        disc.zero_grad()
        lossD.backward(retain_graph = True)   #when we do .backward() allthe forward pass we generate will be remove form the cache
        opt_disc.step()
        

        ## train generator min log(1-D(G(z))) <----> max  log(D(G(z)))
        output = disc(fake).view(-1)  # to compute actual gradients we need intermediate grads from these fakes
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}]\ "
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize = True)
                img_grid_real = torchvision.utils.make_grid(data,normalize=True)
                
                writer_fake.add_image(
                    "Mnists Fake images", img_grid_fake, global_step= step
                )
                
                writer_real.add_image(
                    "Mnist real Images", img_grid_real, global_step = step
                )
                step += 1




  f"Epoch [{epoch}/{num_epochs}]\ "
  f"Epoch [{epoch}/{num_epochs}]\ "


AttributeError: 'Compose' object has no attribute 'Compose'