# Question 2


## (b) Design another neural network “dis_net” to discriminate between blur images and clear images. 

- Blur images can be generated by taking the original MNIST data and do some gaussian blur. 
- Train autoencoder with L1-norm reconstruction loss + discriminator loss.
- Make reconstructed images as clear as possible, that is, the auto encoder will need to be trained so that “dis_net” score it as a clear image 
- Compare results between (a) and (b)

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F

In [None]:
# MNIST Dataset 
dataset = dsets.MNIST(root='./data', 
                      train=True, 
                      transform=transforms.ToTensor(),  
                      download=True)

# Data Loader (Input PipeLineareare)
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=100, 
                                          shuffle=True)

def to_np(x):
    return x.data.cpu().numpy()

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x) 

In [None]:
#Encoder
class Encoder(nn.Module):  
    def __init__(self,X_dim,N,z_dim):
        super(Encoder, self).__init__()
        self.Linear1 = nn.Linear(X_dim, N)
        self.Linear2 = nn.Linear(N, N)
        self.Linear3gauss = nn.Linear(N, z_dim)
    def forward(self, x):
        x = F.dropout(self.Linear1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.Linear2(x), p=0.25, training=self.training)
        x = F.relu(x)
        xgauss = self.Linear3gauss(x)
        return xgauss

In [None]:
# Decoder
class Decoder(nn.Module):  
    def __init__(self,X_dim,N,z_dim):
        super(Decoder, self).__init__()
        self.Linear1 = nn.Linear(z_dim, N)
        self.Linear2 = nn.Linear(N, N)
        self.Linear3 = nn.Linear(N, X_dim)
    def forward(self, x):
        x = F.dropout(self.Linear1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.Linear2(x), p=0.25, training=self.training)
        x = self.Linear3(x)
        return torch.sigmoid(x)

In [None]:
# Discriminator
class Dis_Net(nn.Module):  
    def __init__(self,N,z_dim):
        super(Dis_Net, self).__init__()
        self.Linear1 = nn.Linear(z_dim, N)
        self.Linear2 = nn.Linear(N, N)
        self.Linear3 = nn.Linear(N, 1)
    def forward(self, x):
        x = F.dropout(self.Linear1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.Linear2(x), p=0.2, training=self.training)
        x = F.relu(x)
        return torch.sigmoid(self.Linear3(x))

In [None]:
EPS = 1e-15
z_red_dims = 16
Q = Encoder(784,1000,z_red_dims).cuda()
P = Decoder(784,1000,z_red_dims).cuda()
D_gauss = Dis_Net(500,z_red_dims).cuda()

In [None]:
# Set learning rates
gen_lr = 0.0001
reg_lr = 0.00005

#encode/decode optimizers
optim_P = torch.optim.Adam(P.parameters(), lr=gen_lr)
optim_Q_enc = torch.optim.Adam(Q.parameters(), lr=gen_lr)
#regularizing optimizers
optim_Q_gen = torch.optim.Adam(Q.parameters(), lr=reg_lr)
optim_D = torch.optim.Adam(D_gauss.parameters(), lr=reg_lr)
    
data_iter = iter(data_loader)
iter_per_epoch = len(data_loader)
total_step = 50000

In [None]:
# Start training
for step in range(total_step):

    # Reset the data_iter
    if (step+1) % iter_per_epoch == 0:
        data_iter = iter(data_loader)

    # Fetch the images and labels and convert them to variables
    images, labels = next(data_iter)
    images, labels = to_var(images.view(images.size(0), -1)), to_var(labels)

    #reconstruction loss
    P.zero_grad()
    Q.zero_grad()
    D_gauss.zero_grad()

    z_sample = Q(images)   #encode to z
    X_sample = P(z_sample) #decode to X reconstruction
    recon_loss = F.l1_loss(X_sample+EPS,images+EPS)

    recon_loss.backward()
    optim_P.step()
    optim_Q_enc.step()

    # Discriminator
    ## true prior is random normal (randn)
    ## this is constraining the Z-projection to be normal!
    Q.eval()
    z_real_gauss = Variable(torch.randn(images.size()[0], z_red_dims) * 5.).cuda()
    D_real_gauss = D_gauss(z_real_gauss)

    z_fake_gauss = Q(images)
    D_fake_gauss = D_gauss(z_fake_gauss)

    D_loss = -torch.mean(torch.log(D_real_gauss + EPS) + torch.log(1 - D_fake_gauss + EPS))

    D_loss.backward()
    optim_D.step()

    # Generator
    Q.train()
    z_fake_gauss = Q(images)
    D_fake_gauss = D_gauss(z_fake_gauss)
    G_loss = -torch.mean(torch.log(D_fake_gauss + EPS))
    G_loss.backward()
    optim_Q_gen.step()   

#save the Encoder
torch.save(Q.state_dict(),'Q_encoder_weights.pt')