In [None]:
import sys
sys.path.append("../")

import numpy as np
import matplotlib.pyplot as plt
import torch
import simulators.extended_particle_model_freq as esim
import simulators.on_off_sim as oosim
from utils.plots_wald import plot
from models.deep_set_freq import BinnedModel, generate_for_wald

def param_eval(model,X,poi):
    '''
    parametrized evaluation of a model (by concatenating data and POI)
    '''
    poi_col = poi*torch.ones(X.shape[0])[:,None]
    Xpar = torch.cat([X,poi_col],dim=-1)
    p = model(Xpar)
    return p

def train(simulator,model, start = 0, Nsteps = 30000, bins = None, save_plot = False):
    opt = torch.optim.Adam([
        {'params': model.onset.parameters(), 'lr': 1e-3},
        {'params': model._per_elem.parameters(), 'lr': 1e-4}
    ])
    model.train()
    for i in range(start,start + Nsteps):
        model.train()
        # generate a random POI to and generate data from it
        poi = np.random.uniform(3,7)
        XX,yy,poi = generate_for_wald(simulator, poi, N = 100, bins = bins)
        y = torch.cat(yy)
        p = torch.cat([model(X,poi) for X in XX])
        loss = torch.nn.functional.binary_cross_entropy(p,y)
        loss.backward()
        if i % 250 == 0:
            print(i,float(loss))
            plot(simulator, model, 6, _hp, on_off_lrt, bins = bins)
            if save_plot:
                plt.savefig(f'anim{str(i).zfill(10)}.png')
            if i % 1000 == 0:
                plt.show()
            plt.close()
        if i % 5 == 0:
            opt.step()
            opt.zero_grad()
    return model.eval()

In [None]:
_hp = oosim.on_off_hpars(lumi = 1.0, s0 = 10, b0 = 50, tau = 1)
simulator = esim.get_reparam_simulator(oosim.on_off_reparam,_hp)
on_off_lrt = esim.get_reparam_lrtfunc(oosim.on_off_reparam, _hp)
model = BinnedModel(hardscale = 1, n_elem_feats = 10, set_encoder='ele')
bins = None
trained_model = train(simulator,model, Nsteps=15000, bins = bins)