In [1]:
import pickle

import numpy as np

import klnmf

In [2]:
N_RUNS = 20
METHOD = "mu-standard1999"
DATASET = "mnist"
OUTPATH = f"../results/{METHOD}/{DATASET}"

X = np.load(f"../data/{DATASET}.npy")
print(X.shape)
N_TOPICS = 10

(784, 60000)


Each saved history from an NMF run is stored as a dictionary:

* Keys: iteration numbers
* Values: tuples of (objective‑function value, runtime) for each iteration

In [3]:
def get_runtime(history: dict[int, (float, float)]) -> float:
    """
    Get the total runtime until convergence from the saved history.
    """
    n_iter = np.sort(list(history.keys()))[-1]
    runtime = history[n_iter][1]
    return runtime


def benchmark(X: np.ndarray) -> list:
    histories = []

    for n_run in np.arange(N_RUNS):
        model = klnmf.KLNMF(
            n_topics=N_TOPICS,
            update_method=METHOD,
        )
        model.fit(X, seed=n_run, verbose=1)
        runtime = get_runtime(model.history)
        print(f"run {n_run} took {runtime:.3f}")
        histories.append(model.history)

    return histories

In [6]:
# JIT compile
model = klnmf.KLNMF(n_topics=2, update_method=METHOD, max_iterations=10)
model.fit(X)

In [10]:
# actual benchmark
histories = benchmark(X)

iteration: 1000; objective: 863824545.60
iteration: 2000; objective: 860853750.59
run 0 took 1469.084
iteration: 1000; objective: 863517940.61
run 1 took 1083.456
run 2 took 398.008
iteration: 1000; objective: 860940494.73
run 3 took 1038.672
run 4 took 598.017
run 5 took 683.997
run 6 took 641.316
iteration: 1000; objective: 861226541.61
run 7 took 906.717
run 8 took 445.267
iteration: 1000; objective: 868757473.73
iteration: 2000; objective: 861114887.41
run 9 took 1893.781
iteration: 1000; objective: 862285045.24
iteration: 2000; objective: 859295614.80
run 10 took 1417.633
iteration: 1000; objective: 868559167.25
run 11 took 707.388
run 12 took 571.312
iteration: 1000; objective: 861648660.50
run 13 took 1292.275
iteration: 1000; objective: 868346398.86
iteration: 2000; objective: 865798220.84
run 14 took 1656.396
run 15 took 575.113
run 16 took 693.856
run 17 took 350.952
iteration: 1000; objective: 860243397.39
run 18 took 966.381
iteration: 1000; objective: 863389310.07
run 19 t

In [11]:
with open(f"{OUTPATH}/histories.pkl", "wb") as handle:
    pickle.dump(histories, handle, protocol=pickle.HIGHEST_PROTOCOL)