In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_image(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    image_tensor = (image_tensor+1)/2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    if show:
        plt.show()

In [2]:
class Generator(nn.Module):
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_block(input_dim, hidden_dim*4),
            self.make_block(hidden_dim*4, hidden_dim*2, kernel=4, stride=1),
            self.make_block(hidden_dim*2, hidden_dim),
            self.make_block(hidden_dim, im_chan, kernel=4, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=3, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.ReLU(inplace=True)
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel, stride),
                nn.Tanh()
            )

    def forward(self, x):
        x = x.view(len(x), self.input_dim, 1, 1)
        return self.gen(x)


def get_noise(n_samples, input_dim, device='cuda'):
    return torch.randn(n_samples, input_dim, device=device)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_block(im_chan, hidden_dim),
            self.make_block(hidden_dim, hidden_dim*2),
            self.make_block(hidden_dim*2, 1, final=True)
        )

    def make_block(self, in_chan, out_chan, kernel=4, stride=2, final=False):
        if not final:
            return nn.Sequential(
                nn.Conv2d(in_chan, out_chan, kernel, stride),
                nn.BatchNorm2d(out_chan),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            return nn.Conv2d(in_chan, out_chan, kernel, stride)

    def forward(self, x):
        disc_pred = self.disc(x)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
    return F.one_hot(labels, n_classes)

In [None]:
def combine_vectors(x, y):
    return torch.cat((x,y), 1).float()

In [None]:
mnist_shape=(1, 28, 28)
n_classes = 20
criterion = nn.BCEWithLogitsLoss()
n_epochs = 100
z_dim = 64
display_step = 1000
batch_size = 128
lr = 0.0002
device = "cuda"

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    generator_input_dim = z_dim+n_classes
    discriminator_im_chan = mnist_shape[0] + n_classes
    return generator_input_dim, discriminator_im_chan

In [None]:
gen_dim, disc_dim = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=gen_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

disc = Discriminator(im_chan=disc_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)






KeyboardInterrupt: ignored

In [None]:
cur_step = 0
gen_losses = []
disc_losses = []

for epoch in range(n_epochs):
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)
        
        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        image_one_hot_labels = one_hot_labels[:,:, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1], mnist_shape[2])

        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim)
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)

        fake_image_and_labels = combine_vectors(fake.detach(), image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)
        disc_fake_pred = disc(fake_image_and_labels)
        disc_real_pred = disc(real_image_and_labels)

        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        disc_losses += [disc_loss.item()]

        gen_opt.zero_grad()
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        disc_fake_pred = disc(fake_image_and_labels)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        gen_losses += [gen_loss.item()]

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(gen_losses[-display_step:]) / display_step
            disc_mean = sum(disc_losses[-display_step:]) / display_step

            print(f"Step {cur_step} Disc {disc_mean} Gen {gen_mean}")
            show_tensor_image(fake)
            show_tensor_image(real)

        cur_step += 1

In [None]:
gen = gen.eval()

In [None]:
import math

n_interpolation = 9
interpolation_noise = get_noise(1, z_dim).repeat(n_interpolation, 1)

def interpolate_class(first_number, second_number):
    first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
    second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)

    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
    interpolation_labels = first_label * (1-percent_second_label) + second_label*percent_second_label

    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

start_plot_number = 3
end_plot_number = 7

plt.figure(figsize=(8,8))
interpolate_class(start_plot_number, end_plot_number)

_ = plt.axis('off')

plot_numbers = [2, 3, 4, 5, 7]
n_numbers = len(plot_numbers)
plt.figure(figsize=(8,8))

for i, first_plot_number in enumerate(plot_numbers):
    for j, second_plot_number in enumerate(plot_numbers):
        plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
        interpolate_class(first_plot_number, second_plot_number)
        plt.axis("off")

plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()