### Reconstruct all events using four GPUs and the CRS2 spherical optimizer running in batch mode

In [1]:
import pickle
import math
import time
from multiprocessing import Process, Pool
import pkg_resources

import numpy as np
import pandas as pd

from freedom.llh_service.llh_service import LLHService
from freedom.llh_service.llh_client import LLHClient

from spherical_opt import spherical_opt

In [2]:
with open('/home/atfienberg/freedomDataCopy/public_for_aaron/test_events.pkl', 'rb') as f:
    events = pickle.load(f)

In [3]:
allowed_DOMs = np.load(pkg_resources.resource_filename('freedom', 'resources/allowed_DOMs.npy'))
ndoms = len(allowed_DOMs)

### llh service configuration:

In [4]:
service_conf = {
        "poll_timeout": 1,
        "flush_period": 1,
        "n_hypo_params": 8,
        "n_hit_features": 8,
        "n_evt_features": ndoms*4,
        "batch_size" : {
          "n_hypos": 200,
          "n_observations": 6000, 
        },
        "send_hwm": 10000,
        "recv_hwm": 10000,
        "hitnet_file": "/home/atfienberg/freedomDataCopy/public_for_aaron/HitNet_ranger_30_Jul_2020-15h49/epoch_32_model.hdf5",
        "domnet_file": "/home/atfienberg/freedomDataCopy/public_for_aaron/DOMNet_reduced_22_Jul_2020-15h18/epoch_30_model.hdf5",
        "ndoms": ndoms
}

### Build four services, one per GPU

In [5]:
n_gpus = 4

In [6]:
base_req = "ipc:///tmp/atfrecotestreq"
base_ctrl = "ipc:///tmp/atfrecotestctrl"

req_addrs = []
ctrl_addrs = []
for i in range(n_gpus):
    req_addrs.append(f'{base_req}{i}')
    ctrl_addrs.append(f'{base_ctrl}{i}')

In [7]:
print(req_addrs)
print(ctrl_addrs)

['ipc:///tmp/atfrecotestreq0', 'ipc:///tmp/atfrecotestreq1', 'ipc:///tmp/atfrecotestreq2', 'ipc:///tmp/atfrecotestreq3']
['ipc:///tmp/atfrecotestctrl0', 'ipc:///tmp/atfrecotestctrl1', 'ipc:///tmp/atfrecotestctrl2', 'ipc:///tmp/atfrecotestctrl3']


In [8]:
def start_service(params, index):
    # use a single GPU
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = f'{index}'
    
    params = params.copy()
    params['ctrl_addr'] = ctrl_addrs[index]
    params['req_addr' ] = req_addrs[index]
    with LLHService(**params) as serv:
        print('starting work loop...')
        serv.start_work_loop()
        
    print('done')

procs = []
for i in range(n_gpus):
    proc = Process(target=start_service, args=(service_conf, i))
    proc.start()
    procs.append(proc)

starting work loop...
starting work loop...
starting work loop...
starting work loop...
Received die command... flushing and exiting
cleaning up
done
Received die command... flushing and exiting
cleaning up
done
Received die command... flushing and exiting
cleaning up
done
Received die command... flushing and exiting
cleaning up
done


In [9]:
init_pos_range = 50

time_range = (-1000, 0)

log_energy_range = [-1, 2]

# define limits of search range 
param_search_limits = np.array([
    [-200, 200],
    [-250, 250],
    [-800, -200],
    [8000, 11000],
    [0, 2*math.pi],
    [0, math.pi],
    [0.1, 100],
    [0.1, 100]
]).T

### Functions for fitting

In [10]:
def initial_box(hits, charge_ind=4, n_params=8):
    ''' returns initial box limits for each dimension
    in the form of a n_params x 2 table
    '''
    
    # charge weighted positions, time
    hit_avgs = np.average(hits, weights=hits[:, charge_ind], axis=0)[:4]
    
    limits = np.empty((n_params, 2), np.float32)
    
    # x, y, z range from average - init_pos_range to average + average + init_pos_range
    limits[:3, 0] = hit_avgs[:3] - init_pos_range
    limits[:3, 1] = hit_avgs[:3] + init_pos_range
    
    # time is time average + time_range[0] to average + time_range[1]
    limits[3, 0] = hit_avgs[3] + time_range[0]
    limits[3, 1] = hit_avgs[3] + time_range[1]
        
    # azimuth from 0 to 2pi
    limits[4] = [0, 2*math.pi]
    
    # zenith from 0 to pi
    limits[5] = [0, math.pi]
    
    # log energies from log_energy_range
    limits[6:, 0] = log_energy_range[0]
    limits[6:, 1] = log_energy_range[1]
 
    return limits

def out_of_bounds(params, limits=param_search_limits):
    '''returns boolean array, True for param rows that are out of bounds'''
    return ~np.alltrue(np.logical_and(limits[0] <= params, params <= limits[1]), axis=-1)

In [11]:
NAN_REPLACE_VAL = 1e9

def nan_replace(nll):
    # replace nans with valid, large values 
    nll[np.isnan(nll)] = NAN_REPLACE_VAL
    return nll

def get_batch_closure(client, event):
    def eval_llh(params):
        llhs = client.eval_llh(event['hits'][:, :8], event['doms'][allowed_DOMs], params)
        
        llhs = np.atleast_1d(llhs)
        
        llhs = nan_replace(llhs)
        
        clipped_params = np.copy(params)
        
        llhs[out_of_bounds(params)] = NAN_REPLACE_VAL
            
        return llhs
    return eval_llh

def batch_crs_fit(event, client, rng):
    eval_llh = get_batch_closure(client, event)
    
    box_limits = initial_box(event['hits'])
    
    n_params = len(event['params'])
    
    uniforms = rng.uniform(size=(97, n_params))

    initial_points = box_limits[:, 0] + uniforms*(box_limits[:,1] - box_limits[:, 0])
    
    # energy parameters need to be converted from log energy to energy
    initial_points[:, 6:] = 10**initial_points[:, 6:]
    
    opt_ret = spherical_opt.spherical_opt(
                        func=eval_llh, 
                        method="CRS2", 
                        initial_points=initial_points,
                        spherical_indices=[[4,5]],
                        max_iter=10000,
                        batch_size=12,
                        rand=rng) 
    
    return opt_ret

In [12]:
def fit_events(events, index=0):
    rng = np.random.default_rng()
    
    outputs = []
        
    client = LLHClient(
                    ctrl_addr=ctrl_addrs[index],
                    conf_timeout=60000
                  )
    
    for event in events:
        fit_res = batch_crs_fit(event, client, rng)
        
        true_param_llh = client.eval_llh(event['hits'][:, :8], 
                                         event['doms'][allowed_DOMs],
                                         event['params'])
        retro_param_llh = client.eval_llh(event['hits'][:, :8],
                                          event['doms'][allowed_DOMs],
                                          event['retro'])
        
        outputs.append((fit_res, true_param_llh, retro_param_llh))
                
    return outputs

### Fit a single event

In [13]:
%%time
test_out = fit_events(events[:1])

CPU times: user 600 ms, sys: 56 ms, total: 656 ms
Wall time: 24.3 s


In [14]:
print(test_out[0][0]['fun'])
print(test_out[0][1])
print(test_out[0][2])
print('---')
print(test_out[0][0]['n_calls'])
print(test_out[0][0]['nit'])

-34.666374
-31.858068
-27.370792
---
4020
196


In [15]:
events_to_process = len(events)
# events_to_process = 400
pool_size = 100
evts_per_proc = int(math.ceil(events_to_process/pool_size))
evt_splits = [events[i*evts_per_proc:(i+1)*evts_per_proc] for i in range(pool_size)]
print(sum(len(l) for l in evt_splits))

2963


In [16]:
gpu_inds = np.arange(pool_size) % n_gpus

In [17]:
%%time
start = time.time()
# reconstruct with a worker pool; one LLH client per worker
with Pool(pool_size) as p:
    outs = p.starmap(fit_events, zip(evt_splits, gpu_inds))
delta = time.time() - start

CPU times: user 6.78 s, sys: 8.66 s, total: 15.4 s
Wall time: 30min 35s


In [18]:
print(f'measured time: {delta/60:.1f} minutes')

measured time: 30.6 minutes


In [19]:
print(sum(len(out) for out in outs))

2963


In [20]:
n_params = len(events[0]['params'])

In [21]:
all_outs = sum((out for out in outs), [])

Note: the following timing is from running on four Titan X GPUs in parallel

In [22]:
total_calls = sum(out[0]['n_calls'] for out in all_outs)
total_iters = sum(out[0]['nit'] for out in all_outs)
print(f'{total_calls} total calls')
time_per_call = delta/total_calls
print(f'{total_iters} total iters')
time_per_iter = delta/total_iters
print(f'{total_calls/len(all_outs):.1f} calls per event')
print(f'{time_per_call*1e6:.2f} us per call')

print(f'{total_iters/len(all_outs):.1f} iters per event')
print(f'{time_per_iter*1e6:.2f} us per iter')

13789402 total calls
690397 total iters
4653.9 calls per event
133.08 us per call
233.0 iters per event
2658.01 us per iter


In [23]:
### Build df
evt_idx = []
free_fit_llhs = []
true_param_llhs = []
retro_param_llhs = []
n_calls = []
n_iters = []
best_fit_ps = [[] for _ in range(n_params)]

for i, out in enumerate(all_outs):
    freedom_params = out[0]['x']
    freedom_llh = out[0]['fun']
    n_calls.append(out[0]['n_calls'])
    n_iters.append(out[0]['nit'])
    
    evt_idx.append(i)
    free_fit_llhs.append(freedom_llh)
    true_param_llhs.append(out[1])
    retro_param_llhs.append(out[2])
    for p_ind, p in enumerate(freedom_params):
        best_fit_ps[p_ind].append(p)
        
par_names = ['x', 'y', 'z', 'time',
             'azimuth', 'zenith', 
             'cascade energy', 'track energy']

df_dict = dict(evt_idx=evt_idx, free_fit_llh=free_fit_llhs, 
               true_p_llh=true_param_llhs, retro_p_llh=retro_param_llhs,
               n_calls=n_calls, n_iters=n_iters)

for p_name, p_list in zip(par_names, best_fit_ps):
    df_dict[p_name] = p_list
    
df = pd.DataFrame(df_dict)
df.head()

Unnamed: 0,evt_idx,free_fit_llh,true_p_llh,retro_p_llh,n_calls,n_iters,x,y,z,time,azimuth,zenith,cascade energy,track energy
0,0,-34.461761,-31.858068,-27.370792,2906,145,-23.323169,-6.585221,-259.677067,9797.042344,5.036371,1.051212,8.933672,0.434484
1,1,-96.201569,-90.54525,-88.510056,5989,293,65.400763,-70.280209,-269.121384,9853.956884,4.272805,2.028426,3.612497,0.291329
2,2,-202.435959,-178.86467,-188.484695,7485,362,31.326009,45.812867,-314.598394,9797.352458,2.351043,2.323403,7.053778,1.248822
3,3,-40.53133,-35.01379,-26.499874,3110,156,11.344962,-71.752687,-359.520088,9774.843775,2.33984,2.171221,3.204777,1.300684
4,4,-45.383133,-41.375755,-39.881035,5511,261,108.355228,-62.424106,-482.344295,9855.216802,4.605943,1.596494,1.988044,0.755337


In [24]:
free_f_better = df[df.free_fit_llh < df.retro_p_llh + 10]
frac = len(free_f_better)/len(df)
print(f'free fit better frac: {frac:.2f}')

free fit better frac: 0.98


In [25]:
bad = df[df.free_fit_llh >= df.retro_p_llh + 5]

In [26]:
bad.head()

Unnamed: 0,evt_idx,free_fit_llh,true_p_llh,retro_p_llh,n_calls,n_iters,x,y,z,time,azimuth,zenith,cascade energy,track energy
53,53,-22.293381,-36.338837,-42.369255,2666,138,37.061517,-74.601142,-586.969872,9205.867274,2.959649,1.154117,13.358325,1.393736
54,54,-26.306883,-32.668518,-35.08102,5095,250,30.436879,-222.995996,-322.303526,9187.060493,4.151299,1.193359,0.111457,18.920978
234,234,-32.275146,-44.071358,-43.780968,1390,72,5.180658,-120.414944,-575.580559,9289.718886,2.244293,0.5074,48.962139,64.811871
296,296,-35.197254,-41.849667,-49.359653,21192,907,199.54045,46.096608,-318.117535,10993.825979,0.019116,2.758519,28.002436,94.757256
300,300,-51.624302,-60.576248,-64.871147,3214,166,-185.338219,-38.929691,-495.762028,9149.63415,3.566294,1.859423,6.624288,43.340049


In [29]:
df.to_pickle('./fit_results.pkl')

In [28]:
# kill all the services
import zmq
for proc, ctrl_addr in zip(procs, ctrl_addrs): 
    with zmq.Context.instance().socket(zmq.REQ) as ctrl_sock:
        ctrl_sock.connect(ctrl_addr)
        ctrl_sock.send_string("die")
        proc.join()