In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
from numpy import linalg as LA
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torch.optim as optim
from PIL import Image
import os

In [3]:
class sparse_autoencoder(torch.nn.Module):

    def __init__(self):
        super(sparse_autoencoder, self).__init__()
        input_size = 784
        output_size = 784      
        self.hidden = nn.Linear(input_size, 256)
        self.out = nn.Linear(256, output_size)
        self.batch_size = 8

    def forward(self, x):
        h = F.sigmoid(self.hidden(x))
        y_hat = self.out(h)
        return y_hat, h
    
    def loss(self, x, y, beta=0.2, rho=0.1):
        y_hat, h = self.forward(x)
        rho_hat = torch.sum(h, dim=0) / self.batch_size
        c = nn.MSELoss()
        l = c(y_hat, y) + beta * torch.sum(rho * torch.log(rho / rho_hat) + (1-rho) * torch.log((1-rho) / (1-rho_hat)))
        return l

In [4]:
ae = sparse_autoencoder()
optimizer = optim.Adam(ae.parameters(), lr=3e-4)

In [5]:
batch_size = 8
train = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
trainset = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)

In [6]:
train
trans = transforms.ToPILImage()

In [7]:
for epoch in range(10):
    cnt = 0
    l = 0
    for data in trainset:
        data = data[0].squeeze()
        x = torch.reshape(data, (batch_size, 784))
        y = x.clone()
        optimizer.zero_grad()
        y_hat, h = ae(x)
        loss = ae.loss(x, y)
        loss.backward()
        optimizer.step()
        l += loss
        cnt += 1
        if cnt % 600 == 0:
            cnt = 0
            print(l / 600)
            l = 0



tensor(1.1667, grad_fn=<DivBackward0>)
tensor(0.1815, grad_fn=<DivBackward0>)
tensor(0.1628, grad_fn=<DivBackward0>)
tensor(0.1576, grad_fn=<DivBackward0>)
tensor(0.1630, grad_fn=<DivBackward0>)
tensor(0.1495, grad_fn=<DivBackward0>)
tensor(0.1409, grad_fn=<DivBackward0>)
tensor(0.1453, grad_fn=<DivBackward0>)
tensor(0.1351, grad_fn=<DivBackward0>)
tensor(0.1488, grad_fn=<DivBackward0>)
tensor(0.1334, grad_fn=<DivBackward0>)
tensor(0.1270, grad_fn=<DivBackward0>)
tensor(0.1234, grad_fn=<DivBackward0>)
tensor(0.1254, grad_fn=<DivBackward0>)
tensor(0.1237, grad_fn=<DivBackward0>)
tensor(0.1319, grad_fn=<DivBackward0>)
tensor(0.1212, grad_fn=<DivBackward0>)
tensor(0.1209, grad_fn=<DivBackward0>)
tensor(0.1096, grad_fn=<DivBackward0>)
tensor(0.1097, grad_fn=<DivBackward0>)
tensor(0.1103, grad_fn=<DivBackward0>)
tensor(0.1135, grad_fn=<DivBackward0>)
tensor(0.1102, grad_fn=<DivBackward0>)
tensor(0.1041, grad_fn=<DivBackward0>)
tensor(0.0950, grad_fn=<DivBackward0>)
tensor(0.1040, grad_fn=<D

In [8]:
test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))
testset = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)

In [9]:
results = []
cnt = 0
os.makedirs("img", exist_ok=True)
for test in testset:
    data = test[0].squeeze()
    x = torch.reshape(data, (batch_size, 784))
    out = ae(x)[0].detach().numpy()
    cnt += 1
    for i in range(data.shape[0]):
        plt.imsave('./img/' + str(cnt) + str(i) + 'org.png', data[i], cmap='gray')
        plt.imsave('./img/' + str(cnt) + str(i) + 'recovered.png', out[i].reshape(28, 28), cmap='gray')

In [10]:
w = list(ae.parameters())
en = w[0].detach().numpy()

In [11]:
os.makedirs("latent", exist_ok=True)
cnt = 0
for i in en:
    cnt += 1
    plt.imsave('./latent/' + str(cnt)+ 'latent.png', i.reshape(28, 28), cmap='gray')