In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

In [None]:
# WARNING: advised to install a specific version, e.g. tensorwaves==0.1.2
%pip install -q tensorwaves[doc,jax,pwa,viz] IPython

```{autolink-concat}
```

# Spin alignment

:::{note}

This page is a continuation of [ampform#213](https://ampform--213.org.readthedocs.build/en/213/usage/helicity/spin-alignment.html).

:::

## Create amplitude model

```{autolink-skip}
```

In [None]:
%config InlineBackend.figure_formats = ['svg']
import logging
import warnings

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

### Generate transitions

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("Lambda(c)+", [-0.5, +0.5]),
    final_state=["p", "K-", "pi+"],
    allowed_intermediate_particles=[
        "Lambda(1405)",
        "Delta(1232)++",
        "K*(1410)0",
    ],
    formalism="canonical-helicity",
    allowed_interaction_types=["strong", "EM", "weak"],
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True, size=5)
graphviz.Source(dot)

In [None]:
from ampform.helicity.decay import TwoBodyDecay

printed_decays = set()
for transition in reaction.transitions:
    decay = TwoBodyDecay.from_transition(transition, node_id=1)
    decay_products = " ".join(map(lambda c: c.particle.name, decay.children))
    description = f"{decay.parent.particle.latex} → {decay_products}"
    if description in printed_decays:
        continue
    printed_decays.add(description)
    parent_mass = decay.parent.particle.mass
    decay_product_masses = sum(map(lambda c: c.particle.mass, decay.children))
    if parent_mass < decay_product_masses:
        print(
            f"Decay {description} lies below threshold. Parent mass:"
            f" {parent_mass} < {decay_product_masses}"
        )

### Formulate amplitude model

In [None]:
import ampform
from ampform.dynamics import PhaseSpaceFactorComplex
from ampform.dynamics.builder import RelativisticBreitWignerBuilder

builder = ampform.get_builder(reaction)
builder.stable_final_state_ids = list(reaction.final_state)
builder.scalar_initial_state_mass = True
add_dynamics = True
use_form_factor = False
if add_dynamics:
    no_ff_bw_builder = RelativisticBreitWignerBuilder()
    complex_bw_builder = RelativisticBreitWignerBuilder(
        form_factor=True, phsp_factor=PhaseSpaceFactorComplex
    )
    bw_builder = RelativisticBreitWignerBuilder(form_factor=True)
    for name in reaction.get_intermediate_particles().names:
        if use_form_factor:
            if name.startswith("K"):
                builder.set_dynamics(name, complex_bw_builder)
            else:
                builder.set_dynamics(name, bw_builder)
        else:
            builder.set_dynamics(name, no_ff_bw_builder)
model = builder.formulate()

In [None]:
full_expression = model.expression.doit()

In [None]:
parameters = model.parameter_defaults
c_par = list(parameters)[0]
parameters[c_par] = 1

In [None]:
substituted_expression = full_expression.xreplace(parameters)
substituted_expression = substituted_expression.xreplace({1.0: 1})

In [None]:
import sympy as sp

dot = sp.dotprint(model.expression, maxdepth=3, size=6, bgcolor="none")
graphviz.Source(dot)

## Generate data

### Phase space sample

In [None]:
from tensorwaves.data import (
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(800_000, rng)

In [None]:
from tensorwaves.data import SympyDataTransformer

helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
phsp = helicity_transformer(phsp_momenta)
phsp = {k: v.real for k, v in phsp.items()}

In [None]:
import pandas as pd

phsp_frame = pd.DataFrame(phsp)
phsp_frame

### Compute intensities

In [None]:
from tensorwaves.function.sympy import create_function

fixed_intensity = create_function(substituted_expression, backend="jax")

In [None]:
import numpy as np

phsp_weights = np.array(fixed_intensity(phsp).real)
phsp_weights

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 5))
kwargs = dict(
    alpha=0.7,
    bins=80,
    weights=phsp_weights,
)

for x in ax.flatten():
    x.set_yticks([])

ax[0, 0].set_xlabel("$m^2(pK^-)$")
ax[0, 1].set_xlabel(R"$m^2(K^-\pi^+)$")
ax[0, 2].set_xlabel(R"$m^2(p\pi^-)$")
ax[1, 0].set_xlabel(R"$\cos\theta(p)$")
ax[1, 1].set_xlabel(R"$\phi(p)$")
ax[1, 2].set_xlabel(R"$\chi$")

ax[0, 0].hist(np.array(phsp["m_01"] ** 2), **kwargs)
ax[0, 1].hist(np.array(phsp["m_12"] ** 2), **kwargs)
ax[0, 2].hist(np.array(phsp["m_02"] ** 2), **kwargs)
ax[1, 0].hist(np.array(np.cos(phsp["theta_01"])), **kwargs)
ax[1, 1].hist(np.array(phsp["phi_01"]), **kwargs)
ax[1, 2].remove()

fig.tight_layout()

plt.show()