In [1]:
from bins import Bins
from utils import calcAllChi2, calcOneChi2, makeHists
from ROOT import TFile, TH1

bins = Bins.readFrom("ranges.yml")
#histsMC = makeHists("medium_isotropic_eff_ag1230ag_np_9deg.dat", "_MC", bins)
histsData = makeHists("apr12_diele_088_090_ag123ag_2500A_accepted_np_2.dat", "_data", bins)

#outfile = TFIle("out.root","RECREATE")
allHistsMC = []

#calcAllChi2(histsMC, histsData)

Welcome to JupyROOT 6.28/04
Before processing events 12:52:48


  0%|          | 0/38546 [00:00<?, ?it/s]

After processing events 12:52:48


In [2]:
len(histsData[0])

12

In [3]:
import torch

torch.manual_seed(0)

import gpytorch
import botorch

import matplotlib.pyplot as plt

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (8, 6)

from tqdm.notebook import tqdm

import warnings

In [4]:
TH1.SetDefaultSumw2

N_PARAMS = 3

grid_x = torch.linspace(-1,1,101)

grid_x1, grid_x2, grid_x3 = torch.meshgrid(grid_x, grid_x, grid_x, indexing="ij")

xs = torch.vstack([grid_x1.flatten(), grid_x2.flatten(), grid_x3.flatten()]).transpose(-2,-1)

print(xs)
print(xs.size())

tensor([[-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -0.9800],
        [-1.0000, -1.0000, -0.9600],
        ...,
        [ 1.0000,  1.0000,  0.9600],
        [ 1.0000,  1.0000,  0.9800],
        [ 1.0000,  1.0000,  1.0000]])
torch.Size([1030301, 3])


In [5]:
lb = -1
ub = 1

bounds = torch.tensor([[lb]*N_PARAMS, [ub]*N_PARAMS], dtype=torch.float)
bounds

tensor([[-1., -1., -1.],
        [ 1.,  1.,  1.]])

In [6]:
class GPModel(gpytorch.models.ExactGP, botorch.models.gpytorch.GPyTorchModel):
    _num_outputs = 1

    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=1)
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


def fit_gp_model(train_x, train_y, num_train_iters=500):
    # declare the GP
    noise = 1e-4

    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = GPModel(train_x, train_y, likelihood)
    model.likelihood.noise = noise

    # train the hyperparameter (the constant)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    model.train()
    likelihood.train()

    for i in range(num_train_iters):
        optimizer.zero_grad()

        output = model(train_x)
        loss = -mll(output, train_y)

        loss.backward()
        optimizer.step()

    model.eval()
    likelihood.eval()

    return model, likelihood

In [7]:
num_queries = 100
num_repeats = 1

In [8]:
def lambdas(x):
    return x[0][0], x[0][1], x[0][2]

def cost(x):
    lambda_theta, lambda_phi, lambda_theta_phi = lambdas(x)

    result = 0

    if pow(1-lambda_phi, 2) - pow(lambda_theta-lambda_phi, 2) < 4*pow(lambda_theta_phi, 2):
        result = 1
    if 1 + lambda_theta + 2*lambda_phi < 0:
        result = 1

    return torch.tensor([result])

def objective(x):
    lambda_theta, lambda_phi, lambda_theta_phi = lambdas(x)

    histsMC = makeHists("medium_isotropic_eff_ag1230ag_np_9deg.dat", "_MC", bins, lambda_theta, lambda_phi, lambda_theta_phi)
    chi2, ndf = calcOneChi2(histsMC[0][0], histsData[0][0])
    allHistsMC.append(histsMC[0][0])
    if not chi2 or not ndf:
        return torch.tensor([0])
    return torch.tensor([1./(chi2 / ndf)])

In [9]:
strategy = "cei"  # "cei" or "ei"
strategy = strategy.upper()

# -2 is the default value when no feasible has been found
default_value = -2
feasible_incumbents = torch.ones((num_repeats, num_queries)) * default_value

for trial in range(num_repeats):
    print("trial", trial)

    torch.manual_seed(trial)
    train_x = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(1, 3)
    train_utility = objective(train_x)
    train_cost = cost(train_x)

    for i in tqdm(range(num_queries)):
        print("query", i)

        feasible_flag = (train_cost <= 0).any()

        if feasible_flag:
            feasible_incumbents[trial, i] = train_utility[train_cost <= 0].max()

        utility_model, utility_likelihood = fit_gp_model(
            train_x, train_utility.squeeze(-1)
        )
        cost_model, cost_likelihood = fit_gp_model(train_x, train_cost.squeeze(-1))

        if feasible_flag:
            best_f = train_utility[train_cost <= 0].max()
        else:
            best_f = torch.tensor(default_value)

        if strategy == "EI":
            policy = botorch.acquisition.analytic.ExpectedImprovement(
                model=utility_model,
                best_f=train_utility.max(),
            )

        if strategy == "CEI":
            policy = botorch.acquisition.analytic.ConstrainedExpectedImprovement(
                model=botorch.models.model_list_gp_regression.ModelListGP(
                    utility_model, cost_model
                ),
                best_f=best_f,
                objective_index=0,
                constraints={1: [None, 0]}
            )
        
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', category=RuntimeWarning)
            next_x, acq_val = botorch.optim.optimize_acqf(
                policy,
                bounds=bounds,
                q=1,
                num_restarts=40,
                raw_samples=100,
            )

        next_utility = objective(next_x)
        next_cost = cost(next_x)

        train_x = torch.cat([train_x, next_x])
        train_utility = torch.cat([train_utility, next_utility])
        train_cost = torch.cat([train_cost, next_cost])

trial 0
Before processing events 12:53:30


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:54:02


  0%|          | 0/100 [00:00<?, ?it/s]

query 0




Before processing events 12:54:38


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:55:11
query 1




Before processing events 12:55:47


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:56:19
query 2




Before processing events 12:56:55


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:57:27
query 3




Before processing events 12:58:04


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:58:36
query 4
Before processing events 12:59:13


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 12:59:46
query 5
Before processing events 13:00:25


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:00:58
query 6
Before processing events 13:01:37


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:02:09
query 7
Before processing events 13:02:50


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:03:24
query 8
Before processing events 13:04:02


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:04:34
query 9
Before processing events 13:05:11


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:05:43
query 10
Before processing events 13:06:20


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:06:52
query 11
Before processing events 13:07:29


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:08:01
query 12
Before processing events 13:08:38


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:09:10
query 13
Before processing events 13:09:49


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:10:20
query 14
Before processing events 13:10:57


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:11:29
query 15
Before processing events 13:12:06


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:12:37
query 16
Before processing events 13:13:14


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:13:46
query 17
Before processing events 13:14:23


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:14:55
query 18
Before processing events 13:15:31


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:16:03
query 19
Before processing events 13:16:40


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:17:12
query 20
Before processing events 13:17:48


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:18:20
query 21
Before processing events 13:18:56


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:19:28
query 22
Before processing events 13:20:05


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:20:37
query 23
Before processing events 13:21:13


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:21:45
query 24
Before processing events 13:22:21


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:22:53
query 25
Before processing events 13:23:30


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:24:02
query 26
Before processing events 13:24:38


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:25:09
query 27
Before processing events 13:25:46


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:26:17
query 28
Before processing events 13:26:54


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:27:25
query 29
Before processing events 13:28:02


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:28:33
query 30
Before processing events 13:29:10


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:29:41
query 31
Before processing events 13:30:18


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:30:49
query 32
Before processing events 13:31:25


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:31:57
query 33
Before processing events 13:32:34


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:33:06
query 34
Before processing events 13:33:42


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:34:13
query 35
Before processing events 13:34:50


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:35:22
query 36
Before processing events 13:35:58


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:36:30
query 37
Before processing events 13:37:07


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:37:38
query 38
Before processing events 13:38:14


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:38:46
query 39
Before processing events 13:39:22


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:39:54
query 40
Before processing events 13:40:31


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:41:02
query 41
Before processing events 13:41:39


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:42:10
query 42
Before processing events 13:42:47


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:43:18
query 43
Before processing events 13:43:55


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:44:26
query 44
Before processing events 13:45:03


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:45:34
query 45
Before processing events 13:46:11


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:46:43
query 46
Before processing events 13:47:19


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:47:51
query 47
Before processing events 13:48:28


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:49:00
query 48
Before processing events 13:49:39


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:50:12
query 49
Before processing events 13:50:48


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:51:20
query 50
Before processing events 13:51:57


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:52:28
query 51
Before processing events 13:53:05


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:53:37
query 52
Before processing events 13:54:14


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:54:45
query 53
Before processing events 13:55:23


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:55:54
query 54
Before processing events 13:56:32


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:57:04
query 55
Before processing events 13:57:40


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:58:12
query 56
Before processing events 13:58:49


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 13:59:20
query 57
Before processing events 13:59:58


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:00:29
query 58
Before processing events 14:01:06


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:01:37
query 59
Before processing events 14:02:14


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:02:46
query 60
Before processing events 14:03:23


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:03:55
query 61
Before processing events 14:04:32


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:05:03
query 62
Before processing events 14:05:40


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:06:12
query 63
Before processing events 14:06:51


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:07:24
query 64
Before processing events 14:08:06


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:08:39
query 65
Before processing events 14:09:20


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:09:52
query 66
Before processing events 14:10:33


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:11:05
query 67
Before processing events 14:11:46


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:12:18
query 68
Before processing events 14:13:02


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:13:35
query 69
Before processing events 14:14:17


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:14:50
query 70
Before processing events 14:15:32


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:16:05
query 71
Before processing events 14:16:45


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:17:19
query 72
Before processing events 14:18:02


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:18:34
query 73
Before processing events 14:19:12


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:19:44
query 74
Before processing events 14:20:24


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:20:57
query 75
Before processing events 14:21:38


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:22:10
query 76
Before processing events 14:22:51


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:23:24
query 77
Before processing events 14:24:04


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:24:36
query 78
Before processing events 14:25:18


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:25:51
query 79
Before processing events 14:26:33


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:27:09
query 80
Before processing events 14:27:55


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:28:28
query 81
Before processing events 14:29:13


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:29:45
query 82
Before processing events 14:30:28


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:31:01
query 83
Before processing events 14:31:40


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:32:12
query 84
Before processing events 14:32:52


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:33:25
query 85
Before processing events 14:34:05


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:34:38
query 86
Before processing events 14:35:19


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:35:52
query 87
Before processing events 14:36:34


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:37:07
query 88
Before processing events 14:37:51


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:38:23
query 89
Before processing events 14:39:08


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:39:42
query 90
Before processing events 14:40:26


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:40:59
query 91
Before processing events 14:41:39


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:42:12
query 92
Before processing events 14:42:50


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:43:22
query 93
Before processing events 14:44:03


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:44:35
query 94
Before processing events 14:45:19


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:45:51
query 95
Before processing events 14:46:34


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:47:06
query 96
Before processing events 14:47:47


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:48:19
query 97
Before processing events 14:48:59


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:49:32
query 98
Before processing events 14:50:11


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:50:45
query 99
Before processing events 14:51:28


  0%|          | 0/1385092 [00:00<?, ?it/s]

After processing events 14:52:03


In [10]:
torch.save(feasible_incumbents, f"./{strategy}.pth")
fout = TFile("out.root", "RECREATE")
fout.cd()
for hist in allHistsMC:
    print ("Writing hist: ", hist.GetName())
    hist.Write()
for j, hists in enumerate(histsData):
    for k, hist in enumerate(hists):
            hist.Write()
            print ("Writing hist: ", j, k, hist)
fout.Close()

Writing hist:  hist_0to150_m06m02_MC_iter1
Writing hist:  hist_0to150_m06m02_MC_iter2
Writing hist:  hist_0to150_m06m02_MC_iter3
Writing hist:  hist_0to150_m06m02_MC_iter4
Writing hist:  hist_0to150_m06m02_MC_iter5
Writing hist:  hist_0to150_m06m02_MC_iter6
Writing hist:  hist_0to150_m06m02_MC_iter7
Writing hist:  hist_0to150_m06m02_MC_iter8
Writing hist:  hist_0to150_m06m02_MC_iter9
Writing hist:  hist_0to150_m06m02_MC_iter10
Writing hist:  hist_0to150_m06m02_MC_iter11
Writing hist:  hist_0to150_m06m02_MC_iter12
Writing hist:  hist_0to150_m06m02_MC_iter13
Writing hist:  hist_0to150_m06m02_MC_iter14
Writing hist:  hist_0to150_m06m02_MC_iter15
Writing hist:  hist_0to150_m06m02_MC_iter16
Writing hist:  hist_0to150_m06m02_MC_iter17
Writing hist:  hist_0to150_m06m02_MC_iter18
Writing hist:  hist_0to150_m06m02_MC_iter19
Writing hist:  hist_0to150_m06m02_MC_iter20
Writing hist:  hist_0to150_m06m02_MC_iter21
Writing hist:  hist_0to150_m06m02_MC_iter22
Writing hist:  hist_0to150_m06m02_MC_iter

In [11]:
t = torch.tensor( [ [-1.0000, -0.4678,  0.5433]])
print(t)
print(t.shape)
cost(t)

tensor([[-1.0000, -0.4678,  0.5433]])
torch.Size([1, 3])


tensor([1])

In [12]:
with torch.no_grad():
    predictive_distribution = cost_likelihood(cost_model(xs))
    predictive_mean = predictive_distribution.mean
    predictive_lower, predictive_upper = predictive_distribution.confidence_region()


In [15]:
from ipywidgets import interact, Layout, IntSlider
import numpy as np

def oneplot(ax, tensor, index, cmap):
        tensor_3d = torch.reshape(tensor, (101,101,101))
        pos = ax.imshow(tensor_3d[index], cmap=cmap, interpolation="nearest", origin="lower", 
        vmin=tensor.min(), vmax=tensor.max(), extent=[-1,1,-1,1])
        plt.colorbar(pos)

def f(x):
        fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(6, 6))
        oneplot(ax[0][0], predictive_mean, x, "hot")
        oneplot(ax[0][1], predictive_lower, x, "hot")
        oneplot(ax[1][1], predictive_upper, x, "hot")
        oneplot(ax[1][0], predictive_upper-predictive_lower, x, "hot")

interact(f, x=IntSlider(50, 0, 100, 1, layout=Layout(width='500px')))


interactive(children=(IntSlider(value=50, description='x', layout=Layout(width='500px')), Output()), _dom_clas…

<function __main__.f(x)>

In [14]:
with torch.no_grad():
    predictive_distribution_cost = cost_likelihood(cost_model(xs))
    predictive_mean_cost = predictive_distribution.mean
    predictive_lower_cost, predictive_upper_cost = predictive_distribution.confidence_region()


def g(x):
        fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(6, 6))
        oneplot(ax[0][0], predictive_mean_cost, x, "cool")
        oneplot(ax[0][1], predictive_lower_cost, x, "cool")
        oneplot(ax[1][1], predictive_upper_cost, x, "cool")


interact(g, x=IntSlider(50, 0, 100, 1, layout=Layout(width='500px')))

interactive(children=(IntSlider(value=50, description='x', layout=Layout(width='500px')), Output()), _dom_clas…

<function __main__.g(x)>