# Template Learning and Template Matching Pipeline Demo

In [None]:
import random
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from lipstick import GifMaker

# For creating simulated neurons
from laminr.simulation.neurons import neuron1_generator, neuron2_generator, neuron3_generator, complex_neuron_generator, even_gabor_neuron_generator, odd_gabor_neuron_generator, phase_invariant_neuron_generator
from laminr.simulation.utils import random_transformation_matrices, plot_grid_points, plot_grid_border

# For MEI generation
from laminr.utils.mei import generate_mei

# For template learning and template matching
from laminr.cppn import CPPNTemplates
from laminr.utils.pipeline import forward
from laminr.utils.trainer_functions import train_template, match_template

# For results
from laminr.datamodule import JitteringGridDatamodule
from laminr.utils.plot_utils import arrange_images_on_circle # plot_utils.py

neuron_type_generators = {
    "even_gabor": even_gabor_neuron_generator,
    "odd_gabor": odd_gabor_neuron_generator,
    "complex_cell": complex_neuron_generator,
    "arbitrary_invariance_1": neuron1_generator,
    "arbitrary_invariance_2": neuron2_generator,
    "arbitrary_invariance_3": neuron3_generator,
}

random_seed = 41

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
device

In [None]:
class single_cell_model(nn.Module):
    def __init__(self, model, idx):
        super().__init__()
        self.model = model
        self.idx = idx
    
    def forward(self, x):
        return self.model(x)[:, self.idx].squeeze()

class ArbitraryMultiNeuronModel(nn.Module):
    def __init__(self, neuron_models_list, order=None):
        super().__init__()
        self.neuron_models_list = nn.ModuleList(neuron_models_list)
        self.order = order
        if self.order is not None:
            if len(self.order) != len(self.neuron_models_list):
                raise ValueError("order should have the same number of elements as models.")

    def forward(self, x):
        preds = []
        for model in self.neuron_models_list:
            preds.append(model(x))
        return torch.stack(preds, dim=1)

## 1. Create two simulated neurons

In [None]:
np.random.seed(random_seed)
img_res = [100, 100]

neuron_types = [
    "complex_cell",
    "complex_cell"
]

num_neurons = len(neuron_types)
global_locations = np.random.rand(num_neurons, 2) * .4 - .2
transformation_matrices = random_transformation_matrices(num_neurons)
transformation_matrices = np.array([np.linalg.inv(tmat) for tmat in transformation_matrices])

neuron_models = []
for neuron_type, gloc, tmat in zip(neuron_types, global_locations, transformation_matrices):
    neuron_generator = neuron_type_generators[neuron_type]
    print(neuron_generator)
    neuron_model, _ = neuron_generator(gloc, img_res, transformation_matrix=tmat)
    neuron_models.append(neuron_model)

model = ArbitraryMultiNeuronModel(neuron_models)

In [None]:
model

## 2. Optimize MEIs (and get the MEI activation and the RF mask)

In [None]:
idxs = np.arange(len(neuron_models))

In [None]:
norm = 1
zscore = .5

# initialize dictionaries to save data in
meis = {}
acts = {}
masks = {}
center_pos = {}

# generate mei data for each neurons in idxs
for idx_n, idx in enumerate(idxs):
    print(f'neuron_idx = {idx+1}: neuron number {idx_n+1}/{len(idxs)}')
    
    # single_neuron_model = neuron_models[idx].to(device)
    single_neuron_model = single_cell_model(model, idx).to(device)
    mei, mei_act, mask, mask_center = generate_mei(single_neuron_model, img_res, norm=norm, zscore=zscore)
    
    # Save data in dictionary
    meis[idx]= mei
    acts[idx]= mei_act
    masks[idx]= mask
    center_pos[idx]= mask_center
    
    # Plot the results
    fig, ax = plt.subplots()
    ax.imshow(np.concatenate((mei, mask), axis=1), vmin=-np.abs(mei).max(), vmax=np.abs(mei).max(), cmap="gray", origin="lower")
    ax.scatter(mask_center["x"], mask_center["y"], s=20, c="crimson")
    ax.scatter(mask_center["x"]+img_res[0], mask_center["y"], s=20, c="crimson")
    ax.plot(img_res, [0, img_res[1]], c="k", lw=.5)
    ax.set(xticks=[], yticks=[], xlim=(0, img_res[0] * 2-1), ylim=(0, img_res[1]))
    plt.show()

## 3. Template learning

In [None]:
template_neuron_idx = 0

In [None]:
meis = np.array([v for v in meis.values()]).astype(np.float32)
mei_acts = np.array([v for v in acts.values()]).astype(np.float32)
masks = np.array([v.cpu().data.numpy() for v in masks.values()]).astype(np.float32)
rf_positions = np.array([[v["x"], v["y"]] for v in center_pos.values()]).astype(np.float32)

In [None]:
allow_scale = True
allow_shear = True
uniform_scale = False

img_res = meis[0].shape
batch_size = 1
template_config = dict(    
    num_neurons=batch_size,
    img_res=img_res,
    num_templates=1,
    out_channels=1,
    widths = [50] * 4,
    positional_encoding_dim=50,
    positional_encoding_projection_scale=10,
    aux_positional_encoding_dim=50,
    aux_positional_encoding_projection_scale=.1,
    periodic_invariance=True,
    nonlinearity=nn.Tanh,
    final_nonlinearity=nn.Tanh,
    weights_scale=.1,
    bias=True,
    only_affine_coordinate_transformation=True,
    stochastic_coordinate_transformation=False,
    allow_scale_coordinate_transformation=allow_scale,
    allow_shear_coordinate_transformation=allow_shear,
    uniform_scale_coordinate_transformation=uniform_scale,
)

In [None]:
pixel_min = -10 
pixel_max = 10
mean_pixel_val = (pixel_max + pixel_min)/2
baseline_input = torch.ones(1, 1, *img_res) * mean_pixel_val
baseline_input = baseline_input.to(device)

In [None]:
# template_neuron_model = neuron_models[template_neuron_idx].to(device)
template_neuron_model = single_cell_model(model, template_neuron_idx).to(device)
template_rf_location = rf_positions[template_neuron_idx]
template_mei_act = mei_acts[template_neuron_idx]
template_rf_mask = masks[template_neuron_idx]

template = CPPNTemplates(**template_config).to(device)

In [None]:
print("Training Template")
requirements = dict(avg=.99, std=1., necessary_min=0.98)
grid, template, img_transforms, grid_dataloader = train_template(
    template, 
    template_neuron_model, 
    template_rf_mask, 
    template_mei_act, 
    requirements, 
    steps_per_epoch=50, 
    pixel_min=pixel_min, 
    pixel_max=pixel_max,
    std=None, 
    norm=1., 
    gaussian_blur_sigma=None, 
    img_transform='FixEnergyNormClip', # used with norm
    num_max_epochs=1000,
)

## 4. Template Matching

In [None]:
target_neurons_idx = [1]
others_mei_act = torch.from_numpy(mei_acts[target_neurons_idx]).to(device)
others_rf_mask = masks[target_neurons_idx]
others_rf_location = rf_positions[target_neurons_idx]

In [None]:
template.reset_coordinate_transform(num_neurons=len(target_neurons_idx))
template.register_coords_shifts(torch.from_numpy(template_rf_location).to(device), torch.from_numpy(others_rf_location).to(device))

In [None]:
template = match_template(
    template, 
    model, 
    img_transforms, 
    grid_dataloader, 
    template_rf_mask, 
    np.arange(len(target_neurons_idx)),
    target_neurons_idx, 
    others_mei_act, 
    others_rf_mask, 
    rotate_angle_and_scale=True,
    patience=25,
    num_epochs=1000,
)

## 5. Results

In [None]:
dataloader_config = dict(
    num_invariances=1,
    grid_points_per_dim=24,
    steps_per_epoch=1,
)
grid_dataloader = JitteringGridDatamodule(**dataloader_config)
grid = grid_dataloader.grid.clone().to(device)

In [None]:
with torch.no_grad():
    template_act_baseline = template_neuron_model(baseline_input).item()

    # invariance activation
    template_img_pre, template_img_post, template_acts, _ = forward(grid, template, img_transforms, template_neuron_model, return_template=True)
    template_act = template_acts.mean().item()
    template_act_relative = (template_act - template_act_baseline)/(template_mei_act - template_act_baseline)

    others_img_pre, others_img_post, others_acts, _ = forward(grid, template, img_transforms, model, return_template=False, 
                                                              other_neurons_loc_in_list=np.arange(len(target_neurons_idx)),
                                                              other_neurons_loc_in_model=target_neurons_idx)
    others_act = others_acts.mean(dim=1).diag().cpu().data.numpy()
    others_act_baseline = model(baseline_input)[0, target_neurons_idx].cpu().data.numpy()
    others_act_relative = (others_act - others_act_baseline) / (mei_acts[target_neurons_idx] - others_act_baseline) 

In [None]:
template_manifold = template_img_post.cpu().data.numpy()
matched_manifold = others_img_post[0].cpu().data.numpy()

In [None]:
max_val = max(np.abs(template_manifold).max(), np.abs(matched_manifold).max())
min_val = -max_val

In [None]:
num_images = len(template_manifold)

with GifMaker("result.gif") as g:
    for img_idx in range(num_images):
        fig, (ax1, ax2) = plt.subplots(1, 2)
        img_template_manifold = template_manifold[img_idx, 0]
        img_matched_manifold = matched_manifold[img_idx, 0]
        ax1.imshow(img_template_manifold, vmin=-np.abs(img_template_manifold).max(), vmax=np.abs(img_template_manifold).max(), origin="lower", cmap="Greys_r")
        ax2.imshow(img_matched_manifold, vmin=-np.abs(img_matched_manifold).max(), vmax=np.abs(img_matched_manifold).max(), origin="lower", cmap="Greys_r")
        ax1.set(xticks=[], yticks=[])
        ax2.set(xticks=[], yticks=[])
        g.add(fig)
g.show()

---