In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

In [None]:
# WARNING: advised to install a specific version, e.g. tensorwaves==0.1.2
%pip install -q tensorwaves[doc,jax,pwa,viz] IPython

```{autolink-concat}
```

# Spin alignment

:::{note}

This page is a continuation of [ampform#213](https://ampform--213.org.readthedocs.build/en/213/usage/helicity/spin-alignment.html).

:::

```{autolink-skip}
```

In [None]:
%config InlineBackend.figure_formats = ['svg']
import logging
import warnings

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

## Create amplitude model

In [None]:
import qrules
from qrules.particle import ParticleCollection, create_particle

PDG = qrules.load_pdg()
particle_db = ParticleCollection()
particle_db.add(PDG["Lambda(c)+"])
particle_db.add(PDG["p"])
particle_db.add(PDG["K-"])
particle_db.add(PDG["pi+"])

particle_db.add(
    create_particle(
        PDG["K*(892)0"],
        name="K*",
        latex="K^*",
        mass=0.9,
        width=0.2,
    )
)
particle_db.add(
    create_particle(
        PDG["Lambda(1405)"],
        name="Lambda*",
        latex=R"\Lambda^*",
        mass=1.6,
        width=0.2,
    )
)
particle_db.add(
    create_particle(
        PDG["Delta(1232)++"],
        name="Delta*++",
        latex=R"\Delta^{*++}",
        mass=1.4,
        width=0.2,
    )
)

In [None]:
reaction = qrules.generate_transitions(
    initial_state=("Lambda(c)+", [-0.5, +0.5]),
    final_state=["p", "K-", "pi+"],
    formalism="helicity",
    particle_db=particle_db,
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True, size=5)
graphviz.Source(dot)

In [None]:
import ampform
from ampform.dynamics.builder import RelativisticBreitWignerBuilder

builder = ampform.get_builder(reaction)
builder.stable_final_state_ids = list(reaction.final_state)
builder.scalar_initial_state_mass = True
bw_builder = RelativisticBreitWignerBuilder()
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, bw_builder)
model = builder.formulate()

In [None]:
import sympy as sp

coefficients = filter(
    lambda p: str(p).startswith("C"), model.parameter_defaults
)
coefficients = sorted(coefficients, key=lambda p: str(p)[::-1])
sp.transpose(sp.Array([coefficients]))

In [None]:
# K*
model.parameter_defaults[coefficients[0]] = 1  # H 1/2,0
model.parameter_defaults[coefficients[1]] = 0.5 + 0.5j  # H 1/2,-1
model.parameter_defaults[coefficients[3]] = 1j  # H -1/2,1
model.parameter_defaults[coefficients[2]] = -0.5 - 0.5j  # H -1/2,0

# Lambda*
model.parameter_defaults[coefficients[4]] = 1j  # H -1/2,0
model.parameter_defaults[coefficients[5]] = 0.8 - 0.4j  # H 1/2,0

# Delta*
model.parameter_defaults[coefficients[6]] = 0.6 - 0.4j  # H -1/2,0
model.parameter_defaults[coefficients[7]] = 0.1j  # H 1/2,0

In [None]:
full_expression = model.expression.doit()
substituted_expression = full_expression.xreplace(model.parameter_defaults)

In [None]:
from IPython.display import Math

latex = sp.multiline_latex(sp.Symbol("I"), model.expression)
Math(latex)

## Generate data

### Phase space sample

In [None]:
from tensorwaves.data import (
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(800_000, rng)

In [None]:
from tensorwaves.data import SympyDataTransformer

helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
phsp = helicity_transformer(phsp_momenta)
phsp = {k: v.real for k, v in phsp.items()}

In [None]:
import pandas as pd

phsp_frame = pd.DataFrame(phsp)
phsp_frame.round(3)

### Compute intensities

In [None]:
from tensorwaves.function.sympy import create_function

intensity_func = create_function(substituted_expression, backend="jax")

In [None]:
import numpy as np

phsp_weights = np.array(intensity_func(phsp).real)
phsp_weights.round(4)

In [None]:
import matplotlib.pyplot as plt

plt.rcParams.update({"font.family": "serif"})
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 5))
kwargs = dict(
    bins=80,
    weights=phsp_weights,
    histtype="step",
    color="red",
)

for x in ax.flatten():
    x.set_yticks([])

ax[0, 0].set_xlabel("$m^2(pK^-)$ [GeV$^2/c^4$]")
ax[0, 1].set_xlabel(R"$m^2(K^-\pi^+)$ [GeV$^2/c^4$]")
ax[0, 2].set_xlabel(R"$m^2(p\pi^-)$ [GeV$^2/c^4$]")
ax[1, 0].set_xlabel(R"$\cos\theta(p)$")
ax[1, 1].set_xlabel(R"$\phi(p)$")
ax[1, 2].set_xlabel(R"$\chi$")

for x, xticks in {
    ax[0, 0]: [2, 2.5, 3, 3.5, 4, 4.5],
    ax[0, 1]: [0.4, 0.6, 0.8, 1, 1.2, 1.4, 1.6, 1.8, 2],
    ax[0, 2]: [1, 1.5, 2, 2.5, 3],
    ax[1, 0]: [-1, -0.5, 0, 0.5, 1],
    ax[1, 1]: [-3, -2, -1, 0, 1, 2, 3],
}.items():
    x.set_xticks(xticks)
    x.set_xticklabels(xticks)


ax[0, 0].hist(np.array(phsp["m_01"] ** 2), **kwargs)
ax[0, 1].hist(np.array(phsp["m_12"] ** 2), **kwargs)
ax[0, 2].hist(np.array(phsp["m_02"] ** 2), **kwargs)
ax[1, 0].hist(np.array(np.cos(phsp["theta_01"])), **kwargs)
ax[1, 1].hist(np.array(phsp["phi_01"]), **kwargs)
ax[1, 2].remove()

ax[0, 2].set_xlim(1, 3.4)
ax[1, 0].set_xlim(-1, +1)
ax[1, 1].set_xlim(-np.pi, +np.pi)

fig.tight_layout()

plt.show()