# Lab 1 Solution: Vanilla GAN

In [None]:
import torch, torchvision, torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
z_dim=64
g_lr=d_lr=2e-4

transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_ds=torchvision.datasets.MNIST('./data', True, download=True, transform=transform)
train_loader=DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
class Generator(nn.Module):
    def __init__(self,z_dim=64):
        super().__init__()
        self.net=nn.Sequential(
            nn.ConvTranspose2d(z_dim,256,3,2,0,bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256,128,4,2,1,bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64,1,5,2,2,bias=False),
            nn.Tanh()
        )
        
    def forward(self,z):
        if z.dim()==2:
            z=z.view(z.size(0),z.size(1),1,1)
        x=self.net(z)
        if x.size(-1)!=28:
            x=torch.nn.functional.interpolate(x, size=(28,28), mode='bilinear', align_corners=False)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(1,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,True),
            
            nn.Conv2d(64,128,4,2,1,bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),
            
            nn.Conv2d(128,1,7,1,0,bias=False)
        )
    
    def forward(self,x):return self.net(x).view(x.size(0),1)

In [None]:

G=Generator(z_dim).to(device)
D=Discriminator().to(device)

crit=nn.BCEWithLogitsLoss()

opt_g=torch.optim.Adam(G.parameters(),lr=g_lr,betas=(0.5,0.999))
opt_d=torch.optim.Adam(D.parameters(),lr=d_lr,betas=(0.5,0.999))

In [None]:
def train_discriminator_step(real):
    opt_d.zero_grad(set_to_none=True)
    b=real.size(0)
    
    ones=torch.ones(b,1,device=device)
    zeros=torch.zeros(b,1,device=device)
    
    loss=(crit(D(real),ones)+crit(D(G(torch.randn(b,z_dim,device=device).view(b,z_dim,1,1)).detach()),zeros))
    
    loss.backward()
    opt_d.step()
    return float(loss.item())

def train_generator_step(b):
    opt_g.zero_grad(set_to_none=True)
    fake=G(torch.randn(b,z_dim,device=device).view(b,z_dim,1,1))
    loss=crit(D(fake),torch.ones(b,1,device=device))
    loss.backward()
    opt_g.step()
    return float(loss.item())