### Install LAMINR

In [None]:
%%capture
!pip install "laminr[colab]"

### Import necessary modules and set the device

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from laminr import neuron_models, get_mei_dict, InvarianceManifold

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"You are using {device.upper()} as device.")

### Specify the input shape for both creating simulated neurons and optimized input

In [None]:
input_shape = [1, 100, 100]  # (channels, height, width)

Load the (pretrained or simulated) neurons model

In [None]:
model = neuron_models.simulated("demo1", img_res=input_shape[1:]).to(device)

### Generate MEIs (Maximally Exciting Inputs)

In [None]:
image_constraints = {
    "pixel_value_lower_bound": -1.0,
    "pixel_value_upper_bound": 1.0,
    "required_img_norm": 1.0,
}
meis_dict = get_mei_dict(model, input_shape, **image_constraints)

In [None]:
# plot the optimized MEIs
fig, axes = plt.subplots(1, len(meis_dict))
for ax, (neuron_idx, mei_dict) in zip(axes, meis_dict.items()):
    mei = mei_dict["mei"]
    vmax = np.abs(mei).max()
    vmin = -vmax
    ax.imshow(mei[0], vmin=vmin, vmax=vmax, cmap="Greys_r")
    ax.set(xticks=[], yticks=[], title=f"Neuron {neuron_idx}")

### Initialize and run the invariance manifold pipeline

In [None]:
inv_manifold = InvarianceManifold(model, meis_dict, **image_constraints)

In [None]:
# Learn invariance manifold for neuron 0 (template)
template_idx = 0
template_imgs, template_activations = inv_manifold.learn(template_idx, steps_per_epoch=1)

In [None]:
# Align the template to neurons 1 and 2
target_idxs = [1, 2]
aligned_imgs, aligned_activations = inv_manifold.match(target_idxs)

### Visualize the learned template

In [None]:
gif = inv_manifold.save_learned_template_as_gif()

In [None]:
gif.show()

### Visualize the matched templates

In [None]:
gif = inv_manifold.save_matched_template_as_gif(target_neuron_idx=1)
gif.show()

In [None]:
gif = inv_manifold.save_matched_template_as_gif(target_neuron_idx=2)
gif.show()