In [None]:
# Standard Imports
import numpy as np
from time import time
from matplotlib import pyplot as plt
import os


from pyspecter.SPECTER import SPECTER
from pyspecter.Observables import Observable
# from pyspecter.SpecialObservables import SpecialObservables

# Utils
from pyspecter.utils.data_utils import load_cmsopendata, load_triangles
from pyspecter.utils.plot_utils import newplot, plot_event

# Jax
from jax import grad, jacobian, jit
import jax.numpy as jnp
from jax import random

# SPECTER
from pyspecter.SpectralEMD_Helper import compute_spectral_representation

In [None]:
# Parameters 
R = 0.01
this_dir = ""
this_study = "cmsopendata"
n_events = 100
epochs = 150

dataset_open, sim_weights, k = load_cmsopendata("~/.energyflow/", "sim", 475, 525, 1.9, 2, pad = 75, n = n_events)
triangle_events, triangle_indices = load_triangles(180, 180, R = 1.0, return_indices=True)

dataset = dataset_open

In [None]:
# SHAPER
from pyshaper.CommonObservables import buildCommmonObservables
from pyshaper.Observables import Observable
from pyshaper.Shaper import Shaper

# Necessary GPU nonsense
import torch 

if torch.cuda.is_available():  
    dev = "cuda:0" 
    print("Using GPU!")
else:  
    dev = "cpu"  
    print("Using CPU!")
device = torch.device(dev) 


# Generate new events

R = 1.0
N = 180
angles = np.linspace(0, np.pi, N)
energies = np.linspace(0, 1, N)

events = []
for i in range(dataset.shape[0]):

    event = dataset[i]

    zs = event[:,0]
    points = event[:,1:3]

    events.append((points, zs))


# EMDs

commonObservables, pointers = buildCommmonObservables(N = 3, beta = 2, R = R, device = device)
_2subjettiness = commonObservables["1-Ringiness"]


# Collect our observables in one dictionary
observables = {}
observables["1-Ringiness"] = _2subjettiness



# Initialize SHAPER
shaper = Shaper(observables, device)
shaper.to(device)


EMDs = []
params = []

plot_dictionary = {
    "plot_directory" : "Plots",
    "gif_directory" : "gifs",  
    "extension" : "png",
    "title" : "SIM Jets"
}
# N = 100, scaling = 0.9, epsilon = 0.001, early_stopping= 25, early_stopping_fraction = 0.95, plot_dictionary=plot_dictionary)



dataset_emds, dataset_params = shaper.calculate(events[0], epochs = 500, verbose=True, lr = 0.01, N = 50, scaling = 0.9, epsilon = 0.001, early_stopping= 25, plot_dictionary=plot_dictionary)
# for j in range(dataset.shape[0]):
#     e = dataset_params["1-Ringiness"][j]["EMD"]
#     EMDs.append(e)
#     params.append(dataset_params["1-Ringiness"][j]["Radius"])



# filename = f"{this_study}_shaper_EMDs.npy"
# save_dir = os.path.join(this_dir, filename)
# np.save(save_dir, EMDs)

# filename = f"{this_study}_shaper_params.npy"
# save_dir = os.path.join(this_dir, filename)
# np.save(save_dir, params)

In [None]:
from pyspecter.SpectralEMD_Helper import ds2_events1_spectral2
import jax.example_libraries.optimizers as jax_opt
from jax.example_libraries import optimizers
import jax
import jax.lax as lax
import tqdm

# Compile 
compiled_compute_spectral_representation = jit(compute_spectral_representation)




def initialize(event, N, seed):

    return {"Radius" : R}


def sample_circle(params, N, seed):

    key = jax.random.PRNGKey(seed)
    thetas = jax.random.uniform(key, shape=(N,), minval=0., maxval=2*jnp.pi)
    
    # Compute x and y coordinates of the sampled points on the circle
    x = params["Radius"] * jnp.cos(thetas)
    y = params["Radius"] * jnp.sin(thetas)
    
    event = jnp.column_stack([jnp.ones(N) / N, x, y])
    
    return event

sample_circle = jax.jit(sample_circle, static_argnums=(1,))

@jax.jit
def project(params):

    radius = params["Radius"]
    params["Radius"] = jnp.maximum(radius, 0.)
    return params
    



def train_step(epoch, spectral_event, params):

    shape_event = sample_circle(params, 25, seed = epoch)
    sEMDS = jax.checkpoint(ds2_events1_spectral2)(shape_event, spectral_event)
    return sEMDS

gradient_train_step = jax.jit(jax.jacfwd(train_step, argnums = 2))
train_step  = jax.jit(train_step)

vmapped_compute_spectral_representation = jax.vmap(compiled_compute_spectral_representation, in_axes = (0,))
vmapped_train_step = jax.vmap(train_step, in_axes = (None, 0, 0))
vmapped_gradient_train_step = jax.vmap(gradient_train_step, in_axes = (None, 0, 0))
vmapped_initialize = jax.vmap(initialize, in_axes = (0, None, None))
vmapped_project = jax.vmap(project, in_axes = (0,))



# Chain rule
few_to_many_grad = jax.grad(sample_circle, argnums=0)
many_to_few = jax.grad(ds2_events1_spectral2)
def my_gradient(epoch, spectral_event, params):
    shape_event = sample_circle(params, 25, seed = epoch)
    d1 = few_to_many_grad(params, 25, seed = epoch)
    d2 = many_to_few(shape_event, spectral_event)
    return jnp.dot(d1, d2)
vmapped_my_gradient = jax.vmap(my_gradient, in_axes = (0,))

@jax.jit
def finite_differences_gradient(epoch, spectral_event, params, epsilon=1e-2):
    """
    Compute the gradient of `loss_fn` with respect to `params` using finite differences.
    This version uses multiplicative epsilon and is JAX-compilable.
    
    Args:
        params (dict): A dictionary containing the parameters.
        loss_fn (callable): The loss function to compute the gradient of.
        x (array): The input data for the loss function.
        y (array): The output data for the loss function.
        epsilon (float, optional): The small relative change to apply to each parameter to calculate finite differences.
    
    Returns:
        dict: A dictionary of gradients for each parameter.
    """
    
    def get_perturbed_loss(delta):
        # Perturb each parameter by a small relative amount and evaluate the loss
        perturbed_params = jax.tree_map(lambda v: v * (1 + delta), params)
        return train_step(epoch, spectral_event, perturbed_params)

    # Compute the perturbed losses for positive and negative epsilon
    loss_plus_epsilon = get_perturbed_loss(epsilon)
    loss_minus_epsilon = get_perturbed_loss(-epsilon)

    # Use jax.tree_multimap to compute the gradient for each parameter
    gradients = jax.tree_map(
        lambda v: (loss_plus_epsilon - loss_minus_epsilon) / (2 * epsilon * v),
        params
    )
    
    return gradients


vmapped_finite_differences_gradient = jax.vmap(finite_differences_gradient, in_axes = (None, 0, 0))

def compute_single_event(event, learning_rate = 0.001, epochs = 150, finite_difference = True):

    spectral_event = compiled_compute_spectral_representation(event)
    params =   initialize(event, 75, seed = random.PRNGKey(0))

    # Optimizer
    opt_state = None
    opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
    opt_state = opt_init(params)

    losses = np.zeros((epochs,))
    for epoch in tqdm.tqdm(range(epochs)):

        params = get_params(opt_state)
        params = project(params)
        
        sEMD = train_step(epoch, spectral_event, params)
        if finite_difference:
            grads = finite_differences_gradient(epoch, spectral_event, params)
        else:
            grads = gradient_train_step(epoch, spectral_event, params)
        opt_state = opt_update(epoch, grads, opt_state)

       # Apply the separate function to modify the parameters
        new_params = project(get_params(opt_state))

        # Manually modify the opt_state's parameters without resetting internal state
        opt_state = replace_params_in_state(opt_state, new_params)
        losses[epoch] = sEMD

    return jnp.min(losses), params, losses





def compute_events(events, learning_rate = 0.001, epochs = 150, finite_difference = True, save_history = False):

    spectral_event = vmapped_compute_spectral_representation(events)
    print(spectral_event.shape, events.shape)
    params =   vmapped_initialize(events, 75, 0)
    best_params = params.copy()
    params_history = []  

    # Optimizer
    opt_state = None
    opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
    opt_state = opt_init(params)

    losses = np.ones((epochs,events.shape[0])) * 99999
    early_stopping_counter = 0

    for epoch in tqdm.tqdm(range(epochs)):

        params = get_params(opt_state)
        params = vmapped_project(params)
        
        sEMD = vmapped_train_step(epoch, spectral_event, params)
        if finite_difference:
            grads = vmapped_finite_differences_gradient(epoch, spectral_event, params)
        else:
            grads = vmapped_gradient_train_step(epoch, spectral_event, params)
        opt_state = opt_update(epoch, grads, opt_state)

       # Apply the separate function to modify the parameters
        new_params = vmapped_project(get_params(opt_state))

        # Manually modify the opt_state's parameters without resetting internal state
        opt_state = replace_params_in_state(opt_state, new_params)
        losses[epoch] = sEMD

        # if the loss has not changed in 10 epochs, stop
        if epoch > 10:
            if np.all(losses[epoch] >= losses[epoch-10]):
                early_stopping_counter += 1
            else:
                early_stopping_counter = 0

        if early_stopping_counter > 10:
            break

        # if the loss has decreased, save the parameters for each event
        for i in range(events.shape[0]):
            if i > 0 and sEMD[i] < losses[epoch-1,i]:
                for key in new_params.keys():
                    best_params[key] = best_params[key].at[i].set(new_params[key][i])


        if save_history:
            params_history.append(new_params.copy())

    if save_history:
        return jnp.min(losses, axis = 0), best_params, losses, params_history
    
    return jnp.min(losses, axis = 0), best_params, losses



def replace_params_in_state(opt_state, new_params):
    if isinstance(opt_state, tuple) and len(opt_state) == 2 and isinstance(opt_state[0], dict):
        # This is the parameter tuple for Adam
        return (new_params, opt_state[1])
    elif isinstance(opt_state, tuple):
        # Unpack and modify recursively
        return tuple(replace_params_in_state(sub_state, new_params) for sub_state in opt_state)
    else:
        # Leaf node or unknown type, return unchanged
        return opt_state






In [None]:
sEMDs, shape_events, losses, history = compute_events(dataset, 0.001, epochs, finite_difference=False, save_history=True)


In [None]:
radii = shape_events["Radius"]
min_radii_index = np.argmin(radii)
max_radii_index = np.argmax(radii)
min_radius = radii[min_radii_index]
max_radius = radii[max_radii_index]

print(min_radius, max_radius)



In [12]:
from pyjet import cluster
from PIL import Image
import glob

def kT_N(events, N, R):

    jets = []

    for event in events:

        # Set up 4-vectors
        four_vectors = []
        for particle in event:
            four_vectors.append((particle[0], particle[1], particle[2], 0))
        four_vectors = np.array(four_vectors, dtype=[("pt", "f8"), ("eta", "f8"), ("phi", "f8"), ("mass", "f8")])

        # Cluster with kT (p = 1)
        sequence = cluster(four_vectors, R=0.4, p=1)
        subjets = sequence.exclusive_jets(N)

        output = np.zeros((N, 3))
        for i, subjet in enumerate(subjets):
            output[i,0] = subjet.pt
            output[i,1] = subjet.eta
            output[i,2] = subjet.phi


        # Normalize
        output[:,0] = np.nan_to_num(output[:,0] / np.sum(output[:,0]))

        jets.append(output)


    return np.array(jets)

def draw(ax, center, radius):

    ax.add_artist(plt.Circle(center, radius, color='purple', fill=False, alpha = 0.5, lw = 3))


def cumulative_spectral_function(omega, radius):

    return np.minimum(np.nan_to_num(2 * 1 *  np.arcsin(omega / 2 / radius) / np.pi, nan = 1.0), 1)

def animate_history_2D(event_number, delete = True):

    

    for i in range(len(history)):

        fig, ax = newplot()


        # plot_event(history[i]["Radius"][event_number], ax, color = "blue")
        plot_event(dataset[event_number], R = 0.75, ax = ax, color = "red", show = False)

        center = kT_N([dataset[event_number],], 1, 1)[0][0,1:]
        radius = history[i]["Radius"][event_number]
        draw(ax, center, radius)

        # Text at the top right that says the epoch
        ax.text(0.95, 0.95, f"Epoch: {i}", horizontalalignment='right', verticalalignment='top', transform=ax.transAxes)

        # Text at the bottom left thtat says the radius and sEMD
        ax.text(0.05, 0.05, f"Radius: {radius:.3f}\n sEMD: {losses[i][event_number]:.3f}", horizontalalignment='left', verticalalignment='bottom', transform=ax.transAxes)

        # Add two lines of text to the upper left
        upper_margin = 0.96
        spacing = 0.05
        ax.text(0.02, upper_margin - 0 * spacing, "CMS Open Sim", transform=plt.gca().transAxes, verticalalignment='top')
        ax.text(0.02, upper_margin - 1 * spacing, f"2011AJets, Event {event_number}", transform=plt.gca().transAxes, verticalalignment='top')


        plt.savefig(f"gifs/temp/{event_number}_2D_{i}.png")

    # Convert the images to a gif
    frames = [Image.open(f"gifs/temp/{event_number}_2D_{i}.png") for i in range(len(history))]
    frame_one = frames[0]
    frame_one.save(f"gifs/animation_2D_{event_number}.gif", format="GIF", append_images=frames,
               save_all=True, duration=100, loop=0)
    
    if delete:
        for image in glob.glob(f"gifs/temp/{event_number}_2D_*.png"):
            os.remove(image)


def animate_history_1D(event_number, delete = True):

    event = dataset[event_number]
    spectral_event = compiled_compute_spectral_representation(event)
    omegas_event = spectral_event[:,0]

    cumulative_2EE = np.cumsum(spectral_event[:,1])

    for i in range(len(history)):

        fig, ax = newplot()

        radius = history[i]["Radius"][event_number]

        ax.plot(omegas_event, cumulative_2EE, label = "Original", color = "red", lw = 3)
        ax.plot(omegas_event, cumulative_spectral_function(omegas_event, radius), label = "Reconstructed", color = "purple", lw = 3, alpha = 0.5, ls = "--")

        ax.set_xlabel(r"$\omega$")
        ax.set_ylabel(r"Cumulative Spectral Function")

        # Text at the top right that says the epoch
        ax.text(0.95, 0.95, f"Epoch: {i}", horizontalalignment='right', verticalalignment='top', transform=ax.transAxes)

        # Text at the bottom left thtat says the radius and sEMD
        ax.text(0.05, 0.05, f"Radius: {radius:.3f}\n sEMD: {losses[i][event_number]:.3f}", horizontalalignment='left', verticalalignment='bottom', transform=ax.transAxes)

        # Add two lines of text to the upper left
        upper_margin = 0.96
        spacing = 0.05
        ax.text(0.02, upper_margin - 0 * spacing, "CMS Open Sim", transform=plt.gca().transAxes, verticalalignment='top')
        ax.text(0.02, upper_margin - 1 * spacing, "2011AJets, Event 0", transform=plt.gca().transAxes, verticalalignment='top')

        plt.ylim(0,1.2)

        plt.savefig(f"gifs/temp/{event_number}_1D_{i}.png")

    # Convert the images to a gif
    frames = [Image.open(f"gifs/temp/{event_number}_1D_{i}.png") for i in range(len(history))]
    frame_one = frames[0]
    frame_one.save(f"gifs/animation_1D_{event_number}.gif", format="GIF", append_images=frames,
               save_all=True, loop=0, duration = 100)
    
    if delete:
        for image in glob.glob(f"gifs/temp/{event_number}_1D_*.png"):
            os.remove(image)


In [13]:
radii = shape_events["Radius"]
min_radii_index = np.argmin(radii)
max_radii_index = np.argmax(radii)
min_radius = radii[min_radii_index]
max_radius = radii[max_radii_index]

animate_history_2D(min_radii_index, delete = False)
animate_history_2D(max_radii_index, delete = False)

In [None]:

animate_history_2D(0)
animate_history_1D(0)