# EBMT dataset: prediction task

In [1]:
import jax
jax.config.update('jax_platform_name', 'cpu')

import collections
import csv
import matplotlib.pyplot as plt
import numpy as np

from mixmarkov import CTMC, GamMixCTMC, FiniteMixCTMC, summarize_sequences
from mixmarkov.utils import draw_chain

## Loading & processing the data

In [2]:
rng = np.random.default_rng(seed=0)

vecs_year = rng.normal(size=(3, 5))
year_map = {
    "1985-1989": vecs_year[0],
    "1990-1994": vecs_year[1],
    "1995-1998": vecs_year[2],
}

vecs_agecl = rng.normal(size=(3, 5))
agecl_map = {
    "<=20": vecs_agecl[0],
    "20-40": vecs_agecl[1],
    ">40": vecs_agecl[2],
}

In [3]:
seqs = list()
feats = list()

with open("../data/ebmt.dat") as f:
    #next(f)  # First line is header.
    cur = None
    for row in csv.DictReader(f, delimiter=" "):
        #idx, src, dst, _, start, stop, _, status, match, proph, year, agecls = line.strip().split(" ")
        src = int(row["from"]) - 1
        dst = int(row["to"]) - 1
        if row["id"] != cur:
            if cur is not None:
                if seq[-1][1] < max_stop:
                    seq.append((seq[-1][0], max_stop))
                seqs.append(seq)
            seq = list()
            cur = row["id"]
            seq.append((src, float(row["Tstart"])))
            feats.append(np.concatenate((
                #match_map[row["match"]],
                #proph_map[row["proph"]],
                year_map[row["year"]],
                agecl_map[row["agecl"]],
                rng.normal(size=(5,))
            )))
        if row["status"] == "1":
            seq.append((dst, float(row["Tstop"])))
        max_stop = float(row["Tstop"])
    seqs.append(seq)

# Indices of sequences that are > 180 days.
idx = [seq[-1][1] > 1800 for seq in seqs]

seqs = np.array(seqs, dtype=object)[idx]
feats = np.array(feats, dtype=float)[idx]

In [4]:
m = len(seqs)
n = 6
t0 = 60
tf = 1800

rng = np.random.default_rng(seed=1)
idx = rng.permutation(len(seqs))

seqs = seqs[idx]
xs = feats[idx]
ks, ts = summarize_sequences(seqs, n)
(ks_offset, ts_offset), _ = summarize_sequences(seqs, n, split=t0)

mask = ks.sum(axis=0).astype(bool)

states_t0 = np.zeros((m, n))
states_tf = np.zeros((m, n))

for i, seq in enumerate(seqs):
    prv_state = None
    for state, t in seq:
        if t > t0 and states_t0[i].sum() == 0.0:
            states_t0[i, prv_state] = 1.0
        if t > tf and states_tf[i].sum() == 0.0:
            states_tf[i, prv_state] = 1.0
        prv_state = state

n_splits = 10
zs = np.linspace(0, len(seqs), num=(n_splits + 1), dtype=int)

In [5]:
def log_loss(pred, actual):
    """Compute log-loss."""
    return np.sum(np.clip(-np.log(pred), None, 1e6) * actual)

In [6]:
res = dict()

## Markov chain

In [7]:
%%time
model1 = CTMC(mask)
tot = 0.0
tmp = list()

for i, (z1, z2) in enumerate(zip(zs[:-1], zs[1:])):
    ks_train = np.concatenate((ks[:z1], ks[z2:]))
    ts_train = np.concatenate((ts[:z1], ts[z2:]))
    xs_train = np.concatenate((xs[:z1], xs[z2:]))
    # CTMC
    model1.fit(ks_train, ts_train, xs=xs_train, l2=3.0, verbose=False)
    pred = model1.predict(states_t0[z1:z2], xs=xs[z1:z2], t=(tf - t0))
    tot += log_loss(pred, states_tf[z1:z2])
    tmp.append(tot)

  return np.sum(np.clip(-np.log(pred), None, 1e6) * actual)


CPU times: user 2min 26s, sys: 19.5 s, total: 2min 46s
Wall time: 18.2 s


In [8]:
res["ctmc"] = tot / m
print(res["ctmc"])

1.0431594319636326


## Infinite mixture

In [9]:
%%time
model2 = GamMixCTMC(mask)
tot_no = 0.0
tot_wo = 0.0

for i, (z1, z2) in enumerate(zip(zs[:-1], zs[1:])):
    ks_train = np.concatenate((ks[:z1], ks[z2:]))
    ts_train = np.concatenate((ts[:z1], ts[z2:]))
    xs_train = np.concatenate((xs[:z1], xs[z2:]))
    model2.fit(ks_train, ts_train, xs=xs_train, l2=3.0, verbose=False)
    # No offsets.
    pred = model2.predict(
        states_t0[z1:z2],
        xs=xs[z1:z2],
        t=(tf - t0),
        n_samples=100,
    ).block_until_ready()
    tot_no += log_loss(pred, states_tf[z1:z2])
    # With offsets.
    pred = model2.predict(
        states_t0[z1:z2],
        xs=xs[z1:z2],
        t=(tf - t0),
        n_samples=100,
        offset=(ks_offset[z1:z2], ts_offset[z1:z2]),
    ).block_until_ready()
    tot_wo += log_loss(pred, states_tf[z1:z2])
    print(".", end="", flush=True)
print()

  return np.sum(np.clip(-np.log(pred), None, 1e6) * actual)


..........
CPU times: user 6min 25s, sys: 1min 49s, total: 8min 14s
Wall time: 1min 11s


In [10]:
res["gammix-no"] = tot_no / m
res["gammix-wo"] = tot_wo / m
print(res["gammix-no"])
print(res["gammix-wo"])

0.9072854221321921
0.6166308245249096


## Finite mixture

In [11]:
%%time
model3 = FiniteMixCTMC(mask, n_comps=5)
tot_no = 0.0
tot_wo = 0.0

for i, (z1, z2) in enumerate(zip(zs[:-1], zs[1:])):
    ks_train = np.concatenate((ks[:z1], ks[z2:]))
    ts_train = np.concatenate((ts[:z1], ts[z2:]))
    xs_train = np.concatenate((xs[:z1], xs[z2:]))
    with np.errstate(divide="ignore"):
        model3.fit(ks_train, ts_train, xs=xs_train, l2=1.5, seed=0, verbose=False)
        # No offsets.
        pred = model3.predict(
            states_t0[z1:z2],
            xs=xs[z1:z2],
            t=(tf - t0),
        ).block_until_ready()
        tot_no += log_loss(pred, states_tf[z1:z2])
        # With offsets.
        pred = model3.predict(
            states_t0[z1:z2],
            xs=xs[z1:z2],
            t=(tf - t0),
            offset=(ks_offset[z1:z2], ts_offset[z1:z2]),
        ).block_until_ready()
        tot_wo += log_loss(pred, states_tf[z1:z2])
    print(".", end="", flush=True)
print()

..........
CPU times: user 1h 58min 50s, sys: 2h 56min 34s, total: 4h 55min 25s
Wall time: 9min 20s


In [12]:
res["finmix-no"] = tot_no / m
res["finmix-wo"] = tot_wo / m
print(res["finmix-no"])
print(res["finmix-wo"])

1.6047254040055374
0.7864194385405932


## Summarizing the results

In [13]:
print("|-----------|-----------|-----------|")
print("| Model     | no offset | w/ offset |")
print("|-----------|-----------|-----------|")
print("| DTMC      |     {:.3f} |     {:.3f} |".format(res["ctmc"], res["ctmc"]))
print("| Inf. mix. |     {:.3f} |     {:.3f} |".format(res["gammix-no"], res["gammix-wo"]))
print("| Fin. mix. |     {:.3f} |     {:.3f} |".format(res["finmix-no"], res["finmix-wo"]))
print("|-----------|-----------|-----------|")

|-----------|-----------|-----------|
| Model     | no offset | w/ offset |
|-----------|-----------|-----------|
| DTMC      |     1.043 |     1.043 |
| Inf. mix. |     0.907 |     0.617 |
| Fin. mix. |     1.605 |     0.786 |
|-----------|-----------|-----------|
