In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
# This is need so you can import larndsim without doing python setup.py install
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)

In [3]:
import matplotlib.pyplot as plt
from matplotlib import cm, colors
import mpl_toolkits.mplot3d.art3d as art3d

import numpy as np
import eagerpy as ep
import h5py

import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100

In [4]:
from numpy.lib import recfunctions as rfn
import torch

def torch_from_structured(tracks):
    tracks_np = rfn.structured_to_unstructured(tracks, copy=True, dtype=np.float32)
    return torch.from_numpy(tracks_np).float()

def structered_from_torch(tracks_torch, dtype):
    return rfn.unstructured_to_structured(tracks_torch.cpu().numpy(), dtype=dtype)

### Dataset import
First of all we load the `edep-sim` output. For this sample we need to invert $z$ and $y$ axes.

In [5]:
with h5py.File('module0_corsika.h5', 'r') as f:
    tracks = np.array(f['segments'])

x_start = np.copy(tracks['x_start'] )
x_end = np.copy(tracks['x_end'])
x = np.copy(tracks['x'])

tracks['x_start'] = np.copy(tracks['z_start'])
tracks['x_end'] = np.copy(tracks['z_end'])
tracks['x'] = np.copy(tracks['z'])

tracks['z_start'] = x_start
tracks['z_end'] = x_end
tracks['z'] = x

selected_tracks = tracks[30:40]
selected_tracks_torch = torch_from_structured(np.copy(selected_tracks))

## Simulation
To flexibly keep track of parameters/gradients, simulations are housed in a class `sim_with_grad`. This is derived from class versions of all the other modules. Parameters are housed in `consts`, with method `track_gradients` to promote the constants to `requires_grad=True` PyTorch tensors.

In [6]:
from larndsim.sim_with_grad import sim_with_grad

Generate <class 'larpix.configuration.configuration_v2b.Configuration_v2b'>.pixel_trim_dac using bits (0, 512)
	<function _list_property at 0x12460d280>((<class 'int'>, 0, 31, 64, 8)) 
Generate <class 'larpix.configuration.configuration_v2b.Configuration_v2b'>.threshold_global using bits (512, 520)
	<function _basic_property at 0x12460d1f0>((<class 'int'>, 0, 255)) 
Generate <class 'larpix.configuration.configuration_v2b.Configuration_v2b'>.csa_gain using bits (520, 521)
	<function _compound_property at 0x12460d310>((['csa_gain', 'csa_bypass_enable', 'bypass_caps_en'], (<class 'int'>, <class 'bool'>), 0, 1)) 
Generate <class 'larpix.configuration.configuration_v2b.Configuration_v2b'>.csa_bypass_enable using bits (521, 522)
	<function _compound_property at 0x12460d310>((['csa_gain', 'csa_bypass_enable', 'bypass_caps_en'], (<class 'int'>, <class 'bool'>), 0, 1)) 
Generate <class 'larpix.configuration.configuration_v2b.Configuration_v2b'>.bypass_caps_en using bits (522, 523)
	<function _c

## The simulation
Following the flow of the simulation chain, define a function which takes in the `sim_with_grad` object, runs whatever pieces of the simulation, and returns desired output.

In [7]:
def all_sim(sim, selected_tracks, fields):
    selected_tracks_quench = sim.quench(selected_tracks, sim.birks, fields=fields)
    selected_tracks_drift = sim.drift(selected_tracks_quench, fields=fields)

    # Here we build a map between tracks and event IDs
#    unique_eventIDs = np.unique(selected_tracks['eventID'])
#    event_id_map = np.searchsorted(unique_eventIDs,np.asarray(selected_tracks['eventID']))
#    event_id_map_torch = torch.from_numpy(event_id_map)

#    active_pixels_torch, neighboring_pixels_torch, n_pixels_list_ep = sim.get_pixels(selected_tracks_torch,
#                                                                                     fields=selected_tracks.dtype.names)

#    track_starts_torch, max_length_torch = sim.time_intervals(event_id_map_torch, 
#                                                              selected_tracks_torch, 
#                                                              fields=selected_tracks.dtype.names)

#    signals_ep = sim.tracks_current(neighboring_pixels_torch, selected_tracks_torch, 
#                                          max_length_torch,
#                                          fields=selected_tracks.dtype.names)
    return selected_tracks_drift

  and should_run_async(code)


In [8]:
# Update parameters for training loop
def update_grad_param(sim, name, value):
    setattr(sim, name, value)
    sim.track_gradients([name])

## Experiment: pseudo-data
1. Construct dataset with one set of parameters
2. Initialize parameters to something else
3. See if we can recover default parameters

In [9]:
#Simulate with defaults: in particular, eField == 0.5
sim_default = sim_with_grad()
sim_default.load_detector_properties("../larndsim/detector_properties/module0.yaml",
                             "../larndsim/pixel_layouts/multi_tile_layout-2.2.16.yaml")
out05 = all_sim(sim_default, selected_tracks_torch, selected_tracks.dtype.names)

In [10]:
#Setup simulation object for training
sim = sim_with_grad()
sim.load_detector_properties("../larndsim/detector_properties/module0.yaml",
                             "../larndsim/pixel_layouts/multi_tile_layout-2.2.16.yaml")

In [11]:
#Simple MSE loss between target and output
loss_fn = torch.nn.MSELoss()

#Initialize eField to different value (0.55)
eField = 0.55
lr = 1e-9

#Training loop
for epoch in range(100):
    
    #First, set eField parameter 
    update_grad_param(sim, "eField", eField)
    
    #Simulate with that parameter and get output
    output = all_sim(sim, selected_tracks_torch, selected_tracks.dtype.names)
    
    #Calc loss between simulated and target + backprop
    loss = loss_fn(output, out05)
    loss.backward()
    
    #Gradient descent
    with torch.no_grad():
        eField -= lr*sim.eField.grad
    print(eField)
        
    sim.eField.grad.data.zero_()

tensor(0.5458)
tensor(0.5419)
tensor(0.5382)
tensor(0.5348)
tensor(0.5317)
tensor(0.5288)
tensor(0.5261)
tensor(0.5236)
tensor(0.5214)
tensor(0.5193)
tensor(0.5174)
tensor(0.5157)
tensor(0.5141)
tensor(0.5127)
tensor(0.5114)
tensor(0.5103)
tensor(0.5092)
tensor(0.5083)
tensor(0.5074)
tensor(0.5066)
tensor(0.5059)
tensor(0.5053)
tensor(0.5048)
tensor(0.5043)
tensor(0.5038)
tensor(0.5034)
tensor(0.5030)
tensor(0.5027)
tensor(0.5024)
tensor(0.5022)
tensor(0.5019)
tensor(0.5017)
tensor(0.5015)
tensor(0.5014)
tensor(0.5012)
tensor(0.5011)
tensor(0.5010)
tensor(0.5009)
tensor(0.5008)
tensor(0.5007)
tensor(0.5006)
tensor(0.5006)
tensor(0.5005)
tensor(0.5004)
tensor(0.5004)
tensor(0.5004)
tensor(0.5003)
tensor(0.5003)
tensor(0.5002)
tensor(0.5002)
tensor(0.5002)
tensor(0.5002)
tensor(0.5002)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5001)
tensor(0.5000)
tensor(0.5000)
tensor(0.5000)
tensor(0.5