In [None]:
%matplotlib notebook

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

In [None]:
# Randomly drops values in a tensor to zero with probability p
def drop_data(data, p, device):
    return (torch.rand(data.shape) < (1-p)).to(device).int() * data

In [None]:
# A small denoising autoencoder model with 0.25 of input data dropped
class DenoisingAE(nn.Module):
    def __init__(self):
        super(DenoisingAE, self).__init__()
        self.flatten = nn.Flatten()
        self.add_noise = drop_data
        self.encode = nn.Sequential(
            nn.Linear(28*28, 250),
            nn.ReLU(),
            nn.Linear(250, 50),
            nn.ReLU(),
            nn.Linear(50, 3),
        )
        self.decode = nn.Sequential(
            nn.Linear(3, 50),
            nn.ReLU(),
            nn.Linear(50, 250),
            nn.ReLU(),
            nn.Linear(250, 28*28)
        )
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.add_noise(x, 0.25, device)
        x = self.encode(x)
        x = self.decode(x)
        return x
    

In [None]:
# Set device, loss function, and optimizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DenoisingAE().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-5)

In [None]:
# Train model for one epoch
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(torch.reshape(pred, (X.shape[0], 1, 28, 28)), X)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if batch % 100 == 0:
        #     loss, current = loss.item(), batch * len(X)
        #     print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
# Evaluate model accuracy on one batch
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            pred = model(X)
            test_loss += loss_fn(torch.reshape(pred, (X.shape[0], 1, 28, 28)), X).item()
    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")

In [None]:
# Train model and output losses for 50 epochs
epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

In [None]:
# Display a random image from test data and its reconstruction
image = test_data[np.random.randint(len(test_data))][0]
plt.imshow(torch.reshape(image, (28, 28)))
plt.show()
output = torch.reshape(model(image.to(device)), (28, 28))
plt.imshow(output.detach().cpu())
plt.show()

In [None]:
with torch.no_grad():
  encodings = []
  classes = []
  for image, obj in test_data: 
    encodings.append(model.encode(torch.reshape(image, (28 * 28,)).to(device)).cpu().numpy())
    classes.append(obj)
encodings = np.array(encodings)
clothing = {0: "tee/top",
            1: "trouser",
            2: "pullover",
            3: "dress",
            4: "coat",
            5: "sandal",
            6: "shirt",
            7: "sneaker",
            8: "bag",
            9: "ankle boot"}
classes = [clothing[i] for i in classes]

In [None]:
import plotly.express as px
import pandas as pd
fig = px.scatter_3d(x=encodings[:, 0], y=encodings[:, 1], z=encodings[:, 2], color=classes)
fig.show()