In [None]:
import jetnet
from jetnet.datasets import JetNet
import numpy as np
import matplotlib.pyplot as plt
import gen_metrics
from tqdm import tqdm
import pandas as pd
from IPython.display import Markdown, display
import pickle
import plotting
from typing import OrderedDict

plt.rcParams.update({"font.size": 16})


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

plot_dir = "../plots/fgd_inf/Nov6"
_ = os.system(f"mkdir -p {plot_dir}")

data_dir = "../saved_dir/"

In [None]:
gen_jets = np.load(f"{data_dir}/best_epoch_gen_jets-3.npy")
gen_efps = jetnet.utils.efps(gen_jets, efpset_args=[("d<=", 4)])
real_efps = np.load(f"{data_dir}/t.npy")


In [None]:
_ = plt.hist(real_efps[:, 24], np.linspace(0, 0.2, 101), histtype='step', label="Real")
_ = plt.hist(gen_efps[:, 24], np.linspace(0, 0.2, 101), histtype='step', label="Gen")
plt.yscale('log')
plt.xlabel("EFP 24")
# plt.xscale('log')
# plt.ylabel(r"$\overline{\mathrm{FGD}}_{\infty}^{+\sigma}$")
_ = plt.legend()

In [None]:
from scipy.optimize import curve_fit
from scipy.stats import linregress

def linear(x, intercept, slope):
    return intercept + slope * x

In [None]:
numb = 5
nump = 200

means_ses = []
for i in tqdm(range(10)):
    res = gen_metrics.one_over_n_extrapolation_repeated_measurements(real_efps, gen_efps, min_samples=5000, max_samples=50_000, num_batches=numb, num_points=nump, seed=i)
    means_ses.append([res[0], res[1]])

means_ses = np.array(means_ses)
plt.figure(figsize=(12, 12))
plt.errorbar(range(10), means_ses[:, 0], means_ses[:, 1], fmt="o")
plt.ylabel(r"FGD$_\infty$")
plt.xlabel("Seed")
plt.savefig(f"{plot_dir}/fgdinf_check_b{numb}_p{nump}_vb.pdf")

In [None]:
numb = 10
nump = 200

res = gen_metrics.one_over_n_extrapolation_repeated_measurements(real_efps, gen_efps, min_samples=5000, max_samples=50_000, num_batches=numb, num_points=nump, seed=0)

In [None]:
plt.figure(figsize=(12, 12))
# plt.errorbar(0, res[0], res[1], fmt="o")
plt.scatter(res[2], res[3])
plt.plot(np.linspace(0, 50_000, 101), res[0] + res[4] * (1 / np.linspace(0, 50_000, 101)), color="red")
# plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.errorbar(0, res[0], res[1], fmt="o")
plt.scatter(1 / res[2], res[3])
plt.plot(1 / np.linspace(5000, 1e8, 101), res[0] + res[4] * (1 / np.linspace(5000, 1e8, 101)), color="red")
# plt.legend()

In [None]:
means_ses = np.array(means_ses)
plt.figure(figsize=(12, 12))
plt.errorbar(range(10), means_ses[:, 0], means_ses[:, 1], fmt="o")
plt.ylabel(r"FGD$_\infty$")
plt.xlabel("Seed")
plt.savefig(f"{plot_dir}/fgdinf_check_b{numb}_p{nump}_vb.pdf")

In [None]:
res = gen_metrics.one_over_n_extrapolation(real_efps, gen_efps, gen_metrics.frechet_gaussian_distance, min_samples=5_000, max_samples=25_000, num_batches=1, num_points=101, seed=1)
res

In [None]:
plt.figure(figsize=(12, 12))
plt.errorbar(res[2], res[3][:, 0], res[3][:, 1], fmt="o")
plt.plot(np.linspace(0, 50_000, 101), res[0][0] + res[0][1] * (1 / np.linspace(0, 50_000, 101)), color="red")
# plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
# plt.errorbar(1 / res[0], res[1][:, 0], res[1][:, 1], fmt="o")
plt.scatter(1 / res[2], res[3][:, 0])
plt.plot(np.linspace(0, 0.0002, 101), res[0][0] + res[0][1] * np.linspace(0, 0.0002, 101), label="Full Fit", color="red")
plt.errorbar(0, res[0][0], res[1][0], fmt="o", color="red")
plt.title("Without errors on measurements")
plt.savefig(f"{plot_dir}/fgdinf_fit_b{numb}_p{nump}.pdf")
# plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(res[2], res[3])
# plt.plot(np.linspace(0, 0.0002, 101), res[0] + res[4] * np.linspace(0, 0.0002, 101), label="Full Fit", color="red")
# plt.plot(np.linspace(0, 0.0002, 101), val[0] + val[1] * np.linspace(0, 0.0002, 101), label="Averaged Fit", color="green")
# plt.errorbar(1 / np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1), np.std(res[3][:-1].reshape(-1, 10), axis=1), fmt="o", color="green")
# plt.errorbar(0, res[0], res[1], fmt="o", color="red")
# plt.errorbar(-0.000001, val[0], np.sqrt(np.diag(cov))[0], fmt="o", color="green")
plt.xlabel("N")
plt.ylabel("FGD")
plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(1 / res[2], res[3])
plt.plot(np.linspace(0, 0.0002, 101), res[0] + res[4] * np.linspace(0, 0.0002, 101), label="Full Fit", color="red")
# plt.plot(np.linspace(0, 0.0002, 101), val[0] + val[1] * np.linspace(0, 0.0002, 101), label="Averaged Fit", color="green")
# plt.errorbar(1 / np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1), np.std(res[3][:-1].reshape(-1, 10), axis=1), fmt="o", color="green")
plt.errorbar(0, res[0], res[1], fmt="o", color="red")
# plt.errorbar(-0.000001, val[0], np.sqrt(np.diag(cov))[0], fmt="o", color="green")
plt.xlabel("N")
plt.ylabel("FGD")
plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(1 / res[2], res[3])
# plt.plot(np.linspace(0, 0.0002, 101), res[0] + res[4] * np.linspace(0, 0.0002, 101), label="Full Fit", color="red")
# plt.plot(np.linspace(0, 0.0002, 101), val[0] + val[1] * np.linspace(0, 0.0002, 101), label="Averaged Fit", color="green")
# plt.errorbar(1 / np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1), np.std(res[3][:-1].reshape(-1, 10), axis=1), fmt="o", color="green")
# plt.errorbar(0, res[0], res[1], fmt="o", color="red")
# plt.errorbar(-0.000001, val[0], np.sqrt(np.diag(cov))[0], fmt="o", color="green")
plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(1 / res[2], res[3])
plt.plot(np.linspace(0, 0.0002, 101), res[0] + res[4] * np.linspace(0, 0.0002, 101), label="Full Fit", color="red")
plt.plot(np.linspace(0, 0.0002, 101), val[0] + val[1] * np.linspace(0, 0.0002, 101), label="Averaged Fit", color="green")
plt.errorbar(1 / np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1), np.std(res[3][:-1].reshape(-1, 10), axis=1), fmt="o", color="green")
plt.errorbar(0, res[0], res[1], fmt="o", color="red")
plt.errorbar(-0.000001, val[0], np.sqrt(np.diag(cov))[0], fmt="o", color="green")
plt.legend()

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(res[2], res[3])
plt.plot(np.linspace(0, 50_000, 101), res[0] + res[4] * (1 / np.linspace(0, 50_000, 101)))

In [None]:
plt.figure(figsize=(12, 12))
plt.scatter(np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1))

In [None]:
plt.figure(figsize=(12, 12))
plt.errorbar(1 / np.mean(res[2][:-1].reshape(-1, 10), axis=1), np.mean(res[3][:-1].reshape(-1, 10), axis=1), np.std(res[3][:-1].reshape(-1, 10), axis=1), fmt="o")

In [None]:
ms = []
for batch_size in tqdm(batch_sizes["fgd"][9:]):
    mean_std, timing = gen_metrics.multi_batch_evaluation(
        real_efps,
        gen_efps,
        5,
        batch_size,
        gen_metrics.frechet_gaussian_distance,
        timing=True,
        normalise=True
    )
    ms.append(mean_std)