In [1]:
# Standard Imports
import numpy as np
from time import time
from matplotlib import pyplot as plt


from pyspecter.SPECTER import SPECTER
from pyspecter.Observables import Observable
# from pyspecter.SpecialObservables import SpecialObservables

# Utils
from pyspecter.utils.data_utils import load_cmsopendata, load_triangles
from pyspecter.utils.plot_utils import newplot
from pyshaper.utils.plot_utils import plot_event

# Jax
from jax import grad, jacobian, jit
import jax.numpy as jnp
from jax import random

# SPECTER
from pyspecter.SpectralEMD_Helper import compute_spectral_representation

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
  register_backend(TensorflowBackend())


In [2]:
# Parameters 
R = 1
this_dir = "studies/n_subjettiness/"
this_study = "triangles"

triangle_events, triangle_indices = load_triangles(180, 180, R = 1.0, return_indices=True)

triangle_events_matrix = np.zeros((180, 180, 3, 3))
for i in range(len(triangle_indices)):
    triangle_events_matrix[triangle_indices[i][0], triangle_indices[i][1]] = triangle_events[i]

equal_energy_triangles = triangle_events_matrix[:,120]


In [13]:
from pyspecter.SpectralEMD_Helper import ds2_events1_spectral2
import jax.example_libraries.optimizers as jax_opt
import jax
import tqdm


# Gradients
vmap_compute_spectral_representation = (compute_spectral_representation)
vmap_ds2_events1_spectral2 = ds2_events1_spectral2
ds2_events1_spectral2_gradients = grad(ds2_events1_spectral2, )
vmap_ds2_events1_spectral2_gradients = ds2_events1_spectral2_gradients


# Function to enforce normalization of energies
def project(events):

    temp = jnp.copy(events)
    zs = events[:,:,0]
    num_particles = events.shape[1]
    batch_size = events.shape[0]

    cnt_n = jnp.arange(num_particles)
    cnt_i = jnp.arange(batch_size)

    u = jnp.sort(zs, axis = -1)[:,::-1]
    v = (jnp.cumsum(u, axis = -1)-1) / (cnt_n + 1)
    w = v[cnt_i, jnp.sum(u > v, axis = -1) - 1]
    temp2 = temp.at[:,:,0].set(jnp.maximum(zs - w[:,None], 0))

    return temp2

# For N-spronginess, the shape is just the params
def shape_from_params(params):
    return params


def initialize(events, N, seed):

    new_events = jnp.ones((events.shape[0], N, 3)) / N
    new_events = new_events.at[:,:,1:3].set(R/4 * jax.random.normal(seed, (N, 2)))
    new_events = new_events.at[:,0,1:3].set( (0,.000))
    new_events = new_events.at[:,1,1:3].set( (1,0))


    return new_events

def compute_N_spronginess(events, N, epochs = 100, learning_rate = 1e-3):
    
    # Get spectral representation of events:
    events_spectral = vmap_compute_spectral_representation(events)

    # Initialize events
    shape_events = initialize(events, N, seed = random.PRNGKey(0))

    # Optimizer
    opt_state = None
    opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
    opt_state = opt_init(shape_events)

    losses = np.zeros((epochs, events.shape[0]))


    for epoch in tqdm.tqdm(range(epochs)):


        params = get_params(opt_state)
        shape_events = shape_from_params(params)
        shape_events = project(shape_events)

        sEMDs, grads = train_step(epoch, events_spectral, shape_events)
        opt_state = opt_update(epoch, grads, opt_state)


        losses[epoch] = sEMDs
        


    return sEMDs, shape_events, losses


@jax.jit
def train_step(epoch, spectral_events, shape_events):

    sEMDS = vmap_ds2_events1_spectral2(shape_events, spectral_events)
    grads = vmap_ds2_events1_spectral2_gradients(shape_events, spectral_events)
    return sEMDS, jnp.nan_to_num(grads)


# Test gradient
print(equal_energy_triangles[0])

equal_energy_triangles[0] = np.array([[2/3, 0.,         0.        ],
 [1/3, 1.,         0.,        ],
 [1/3, 1.,         0.,        ]])

normalized = project(equal_energy_triangles)
print(normalized[0])

spectral_equal_energy_triangle = vmap_compute_spectral_representation(equal_energy_triangles[0])
test_event = np.array( ((1/3,0.0,0.),(2/3, 1.0, 0.0)))
sperctral_test_event = vmap_compute_spectral_representation(test_event)

print(vmap_compute_spectral_representation(test_event))
print(spectral_equal_energy_triangle)

print(vmap_ds2_events1_spectral2(test_event, spectral_equal_energy_triangle))
print(ds2_events1_spectral2_gradients(test_event, spectral_equal_energy_triangle))

print(sperctral_test_event[1,1] , (spectral_equal_energy_triangle[2,1] + spectral_equal_energy_triangle[3,1]))
print(np.sum(spectral_equal_energy_triangle[:,1]))



[[0.66666667 0.         0.        ]
 [0.16666667 1.         0.        ]
 [0.16666667 1.         0.        ]]
[[0.5555556  0.         0.        ]
 [0.22222221 1.         0.        ]
 [0.22222221 1.         0.        ]]
[[0.         0.5555556 ]
 [1.         0.44444445]]
[[0.         0.6666667 ]
 [0.         0.22222222]
 [1.         0.44444445]
 [1.         0.44444445]]
1.3333336
[[-1.3333334 -0.8888891  0.       ]
 [-0.6666667  0.8888891  0.       ]]
0.44444445 0.8888889
1.7777779


In [12]:
from pyspecter.SpectralEMD_Helper import cross_term, cross_term_improved

print(cross_term(test_event, test_event))
print(cross_term(spectral_equal_energy_triangle, sperctral_test_event))
print(cross_term_improved(spectral_equal_energy_triangle, sperctral_test_event))

0.44444445
0.44444442
0.44444442
