# Testing the full EM algorithm

This represents a simple test with synthetic data to see if the EM-algorithm with the lymphatic progression model works as intended.

As always, we start with some imports

In [None]:
import numpy as np
import pandas as pd

rng = np.random.default_rng(12345)

from lymixture import LymphMixture
from lymixture.utils import binom_pmf, late_binomial, normalize
from lymph.models import Unilateral

## Synthetic Data

The following parameters were used to generate a synthetic dataset of 3000 patients. one third used the `PARAMS_C1`, another third the `PARAMS_C2` and the last third represents a 50/50 mix of the two.

```json
PARAMS_C1 = {
    "TtoII_spread": 0.5,
    "TtoIII_spread": 0.25,
    "TtoIV_spread": 0.1,
    "IItoIII_spread": 0.4,
    "IIItoIV_spread": 0.3,
    "late_p": 0.5,
}
PARAMS_C2 = {
    "TtoII_spread": 0.65,
    "TtoIII_spread": 0.15,
    "TtoIV_spread": 0.05,
    "IItoIII_spread": 0.5,
    "IIItoIV_spread": 0.4,
    "late_p": 0.5,
}
```

Below we load the synthetic dataset generated witht these parameters.

In [None]:
data = pd.read_csv("data/mixture.csv", header=[0, 1, 2])
data.head()

In [None]:
data.shape

## Model Initialization

In [None]:
graph = {
    ("tumor", "T"): ["II", "III"],
    ("lnl", "II"): ["III"],
    ("lnl", "III"): [],
}
num_components = 2

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
    universal_p=False,
)
mixture.load_patient_data(
    data,
    split_by=("tumor", "1", "subsite"),
    mapping=lambda x: x,
)

Set the diagnostic modality to be the same as in the generated dataset.

In [None]:
mixture.set_modality("path", 1.0, 1.0)
# mixture.set_modality("diagnose", 1., 0.81 )
mixture.get_all_modalities()

Fix the distribution over diagnosis times for early T-stage (T1 & T2) to be a binomial distribution with a parameters $p=0.3$.

The late T-stage's diagnosis time distribution is a binomial one with a free model parameter than needs to be learned as well.

In [None]:
mixture.set_distribution("early", binom_pmf(np.arange(11), 10, 0.3))
mixture.set_distribution("late", late_binomial)
mixture.get_all_distributions()

## The EM-Algorithm

Here, we initialize random model parameters and latent variables/responsibilities.

In [None]:
from lymixture.em import expectation, maximization

params = {k: rng.uniform() for k in mixture.get_params()}
mixture.set_params(**params)
mixture.normalize_mixture_coefs()
latent = normalize(rng.uniform(size=mixture.get_resps().shape).T, axis=0).T

Then we define some helper functions, as well as a function to check the convergence of the algorithm.

In [None]:
def to_numpy(params: dict[str, float]) -> np.ndarray:
    return np.array([p for p in params.values()])

def is_converged(
    history: list[dict[str, float]],
    rtol: float = 1e-4,
) -> bool:
    if len(history) < 2:
        return False

    old, new = to_numpy(history[-2]), to_numpy(history[-1])
    return np.allclose(old, new, rtol=rtol)

Iterate the computation of the expectation value of the latent variables (E-step) and the maximization of the (complete) data log-likelihood w.r.t. the model parameters (M-step).

In [None]:
count = 0
snapshot = {
    "llh": mixture.incomplete_data_likelihood(),
    **mixture.get_params(as_dict=True, as_flat=True),
}
history = [snapshot]

while not is_converged(history, rtol=1e-2):
    print(f"iteration {count:>3d}: {history[-1]['llh']:.3f}")
    count += 1

    latent = expectation(mixture, params)
    assert np.allclose(latent.sum(axis=1), 1.)
    params = maximization(mixture, latent)

    snapshot = {
        "llh": mixture.incomplete_data_likelihood(),
        **mixture.get_params(as_dict=True, as_flat=True),
    }
    history.append(snapshot)

## Results

Let's have a look at the convergence and the parameters

In [None]:
history_df = pd.DataFrame(history)
history_df.plot(
    y=["llh", "0_TtoII_spread", "1_TtoII_spread"],
    subplots=[("llh",), ("0_TtoII_spread", "1_TtoII_spread")],
    sharex=True,
    xlim=(0, None),
);

In [None]:
mixture.get_params(as_dict=True, as_flat=True)

## Sample model parameters

In [None]:
from lymixture.em import complete_samples, sample_model_params

samples = sample_model_params(mixture, steps=20)
indices = np.random.choice(len(samples), 50, replace=False)
reduced_set = samples[indices]
complete_samples = complete_samples(mixture, reduced_set)