In [8]:
import os
from pathlib import Path
import pickle
from statkit.non_parametric import bootstrap_score

os.environ['ARTEFACT_DIR'] = '/home/hylke/digits_more_topics'


import jax
import jax.numpy as jnp
from jax.tree_util import tree_map
from mubelnet.random import PRNGSequence
from mubelnet.utils import perplexity
import numpy as np
from sklearn.decomposition import NMF

from dataset import load_digits


# Training hyperparameters
RANDOM_SEED = 42
key_seq = PRNGSequence(jax.random.PRNGKey(RANDOM_SEED))


# Model hyperparameters.
MODEL = "multinomial_belief"
n_topics = 10
HIDDEN_LAYER_SIZES = (n_topics, n_topics, n_topics)
GAMMA_0 = 1.0
ETA = 0.05
_bottom_layer_name = (
    f"{MODEL}/~/multinomial_layer"
    if MODEL == "multinomial_belief"
    else f"{MODEL}/~/poisson_layer"
)

X_train, X_test = load_digits()
n_features = X_train.shape[1]

def probability(params, state):
    bottom_params = params.get(_bottom_layer_name, {})
    bottom_state = state[_bottom_layer_name]
    phi = bottom_params.get("phi", bottom_state.get("phi"))
    theta = bottom_state["copy[theta(1)]"]
    probs = theta @ phi
    probs /= probs.sum(axis=-1, keepdims=True)
    return probs


def evaluate(params, states, X, axis=[0, 1]):
    """Compute perplexity over chains and samples by default (axis=[0, 1])."""
    probs = probability(params, states).mean(axis)
    return perplexity(X, probs)

In [2]:
def load_trace(pkl_path: Path, thin: int = 1):
    with open(pkl_path, 'rb') as fi:
        trace = pickle.load(fi)
    trace_thinned = tree_map(lambda x: x[:, ::thin], trace)
    del trace
    trace = trace_thinned
    return trace

In [3]:
trace_10_20_30 = load_trace(Path('/home/hylke/digits_more_topics/digits_more_topics/multinomial_belief/10-20-30/samples/sample_1000.pkl'))
probs = probability({}, trace_10_20_30).mean(axis=[0, 1])
pp_estimate_1 = bootstrap_score(X_test, probs, metric=perplexity, random_state=42)
print("Perplexity on unseen observations", pp_estimate_1)
print(pp_estimate_1.latex())
del trace_10_20_30

Perplexity on unseen observations 3.07e+01 (95 % CI: 3.05e+01-3.08e+01)
3.07$^{+0.01}_{-0.01} \cdot 10^{1}$


In [4]:
trace_20_30 = load_trace(Path('/home/hylke/digits_more_topics/digits_more_topics/multinomial_belief/20-30/samples/sample_1000.pkl'))
probs = probability({}, trace_20_30).mean(axis=[0, 1])
pp_estimate_2 = bootstrap_score(X_test, probs, metric=perplexity, random_state=42)
print("Perplexity on unseen observations", pp_estimate_2)
print(pp_estimate_2.latex())
del trace_20_30

Perplexity on unseen observations 3.07e+01 (95 % CI: 3.06e+01-3.08e+01)
3.07$^{+0.01}_{-0.01} \cdot 10^{1}$


In [5]:
trace_30 = load_trace(Path('/home/hylke/digits_more_topics/digits_more_topics/multinomial_belief/30/samples/sample_1000.pkl'))
probs = probability({}, trace_30).mean(axis=[0, 1])
pp_estimate_3 = bootstrap_score(X_test, probs, metric=perplexity, random_state=42)
print("Perplexity on unseen observations", pp_estimate_3)
print(pp_estimate_3.latex())
del trace_30

Perplexity on unseen observations 3.10e+01 (95 % CI: 3.08e+01-3.11e+01)
3.10$^{+0.01}_{-0.01} \cdot 10^{1}$


In [6]:
nmf = NMF(
    n_components=30, random_state=42, solver='mu', beta_loss="kullback-leibler", max_iter=10_000
).fit(X_train)
h = nmf.transform(X_train)
unnormed_probs = h @ nmf.components_
probs_nmf = unnormed_probs / unnormed_probs.sum(axis=-1, keepdims=True)

In [9]:
# Compute performance on test set.
ln_probs_nmf = jnp.log(probs_nmf)
is_inf = np.isneginf(ln_probs_nmf) & (X_test > 0)

# 2) Delete zero rows.
zero_rows = np.where(is_inf)[0]
print('Deleting rows', zero_rows)
X_test_subset = np.delete(X_test, zero_rows, axis=0)
probs_sigprof_xtrct_subset = np.delete(probs_nmf, zero_rows, axis=0)

pp_nmf = bootstrap_score(X_test_subset, probs_sigprof_xtrct_subset, metric=perplexity, random_state=43)
print('Perplexity on unseen observations (removed samples causing infinite perplexity)', pp_nmf)
print(pp_nmf.latex())


Deleting rows [  87  502 1264 1685]
Perplexity on unseen observations (removed samples causing infinite perplexity) 3.42e+01 (95 % CI: 3.38e+01-3.45e+01)
3.42$^{+0.03}_{-0.03} \cdot 10^{1}$
