# Lab 2 Solution: Fixed Vanilla GAN

In [None]:
import torch, torchvision, torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform=transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))
])
loader=DataLoader(
    torchvision.datasets.MNIST('./data', True, download=True, transform=transform), batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
z_dim=100; g_lr=d_lr=2e-4

In [None]:
class D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net=nn.Sequential(
            nn.Conv2d(1,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,True),
            
            nn.Conv2d(64,128,4,2,1,bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,True),
            
            nn.Conv2d(128,1,7,1,0,bias=False)
        )
        
    def forward(self,x):
        return self.net(x).view(x.size(0),1)

In [None]:
class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.net=nn.Sequential(
            nn.ConvTranspose2d(100,128,4,1,0,bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64,1,5,2,2,bias=False),
            nn.Tanh()
        )
        
    def forward(self,z):
        if z.dim()==2:
            z=z.view(z.size(0), z.size(1), 1, 1)
        x=self.net(z)
        if x.size(-1)!=28:
            x=torch.nn.functional.interpolate(x, size=(28,28), mode='bilinear', align_corners=False)
        return x

In [None]:
Dnet=D().to(device)
Gnet=G().to(device)
crit=nn.BCEWithLogitsLoss()

opt_d=torch.optim.Adam(Dnet.parameters(),lr=d_lr,betas=(0.5,0.999)) 
opt_g=torch.optim.Adam(Gnet.parameters(),lr=g_lr,betas=(0.5,0.999))

In [None]:
from tqdm import tqdm

for real,_ in tqdm(loader):
    real=real.to(device)
    b=real.size(0)
    
    opt_d.zero_grad(set_to_none=True)
    z=torch.randn(b,z_dim,device=device)
    fake=Gnet(z.view(b,z_dim,1,1)).detach()
    
    loss_d=crit(Dnet(real),torch.ones(b,1,device=device))+crit(Dnet(fake),torch.zeros(b,1,device=device))
    loss_d.backward()
    opt_d.step()
    
    opt_g.zero_grad(set_to_none=True)
    z=torch.randn(b,z_dim,device=device)
    fake=Gnet(z.view(b,z_dim,1,1))
    
    loss_g=crit(Dnet(fake),torch.ones(b,1,device=device))
    loss_g.backward()
    opt_g.step()
    
print('Fixed training loop runs.')