In [40]:
%load_ext autoreload
%autoreload 2
import os

import torch
import torchvision
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
# make a folder to save output images
from unet import UNet
if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

# a function used to transform numpy array to image format
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

#################### select your hyperparameters ############################
num_epochs = 10
batch_size = 4
n_samples = 20

####### define image transforms, you can have other choices, explore it! #####
transform = transforms.ToTensor()

# download MNIST dataset
dataset = MNIST('./data', transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [42]:
# define autoencoder
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True), 
            nn.Linear(128, 32))
        self.decoder = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(True),
            nn.Linear(128, 28 * 28), 
            nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [43]:
# set the model to be the autoencoder defined above
model = UNet(n_channels=1, n_classes=1)
############ chose appropriate loss function ######################
criterion = nn.MSELoss()

# set the optimizer, explore the effect of different optimizers
lr = 0.1
optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8)

In [44]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        if i > n_samples:
            break
        img, _ = data
        #img = img.view(img.size(0), -1)
        # ===================forward=====================
        output = model(img)
        target = img
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================user interaction========================
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.data))
    if epoch % 1 == 0:
        pic = to_img(output.data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
        
        ori_pic = to_img(img.data)
        save_image(ori_pic, './mlp_img/ori_image_{}.png'.format(epoch))

epoch [1/10], loss:0.0457
epoch [2/10], loss:0.0196
epoch [3/10], loss:0.0228
epoch [4/10], loss:0.0225
epoch [5/10], loss:0.0142
epoch [6/10], loss:0.0132
epoch [7/10], loss:0.0108
epoch [8/10], loss:0.0103
epoch [9/10], loss:0.0116
epoch [10/10], loss:0.0088
