# Positional embedding of the Image Transformer

In [None]:
import importlib.util

import matplotlib.pyplot as plt
import seaborn as sns
import torch

# import gammalearn.utils as utils
from gammalearn.data.telescope_geometry import inject_geometry_into_parameters
from gammalearn.data.LST_data_module import GLearnDataModule
from gammalearn.gammalearn_lightning_module import LitGLearnModule
from gammalearn.experiment_runner import Experiment

sns.set()

### Load experiment setting file, data and model

In [None]:
configuration_file = "../../gammalearn/configuration/examples/experiment_settings_mae_image.py"

spec = importlib.util.spec_from_file_location("settings", configuration_file)
settings = importlib.util.module_from_spec(spec)
spec.loader.exec_module(settings)
experiment = Experiment(settings)

gl_data_module = GLearnDataModule(experiment)
gl_data_module.setup()
geom = gl_data_module.train_set.dataset.datasets[0].dataset.datasets[0].camera_geometry
experiment.net_parameters_dic = inject_geometry_into_parameters(experiment.net_parameters_dic, geom)
dataloader = gl_data_module.train_dataloader()

gl_lightning_module = LitGLearnModule(experiment)
model = gl_lightning_module.net

In [None]:
def get_number_of_parameters(model):
    total_params = 0
    for param in model.parameters():
        total_params += param.numel()

    print("Total number of parameters: ", total_params)

In [None]:
get_number_of_parameters(model)
get_number_of_parameters(model.encoder)
get_number_of_parameters(model.decoder)

### Prototype pixel indices

In [None]:
image_height = 55
image_width = 55
patch_size = 11

pixel_ids = torch.arange(image_height * image_width)
pixel_ids = pixel_ids.view(-1, image_width)
print(pixel_ids.shape)
pixel_ids

In [None]:
n_patches = (image_height // patch_size) * (image_width // patch_size)
n_patches

In [None]:
patch_indices = pixel_ids.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
patch_indices = patch_indices.flatten(start_dim=2)
patch_indices = patch_indices.view(n_patches, -1)

print(patch_indices.shape)
# patch_indices

### Test pixel indices on real images

In [None]:
images = gl_data_module.train_set[0]["image"]
print("images", images.shape)
print(type(images))

In [None]:
image = images[0]

plt.figure(figsize=(10, 10))
plt.imshow(image.reshape(55, 55))
plt.axis("off")
plt.tight_layout()
plt.savefig("vit_image.png", dpi=300)

In [None]:
from gammalearn.data.image_processing.patchification import get_patch_indices_and_grid
image_size = {"height": 55, "width": 55}
patch_size = 5
patch_indices, grid = get_patch_indices_and_grid(image_size, patch_size)
print(patch_indices.shape, grid.shape)

In [None]:
grid

In [None]:
patches = []
for indice in patch_indices.tolist():
    patches.append(image[indice].reshape(patch_size, patch_size))

In [None]:
n_h, n_w = image_size["height"] // patch_size, image_size["width"] // patch_size
vmin = image.min()
vmax = image.max()

plt.figure(figsize=(10, 10))
for i, patch in enumerate(patches):
    plt.subplot(n_w, n_h, i + 1)
    plt.imshow(patch, vmin=vmin, vmax=vmax)
    plt.axis("off")
plt.tight_layout()
plt.savefig("vit_patches.png", dpi=300)

### Positional embedding

In [None]:
from gammalearn.nets.positional_embedding import get_2d_sincos_pos_embedding_from_grid

add_pointing = True
additional_tokens = experiment.net_parameters_dic["parameters"]["backbone"]["parameters"]["add_token_list"]

pos_embed = get_2d_sincos_pos_embedding_from_grid(
    grid=grid,
    embed_dim=256,
    additional_tokens=additional_tokens,
    add_pointing=add_pointing,
)

print(pos_embed.shape)

In [None]:
plt.figure(figsize=(10, 5))
for i in range(20):
    plt.plot(pos_embed[i, :])

In [None]:
plt.imshow(pos_embed, aspect="auto")

In [None]:
# RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [2, 2, 55, 55]