<a href="https://colab.research.google.com/github/satoruk-icepp/mlhep2019_2_phase/blob/master/analysis/pytorch_mnist_conditional_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! [ ! -z "$COLAB_GPU" ] && pip install skorch comet_ml



In [0]:
%%writefile .comet.config
[comet]
api_key=mIel5ZAPOioTs0Cij75dSSQXs
logging_file = /tmp/comet.log
logging_file_level = info

Overwriting .comet.config


In [0]:
from comet_ml import Experiment
experiment = Experiment(project_name="BayesMNISTGAN")

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/satoruk-icepp/bayesmnistgan/94f0cdd89c274d0e8b2ad54de0ee52ae



In [0]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import math
import numpy as np
from IPython.display import clear_output

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
def one_hot(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])

In [0]:
gen_prior_std=1
disc_prior_std=1
bs = 100
z_dim = 90
lr = 0.0008 
alpha=0.002
n_epoch = 200
Nresblock = 5
params={'batch_size': bs,
#         'data_size':N,
        'epochs': n_epoch,
#         'energyscale': EnergyDepositScale,
        'noise_dim': z_dim,
        'learning_rate':lr,
        'alpha':alpha,
        'gen_prior_std':gen_prior_std,
        'disc_prior_std':disc_prior_std,
#         'gnoise_alpha':gnoise_alpha,
#         'Ngen':Ngen,
#         'Ndisc':Ndisc,
#         'PXscale':PXscale,
#         'PYscale':PYscale,
#         'PZscale':PZscale,
#         'XPosscale':XPosscale,
#         'YPosscale':YPosscale,
        'Nresblock':Nresblock
}
experiment.log_parameters(params)

In [0]:


# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
    transforms.Normalize(mean=(0.5,), std=(0.5,))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [0]:
class ReducedConv(nn.Module):
    def __init__(self,input_size,output_size, input_dim, output_dim,kernel_size):
        super(ReducedConv, self).__init__()
        scale = float(output_dim+kernel_size-3)/float(input_dim)
        self.ups = nn.Upsample(scale_factor = scale,mode = 'bilinear',align_corners=False )
        self.ref = nn.ReflectionPad2d(1)
        self.conv = nn.Conv2d(input_size,output_size,kernel_size)
    def forward(self,x):
        return self.conv(self.ref(self.ups(x)))
#         return self.ref(self.ups(x))

In [0]:
class ResidualBlock(nn.Module):
    def __init__(self,input_size):
        super(ResidualBlock, self).__init__()        
        self.conv1 = nn.Conv2d(input_size,input_size,3,padding=1)
    def forward(self,xraw):
        x = F.leaky_relu(self.conv1(xraw))
        x = F.leaky_relu(self.conv1(x)+xraw)
        return x

In [0]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.output_dim = g_output_dim
        self.fc1 = nn.Linear(g_input_dim+10, 64*4*4)
#         self.bn_fc1 = nn.BatchNorm1d(self.fc1.out_features)        
#         self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
#         self.bn_fc2 = nn.BatchNorm1d(self.fc2.out_features)
#         self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
#         self.bn_fc3 = nn.BatchNorm1d(self.fc3.out_features)
#         self.fc4 = nn.Linear(self.fc3.out_features, 16*(g_output_dim-)**2)
        self.trconv1 = nn.ConvTranspose2d(64, 32, 5,stride=2)
        self.trconv2 = nn.ConvTranspose2d(32, 16, 5,stride=2)
        self.trconv3 = nn.ConvTranspose2d(16, 1, 4)

        self.bn1     = nn.BatchNorm2d(self.trconv1.out_channels)
        self.bn2     = nn.BatchNorm2d(self.trconv2.out_channels)        
        self.bnres   = nn.BatchNorm2d(16)
        self.resblock = ResidualBlock(16)
        self.redconv1 = ReducedConv(64,32,4,12,3)
        self.redconv2 = ReducedConv(32,16,12,20,3)        
        self.redconv3 = ReducedConv(16,1 ,20,28,3)                
    
    # forward method
    def forward(self, x,label):
        x = torch.cat([x,label],dim=1)
        x = F.leaky_relu(self.fc1(x), 0.2)        
        x = x.view(-1,64,4,4)
#         x = F.leaky_relu(self.bn_fc1(self.fc1(x)), 0.2)
#         x = F.leaky_relu(self.bn_fc2(self.fc2(x)), 0.2)
#         x = F.leaky_relu(self.bn_fc3(self.fc3(x)), 0.2)
#         x = F.leaky_relu(self.fc4(x), 0.2)
#         x = x.view(-1,16,self.output_dim,self.output_dim)
#         x = x.view(-1,16,self.output_dim-4,self.output_dim-4)
        x= F.leaky_relu(self.bn1(self.redconv1(x)))
        x= F.leaky_relu(self.bn2(self.redconv2(x)))
#         x= F.leaky_relu(self.bn1(self.trconv3(x)))        
#         for i in range(Nresblock):
#             x = F.leaky_relu(self.bnres(self.resblock(x)),0.2)
        x= self.redconv3(x)

#         x = F.leaky_relu(self.redconv1(x) ,0.2)
#         x = self.redconv2(x)
#         x = self.resblock(x)
        x = torch.tanh(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        kernel_size =4
        self.input_dim = d_input_dim
        self.conv1 = nn.Conv2d(1,16,kernel_size)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.conv2 = nn.Conv2d(self.conv1.out_channels,self.conv1.out_channels*2,kernel_size)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
#         self.conv3 = nn.Conv2d(self.conv2.out_channels,self.conv2.out_channels,kernel_size)        
#         self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)        
#         self.conv4 = nn.Conv2d(self.conv3.out_channels,self.conv3.out_channels,kernel_size)                
#         self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)                
        nlayer=0
        self.fc1 = nn.Linear(self.conv2.out_channels*(d_input_dim-(kernel_size-1)*2)**2+10, 1)
#         self.fc1 = nn.Linear(d_input_dim**2+10, 1024)
#         self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
#         self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
#         self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x,label):
#         print(x.shape,label.shape)
#         x = x.view(-1,self.fc1.in_features)
        x = F.leaky_relu(self.conv1(x),0.2)
#         x = F.dropout(x, 0.3)    
        x = F.leaky_relu(self.bn2(self.conv2(x)),0.2)
#         x = F.leaky_relu(self.bn3(self.conv3(x)),0.2)
#         x = F.dropout(x, 0.3)    
#         x = F.leaky_relu(self.bn4(self.conv4(x)),0.2)
        x = x.view(-1,self.fc1.in_features-10)
        label = label.view(-1,10)    
        x = torch.cat([x,label],dim=1)
#         x = F.leaky_relu(self.fc1(x), 0.2)
#         x = F.dropout(x, 0.3)
#         x = F.leaky_relu(self.fc2(x), 0.2)
#         x = F.dropout(x, 0.3)
#         x = F.leaky_relu(self.fc3(x), 0.2)
#         x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc1(x))

In [0]:
# build network

dataset_size = train_dataset.train_data.size(0)
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
mnist_dimx = train_dataset.train_data.size(1)
mnist_dimy = train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dimx).to(device)
D = Discriminator(mnist_dimx).to(device)



In [0]:
G

Generator(
  (fc1): Linear(in_features=100, out_features=1024, bias=True)
  (trconv1): ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(2, 2))
  (trconv2): ConvTranspose2d(32, 16, kernel_size=(5, 5), stride=(2, 2))
  (trconv3): ConvTranspose2d(16, 1, kernel_size=(4, 4), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnres): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (resblock): ResidualBlock(
    (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (redconv1): ReducedConv(
    (ups): Upsample(scale_factor=3.0, mode=bilinear)
    (ref): ReflectionPad2d((1, 1, 1, 1))
    (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
  )
  (redconv2): ReducedConv(
    (ups): Upsample(scale_factor=1.6666666666666667, mode=bilinear)
    (ref): ReflectionPad2d((1, 1, 1, 1))
    (co

In [0]:
D

Discriminator(
  (conv1): Conv2d(1, 16, kernel_size=(4, 4), stride=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(4, 4), stride=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=15498, out_features=1, bias=True)
)

In [0]:
class NoiseLoss(torch.nn.Module):
  # need the scale for noise standard deviation
  # scale = noise  std
    def __init__(self, params, noise_std, observed=None):
        super(NoiseLoss, self).__init__()
        self.observed = observed
        self.noise_std = noise_std

    def forward(self, params,  observed=None):
    # scale should be sqrt(2*alpha/eta)
    # where eta is the learning rate and alpha is the strength of drag term
        if observed is None:
            observed = self.observed

#         assert scale is not None, "Please provide scale"
        noise_loss = 0.0
        for var in params:
            # This is scale * z^T*v
            # The derivative wrt v will become scale*z
#             _noise = noise.normal_(0.,self.noise_std)
            _noise = self.noise_std*torch.randn(1)
            noise_loss += torch.sum(Variable(_noise)*var)
        noise_loss /= observed
        return noise_loss

class PriorLoss(torch.nn.Module):
  # negative log Gaussian prior
    def __init__(self, prior_std=1., observed=None):
        super(PriorLoss, self).__init__()
        self.observed = observed
        self.prior_std = prior_std

    def forward(self, params, observed=None):
        if observed is None:
            observed = self.observed
        prior_loss = 0.0
        for var in params:
            prior_loss += torch.sum(var*var/(self.prior_std*self.prior_std))
        prior_loss /= observed
        return prior_loss

In [0]:
# loss
criterion = nn.BCELoss() 

# optimizer

G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [0]:

gprior_criterion = PriorLoss(prior_std=gen_prior_std, observed=dataset_size)
gnoise_criterion = NoiseLoss(params=G.parameters(), noise_std=math.sqrt(2 * alpha * lr), observed=dataset_size)
dprior_criterion = PriorLoss(prior_std=disc_prior_std, observed=dataset_size)
dnoise_criterion = NoiseLoss(params=D.parameters(), noise_std=math.sqrt(2 * alpha * lr), observed=dataset_size)

In [0]:
def D_train(x,y):
    #=======================Train the discriminator=======================#
    D.zero_grad()
    
    # train discriminator on real
    x_real,y_label, y_real = x.view(-1,1, mnist_dimx,mnist_dimx), torch.FloatTensor(one_hot(y,10)),torch.ones(bs, 1)
    x_real,y_label, y_real = Variable(x_real.to(device)), Variable(y_label.to(device)),Variable(y_real.to(device))

    D_output = D(x_real,y_label)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z,y_label), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake,y_label)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output
    # Bayesian Loss
    D_prior_loss = dprior_criterion(D.parameters())
    D_noise_loss = dnoise_criterion(D.parameters())

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss += D_prior_loss + D_noise_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item(), D_real_loss.data.item(),D_fake_loss.data.item(),D_prior_loss.data.item(),D_noise_loss.data.item()

In [0]:
def G_train(x,y):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y_label = Variable(torch.FloatTensor(one_hot(y,10)).to(device))
    y_real  = Variable(torch.ones(bs, 1).to(device))
#     print(y_label.shape)
    G_output = G(z,y_label)
    D_output = D(G_output,y_label)
    G_real_loss = criterion(D_output, y_real)
    G_prior_loss = gprior_criterion(G.parameters())
    G_noise_loss = gnoise_criterion(G.parameters())
    G_loss=G_real_loss+G_prior_loss+G_noise_loss

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item(),G_real_loss.data.item(),G_prior_loss.data.item(),G_noise_loss.data.item()

In [0]:

ibatch = 0
with experiment.train():
    for epoch in range(1, n_epoch+1):
        D_losses, G_losses = [], []
        for batch_idx, (x, y) in enumerate(train_loader):
#             print(x,y.shape)
            D_loss,D_real_loss,D_fake_loss,D_prior_loss,D_noise_loss = D_train(x,y)
            G_loss,G_real_loss,G_prior_loss,G_noise_loss = G_train(x,y)
            D_losses.append(D_loss)
            G_losses.append(G_loss)
            y_label = Variable(torch.FloatTensor(one_hot(y,10)).to(device))
#             print(y_label.shape)
            experiment.log_metric("d_loss", D_loss,step=ibatch)
            experiment.log_metric("g_loss", G_loss,step=ibatch)
            experiment.log_metric("d_real_loss", D_real_loss,step=ibatch)
            experiment.log_metric("d_fake_loss", D_fake_loss,step=ibatch)
            experiment.log_metric("d_prior_loss", D_prior_loss,step=ibatch)            
            experiment.log_metric("d_noise_loss", D_noise_loss,step=ibatch)            
            experiment.log_metric("g_real_loss", G_real_loss,step=ibatch)
            experiment.log_metric("g_prior_loss", G_prior_loss,step=ibatch)
            experiment.log_metric("g_noise_loss", G_noise_loss,step=ibatch)
            if ibatch%10==0:
                clear_output()
                plt.figure(figsize=(30,12))
                grid = plt.GridSpec(2, 5, wspace=0.4, hspace=0.3)
                label = torch.LongTensor([i for i in range(10)])
                label =  Variable(torch.FloatTensor(one_hot(label,10)).to(device))                
                z = Variable(torch.randn(10, z_dim).to(device))                
                generated=G(z,label)
                generated=generated.view(generated.size(0), 1, 28, 28)
                for i in range(10):
                    plt.subplot(grid[i//5,i%5])
                    plt.imshow(generated[i][0].detach())
                    plt.title("%d"%(i))
                experiment.log_figure(figure=plt)
            ibatch += 1

        print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
                (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

In [0]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')