In [1]:
# Add root to the sys.path

import sys
from pathlib import Path

ROOT = Path().resolve().parents[0]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

In [2]:
# Import libraries

%load_ext autoreload
%autoreload 2

import plotly.express as px
import torch
import torch.nn.functional as F
from einops import *
from kornia.augmentation import RandomGaussianNoise

from image import CIFAR10, EMNIST, SVHN, Model

device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <063DD42B-A1DA-3BBE-AD5E-03C12F7E3DBA> /Users/shidqietaufiqurrahman/workspace/miniconda3/envs/fact-ai/lib/python3.12/site-packages/torchvision/image.so
  warn(


In [3]:
# Load the dataset
cifar10_train, cifar10_test = (
    CIFAR10(train=True, device=device, grayscale=True),
    CIFAR10(train=False, device=device, grayscale=True),
)
svhn_train, svhn_test = (
    SVHN(split="train", device=device, grayscale=True),
    SVHN(split="test", device=device, grayscale=True),
)
emnist_letters_train, emnist_letters_test = (
    EMNIST(train=True, device=device, split="letters"),
    EMNIST(train=False, device=device, split="letters"),
)
cifar10_labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
svhn_labels = [f"{i}" for i in range(10)]
emnist_letters_labels = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")

Files already downloaded and verified


  entry = pickle.load(f, encoding="latin1")


Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


In [4]:
# Instantiate the models
cifar10_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_input=1024, d_hidden=512, seed=420
).to(device)
svhn_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_input=1024, d_hidden=512, seed=420
).to(device)
emnist_letters_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, d_output=26, seed=420
).to(device)

# Train models with random gaussian noise
cifar10_metrics = cifar10_model.fit(
    cifar10_train, cifar10_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
svhn_metrics = svhn_model.fit(
    svhn_train, svhn_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
emnist_letters_metrics = emnist_letters_model.fit(
    emnist_letters_train, emnist_letters_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

train/loss: 1.721, train/acc: 0.393, val/loss: 1.696, val/acc: 0.403: 100%|██████████| 100/100 [01:24<00:00,  1.18it/s]
train/loss: 1.490, train/acc: 0.499, val/loss: 1.194, val/acc: 0.760: 100%|██████████| 100/100 [02:06<00:00,  1.27s/it]
train/loss: 0.396, train/acc: 0.875, val/loss: 0.311, val/acc: 0.909: 100%|██████████| 100/100 [02:59<00:00,  1.79s/it]


In [5]:
# Decompose the eigenvalues and eigenvectors
svhn_vals, svhn_vecs = svhn_model.decompose()
emnist_vals, emnist_vecs = emnist_letters_model.decompose()
cifar10_vals, cifar10_vecs = cifar10_model.decompose()

emnist_top10 = emnist_vecs[:10, -1]
emnist_top10_32 = emnist_top10.view(10, 1, 28, 28)
emnist_top10_32 = F.interpolate(
    emnist_top10_32, size=(32, 32), mode="bilinear", align_corners=False
)
emnist_top10_32 = emnist_top10_32.view(10, 1024)

# Use all eigenvectors
vecs = torch.cat(
    [
        svhn_vecs[:, -1],
        emnist_top10_32,
        cifar10_vecs[:, -1],
    ]
)
vecs /= (
    vecs.abs().max(1, keepdim=True).values
)  # Normalize eigenvectors to [-1, 1] for visualization

color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
fig = px.imshow(
    vecs.view(-1, 32, 32).cpu(),
    facet_col=0,
    facet_col_wrap=10,
    height=330,
    width=1000,
    facet_row_spacing=0.1,
    **color,
)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=0, t=20))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

labels = cifar10_labels + emnist_letters_labels[:10] + svhn_labels
[
    a.update(text=f"<b>{labels[i]}</b>", y=a["y"] + 0.005)
    for i, a in enumerate(fig.layout.annotations)
]

fig.show()