In [None]:
# Standard Imports
import numpy as np
from time import time
from matplotlib import pyplot as plt

# SHAPER
from pyshaper.CommonObservables import buildCommmonObservables
from pyshaper.Observables import Observable
from pyshaper.Shaper import Shaper

# Utils
from pyshaper.utils.data_utils import load_cmsopendata
from pyshaper.utils.plot_utils import plot_event


# Necessary GPU nonsense
import torch 

if torch.cuda.is_available():  
    dev = "cuda:0" 
    print("Using GPU!")
else:  
    dev = "cpu"  
    print("Using CPU!")
device = torch.device(dev) 

In [None]:
# Parameters
R = 0.5
beta = 1.0
N = 50
pt_lower = 475
pt_upper = 525
eta = 1.9
quality = 2
pad = 125
plot_dir = "results"

# Load data (NOTE: Need the `energyflow` package installed for the default dataset, or provide your own data)
dataset, _ = load_cmsopendata("~/.energyflow/", "cms", pt_lower, pt_upper, eta, quality, n = N)

# Convert to a numpy array of the form (N, pad, 3) where the last dimension is (pt, eta, phi)
temp = np.zeros((len(dataset), pad, 3))
for (i, event) in enumerate(dataset):

    num_particles = len(event[1])
    temp[i, :num_particles, 0] = event[1]
    temp[i, :num_particles, 1] = event[0][:,0]
    temp[i, :num_particles, 2] = event[0][:,1]

dataset = temp

example_event = dataset[0]
plot_event(example_event[:,1:], example_event[:,0], R, color = "red")

In [None]:
# Initialize SHAPER
shaper = Shaper({}, device)
shaper.to(device)

# Pairwise EMDS
pairwise_emds = shaper.pairwise_emds2(dataset, dataset, beta = 1.0, R = R)
