In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import  MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7ff5dd87d5a0>

In [2]:
def show_tensor_images(image_tensor, num_images=25, size=(1,28,28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.show()

In [3]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=128):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_block(z_dim, hidden_dim*4),
            self.make_block(hidden_dim*4, hidden_dim*2, kernel=4, stride=1),
            self.make_block(hidden_dim*2, hidden_dim),
            self.make_block(hidden_dim, im_chan, kernel=4, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=3, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)

    def forward(self, x):
        x = self.unsqueeze_noise(x)
        return self.gen(x)

def get_noise(n_samples, z_dim, device="cuda"):
    return torch.randn(n_samples, z_dim, device=device)


In [4]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=32):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_block(im_chan, hidden_dim),
            self.make_block(hidden_dim, hidden_dim*2),
            self.make_block(hidden_dim*2, 1, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=4, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.Conv2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            return nn.Conv2d(in_chan, out_chan, kernel, stride)

    def forward(self, x):
        disc_pred = self.disc(x)
        return disc_pred.view(len(disc_pred), -1)

In [5]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 1000
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
device = "cuda"

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))

disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)




In [7]:
n_epochs = 50
cur_step = 0
mean_gen_loss = 0
mean_disc_loss = 0

for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)

        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss+disc_real_loss)/2

        mean_disc_loss += disc_loss.item() / display_step
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        mean_gen_loss += gen_loss.item() / display_step

        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step} Disc {mean_disc_loss} Gen {mean_gen_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_gen_loss = 0
            mean_disc_loss = 0
        cur_step += 1


Output hidden; open in https://colab.research.google.com to view.