# GammaLearn inference example

In this notebook, we'll see how we can run inference interactively using a trained model.    
If you want to process more data, please refer to the CLI program.

## Setup
For this example we'll use the toy model trained by integration tests.
Start by running the following commands:

In [None]:
from importlib.metadata import version as runtime_version
print(runtime_version("gammalearn"))

In [None]:
import subprocess

# This command will train a network based on the setting file define below

settings_file = "../../gammalearn/configuration/examples/experiment_settings_train_MC.py"
subprocess.run(["gammalearn", settings_file])

## Load an experiment from the settings file

In [None]:
from gammalearn.experiment_runner import load_experiment
experiment = load_experiment(settings_file)
experiment.experiment_name

In [None]:
# This part is only here to load the camera geometry from the data.
# It should be refactored to be simplified...

from gammalearn.data.telescope_geometry import get_dataset_geom, inject_geometry_into_parameters

gl_data_module_train = experiment.data_module_train["module"](experiment)
gl_data_module_train.setup_train()
geometries = []
get_dataset_geom(gl_data_module_train.train_set, geometries)
experiment.net_parameters_dic = inject_geometry_into_parameters(experiment.net_parameters_dic, geometries[0])

## Load the model and its weight from a checkpoint

In [None]:
from pathlib import Path

# Here we take the checkpoint from previous training
checkpoint_path = Path(experiment.main_directory) / experiment.experiment_name / 'last.ckpt'

In [None]:
from gammalearn.gammalearn_lightning_module import LitGLearnModule

model = LitGLearnModule.load_from_checkpoint(checkpoint_path, experiment=experiment, strict=False)
model.eval()
model.to('cpu')

## Load the data

In [None]:
gl_data_module_test = experiment.data_module_test["module"](experiment)
gl_data_module_test.setup_test()
test_dataloaders = gl_data_module_test.test_dataloaders()
dataloader = test_dataloaders[0]

### Get a batch
Let's get the first batch to play with.     
A batch contains the images (2 channels), the true labels and the dl1 parameters of the event

In [None]:
batch = next(iter(dataloader))
batch

## Visualize images

In [None]:
import matplotlib.pyplot as plt
from ctapipe.visualization import CameraDisplay

# If we take the first image sample, it contains two channels, the charges and the time map, that are the inputs for g-PhysNet
sample = batch['image'][0]
image = sample[0]
time_map = sample[1]

geom = geometries[0]

fig, axes = plt.subplots(1,2, figsize=(8,4))
display = CameraDisplay(geom, image, ax=axes[0])
display.add_colorbar()
display.axes.set_title('Image')

display = CameraDisplay(geom, time_map, ax=axes[1])
display.add_colorbar()
display.axes.set_title('Time map')

## Inference

In [None]:
import torch

with torch.no_grad():
    output = model(batch['image'])
    
output

## Comparison to true parameters

This output is to compared to the true parameters from the MC simulation:    
Note that we are here using a toy model with not much training, so the results will be far to be satisfying ;-)

In [None]:
batch['label']