# Inference Using the Expectation-Maximization Algorithm

In this notebook we demonstrate how to train the mixture lymphatic progression models. We do this for a simple set of synthetic data and see if and how well we can recover the original parameters that we set.

## Imports

In [None]:
from collections import namedtuple
from typing import Literal, Any, Callable

import numpy as np
import pandas as pd

from lymixture import LymphMixture
from lymixture import utils
from lymph.models import Unilateral

rng = np.random.default_rng(42)

## Synthetic Data

Define parameters and configuration to draw a number of synthetic data.

In [None]:
Modality = namedtuple("Modality", ["spec", "sens"])

# definition of the directed acyclic graph
GRAPH_DICT = {
    ("tumor", "T"): ["II", "III"],
    ("lnl", "II"): ["III"],
    ("lnl", "III"): [],
}
# definition of the diagnostic modality
MODALITIES = {
    "path": Modality(spec=0.9, sens=0.9),
}
# assumed distributions over the time to diagnosis
DISTRIBUTIONS = {
    "early": utils.binom_pmf(k=np.arange(11), n=10, p=0.3),
    "late": utils.late_binomial,
}

# params of component 1
PARAMS_C1 = {
    "TtoII_spread": 0.05,
    "TtoIII_spread": 0.25,
    "IItoIII_spread": 0.5,
    "late_p": 0.5,
}
# params of component 2
PARAMS_C2 = {
    "TtoII_spread": 0.25,
    "TtoIII_spread": 0.05,
    "IItoIII_spread": 0.1,
    "late_p": 0.5,
}
SUBSITE_COL = ("tumor", "1", "subsite")

In [None]:
ModalityDict = dict[str, dict[str, float | Literal["clinical", "pathological"]]]


def create_model(
    model_kwargs: dict[str, Any] | None = None,
    modalities: ModalityDict | None = None,
    distributions: dict[str, list[float] | Callable] | None = None,
) -> Unilateral:
    """Create a model to draw patients from."""
    model = Unilateral(**(model_kwargs or {"graph_dict": GRAPH_DICT}))

    for name, modality in (modalities or MODALITIES).items():
        model.set_modality(name, modality.spec, modality.sens)

    for t_stage, dist in (distributions or DISTRIBUTIONS).items():
        model.set_distribution(t_stage, dist)

    return model

In [None]:
def draw_datasets(
    model: Unilateral,
    num_c1: int,
    num_c2: int,
    num_c3: int,
    tstage_ratio: float,
    mix: float,
    rng: np.random.Generator,
) -> pd.DataFrame:
    """Draw patients for the three datasets."""
    model.set_params(**PARAMS_C1)
    c1_data = model.draw_patients(
        num=num_c1 + int(num_c3 * mix),
        stage_dist=[tstage_ratio, 1 - tstage_ratio],
        rng=rng,
    )
    model.set_params(**PARAMS_C2)
    c2_data = model.draw_patients(
        num=num_c2 + int(num_c3 * (1 - mix)),
        stage_dist=[tstage_ratio, 1 - tstage_ratio],
        rng=rng,
    )
    c3_data = pd.concat(
        [
            c1_data.iloc[num_c1:],
            c2_data.iloc[num_c2:],
        ],
        ignore_index=True,
        axis=0,
    )
    c1_data = c1_data.iloc[:num_c1]
    c2_data = c2_data.iloc[:num_c2]

    c1_data[SUBSITE_COL] = "c1"
    c2_data[SUBSITE_COL] = "c2"
    c3_data[SUBSITE_COL] = "c3"

    return pd.concat([c1_data, c2_data, c3_data], ignore_index=True, axis=0)

In [None]:
model = create_model()
synthetic_data = draw_datasets(
    model=model,
    num_c1=1000,
    num_c2=1000,
    num_c3=1000,
    tstage_ratio=0.4,
    mix=0.5,
    rng=rng,
)

random_idx = rng.choice(synthetic_data.index, size=6, replace=False)
synthetic_data.iloc[random_idx]

## Model Initialization

Now, we define the mixture model and load the just drawn data. Note that we use only two components, hoping that the `"c3"` subgroup can be described as a mixture of these two components.

In [None]:
graph = {
    ("tumor", "T"): ["II", "III"],
    ("lnl", "II"): ["III"],
    ("lnl", "III"): [],
}
num_components = 2

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
    universal_p=False,
)
mixture.load_patient_data(
    synthetic_data,
    split_by=("tumor", "1", "subsite"),
    mapping=lambda x: x,
)

Set the diagnostic modality to be the same as in the generated dataset.

In [None]:
for name, modality in MODALITIES.items():
    mixture.set_modality(name=name, spec=modality.spec, sens=modality.sens)

Fix the distribution over diagnosis times. Again, we set this to be the same as during the synthetic data generation.

In [None]:
for t_stage, dist in DISTRIBUTIONS.items():
    mixture.set_distribution(t_stage, dist)

## Inference

In [None]:
from lymixture.em import expectation, maximization

The iterative steps of computing the expectation over the latent variables (E-step) and maximizing the model parameters (M-step) can be initialized with an arbitrary set of starting parameters.

In [None]:
params = {k: rng.uniform() for k in mixture.get_params()}
mixture.set_params(**params)
mixture.normalize_mixture_coefs()
log_resps = utils.normalize(rng.uniform(size=mixture.get_resps().shape).T, axis=0).T

Then we define a function to check the convergence of the algorithm.

In [None]:
def is_converged(
    history: list[dict[str, float]],
    rtol: float = 1e-4,
) -> bool:
    """Check if the EM algorithm has converged."""
    if len(history) < 2:
        return False

    old, new = history[-2]["llh"], history[-1]["llh"]
    return np.isclose(old, new, rtol=rtol)

Finally, we can iterate the computation of the expectation value of the latent variables (E-step) and the maximization of the (complete) data log-likelihood w.r.t. the model parameters (M-step).

While the algorithm converges, we check the incomplete data likelihood after each round.

In [None]:
count = 0
snapshot = {
    "llh": mixture.incomplete_data_likelihood(),
    **mixture.get_params(as_dict=True, as_flat=True),
}
history = [snapshot]

while not is_converged(history, rtol=1e-4):
    print(f"iteration {count:>3d}: {history[-1]['llh']:.3f}")
    count += 1

    log_resps = expectation(mixture, params, log=True)
    mixture.set_resps(np.exp(log_resps))
    assert np.allclose(np.exp(np.logaddexp.reduce(log_resps, axis=1)), 1.0)
    params = maximization(mixture, log_resps)

    snapshot = {
        "llh": mixture.incomplete_data_likelihood(),
        **mixture.get_params(as_dict=True, as_flat=True),
    }
    history.append(snapshot)

## Results

After convergence, we can have a look at the likelihood and the parameters during the iterations. Ideally, the likelihood increases strictly monotonically.

In [None]:
history_df = pd.DataFrame(history)
history_df.plot(
    y=["llh", "0_TtoII_spread", "1_TtoII_spread"],
    subplots=[("llh",), ("0_TtoII_spread", "1_TtoII_spread")],
    sharex=True,
    xlim=(0, None),
);

And, more importantly, let's also see if the learned parameters reproduce what we put into the model.

In [None]:
fixed_params = {}
fixed_params.update({f"0_{name}": value for name, value in PARAMS_C1.items()})
fixed_params.update({f"1_{name}": value for name, value in PARAMS_C2.items()})

learned_params = mixture.get_params(as_dict=True, as_flat=True)

for name, fixed in fixed_params.items():
    learned = learned_params[name]
    print(f"{name:>16s}: {fixed = :.3f}, {learned = :.3f}")