<a href="https://colab.research.google.com/github/satoruk-icepp/mlhep2019_2_phase/blob/master/analysis/pytorch_mnist_conditional_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! [ ! -z "$COLAB_GPU" ] && pip install skorch comet_ml



In [0]:
%%writefile .comet.config
[comet]
api_key=mIel5ZAPOioTs0Cij75dSSQXs
logging_file = /tmp/comet.log
logging_file_level = info

Overwriting .comet.config


In [0]:
from comet_ml import Experiment
experiment = Experiment(project_name="BayesMNISTGAN")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/satoruk-icepp/bayesmnistgan/d3acec18d35546ffb031116a697eec5c



In [0]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import math
import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
def one_hot(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])

In [0]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    transforms.Normalize(mean=(0.5,), std=(0.5,))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [0]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim+10, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x,label):
        x = torch.cat([x,label],dim=1)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim+10, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x,label):
#         print(x.shape,label.shape)
        x = torch.cat([x,label],dim=1)
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [0]:
# build network
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
mnist_dimx = train_dataset.train_data.size(1)
mnist_dimy = train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)



In [0]:
G

Generator(
  (fc1): Linear(in_features=110, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [0]:
D

Discriminator(
  (fc1): Linear(in_features=794, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [0]:
class NoiseLoss(torch.nn.Module):
  # need the scale for noise standard deviation
  # scale = noise  std
    def __init__(self, params, noise_std, observed=None):
        super(NoiseLoss, self).__init__()
        self.observed = observed
        self.noise_std = noise_std

    def forward(self, params,  observed=None):
    # scale should be sqrt(2*alpha/eta)
    # where eta is the learning rate and alpha is the strength of drag term
        if observed is None:
            observed = self.observed

#         assert scale is not None, "Please provide scale"
        noise_loss = 0.0
        for var in params:
            # This is scale * z^T*v
            # The derivative wrt v will become scale*z
#             _noise = noise.normal_(0.,self.noise_std)
            _noise = self.noise_std*torch.randn(1)
            noise_loss += torch.sum(Variable(_noise)*var)
        noise_loss /= observed
        return noise_loss

class PriorLoss(torch.nn.Module):
  # negative log Gaussian prior
    def __init__(self, prior_std=1., observed=None):
        super(PriorLoss, self).__init__()
        self.observed = observed
        self.prior_std = prior_std

    def forward(self, params, observed=None):
        if observed is None:
            observed = self.observed
        prior_loss = 0.0
        for var in params:
            prior_loss += torch.sum(var*var/(self.prior_std*self.prior_std))
        prior_loss /= observed
        return prior_loss

In [0]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [0]:
alpha=0.002
gprior_criterion = PriorLoss(prior_std=1., observed=10000)
gnoise_criterion = NoiseLoss(params=G.parameters(), noise_std=math.sqrt(2 * alpha * lr), observed=10000)
dprior_criterion = PriorLoss(prior_std=1., observed=10000)
dnoise_criterion = NoiseLoss(params=D.parameters(), noise_std=math.sqrt(2 * alpha * lr), observed=10000)

In [0]:
def D_train(x,y):
    #=======================Train the discriminator=======================#
    D.zero_grad()
    
    # train discriminator on real
    x_real,y_label, y_real = x.view(-1, mnist_dim), torch.FloatTensor(one_hot(y,10)),torch.ones(bs, 1)
    x_real,y_label, y_real = Variable(x_real.to(device)), Variable(y_label.to(device)),Variable(y_real.to(device))

    D_output = D(x_real,y_label)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z,y_label), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake,y_label)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output
    # Bayesian Loss
    D_prior_loss = dprior_criterion(D.parameters())
    D_noise_loss = dnoise_criterion(D.parameters())

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss += D_prior_loss + D_noise_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [0]:
def G_train(x,y):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y_label = Variable(torch.FloatTensor(one_hot(y,10)).to(device))
    y_real  = Variable(torch.ones(bs, 1).to(device))
#     print(y_label.shape)
    G_output = G(z,y_label)
    D_output = D(G_output,y_label)
    G_loss = criterion(D_output, y_real)
    G_prior_loss = gprior_criterion(G.parameters())
    G_noise_loss = gnoise_criterion(G.parameters())
    G_loss+=G_prior_loss+G_noise_loss

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [0]:
n_epoch = 200
ibatch = 0
with experiment.train():
    for epoch in range(1, n_epoch+1):           
        D_losses, G_losses = [], []
        for batch_idx, (x, y) in enumerate(train_loader):
#             print(x,y.shape)
            D_loss = D_train(x,y)
            G_loss = G_train(x,y)
            D_losses.append(D_loss)
            G_losses.append(G_loss)
            y_label = Variable(torch.FloatTensor(one_hot(y,10)).to(device))
#             print(y_label.shape)
            experiment.log_metric("d_loss", D_loss,step=ibatch)
            experiment.log_metric("g_loss", G_loss,step=ibatch)
            if ibatch%10==0:
                plt.figure(figsize=(30,12))
                grid = plt.GridSpec(2, 5, wspace=0.4, hspace=0.3)
                label = torch.LongTensor([i for i in range(10)])
                label =  Variable(torch.FloatTensor(one_hot(label,10)).to(device))                
                z = Variable(torch.randn(10, z_dim).to(device))                
                generated=G(z,label)
                generated=generated.view(generated.size(0), 1, 28, 28)
                for i in range(10):
                    plt.subplot(grid[i//5,i%5])
                    plt.imshow(generated[i][0].detach())
                    plt.title("%d"%(y[0]))
                experiment.log_figure(figure=plt)
            ibatch += 1

        print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
                (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size



torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size([100, 10])
torch.Size

In [0]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')