# Simplified MNIST decomposition example
This aims to show how MNIST decomposition works.
The *bulk* of the work is being done by the ``Eigen`` class, which (probably) shouldn't be too hard to understand.

In [None]:
%load_ext autoreload
%autoreload 2

from mnist.simple.model import Model
from mnist.simple.utils import MNIST, eigen
import plotly.express as px
from einops import *
import torch

# Train a simplified model
model = Model.from_config(epochs=30, wd=0.5, latent_noise=0.0, input_noise=2.0, n_layer=1).cuda()
train, test = MNIST(train=True, download=True), MNIST(train=False, download=True)

torch.set_grad_enabled(True)
metrics = model.fit(train, test)

px.line(metrics, x=metrics.index, y=["train/acc", "val/acc"], title="Acc")

In [None]:
# Get the eigenvalues and eigenvectors of the model
torch.set_grad_enabled(False)
vals, vecs = eigen(model, 8)

px.line(vals.cpu()).show()

color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
px.imshow(vecs[-5:].flip(0).view(-1, 28, 28).cpu(), facet_col=0, **color).show()