In [1]:
import random
import sys
import madjax
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax.nn import sigmoid
from jax import custom_jvp

from tqdm import tqdm


from madjax.phasespace.new_flat_phase_space_generator import *

I first run MadGraph with the commands
```
generate mu- > e- ve~ vm
output madjax mu_decay
set auto_update 0
!rm py.py
```

with nonzero masses for the electron and muon.

There is onmly 1 matrix element: `Matrix_1_mum_emvexvm`



# Helper functions

Mostly taken from https://github.com/madjax-hep/madjax/blob/1ea61c8ee3e39360c386319b10e9cd92c8e94c19/examples/KinFit_tt_bqq_bqq.ipynb#L36

I don't know why this `SigmoidStraightThrough` is necessary for the phase space generator...

In [2]:
@custom_jvp
def SigmoidStraightThrough(x):
    return sigmoid(x)

SigmoidStraightThrough.defjvps(lambda t, ans, x: t)

In [3]:
mj = madjax.MadJax("mu_decay")

In [4]:
#Setup Process

process_name = "Matrix_1_mum_emvexvm"
process = mj.processes[process_name]()

E_mu = 1.056600e-01 # muon mass in GeV
E_e = 5.110000e-04  # electron mass in GeV

# generating the phase space params
external_parameters = {}
parameters = mj.parameters.calculate_full_parameters(external_parameters)
external_masses = process.get_external_masses(parameters)


PS_inputs = generate_phase_space_inputs( [E_mu], # initial_masses
                            [E_e ,0,0], # final_masses
                             [E_mu], # beam_Es
                                        beam_types= (0, 0) # I have no idea what the beam type means. I think 1 means we need a pdf
                           )

print(PS_inputs)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


{'initial_masses': [0.10566], 'masses': [0.000511, 0, 0], 'n_initial': 1, 'n_final': 3, 'beam_Es': [0.10566], 'collider_energy': 0.10566, 'beam_types': (0, 0), 'is_beam_factorization_active': (False, False), 'correlated_beam_convolution': False, 'dim_ordered_names': ['x_1', 'x_2', 'x_3', 'x_4', 'x_5'], 'dim_name_to_position': {'x_1': 0, 'x_2': 1, 'x_3': 2, 'x_4': 3, 'x_5': 4}, 'position_to_dim_name': {0: 'x_1', 1: 'x_2', 2: 'x_3', 3: 'x_4', 4: 'x_5'}}


In [5]:
num_PS_points = 100

MEs, weights = [], []

for i in tqdm(range(num_PS_points)):
    random_variables = jnp.array([np.random.standard_normal() for _ in range(nDimPhaseSpace(3))])
    random_variables = SigmoidStraightThrough(random_variables)

    # generate a phase space point at the correct center of massE
    PS_point, jacobian = generateKinematics(PS_inputs, E_mu, random_variables)

    # generate the matrix element in-place
    def matrix_element(process, PS_point, parameters, return_grad=True, do_jit=True):
        def func(external_parameters, random_variables):
            return process.smatrix(PS_point, parameters)
    
        if return_grad:
            return jax.jit(jax.value_and_grad(func)) if do_jit else jax.value_and_grad(func)
        else:
            return jax.jit(func) if do_jit else func

    
    ME_func = matrix_element(process, PS_point, parameters, return_grad = False)
    ME_val = ME_func(external_parameters, random_variables)

    MEs.append(ME_val)
    weights.append(jacobian)

   


100%|██████████| 100/100 [02:09<00:00,  1.29s/it]


In [8]:
averaged_Me = np.mean(np.array(MEs)*np.array(weights))


decay_rate = (1.0/(2.0*E_mu))*averaged_Me
print(f"muon decay rate: {decay_rate}")
lifetime =  6.58e-25/decay_rate
print(f"muon lifetime: {lifetime}")

muon decay rate: 3.2219172992177464e-15
muon lifetime: 2.0422622274003019e-10
