In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np

%matplotlib notebook
import matplotlib.pyplot as plt

from itertools import chain
from model import Discriminator, Encoder, Generator

#debug use
import pdb

In [2]:
trans = transforms.Compose([transforms.ToTensor(),
                            #transforms.Normalize((0.5,), (0.5,))
                           ])

In [3]:
train_set = torchvision.datasets.MNIST(root="../mnist/", train=True, transform=trans, download=True)
test_set = torchvision.datasets.MNIST(root="../mnist/", train=False, transform=trans, download=True)

batch_size = 256

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)

test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

In [4]:
device = torch.device("cuda:0")

In [5]:
params = {
        "slope": 2e-2,
        "dropout": 0.2,
        "num_channels": 1,
        "z_dim": 128
}

In [6]:
D = Discriminator(params)
G = Generator(params)
E = Encoder(params)

D = nn.DataParallel(D)
G = nn.DataParallel(G)
E = nn.DataParallel(E)

In [None]:
# optimizers
EG_optimzer = optim.Adam(chain(E.parameters(), G.parameters()), 1e-4, betas=(0.5, 0.999), weight_decay=self.decay)
D_optimzer = optim.Adam(D.parameters(), 1e-4, betas=(0.5, 0.999), weight_decay=self.decay)
loss_func = nn.BCELoss()

In [None]:
# set iteration to 0
iter_cnt = 0
EPS = 1e-16

In [None]:
D.to(device)
G.to(device)
E.to(device)

for epoch in range(10):
        for batch_idx, (batch_x, _) in enumerate(train_loader):
            
            batch_x = batch_x.cuda()
            
            # labels
            y_real = torch.ones((batch_x.shape[0])).cuda()
            y_fake = torch.zeros((batch_x.shape[0])).cuda()
            
            # real loss
            e = E(batch_x)
            D_real_out = D(batch_x, e).squeeze()
            
            # fake loss with noise
            z = torch.randn(batch_x.shape[0], 128, 1, 1).cuda()
            G_out = G(z)
            D_fake_out = D(G_out, z).squeeze()
            
            D_loss = -torch.mean(torch.log(D_real_out + EPS) + torch.log(1 - D_fake_out + EPS))
            EG_loss = -torch.mean(torch.log(D_fake_out + EPS) + torch.log(1 - D_real_out + EPS))
            
            D_optimzer.zero_grad()
            D_loss.backward(retain_graph=True)
            D_optimzer.step()
            
            EG_optimzer.zero_grad()
            EG_loss.backward()
            EG_optimzer.step()
            
            if iter_cnt % 100 == 0:
                print("Iter ", iter_cnt, " D_Loss ", D_loss.item(), " EG_Loss ", EG_loss.item())
            
            iter_cnt += 1

In [None]:
test_net = G.cpu()

In [None]:
noise = torch.randn(10, 100, 1, 1)
images = test_net(noise)
plt.imshow(images[0].squeeze().detach().numpy())