In [None]:
import data
import matplotlib.pyplot as plt


loader_train, loader_test = data.get_dataloaders(32)
plt.imshow(loader_train.dataset[2][0].squeeze(), cmap="gray_r")

In [None]:
from torch import nn


class AutoEncoder(nn.Module):

    def __init__(self):
        super(AutoEncoder, self).__init__()
        self._flatten = nn.Flatten()
        self._encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Sigmoid()
        )
        self._decoder = nn.Sequential(
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def encode(self, x):
        return self._encoder(self._flatten(x))

    def decode(self, x):
        return self._decoder(x).view(-1, 1, 28, 28)

    def forward(self, x):
        return self.decode(self.encode(x))


autoEncoder = AutoEncoder()
print(autoEncoder)

In [None]:
from torch import nn, optim
from tqdm.notebook import tqdm


loss_function = nn.MSELoss()
optimizer = optim.Adam(autoEncoder.parameters())  

def train():
    with tqdm(desc="Batch", total=len(loader_train)) as pbatch:
        for (X, _) in loader_train:
            pred = autoEncoder(X)
            loss = loss_function(pred, X)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbatch.set_postfix(loss=loss.item())
            pbatch.update()

@torch.no_grad()
def test():
    test_loss = 0.0
    for X, _ in loader_test:
        pred = autoEncoder(X)
        test_loss += loss_function(pred, X).item()
    test_loss /= len(loader_test)
    tqdm.write(f"Test -> Loss: {test_loss}")

In [None]:
import torch
from tqdm.notebook import trange


torch.manual_seed(666)

for epoch in trange(1, 21, desc="Epoch"):
    train()
    test()

In [None]:
import torch
import matplotlib.pyplot as plt


with torch.no_grad():
    x_train_enc = autoEncoder.encode(loader_train.dataset[2][0])
    x_train_dec = autoEncoder.decode(x_train_enc)

    plt.imshow(x_train_dec.squeeze(), cmap="gray_r")