In [None]:
from models import KMeansClassifier, LinearClassifier
import torchvision
import torchvision.transforms as tvtf
import torch
from torch.nn import DataParallel
from torch.utils.data import DataLoader

from tqdm import tqdm
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
batch_size = 256
epochs = 200
data_root = '~/.pytorch-datasets'

out_channels = 512
K = 10
cluster_temp = 50
learning_rate = 0.01
dropout = 0.2

seed = 24

torch.manual_seed(seed)

ds_train_mnist = torchvision.datasets.MNIST(data_root, train=True, transform=tvtf.ToTensor(), download=True)
ds_test_mnist = torchvision.datasets.MNIST(data_root, train=False, transform=tvtf.ToTensor(), download=True)

dl_train_mnist = DataLoader(ds_train_mnist, batch_size, shuffle=True)
dl_test_mnist = DataLoader(ds_test_mnist, batch_size, shuffle=True)

ds_train_fashionmnist = torchvision.datasets.FashionMNIST(data_root, train=True, transform=tvtf.ToTensor(), download=True)
ds_test_fashionmnist = torchvision.datasets.FashionMNIST(data_root, train=False, transform=tvtf.ToTensor(), download=True)

dl_train_fashionmnist = DataLoader(ds_train_fashionmnist, batch_size, shuffle=True)
dl_test_fashionmnist = DataLoader(ds_test_fashionmnist, batch_size, shuffle=True)

ds_train_cifar = torchvision.datasets.CIFAR10(data_root, train=True, transform=tvtf.ToTensor(), download=True)
ds_test_cifar = torchvision.datasets.CIFAR10(data_root, train=False, transform=tvtf.ToTensor(), download=True)

dl_train_cifar = DataLoader(ds_train_cifar, batch_size, shuffle=True)
dl_test_cifar = DataLoader(ds_test_cifar, batch_size, shuffle=True)

In [None]:
def loss_fn(y_pred, y_true):
    return torch.nn.functional.cross_entropy(y_pred, y_true)

In [None]:
def train(model, optimizer, dl_train, dl_test, num_epochs=epochs):
    temperature = 100

    num_samples = len(dl_train.sampler)
    num_batches = len(dl_train.batch_sampler)
    num_test_samples = len(dl_test.sampler)

    best_accuracy = 0
    best_test_accuracy = 0
    
    global train_accuracies, train_losses, test_accuracies, test_losses
    train_accuracies = []
    train_losses = []
    test_accuracies = []
    test_losses = []

    with tqdm(total=num_batches, bar_format='{l_bar}{bar}{r_bar}') as pbar:
        for epoch in range(num_epochs):
            pbar.reset()
            pbar.set_description(f'Epoch {epoch + 1}/{num_epochs}')
            pbar.refresh()

            num_correct = 0
            for i, (x_train, y_train) in enumerate(dl_train, 0):
                x_train = x_train.to(device)
                y_train = y_train.to(device)

                optimizer.zero_grad()

                r = model(x_train)
                y_pred = torch.softmax(temperature * r, dim=1)
                loss = loss_fn(y_pred, y_train)
                loss.backward()
                optimizer.step()
                num_correct += (y_pred.argmax(axis=1) == y_train).sum().item()
                pbar.update()
                accuracy = 100. * num_correct / num_samples

            train_accuracies.append(accuracy)
            train_losses.append(loss.item())
            best_accuracy = max(best_accuracy, accuracy)
            pbar.set_postfix(dict(accuracy=accuracy, best_accuracy=best_accuracy), loss=loss.item())

            num_correct = 0
            for i, (x_test, y_test) in enumerate(dl_test, 0):
                x_test = x_test.to(device)
                y_test = y_test.to(device)
                r = model(x_test)
                y_pred = torch.softmax(temperature * r, dim=1)
                loss = loss_fn(y_pred, y_test)
                num_correct += (y_pred.argmax(axis=1) == y_test).sum().item()
                accuracy = 100. * num_correct / num_test_samples
            best_test_accuracy = max(best_test_accuracy, accuracy)
            test_losses.append(loss)
            test_accuracies.append(accuracy)
    print(f'best test accuracy: {best_test_accuracy}')

In [None]:
def plot_all(title):
    fig, axes=plt.subplots(2,2, figsize=(12,8))
    fig.suptitle(title)
    axes[0,0].plot(train_losses)
    axes[0,0].set_xlabel('epoch')
    axes[0,0].set_title('train loss')
    axes[0,1].plot(train_accuracies)
    axes[0,1].set_xlabel('epoch')
    axes[0,1].set_title('train accuracy')
    axes[1,0].plot(test_losses)
    axes[1,0].set_xlabel('epoch')
    axes[1,0].set_title('test loss')
    axes[1,1].plot(test_accuracies)
    axes[1,1].set_xlabel('epoch')
    axes[1,1].set_title('test accuracy')
    plt.tight_layout()
    fig.show()

## Experiments

### Classifier with KMeans layer

In [None]:
in_dims = ds_train_mnist.data[0].shape
model = KMeansClassifier(in_dims, out_channels, K, dropout, cluster_temp)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_mnist, dl_test_mnist)

In [None]:
plot_all('kmeans model on mnist')

In [None]:
in_dims = ds_train_fashionmnist.data[0].shape
model = KMeansClassifier(in_dims, out_channels, K, dropout, cluster_temp)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_fashionmnist, dl_test_fashionmnist)

In [None]:
plot_all('kmeans model on fashion mnist')

In [None]:
in_dims = ds_train_cifar.data[0].shape
model = KMeansClassifier(in_dims, out_channels, K, dropout, cluster_temp)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_cifar, dl_test_cifar)

In [None]:
plot_all('kmeans model on cifar')

### Classifier with linear layer

In [None]:
in_dims = ds_train_mnist.data[0].shape
model = LinearClassifier(in_dims, out_channels, K, dropout)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_mnist, dl_test_mnist)

In [None]:
plot_all('linear model on mnist')

In [None]:
in_dims = ds_train_fashionmnist.data[0].shape
model = LinearClassifier(in_dims, out_channels, K, dropout)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_fashionmnist, dl_test_fashionmnist)

In [None]:
plot_all('linear model on fashion mnist')

In [None]:
in_dims = ds_train_cifar.data[0].shape
model = LinearClassifier(in_dims, out_channels, K, dropout)
model = DataParallel(model).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=5e-4, momentum=0.9)

train(model, optimizer, dl_train_cifar, dl_test_cifar)

In [None]:
plot_all('linear model on cifar')