In [20]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def make_grad_hook():
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

In [21]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_block(z_dim, hidden_dim*4),
            self.make_block(hidden_dim*4, hidden_dim*2, kernel=4, stride=1),
            self.make_block(hidden_dim*2, hidden_dim),
            self.make_block(hidden_dim, im_chan, kernel=4, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=3, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.Tanh()
            )

        
    def forward(self, x):
        x = x.view(len(x), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device="cuda"):
    return torch.randn(n_samples, z_dim, device=device)

In [22]:
class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.make_block(im_chan, hidden_dim),
            self.make_block(hidden_dim, hidden_dim*2),
            self.make_block(hidden_dim*2, 1, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=4, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.Conv2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            return nn.Conv2d(in_chan, out_chan, kernel, stride)


    def forward(self, x):
        crit_pred = self.crit(x)
        return crit_pred.view(len(crit_pred), -1)


In [23]:
n_epochs = 100
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = "cuda"

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

In [24]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

crit = Critic().to(device)
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

In [25]:
def get_gradient(crit, real, fake, epsilon):
    mixed_images = real*epsilon + fake*(1-epsilon)
    mixed_scores = crit(mixed_images)
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]

    return gradient

In [26]:
def get_gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)

    penalty = (gradient_norm-1).pow(2).mean()
    return penalty

In [27]:
def get_gen_loss(crit_fake_pred):
    return -crit_fake_pred.mean()

In [28]:
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    return crit_fake_pred.mean()-crit_real_pred.mean()+c_lambda*gp

In [29]:
cur_step = 0
gen_losses = []
crit_losses = []

for epoch in range(n_epochs):

    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)
        mean_iter_crit_loss = 0
        for _ in range(crit_repeats):
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim)
            fake = gen(fake_noise)

            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = get_gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            mean_iter_crit_loss += crit_loss.item() / crit_repeats

            crit_loss.backward(retain_graph=True)
            crit_opt.step()
        crit_losses += [mean_iter_crit_loss]

        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()
        gen_opt.step()
        gen_losses += [gen_loss.item()]

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(gen_losses[-display_step:]) / display_step
            crit_mean = sum(crit_losses[-display_step:]) / display_step

            print(f"Step {cur_step} Gen {gen_mean} Crit {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
        cur_step += 1 

Output hidden; open in https://colab.research.google.com to view.