diff --git a/workflows/benchmark.smk b/workflows/benchmark.smk index 391b5e1..8735bff 100644 --- a/workflows/benchmark.smk +++ b/workflows/benchmark.smk @@ -10,7 +10,7 @@ import matplotlib.transforms as mtransforms matplotlib.use("Agg") import numpy as np - +from numpyro.diagnostics import summary import labelshift.algorithms.api as algo import labelshift.experiments.api as exp @@ -22,7 +22,7 @@ ESTIMATORS = { "BBS": algo.BlackBoxShiftEstimator(), "CC": algo.ClassifyAndCount(), "RIR": algo.InvariantRatioEstimator(restricted=True), - "BAY": algo.DiscreteCategoricalMeanEstimator(), + "BAY": algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)), } ESTIMATOR_COLORS = { "BBS": "orangered", @@ -146,6 +146,7 @@ _k_vals = [2, 3, 5, 7, 9] _quality = [0.55, 0.65, 0.75, 0.85, 0.95] _quality_prime = [0.45, 0.55, 0.65, 0.75, 0.80, 0.85, 0.90, 0.95] + BENCHMARKS = { "change_prevalence": BenchmarkSettings( param_name="Prevalence $\\pi'_1$", @@ -179,6 +180,7 @@ BENCHMARKS = { ), } + def get_data_setting(benchmark: str, param: int | str) -> DataSetting: return BENCHMARKS[str(benchmark)].settings[int(param)] @@ -234,6 +236,14 @@ rule apply_estimator: elapsed_time = timer.check() run_ok = True additional_info = {} + + if hasattr(estimator, "get_mcmc"): + samples = estimator.get_mcmc().get_samples(group_by_chain=True) + summ = summary(samples) + n_eff_list = [np.min(d["n_eff"]) for d in summ.values()] + r_hat_list = [np.max(d["r_hat"]) for d in summ.values()] + additional_info = additional_info | {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)} + except Exception as e: elapsed_time = float("nan") estimate = np.full_like(data.n_y_labeled, fill_value=float("nan")) @@ -267,9 +277,13 @@ def _get_paths_to_be_assembled(wildcards): rule assemble_results: output: csv = "results/benchmark-{benchmark}-metric-{metric}.csv", - err = "results/status/benchmark-{benchmark}-metric-{metric}.txt" + err = "results/status/benchmark-{benchmark}-metric-{metric}.txt", + convergence = "results/convergence/benchmark-{benchmark}-metric-{metric}.txt", input: _get_paths_to_be_assembled run: + max_r_hat = -1e9 + min_n_eff = 1e9 + results = [] for pth in input: res = joblib.load(pth) @@ -285,6 +299,10 @@ rule assemble_results: } results.append(nice) + if "max_r_hat" in res.additional_info: + max_r_hat = max(max_r_hat, res.additional_info["max_r_hat"]) + min_n_eff = min(min_n_eff, res.additional_info["min_n_eff"]) + results = pd.DataFrame(results) df_ok = results[results["run_ok"]] @@ -298,6 +316,9 @@ rule assemble_results: df_ok = df_ok.drop(columns=["run_ok", "additional_info"]) df_ok.to_csv(str(output.csv), index=False) + with open(output.convergence, "w") as f: + f.write(f"Max r_hat: {max_r_hat}\n") + f.write(f"Min n_eff: {min_n_eff}\n") def plot_results(ax, df, plot_std: bool = True, alpha: float = 0.5): diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index 2e79b54..b85f20a 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -3,7 +3,7 @@ # ---------------------------------------------------------------------------------- from dataclasses import dataclass import numpy as np -import pandas as pd +import json import joblib import matplotlib @@ -14,7 +14,7 @@ matplotlib.use("agg") import jax import numpyro import numpyro.distributions as dist - +from numpyro.diagnostics import summary import labelshift.algorithms.api as algo from labelshift.datasets.discrete_categorical import SummaryStatistic @@ -57,18 +57,18 @@ N_POINTS = [100, 1000, 10_000] PI_LABELED = 0.5 PI_UNLABELED = 0.2 -N_MCMC_WARMUP = 500 -N_MCMC_SAMPLES = 1000 +N_MCMC_WARMUP = 1500 +N_MCMC_SAMPLES = 2000 +N_MCMC_CHAINS = 4 COVERAGES = np.arange(0.05, 0.96, 0.05) rule all: - input: expand("plots/{n_points}.pdf", n_points=N_POINTS) - -# rule all: -# input: expand("figures/{setting}-{seed}.pdf", setting=SETTINGS.keys(), seed=SEEDS) + input: + plots = expand("plots/{n_points}.pdf", n_points=N_POINTS), + convergence = "convergence_overall.json", rule generate_data: @@ -82,7 +82,7 @@ rule generate_data: def gaussian_model(observed: Data, unobserved: np.ndarray): sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2))) - mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1)) + mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3)) pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2))) @@ -101,7 +101,7 @@ def gaussian_model(observed: Data, unobserved: np.ndarray): def student_model(observed: Data, unobserved: np.ndarray): df = numpyro.sample('df', dist.Gamma(np.ones(2), np.ones(2))) sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2))) - mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1)) + mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3)) pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2))) @@ -117,36 +117,55 @@ def student_model(observed: Data, unobserved: np.ndarray): numpyro.sample('x', mixture, obs=unobserved) +def generate_summary(samples): + summ = summary(samples) + n_eff_list = [float(np.min(d["n_eff"])) for d in summ.values()] + r_hat_list = [float(np.max(d["r_hat"])) for d in summ.values()] + return {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)} + rule run_gaussian_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Gaussian/{seed}.npy" + output: + samples = "samples/{n_points}/Gaussian/{seed}.npy", + convergence = "convergence/{n_points}/Gaussian/{seed}.joblib", run: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(gaussian_model), num_warmup=N_MCMC_WARMUP, num_samples=N_MCMC_SAMPLES, + num_chains=N_MCMC_CHAINS, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) samples = mcmc.get_samples() - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(mcmc.get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) rule run_student_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Student/{seed}.npy" - run: + output: + samples = "samples/{n_points}/Student/{seed}.npy", + convergence = "convergence/{n_points}/Student/{seed}.joblib", + run: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(student_model), num_warmup=N_MCMC_WARMUP, num_samples=N_MCMC_SAMPLES, + num_chains=N_MCMC_CHAINS, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) samples = mcmc.get_samples() - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(mcmc.get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) + def _calculate_bins(n: int): @@ -169,15 +188,24 @@ def generate_summary_statistic( rule run_discrete_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Discrete-{n_bins}/{seed}.npy" + output: + samples = "samples/{n_points}/Discrete-{n_bins}/{seed}.npy", + convergence = "convergence/{n_points}/Discrete-{n_bins}/{seed}.joblib", run: data_labeled, data_unlabeled = joblib.load(str(input)) estimator = algo.DiscreteCategoricalMeanEstimator( seed=int(wildcards.seed) + 101, - params=algo.SamplingParams(warmup=N_MCMC_WARMUP, samples=N_MCMC_SAMPLES), + params=algo.SamplingParams( + warmup=N_MCMC_WARMUP, + samples=N_MCMC_SAMPLES, + chains=N_MCMC_CHAINS, + ), ) samples = estimator.sample_posterior(generate_summary_statistic(data_labeled, data_unlabeled.xs, int(wildcards.n_bins))) - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(estimator.get_mcmc().get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) def calculate_hdi(arr, prob: float) -> tuple[float, float]: @@ -196,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]: rule contains_ground_truth: - input: "samples/{n_points}/{algorithm}/{seed}.npy" + input: + samples = "samples/{n_points}/{algorithm}/{seed}.npy", + convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib", output: "contains/{n_points}/{algorithm}/{seed}.joblib" run: - samples = joblib.load(str(input)) + samples = joblib.load(input.samples) + convergence = joblib.load(input.convergence) + run_ok = True if convergence["max_r_hat"] < 1.02 else False + pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1] - + results = [] intervals = [] for coverage in COVERAGES: @@ -212,7 +245,7 @@ rule contains_ground_truth: results = np.asarray(results, dtype=float) intervals = np.asarray(intervals, dtype=float) - joblib.dump((results, intervals), str(output)) + joblib.dump((results, intervals, run_ok), str(output)) def _input_paths_calculate_coverages(wildcards): @@ -221,15 +254,62 @@ def _input_paths_calculate_coverages(wildcards): rule calculate_coverages: input: _input_paths_calculate_coverages - output: "coverages/{n_points}/{algorithm}.npy" + output: + coverages = "coverages/{n_points}/{algorithm}.npy", + excluded_runs = "excluded/{n_points}-{algorithm}.json" run: results = [] + + ok_runs = 0 + excluded_runs = 0 for pth in input: - res, _ = joblib.load(pth) - results.append(res) + res, _, run_ok = joblib.load(pth) + if run_ok: + results.append(res) + ok_runs += 1 + else: + excluded_runs += 1 + results = np.asarray(results) coverages = results.mean(axis=0) - np.save(str(output), coverages) + np.save(output.coverages, coverages) + + with open(output.excluded_runs, "w") as fh: + json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh) + +def _input_paths_summarize_convergence(wildcards): + return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS] + + +rule summarize_convergence: + input: _input_paths_summarize_convergence + output: "convergence/{n_points}/{algorithm}.json" + run: + min_n_effs = [] + max_r_hats = [] + for pth in input: + res = joblib.load(pth) + min_n_effs.append(res["min_n_eff"]) + max_r_hats.append(res["max_r_hat"]) + + with open(str(output), "w") as fh: + json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh) + + +rule summarize_convergence_overall: + input: expand("convergence/{n_points}/{algorithm}.json", n_points=N_POINTS, algorithm=["Gaussian", "Student", "Discrete-5", "Discrete-10"]) + output: "convergence_overall.json" + run: + min_n_effs = [] + max_r_hats = [] + for pth in input: + with open(pth) as fh: + res = json.load(fh) + min_n_effs.append(res["min_n_eff"]) + max_r_hats.append(res["max_r_hat"]) + + with open(str(output), "w") as fh: + json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh) rule plot_coverage: input: @@ -243,7 +323,7 @@ rule plot_coverage: sample_discrete10 = "samples/{n_points}/Discrete-10/1.npy", output: "plots/{n_points}.pdf" run: - fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=150, left=0.2, top=0.3, right=1.8) + fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=400, left=0.2, top=0.3, right=1.8) axs = axs.ravel() # Conditional distributions P(X|Y)