In [1]:
from torch import nn
import torch
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
from tqdm.notebook import tqdm

from models.MNIST_MLP_cgan import MnistCGAN, MLPGenerator, MLPDiscriminator
from models.helper_functions import display_images

In [2]:
has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)

device = "mps:0" if getattr(torch,'has_mps',False) \
    else "cuda:0" if torch.cuda.is_available() else "cpu"

### Define parameters, Load MNIST

In [3]:
EPOCHS = 50
batch_size = 32


transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=0.5, std=0.5)
])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

dataiter = iter(testloader)
images, labels = next(dataiter)
random_image = images[0]

inverse_transform = T.Normalize(mean=-1.0, std=2.0)

### Train function

In [4]:
def train_GAN(model, train_loader, EPOCHS, device):
    for epoch in tqdm(range(0, EPOCHS), desc="Epoch"):
        G_running_loss = 0
        D_running_loss = 0

        for i, (X, y) in enumerate(tqdm(train_loader, leave=False, desc=f"Epoch {epoch}")):
            X = X.to(device)
            y = y.to(device)

            # iteration for discriminator
            z, D_loss = model.D_step(X, y)

            # iteration for the generator
            G_loss = model.G_step(z, y)

            # update running metrics
            D_running_loss += D_loss
            G_running_loss += G_loss

        # logging, saving checkpoints

        # log random images
        with torch.no_grad():
            if epoch % 1 == 0:
                display_images(model(10, torch.arange(0, 10)).permute(0, 2, 3, 1), figsize=(13, 2), title="", cmap="gray")
                plt.show()

In [None]:
model = MnistCGAN(D_model=MLPDiscriminator(), G_model=MLPGenerator(), device=device)
train_GAN(model, trainloader, 50, device)