In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import numpy as np
import random

%matplotlib notebook
import matplotlib.pyplot as plt

from itertools import chain
from model import Discriminator1, Discriminator2, Encoder, Generator, Hencoder, Aggregator


#debug use
import pdb

In [2]:
def splitImage(x, divisor=2):
    w = x.shape[2]
    w = w // divisor
    
    res = []
    for i in range(divisor):
        for j in range(divisor):
            res.append(x[:, :, i*w:(i+1)*w, j*w:(j+1)*w])
        
    return torch.stack(res, -1)

In [3]:
trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5,), (0.5,))
                           ])

In [4]:
train_set = torchvision.datasets.MNIST(root="../mnist/", train=True, transform=trans, download=True)
test_set = torchvision.datasets.MNIST(root="../mnist/", train=False, transform=trans, download=True)

batch_size = 1024

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)

test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

In [5]:
device = torch.device("cuda:0")

In [6]:
params = {
        "slope": 2e-2,
        "dropout": 0.2,
        "num_channels": 1,
        "z_dim": 128,
        "device": device,
        "num_views": 4,
        "aggregation_dim": 1024
}

In [7]:
D1 = Discriminator1(params)
D2 = Discriminator2(params)
G = Generator(params)
E = Encoder(params)
A = Aggregator(params)
H = Hencoder(params)

In [8]:
GEH = [G, E, H, A]
D1D2 = [D1, D2, A]

In [9]:
# optimizers
min_optimzer = optim.Adam(chain.from_iterable([m.parameters() for m in GEH]), 1e-4, betas=(0.5, 0.999), weight_decay=2.5*1e-5)
max_optimzer = optim.Adam(chain.from_iterable([m.parameters() for m in D1D2]), 1e-4, betas=(0.5, 0.999), weight_decay=2.5*1e-5)

In [10]:
# set iteration to 0
iter_cnt = 0
EPS = 1e-12

In [57]:
[m.train() for m in GEH]
[m.to(device) for m in GEH]
[m.train() for m in D1D2]
[m.to(device) for m in D1D2]

for epoch in range(30):
        for batch_idx, (batch_x, _) in enumerate(train_loader):
            
            batch_x = batch_x.to(device)
            
            # real loss
            e = E(batch_x)
            D1_Y_out = D1(batch_x, e).squeeze()
            
            # fake loss with noise
            z = torch.randn(batch_x.shape[0], 128, 1, 1).to(device)
            Gz = G(z)
            D1_z_out = D1(Gz, z).squeeze()
            
            multiview = splitImage(batch_x).squeeze()
            # randomly select indices
            indices = random.sample(set([0, 1, 2, 3]), random.randint(1, 3))
            
            
            H_aggre = A(multiview, indices)
            Y_aggre = A(multiview, random.sample(set([0, 1, 2, 3]), random.randint(3, 5)))
            
            D2_r_out = D2(e.squeeze(), Y_aggre)
            
            H_X = H(H_aggre)
            H_X_PRIME = H(Y_aggre)
            
            D2_f_out = D2(H_X.sample(), H_aggre)
            
            kl_loss = torch.distributions.kl.kl_divergence(H_X, H_X_PRIME).mean()
            
            # loss calculation
            max_loss = -torch.mean(torch.log(D1_Y_out + EPS) + torch.log(1 - D1_z_out + EPS)) \
                         -torch.mean(torch.log(D2_r_out + EPS) + torch.log(1 - D2_f_out + EPS))
            
            
            
            min_loss = -torch.mean(torch.log(D1_z_out + EPS) + torch.log(1 - D1_Y_out + EPS)) \
                        -torch.mean(torch.log(D2_f_out + EPS) + torch.log(1 - D2_r_out + EPS)) + kl_loss
            
            max_optimzer.zero_grad()
            max_loss.backward(retain_graph=True)
            max_optimzer.step()
            
            min_optimzer.zero_grad()
            min_loss.backward()
            min_optimzer.step()
            
            if iter_cnt % 100 == 0:
                print("Iter ", iter_cnt, " Max_Loss ", max_loss.item(), " Min_Loss ", min_loss.item())
            
            iter_cnt += 1
            
            

Iter  3600  Max_Loss  1.7973804473876953  Min_Loss  6.730755805969238
Iter  3700  Max_Loss  1.46648108959198  Min_Loss  7.631901264190674
Iter  3800  Max_Loss  1.7657947540283203  Min_Loss  6.901998043060303
Iter  3900  Max_Loss  1.804825782775879  Min_Loss  7.078032493591309
Iter  4000  Max_Loss  1.8235807418823242  Min_Loss  6.962728023529053
Iter  4100  Max_Loss  1.6455621719360352  Min_Loss  6.937549591064453
Iter  4200  Max_Loss  1.4490865468978882  Min_Loss  7.549886703491211
Iter  4300  Max_Loss  1.649861454963684  Min_Loss  7.314852237701416
Iter  4400  Max_Loss  1.4786423444747925  Min_Loss  7.616426467895508
Iter  4500  Max_Loss  1.5956770181655884  Min_Loss  7.09641170501709
Iter  4600  Max_Loss  1.8515230417251587  Min_Loss  7.4279656410217285
Iter  4700  Max_Loss  1.5577671527862549  Min_Loss  7.3778204917907715
Iter  4800  Max_Loss  1.7720904350280762  Min_Loss  7.680863857269287
Iter  4900  Max_Loss  1.6991682052612305  Min_Loss  7.2189226150512695
Iter  5000  Max_Loss  

In [58]:
G.eval()
E.eval()
H.eval()
A.eval()

real = batch_x[10].detach().unsqueeze(0)
plt.imshow(real.cpu().squeeze().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fed2c624780>

In [59]:
z = E(real)
image = G(z)
plt.imshow(image.detach().cpu().squeeze().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fed2c549898>

In [70]:
multiview = splitImage(real).squeeze().unsqueeze(0)
Y_aggre = A(multiview, [0, 1, 2])
z = H(Y_aggre).sample().unsqueeze(-1).unsqueeze(-1)
image = G(z)
plt.imshow(image.detach().cpu().squeeze().numpy())

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fed2c05ee80>