## Auto Encoders 

In [None]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.datasets as Datasets
import torchvision.transforms as Transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np


In [None]:
#Load the data
transform = Transforms.Compose([
                Transforms.Scale(28),
                Transforms.ToTensor(),
            ])


datapath = '../../Datasets/pytorch'
dataset = Datasets.MNIST(datapath, download=False, transform=transform)
dataloader = DataLoader(dataset, batch_size=256, num_workers=4)
dataset_size = len(dataset)



In [None]:
#inspect some data

def imshow(img, title=None):
    img = img.numpy()
    plt.imshow(np.transpose(img, [1,2,0]))
    plt.show()
    
imgs, labels = next(iter(dataloader))
print(imgs.shape)
inp = torchvision.utils.make_grid(imgs)

imshow(inp)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.fc1 = nn.Linear(784, 32)
        self.fc2 = nn.Linear(32, 784)
        self.maxpool = nn.MaxPool1d(kernel_size=2)
    
    def forward(self, inp):
        x = F.relu(self.fc1(inp))
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x

net = Net()

In [None]:
import time

loss_criterion = nn.MSELoss()
optimizer = optim.Adadelta(net.parameters())

num_epochs = 50
for i in range(num_epochs):
    
    since = time.time()
    
    print('Epoch {}/{}'.format(i + 1, num_epochs))
    running_loss = 0.0
    processed = 0
    to_process = dataset_size
    running_loss_count = 1000
    for data in dataloader:
        imgs, labels = data

        #print(imgs.shape)
        imgs = imgs.view(-1, 784)
        #print(imgs.shape)
        imgs = Variable(imgs)
        
        optimizer.zero_grad()
        
        #forward pass
        out = net(imgs)
        
        loss = loss_criterion(out, imgs)
        running_loss += loss.data[0]
        
        loss.backward()
        optimizer.step()
        
        processed += 1
        
        '''
        if(processed % running_loss_count == 0):
            print('Running Loss at samples {} : {}'.format(processed, running_loss))
            running_loss = 0.0
        '''
            
        if(processed == to_process):
            break
            
    
    print('Running Loss: {}'.format(running_loss))
    time_elapsed = time.time() - since
    print('Running time: {}m {}s'.format(time_elapsed // 60, time_elapsed % 60))
    
    if((i+1) % 5 == 0):
        imgs, labels = next(iter(dataloader))
        imgs = Variable(imgs)
        imgs = imgs.view(-1, 784)
        out = net(imgs)
        out_t = out.data.view(-1, 1, 28, 28)
        inp = torchvision.utils.make_grid(out_t)

        imshow(inp)
    