In the previous notebook `random_forests.ipynb` we have seen how the reconstruction of three different tasks
- energy
- direction
- particle type

is done classicly using random forests:

| Telescope_images | --> | Extracted_features | --> RF --> energy

These tasks can also be adressed using CNN, directly from images:

| Telescope_image | --> CNN --> energy

But instead of three killer sharks
<img width="400px" src="http://protecttheoceans.org/wordpress/wp-content/uploads/2013/09/FindingNemoSharks.jpg" alt="ctapipe"/> 

You might prefer a single friendly big blue whale
<img width="400px" src="https://images.mediabiz.de/flbilder/max03/auto03/auto36/03360190/b780x450.jpg" alt="ctapipe"/>


In [1]:
import torch
print('PyTorch version', torch.__version__)
from torchvision import transforms
import indexedconv
print('IndexedConv version', indexedconv.__version__)
import os
import importlib
import torch
import torch.nn as nn
from minimultinet import MiniMultiNet, load_camera_parameters
from astropy.table import Table, vstack, join
import tables
import numpy as np

PyTorch version 1.7.0
IndexedConv version 1.3


In [2]:
checkpoint_path = 'data/net_checkpoint.tar'
net_trained_parameters = torch.load(checkpoint_path, map_location=torch.device('cpu'))

In [3]:
net_trained_parameters['feature.cv_layer0.cv0.weight'].shape

torch.Size([16, 2, 7])

In [4]:
camera_parameters = load_camera_parameters()

In [5]:
minimultinet = MiniMultiNet({}, camera_parameters)

In [6]:
device = torch.device('cpu')
minimultinet.load_state_dict(torch.load(checkpoint_path, map_location=device))

<All keys matched successfully>

In [7]:
# reading parameters and images from the hdf5 file
allowed_tels = [1, 2, 3, 4]
dl1_path = 'data/dl1.h5'

with tables.open_file(dl1_path) as file:
    table_images = vstack([Table(file.root[f'/dl1/event/telescope/images/tel_00{i}'].read()) for i in allowed_tels])
    table_parameters = vstack([Table(file.root[f'/dl1/event/telescope/parameters/tel_00{i}'].read()) for i in allowed_tels])
    simu_parameters = Table(file.root.simulation.event.subarray.shower.read())
    
good_events = np.isfinite(table_parameters['hillas_intensity'])

table_parameters = table_parameters[good_events]
table_images = table_images[good_events]
table_parameters = join(table_parameters, simu_parameters, keys='event_id')

In [8]:
table_images

obs_id,event_id,tel_id,image [1855],peak_time [1855],image_mask [1855]
int32,int64,int16,float32,float32,bool
7514,153614,1,-1.8468953 .. -2.420372,23.0 .. 14.0,False .. False
7514,192801,1,-1.0080436 .. 2.7459314,25.0 .. 21.375969,False .. False
7514,222202,1,2.3900168 .. 3.8427613,11.747786 .. 27.01762,False .. False
7514,869911,1,-0.86586535 .. -0.71757317,7.0 .. 29.0,False .. False
7514,940800,1,-0.645489 .. -1.4556164,8.0 .. 0.0,False .. False
7514,952004,1,-0.6227406 .. -0.42291516,29.0 .. 22.622976,False .. False
7514,1096802,1,-1.8582697 .. 1.8250699,27.0 .. 16.29847,False .. False
7514,1151704,1,-0.6170534 .. 1.0158916,9.0 .. 3.6071281,False .. False
7514,1633202,1,-0.74359214 .. 1.8047569,0.0 .. 9.277631,False .. False
7514,31012,2,-0.44243342 .. 3.6034596,18.293615 .. 4.9885764,False .. False


In [9]:
event_images = table_images['image']
event_times = table_images['peak_time']

In [10]:
event_images.shape

(46, 1855)

In [11]:
# Let's reshap that a bit to have the dimensions (n_batch, n_filters, n_pixels)

X = torch.tensor([event_images, event_times]).reshape(event_images.shape[0], 2, event_images.shape[1])
X.shape

torch.Size([46, 2, 1855])

In [12]:
pred_dict = minimultinet(X)

In [13]:
energy = 10**(pred_dict['energy'].detach().numpy().flatten())

reco_alt = pred_dict['direction'].detach()[:,0].numpy()
reco_az = pred_dict['direction'].detach()[:,1].numpy()

reco_core_x = pred_dict['impact'].detach()[:,0].numpy()
reco_core_y = pred_dict['impact'].detach()[:,1].numpy()

reco_core_x = pred_dict['impact'].detach()[:,0].numpy()
reco_core_y = pred_dict['impact'].detach()[:,1].numpy()

reco_gammaness = torch.nn.functional.softmax(pred_dict['class'], dim=1).detach()[:,0].numpy()

gammaness_cut = 0.75
reco_type = (reco_gammaness > gammaness_cut).astype(int)

In [14]:
pred = Table([energy, reco_alt, reco_az, reco_type], 
      names=['energy', 'reco_alt', 'reco_az', 'reco_type'])

In [15]:
pred

energy,reco_alt,reco_az,reco_type
float32,float32,float32,int64
0.25576955,-0.011280959,0.03668693,0
0.26230618,-0.019048821,0.039169382,0
0.23677418,-0.018479414,0.043302055,1
0.26939282,-0.018284895,0.06179711,0
0.2642834,-0.024339467,0.041698586,0
0.2649843,-0.014540911,0.061493505,1
0.26146117,-0.017076343,0.043075394,0
0.25628176,-0.011107618,0.0380261,0
0.32826358,-0.05910252,0.055098232,0
0.23325312,-0.011086293,0.03516766,0
