In [None]:
# Standard Imports
import numpy as np
from time import time
from matplotlib import pyplot as plt
import os

from particleloader import load
from particleloader.utils import center_and_normalize_zyphi
from pyspecter.SPECTER import SPECTER
from pyspecter.CommonObservables import build_jet_observables

# Utils
try:
    from rikabplotlib.plot_utils import newplot, plot_event
except:
    from pyspecter.utils.plot_utils import newplot, plot_event


In [None]:
# %%%%%%%%%% Load Data %%%%%%%%%%

# Parameters 
R = 1.0
lr = 0.005
epochs = 1500

N_samples = 150
batch_size = 10000

jets_dataset_name = "SPECTER_qcd_jets" # Also try: "SPECTER_top_jets"
cache_dir = "~/.ParticleLoader"

jets_dataset = load(jets_dataset_name, N_samples)
jets_dataset = center_and_normalize_zyphi(jets_dataset)
jets_dataset = jets_dataset[:,:125,:3]

# Plot an event
fig, ax = newplot("full")
plot_event(ax, jets_dataset[0], R)

In [None]:
jet_observables_dict = build_jet_observables(R = R)

observable_keys = ["spLineliness", "spRinginess", "spDiskiness", "1-sPronginess", "2-sPronginess", "3-sPronginess"]
observable_names = ["line","ring", "disk", "1sprong", "2sprong", "3sprong"]

In [None]:
batch = 0

emds_dict = {}
params_dict = {}
hard_emds_dict = {}
hard_params_dict = {}


for batch_start in range(0, jets_dataset.shape[0], batch_size):


    batch_end = batch_start + batch_size
    dataset = jets_dataset[batch_start:batch_end]


    for o, observable_key in enumerate(observable_keys):

        observable = jet_observables_dict[observable_key]
        observable_name = observable_names[o]
        emds, params, loss_history, params_history = observable.compute(dataset, learning_rate= lr, early_stopping=150, N_sample = 125, finite_difference=False, epochs = epochs)

        # save
        emds_dict[observable_name] = emds
        params_dict[observable_name] = params

        # Try exact computation
        try: 
            hard_emds, hard_params = observable.hard_compute(dataset)

            hard_emds_dict[observable_name] = hard_emds
            hard_params_dict[observable_name] = hard_params
        except Exception as e:
            print(f"{observable_name} has no exact computation")


    batch += 1



In [None]:
for o, observable_key in enumerate(observable_keys):

    observable_name = observable_names[o]

    fig, ax = newplot("full")
    emds = emds_dict[observable_name]
    ax.hist(emds, bins=100, alpha=0.5, label=observable_name, color = "red", alpha = 0.5)
    try:
        hard_emds = hard_emds_dict[observable_name]
        ax.hist(hard_emds, bins=100, alpha=0.5, label="Hard", color = "darkred", histtype="step")
    except:
        pass

