In [1]:
import sys
from pathlib import Path
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

PROJECT_ROOT = Path().resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from mnist_classifier.data import MNISTDataModule
from nn_utils import get_hidden_states, load_models
from nn_plotting import plot_perturved_accuracy
from local_corex import LinearCorex, partition_data
from local_corex._transformers import CorExWrapper, PCAWrapper
from local_corex.utils.plotting import (
    hidden_state_plot,
    multi_rep_plot,
)

# Download MNIST dataset if not already downloaded
from torchvision.datasets import MNIST
from torchvision import transforms
MNIST(f"{PROJECT_ROOT}", train=False, download=True, transform=transforms.ToTensor())

Install CUDA and cudamat (for python) to enable GPU speedups.


Dataset MNIST
    Number of datapoints: 10000
    Root location: C:\Users\tkerby2\Desktop\Projects
    Split: Test
    StandardTransform
Transform: ToTensor()

In [2]:
base_path = f"{PROJECT_ROOT}/Local_CorEx_Demo/paper_mnist/mnist_classifier/model"

sys.path.append(f"{PROJECT_ROOT}/Local_CorEx_Demo/paper_mnist/mnist_classifier/model")

from autoencoder_config import conf as ae_conf
from config import conf as clf_conf

ae_conf = ae_conf['autoencoder']
ae_ckpt = base_path + '/mnist_ae_epoch=091-val_loss=0.5937.ckpt'
clf_conf = clf_conf['classifier']
clf_ckpt = base_path + '/mnist_clf_epoch=068-val_loss=0.0006.ckpt'

do_ae, do_clf = load_models(ae_ckpt, clf_ckpt, ae_conf, clf_conf)

data_module = MNISTDataModule(clf_conf, f"{PROJECT_ROOT}")
data_module.setup('predict')

model_data = get_hidden_states(do_clf, data_module, device=do_clf.device, num_layers=len(clf_conf['hidden_layers']))

inputs=model_data[4]
labels=model_data[5]

do_state_df = pd.DataFrame({
    'input': list(inputs),
    'h1': list(model_data[1]),
    'h2': list(model_data[2]),
    'h3': list(model_data[3]),
    'output': list(model_data[0])
})

Inferred autoencoder config: {'encoder_layers': [500, 400, 300], 'decoder_layers': [400, 500], 'use_batch_norm': True, 'drop_out_p': 0.5, 'lr': 0.002}


In [None]:
import pickle
from_scratch = True

if from_scratch:
    print('computing clusters from scratch')
    num_clusters=20
    indexes = partition_data(inputs, n_partitions=num_clusters, phate_dim=10, n_jobs=-2, seed=42)
    print('saving clusters')
    with open('mnist_20_indexes.pkl', 'wb') as f:
        pickle.dump(indexes, f)
else:
    print("loading clusters")
    with open('mnist_20_indexes.pkl', 'rb') as f:
        indexes = pickle.load(f)

for i in range(0, num_clusters):
    print("Group number:", i, Counter(labels[indexes[i]]))

fig, axes = plt.subplots(2, 10, figsize=(20, 5))

for i, ax in enumerate(axes.flatten()):
    ax.imshow(np.mean(inputs[indexes[i]], axis = 0).reshape(28,28))
    ax.set_title('Frame ' + str(i))

loading clusters


FileNotFoundError: [Errno 2] No such file or directory: 'mnist_20_indexes.pkl'

In [None]:
from local_corex import LinearCorex

partition=11
x = np.concatenate([model_data[3][indexes[partition]], model_data[0][indexes[partition]]], axis=1)

corex_h3_11 = LinearCorex(30, seed=42, gaussianize='outliers')
corex_h3_11.fit(x)
print(corex_h3_11.tcs)
hidden_state_plot(x, corex_h3_11, do_ae, factors=list(range(3)), latent_dim=300, encoder_layer=3, scaler=1.5, output_dim=10)
multi_rep_plot(corex_h3_11, 3, dims=[(15,20),(10,1)], num_per_row=2)

In [None]:
from local_corex import LinearCorex

partition=11
x = np.concatenate([model_data[2][indexes[partition]], model_data[0][indexes[partition]]], axis=1)

corex_h2_11 = LinearCorex(30, seed=42, gaussianize='outliers')
corex_h2_11.fit(x)
print(corex_h2_11.tcs)
hidden_state_plot(x, corex_h2_11, do_ae, factors=list(range(3)), latent_dim=400, encoder_layer=2, scaler=1.5, output_dim=10)
multi_rep_plot(corex_h2_11, 3, dims=[(20,20),(10,1)], num_per_row=2)

In [None]:
from local_corex import LinearCorex

partition=11
x = np.concatenate([model_data[1][indexes[partition]], model_data[0][indexes[partition]]], axis=1)

corex_h1_11 = LinearCorex(30, seed=42, gaussianize='outliers')
corex_h1_11.fit(x)
print(corex_h1_11.tcs)
hidden_state_plot(x, corex_h1_11, do_ae, factors=list(range(3)), latent_dim=500, encoder_layer=1, scaler=1.5, output_dim=10)
multi_rep_plot(corex_h1_11, 3, dims=[(20,25),(10,1)], num_per_row=2)

In [None]:
do_clf.to('cpu')
pred_digit = do_clf(torch.tensor(inputs)).max(1).indices.detach().numpy()
base_accuracies = []
for i in range(num_clusters):
    base_accuracies.append(100*np.mean(pred_digit[indexes[i]] == labels[indexes[i]]))
    print(i, np.round(np.mean(pred_digit[indexes[i]] == labels[indexes[i]]) * 100, 2))

In [None]:
plot_perturved_accuracy(do_clf, corex_h1_11, inputs, labels, indexes, 
                             factor_num=0, hidden_layer_idx=0, num_clusters=20, 
                             num_drop=100, hidden_dim=500)

In [None]:
output = plt.hist(corex_h1_11.moments['MI'][2], bins=50)
print(np.sum(output[0][1:]))

In [None]:
plot_perturved_accuracy(do_clf, corex_h2_11, inputs, labels, indexes, 
                             factor_num=0, hidden_layer_idx=1, num_clusters=20, 
                             num_drop=150, hidden_dim=400)

In [None]:
output = plt.hist(corex_h2_11.moments['MI'][0], bins=50)
print(np.sum(output[0][1:]))

In [None]:
plot_perturved_accuracy(do_clf, corex_h3_11, inputs, labels, indexes, 
                             factor_num=0, hidden_layer_idx=2, num_clusters=20, 
                             num_drop=150, hidden_dim=300)

In [None]:
output = plt.hist(corex_h3_11.moments['MI'][5], bins=50)
print(np.sum(output[0][3:]))