In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
import torch.distributions as dists
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST

%matplotlib inline
%reload_ext autoreload
%autoreload 2

from models import *
import utils as ut

In [None]:
# transforms
lst_trans = []

deg = 45
lst_trans.append(
    transforms.Compose([
        torchvision.transforms.RandomRotation(degrees=(-deg, deg)),
        torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
        transforms.ToTensor(),
    ]))

lst_trans.append(
    transforms.Compose([
        torchvision.transforms.RandomRotation(degrees=(-deg, deg)),
        torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
        transforms.ToTensor(),
    ]))

transform = ut.Transform(lst_trans)

# dataset
dataset_train = MNIST(root='./data', train=True, transform=transform)
dataset_test = MNIST(root='./data', train=False, transform=transform)

# loader
batch_size = 1024
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

# device
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")

# model
model = AE_MNIST(2).to(device)

# optim
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=1e-5, 
                             weight_decay=5e-4)

# estimation
loss = ut.SimCLR()

In [None]:
# plot loader
fig, ax = plt.subplots(1, 2)
for i, (lst_x, _) in enumerate(loader_test):
    x = lst_x[0]
    ax[0].imshow(x[0, 0])
    x = lst_x[1]
    ax[1].imshow(x[0, 0])
    break

In [None]:
log = {
    'loss_train': [],
    'loss_test': [],
}

for epoch in range(200):
    # train
    lst_l = []
    for lst_x, _ in loader_test:
        lst_z = []
        for x in lst_x:
            x = x.to(device)
            lst_z.append(model.encode(x))
        l = loss(lst_z)
        optimizer.zero_grad()
        l.backward(retain_graph=True)
        optimizer.step()
        lst_l.append(l.item())
    log['loss_train'].append(np.mean(lst_l))
    # test
    lst_l = []
    with torch.no_grad():
        for lst_x, _ in loader_test:
            lst_z = []
            for x in lst_x:
                x = x.to(device)
                lst_z.append(model.encode(x))
            l = loss(lst_z)
        lst_l.append(l.item())
    log['loss_test'].append(np.mean(lst_l))
    # log
    print(
        f"epoch: {epoch+1}",
        f"train: {log['loss_train'][-1]: .5f}", 
        f"test: {log['loss_test'][-1]: .5f}")

In [None]:
# plot encoder

# dataset
transform = transforms.ToTensor()
dataset = MNIST(root='./data', train=False, transform=transform)

# loader
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

plt.figure(figsize=(10, 10))
for x, y in loader:
    x = x.to(device)
    z = model.encode(x)
    z = z.cpu().detach()
    y = y.cpu()
    for label in range(10):
        idx = y == label
        plt.scatter(z[idx,0], z[idx,1])
    plt.legend(np.arange(10, dtype=np.int32))
plt.show()