In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from cycler import cycler

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage, Resize
import torchvision.datasets as dset
to_img  = ToPILImage()
resize_img  = Resize(200)
def show_img(img):
    return resize_img(to_img(img))

cmap=plt.cm.tab10
c = cycler('color', cmap(np.linspace(0,1,10)))
plt.rcParams["axes.prop_cycle"] = c

%config InlineBackend.figure_format = 'retina'
%matplotlib notebook
%matplotlib notebook

%load_ext autoreload
%autoreload 2

In [21]:
def img_ae(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 64, 64)
    return x

In [22]:
%matplotlib notebook
%matplotlib notebook

In [23]:
bs = 1

dataset = dset.ImageFolder(root='/scratch/rag394/data/gaussian_generator/',
                           transform=transforms.Compose([
                                   transforms.Grayscale(),
                                   transforms.Resize(64),
                                   transforms.CenterCrop(64),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5,),(0.5,)),
                               ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs,
                                         shuffle=False, num_workers=32)

In [None]:
targets = pd.read_csv('/scratch/rag394/data/gaussian_generator/gaussian_parameters.csv').drop(['Unnamed: 0'], 1).values
targets = np.reshape(targets,[10000//bs,bs,4])
targets = Variable(torch.from_numpy(targets).type(torch.FloatTensor), requires_grad=False)
targets = targets.cuda()

In [None]:
class GaussianModel(nn.Module):
    def __init__(self):
        super(GaussianModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 4)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [24]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1))  # b, 8, 2, 2
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [25]:
model = autoencoder()
model.cuda()
optimizer = torch.optim.Adam(model.parameters(),lr=1.0e-3)
criterion = nn.MSELoss()
total_loss = []

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = Variable(img).cuda()
        # ===================forward=====================
        output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch+1, num_epochs, loss.data[0]))
    total_loss.append(loss.data[0])
    
    pic = img_ae(output.cpu().data)
    save_image(pic, '/scratch/rag394/data/ae/image_{}.png'.format(epoch))

torch.save(model.state_dict(), '/scratch/rag394/data/ae/conv_autoencoder.pth')

epoch [1/5], loss:0.0029
epoch [2/5], loss:0.0025
epoch [3/5], loss:0.0024


In [None]:
fig, ax = plt.subplots(figsize=(9,7))
pd.Series(total_loss, name='MSE Loss').plot(ax=ax,color='darkred', 
                                            lw=0.6, marker='s', markersize=3., legend=True)
ax.grid(alpha=0.2)
ax.set_xlabel('Epoch')

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_alpha(False)
ax.grid(alpha=0.3)