# Positional embedding of the LST Transformer

In [None]:
import importlib.util
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from matplotlib.colors import BASE_COLORS

from gammalearn.data.telescope_geometry import inject_geometry_into_parameters
from gammalearn.data.LST_data_module import GLearnDataModule
from gammalearn.gammalearn_lightning_module import LitGLearnModule
from gammalearn.experiment_runner import Experiment

sns.set()

### Load experiment setting file, data and model

In [None]:
configuration_file = "../../gammalearn/configuration/examples/experiment_settings_mae_lst.py"


spec = importlib.util.spec_from_file_location("settings", configuration_file)
settings = importlib.util.module_from_spec(spec)
spec.loader.exec_module(settings)
experiment = Experiment(settings)

gl_data_module = GLearnDataModule(experiment)
gl_data_module.setup()
geom = gl_data_module.train_set.dataset.datasets[0].dataset.datasets[0].camera_geometry
experiment.net_parameters_dic = inject_geometry_into_parameters(experiment.net_parameters_dic, geom)
dataloader = gl_data_module.train_dataloader()

gl_lightning_module = LitGLearnModule(experiment)
model = gl_lightning_module.net

In [None]:
def get_number_of_parameters(model):
    total_params = 0
    for param in model.parameters():
        total_params += param.numel()

    print("Total number of parameters: ", total_params)

In [None]:
get_number_of_parameters(model)
get_number_of_parameters(model.encoder)
get_number_of_parameters(model.decoder)

### Plot the position of the pixels from the geometry

In [None]:
x = geom.pix_x.value.astype(np.float32)
y = geom.pix_y.value.astype(np.float32)
print(x.shape, y.shape)

pixel_ids = torch.arange(geom.n_pixels)
n_pixels_per_module = 7
patch_indices = pixel_ids.view(-1, n_pixels_per_module)

print("Number of pixels:", geom.n_pixels)
print(patch_indices.shape)
print(patch_indices[0])

plt.figure(figsize=(5, 5))
plt.plot(x, y, "rx")

### Check if adjacent pixels correspond to the same module

In [None]:
xx = torch.from_numpy(x).view(-1, 7).numpy()
yy = torch.from_numpy(y).view(-1, 7).numpy()
colors = list(BASE_COLORS.keys())
plt.figure(figsize=(5, 5))
plt.xlim([-1.5, 1.5])
plt.ylim([-1.5, 1.5])
for i in range(15):
    plt.plot(xx[i, :], yy[i, :], colors[i % len(colors)], marker="x")

### Calculate the centroid of each module

In [None]:
from gammalearn.data.image_processing.patchification import get_centroids_from_patches, get_patch_indices_and_centroids_from_geometry, check_patches
patch_centroids = get_centroids_from_patches(patch_indices, geom)
print(patch_centroids.shape)

plt.figure(figsize=(5, 5))
plt.plot(x, y, "b.")
plt.plot(patch_centroids[:, 0], patch_centroids[:, 1], "rx")

In [None]:
patch_indices, patch_centroids = get_patch_indices_and_centroids_from_geometry(geom)

print(patch_centroids.shape)

In [None]:
check_patches(patch_indices, patch_centroids, geom, width_ratio=1.2)

### Centroid rescaling

In [None]:
centroids = deepcopy(patch_centroids)

y_width = np.ptp(centroids[:, 1])
ratio = np.sqrt(len(centroids)) / y_width
centroids[:, 0] -= centroids[:, 0].min()
centroids[:, 1] -= centroids[:, 1].min()
centroids *= ratio

print(len(centroids), np.sqrt(len(centroids)), y_width, ratio)

### Positional embedding

In [None]:
from gammalearn.nets.positional_embedding import get_2d_sincos_pos_embedding_from_patch_centroids
add_pointing = True
additional_tokens = experiment.net_parameters_dic["parameters"]["backbone"]["parameters"]["add_token_list"]

pos_embed = get_2d_sincos_pos_embedding_from_patch_centroids(
    centroids=patch_centroids,
    embed_dim=256,
    additional_tokens=additional_tokens,
    add_pointing=add_pointing,
)

print(pos_embed.shape)

In [None]:
pos_embed

In [None]:
plt.figure(figsize=(10, 5))
for i in range(20):
    plt.plot(pos_embed[i, :])

In [None]:
plt.imshow(pos_embed, aspect="auto")

In [None]:
embed_dim = 256

sin_x = torch.sin(torch.mm(centroids[:, 0].unsqueeze(1), torch.tensor([[1 / 10000]]).float()))
cos_x = torch.cos(
    torch.mm(
        centroids[:, 0].unsqueeze(1), torch.tensor([[1]]).float()
    )
)
sin_y = torch.sin(
    torch.mm(
        centroids[:, 1].unsqueeze(1), torch.tensor([[1]]).float()
    )
)
cos_y = torch.cos(
    torch.mm(
        centroids[:, 1].unsqueeze(1), torch.tensor([[1]]).float()
    )
)
pos_embed = torch.cat([sin_x, cos_x, sin_y, cos_y], dim=1)

In [None]:
pos_embed