In [None]:
import numpy as np
rng = np.random.default_rng(12345)

from lymph.models import Unilateral
from lymixture import LymphMixture
from lymixture.utils import map_to_simplex
from fixtures import (
    get_graph,
    get_patient_data,
    SIMPLE_SUBSITE,
)

In [None]:
graph = get_graph(size="medium")
patient_data = get_patient_data()
num_components = 3

mixture = LymphMixture(
    model_cls=Unilateral,
    model_kwargs={"graph_dict": graph},
    num_components=num_components,
)
mixture.load_patient_data(patient_data, split_by=SIMPLE_SUBSITE)
mixture.set_modality("max_llh", spec=1., sens=1.)
mixture.subgroups

In [None]:
mixture.subgroups["C05"].get_all_modalities()

In [None]:
resp_from_cube = rng.uniform(size=(len(patient_data), num_components-1))
resp = np.array([map_to_simplex(line) for line in resp_from_cube])

mixture.set_resps(resp)

In [None]:
mixture.get_resps()

In [None]:
mixture.set_distribution("early", np.linspace(0., 1., 11))
mixture.get_all_distributions()

In [None]:
tmp = rng.uniform(size=(num_components, len(mixture.subgroups)))
tmp /= tmp.sum(axis=0)
mixture.set_mixture_coefs(tmp)

In [None]:
for subgroup in mixture.subgroups.values():
    print(subgroup.get_all_modalities())

In [None]:
mixture.set_distribution("early", np.linspace(0, 10, 11))
mixture.set_distribution("late", np.linspace(10, 0, 11))

In [None]:
for comp in mixture.components:
    print(comp.get_all_distributions())

In [None]:
params_to_set = mixture.get_params()
for param in params_to_set.keys():
    params_to_set[param] = rng.uniform()

mixture.set_params(**params_to_set)
mixture.get_params()

In [None]:
total = 0.
for c, _ in enumerate(mixture.components):
    total += mixture.get_params()[f"{c}_C05_coef"]

total

In [None]:
mixture.get_mixture_coefs().sum(axis=0)

In [None]:
mixture.patient_mixture_likelihoods(log=False, marginalize=True)

In [None]:
mixture.likelihood()

In [None]:
mixture.likelihood(given_resps=mixture.get_resps(norm=True))

In [None]:
mixture.normalize_mixture_coefs()
mixture.repeat_mixture_coefs(log=False).sum(axis=1)

In [None]:
mixture.set_resps(
    mixture.patient_mixture_likelihoods(log=False, marginalize=False)
)
mixture.get_resps(norm=True)