From 03395ab7883bcd2cf9bd1f92365f1087b72587aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 May 2021 10:26:08 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../notebooks/learn/details_cls.ipynb | 325 +++++++++--------- 1 file changed, 164 insertions(+), 161 deletions(-) diff --git a/docs/examples/notebooks/learn/details_cls.ipynb b/docs/examples/notebooks/learn/details_cls.ipynb index 1613720e3e..5f3ca25838 100644 --- a/docs/examples/notebooks/learn/details_cls.ipynb +++ b/docs/examples/notebooks/learn/details_cls.ipynb @@ -12,50 +12,55 @@ "import pyhf.contrib.viz.brazil\n", "from pyhf.infer.calculators import AsymptoticCalculator\n", "import scipy.stats\n", - "plt.rcParams['savefig.facecolor']='white'\n", "\n", + "plt.rcParams['savefig.facecolor'] = 'white'\n", + "\n", + "\n", + "def gpdf(qq, muprime, sig, mu):\n", + " cut = mu ** 2 / sig ** 2\n", "\n", - "def gpdf(qq,muprime,sig,mu):\n", - " cut = mu**2/sig**2\n", - " \n", " # 1/√(2π)\n", - " standard_pre = 1/np.sqrt(2*np.pi)\n", + " standard_pre = 1 / np.sqrt(2 * np.pi)\n", "\n", " # compute the arg of the exponential\n", - " muprime_muhat1_minus_over_sig = (qq - (mu**2/sig**2 - (2*mu*muprime/sig**2)) )/(2*mu/sig)\n", - " muprime_muhat2_minus_over_sig = (np.sqrt(qq)-(mu-muprime)/sig) \n", + " muprime_muhat1_minus_over_sig = (\n", + " qq - (mu ** 2 / sig ** 2 - (2 * mu * muprime / sig ** 2))\n", + " ) / (2 * mu / sig)\n", + " muprime_muhat2_minus_over_sig = np.sqrt(qq) - (mu - muprime) / sig\n", "\n", " # chose which one\n", - " muhat_over_sig = np.where(qq>cut,muprime_muhat1_minus_over_sig,muprime_muhat2_minus_over_sig)\n", + " muhat_over_sig = np.where(\n", + " qq > cut, muprime_muhat1_minus_over_sig, muprime_muhat2_minus_over_sig\n", + " )\n", "\n", " # exp(-0.5 (mu^-mu'))\n", - " arg = -0.5*(muhat_over_sig)**2\n", + " arg = -0.5 * (muhat_over_sig) ** 2\n", "\n", - " # compute the jacobian \n", + " # compute the jacobian\n", " # 1/ sigma * dµ^/dq\n", - " oos_j1 = 1/(2*mu/sig)\n", - " oos_j2 = 1/(2*np.sqrt(qq))\n", - " \n", + " oos_j1 = 1 / (2 * mu / sig)\n", + " oos_j2 = 1 / (2 * np.sqrt(qq))\n", + "\n", " # chose which one\n", - " one_over_sig_jacobian = np.where(qq>cut,oos_j1,oos_j2)\n", + " one_over_sig_jacobian = np.where(qq > cut, oos_j1, oos_j2)\n", "\n", " # compute the reparametrized gaussian\n", - " return standard_pre * np.exp(arg)*one_over_sig_jacobian\n", + " return standard_pre * np.exp(arg) * one_over_sig_jacobian\n", + "\n", "\n", - "def testrail_parab(muhat,mu,sigma):\n", - " return (muhat-mu)**2/sigma**2\n", + "def testrail_parab(muhat, mu, sigma):\n", + " return (muhat - mu) ** 2 / sigma ** 2\n", "\n", - "def testrail_flat(muhat,mu,sigma):\n", + "\n", + "def testrail_flat(muhat, mu, sigma):\n", " return np.where(\n", - " muhat<0,\n", - " (muhat-mu)**2/sigma**2 - muhat**2/sigma**2,\n", - " np.where(muhat < mu,\n", - " (muhat-mu)**2/sigma**2,\n", - " 0.0\n", - " )\n", + " muhat < 0,\n", + " (muhat - mu) ** 2 / sigma ** 2 - muhat ** 2 / sigma ** 2,\n", + " np.where(muhat < mu, (muhat - mu) ** 2 / sigma ** 2, 0.0),\n", " )\n", "\n", - "def get_toy_results(test_poi,toys):\n", + "\n", + "def get_toy_results(test_poi, toys):\n", " qmu_tilde = np.asarray(\n", " [\n", " pyhf.infer.test_statistics.qmu_tilde(\n", @@ -77,15 +82,21 @@ " )\n", "\n", " muhat = pars[:, model.config.poi_index]\n", - " return muhat,qmu_tilde \n", + " return muhat, qmu_tilde\n", + "\n", "\n", "model = pyhf.simplemodels.hepdata_like(\n", " signal_data=[30.0], bkg_data=[50.0], bkg_uncerts=[7.0]\n", ")\n", "data = [55.0] + model.config.auxdata\n", "\n", - "asimov_data = pyhf.infer.calculators.generate_asimov_data(0.0,data,model,\n", - " model.config.suggested_init(),model.config.suggested_bounds(),model.config.suggested_fixed()\n", + "asimov_data = pyhf.infer.calculators.generate_asimov_data(\n", + " 0.0,\n", + " data,\n", + " model,\n", + " model.config.suggested_init(),\n", + " model.config.suggested_bounds(),\n", + " model.config.suggested_fixed(),\n", ")" ] }, @@ -95,35 +106,34 @@ "metadata": {}, "outputs": [], "source": [ - "muhatmin, muhatmax = -2,3\n", - "mumin, mumax = 1e-2,1.2\n", + "muhatmin, muhatmax = -2, 3\n", + "mumin, mumax = 1e-2, 1.2\n", "maxsigma = 4\n", "\n", - "qqspan = np.linspace(0,maxsigma**2,10001)\n", - "muspan = np.linspace(mumin, mumax,31)\n", - "muhatspan = np.linspace(muhatmin, muhatmax,1001)\n", + "qqspan = np.linspace(0, maxsigma ** 2, 10001)\n", + "muspan = np.linspace(mumin, mumax, 31)\n", + "muhatspan = np.linspace(muhatmin, muhatmax, 1001)\n", "test_stats = []\n", "vals = []\n", "\n", "for mu_test in muspan:\n", - " calc = AsymptoticCalculator(data,model)\n", + " calc = AsymptoticCalculator(data, model)\n", " ts = calc.teststatistic(mu_test)\n", - " ds,db = calc.distributions(mu_test)\n", - " cl_sb,cl_b,cl_s = calc.pvalues(ts,ds,db)\n", - " test_stats.append([ds.shift,db.shift,ts])\n", - " vals.append([cl_sb,cl_b,cl_s])\n", + " ds, db = calc.distributions(mu_test)\n", + " cl_sb, cl_b, cl_s = calc.pvalues(ts, ds, db)\n", + " test_stats.append([ds.shift, db.shift, ts])\n", + " vals.append([cl_sb, cl_b, cl_s])\n", "\n", "vals = np.array(vals)\n", "test_stats = np.array(test_stats)\n", "\n", "\n", "def getv(mu_test):\n", - " calc = AsymptoticCalculator(data,model)\n", + " calc = AsymptoticCalculator(data, model)\n", " ts = calc.teststatistic(mu_test)\n", - " ds,db = calc.distributions(mu_test)\n", - " cl_sb,cl_b,cl_s = calc.pvalues(ts,ds,db)\n", - " return [cl_sb,cl_b,cl_s]\n", - "\n" + " ds, db = calc.distributions(mu_test)\n", + " cl_sb, cl_b, cl_s = calc.pvalues(ts, ds, db)\n", + " return [cl_sb, cl_b, cl_s]" ] }, { @@ -948,193 +958,186 @@ " vals,\n", "):\n", "\n", - "\n", " mutest_at_index = muspan[index]\n", - " means_at_index = -test_stats[index,[1,0]]\n", - " ts_at_index = -test_stats[index,2]\n", - "\n", + " means_at_index = -test_stats[index, [1, 0]]\n", + " ts_at_index = -test_stats[index, 2]\n", "\n", " tail_span = np.linspace(muhatmin, ts_at_index, 1001)\n", "\n", - "\n", - "\n", " vals_at_index = vals[index]\n", "\n", - "\n", " obs_pllr_ts = pyhf.infer.test_statistics.qmu_tilde(\n", - " mutest_at_index,data,model,\n", - " model.config.suggested_init(),model.config.suggested_bounds(),model.config.suggested_fixed()\n", + " mutest_at_index,\n", + " data,\n", + " model,\n", + " model.config.suggested_init(),\n", + " model.config.suggested_bounds(),\n", + " model.config.suggested_fixed(),\n", " )\n", "\n", - "\n", " ###\n", - " toys_bkg = model.make_pdf(pyhf.tensorlib.astensor([0.0,1.0])).sample((25,))\n", - " toys_sig = model.make_pdf(pyhf.tensorlib.astensor([1.0,1.0])).sample((25,))\n", - " muhat_s,qmu_tilde_s = get_toy_results(mutest_at_index,toys_sig)\n", - " muhat_b,qmu_tilde_b = get_toy_results(mutest_at_index,toys_bkg)\n", - "\n", + " toys_bkg = model.make_pdf(pyhf.tensorlib.astensor([0.0, 1.0])).sample((25,))\n", + " toys_sig = model.make_pdf(pyhf.tensorlib.astensor([1.0, 1.0])).sample((25,))\n", + " muhat_s, qmu_tilde_s = get_toy_results(mutest_at_index, toys_sig)\n", + " muhat_b, qmu_tilde_b = get_toy_results(mutest_at_index, toys_bkg)\n", "\n", " asimov_pllr_ts = pyhf.infer.test_statistics.qmu_tilde(\n", - " mutest_at_index,asimov_data,model,\n", - " model.config.suggested_init(),model.config.suggested_bounds(),model.config.suggested_fixed()\n", + " mutest_at_index,\n", + " asimov_data,\n", + " model,\n", + " model.config.suggested_init(),\n", + " model.config.suggested_bounds(),\n", + " model.config.suggested_fixed(),\n", " )\n", - " sigma = np.sqrt(mutest_at_index**2/asimov_pllr_ts)\n", + " sigma = np.sqrt(mutest_at_index ** 2 / asimov_pllr_ts)\n", " ###\n", "\n", + " tsrail_parab = testrail_parab(muhatspan, mutest_at_index / sigma, sigma=1)\n", + " tsrail_flat = testrail_flat(muhatspan, mutest_at_index / sigma, sigma=1)\n", "\n", - " tsrail_parab = testrail_parab(muhatspan,mutest_at_index/sigma, sigma = 1)\n", - " tsrail_flat = testrail_flat(muhatspan,mutest_at_index/sigma, sigma = 1)\n", - "\n", - "\n", - "\n", - "\n", - " tail_span_qq = np.linspace(obs_pllr_ts, maxsigma**2, 1001)\n", - "\n", + " tail_span_qq = np.linspace(obs_pllr_ts, maxsigma ** 2, 1001)\n", "\n", " unbounded_bounds = model.config.suggested_bounds()\n", " unbounded_bounds[model.config.poi_index] = (-10, 10)\n", - " empirial_muhat = pyhf.infer.mle.fit(\n", - " data,model,\n", - " par_bounds = unbounded_bounds\n", - " )[model.config.poi_index] \n", + " empirial_muhat = pyhf.infer.mle.fit(data, model, par_bounds=unbounded_bounds)[\n", + " model.config.poi_index\n", + " ]\n", " empirial_muhat = empirial_muhat / sigma\n", "\n", + " muhat_pdfs = scipy.stats.norm(np.array(means_at_index)).pdf(\n", + " np.tile(muhatspan.reshape(-1, 1), (1, 2))\n", + " )\n", "\n", - "\n", - " muhat_pdfs = scipy.stats.norm(np.array(means_at_index)).pdf(np.tile(muhatspan.reshape(-1,1),(1,2)))\n", - "\n", - " muhat_pdfs_tail = scipy.stats.norm(np.array(means_at_index)).pdf(np.tile(tail_span.reshape(-1,1),(1,2)))\n", - "\n", - "\n", - "\n", - "\n", - "\n", + " muhat_pdfs_tail = scipy.stats.norm(np.array(means_at_index)).pdf(\n", + " np.tile(tail_span.reshape(-1, 1), (1, 2))\n", + " )\n", "\n", " ax = axarr['A']\n", " ax.set_ylabel(r'$\\tilde{q}_\\mu$')\n", - " ax.set_xlim(muhatmin,muhatmax)\n", - "\n", - " ax.plot(muhatspan,tsrail_parab, c = 'grey', linestyle = 'dashed')\n", - " ax.plot(muhatspan,tsrail_flat, c = 'grey')\n", - " ax.vlines(0,0,maxsigma**2,colors = 'green', linestyles = 'dashed')\n", - " ax.vlines(means_at_index[1],0,maxsigma**2,colors = 'red', linestyles = 'dashed')\n", - " ax.vlines(ts_at_index,0,maxsigma**2,colors = 'black', linestyles = 'dashed')\n", - " ax.vlines(empirial_muhat,0,maxsigma**2,colors = 'orange', linestyles = 'dashed')\n", + " ax.set_xlim(muhatmin, muhatmax)\n", "\n", + " ax.plot(muhatspan, tsrail_parab, c='grey', linestyle='dashed')\n", + " ax.plot(muhatspan, tsrail_flat, c='grey')\n", + " ax.vlines(0, 0, maxsigma ** 2, colors='green', linestyles='dashed')\n", + " ax.vlines(means_at_index[1], 0, maxsigma ** 2, colors='red', linestyles='dashed')\n", + " ax.vlines(ts_at_index, 0, maxsigma ** 2, colors='black', linestyles='dashed')\n", + " ax.vlines(empirial_muhat, 0, maxsigma ** 2, colors='orange', linestyles='dashed')\n", "\n", - " ax.scatter(muhat_b/sigma,qmu_tilde_b, alpha = 0.2, c = 'green', s = 20)\n", - " ax.scatter(muhat_s/sigma,qmu_tilde_s, alpha = 0.2, c = 'red', s = 20)\n", - " ax.scatter(0.0,asimov_pllr_ts, c = 'green')\n", + " ax.scatter(muhat_b / sigma, qmu_tilde_b, alpha=0.2, c='green', s=20)\n", + " ax.scatter(muhat_s / sigma, qmu_tilde_s, alpha=0.2, c='red', s=20)\n", + " ax.scatter(0.0, asimov_pllr_ts, c='green')\n", "\n", - " ax.hlines(asimov_pllr_ts,0.0,muhatmax, colors = 'green', linestyles='dashed')\n", + " ax.hlines(asimov_pllr_ts, 0.0, muhatmax, colors='green', linestyles='dashed')\n", "\n", + " ax.hlines(obs_pllr_ts, ts_at_index, muhatmax, colors='black', linestyles='dashed')\n", + " ax.scatter(empirial_muhat, obs_pllr_ts, c='orange')\n", + " ax.scatter(ts_at_index, obs_pllr_ts, c='black')\n", + " ax.scatter(mutest_at_index / sigma, 0.0, c='grey')\n", "\n", - " ax.hlines(obs_pllr_ts,ts_at_index,muhatmax, colors = 'black', linestyles='dashed')\n", - " ax.scatter(empirial_muhat,obs_pllr_ts,c = 'orange')\n", - " ax.scatter(ts_at_index,obs_pllr_ts,c = 'black')\n", - " ax.scatter(mutest_at_index/sigma,0.0,c = 'grey')\n", - "\n", - " ax.set_ylim(-1,maxsigma**2)\n", - "\n", + " ax.set_ylim(-1, maxsigma ** 2)\n", "\n", " ax = axarr['B']\n", " ax.set_xlabel(r'$p(\\tilde{q}_\\mu)$')\n", - " pqq_s = gpdf(qqspan,mutest_at_index,1.0,mutest_at_index)\n", - " pqq_b = gpdf(qqspan,0.0,1.0,mutest_at_index)\n", + " pqq_s = gpdf(qqspan, mutest_at_index, 1.0, mutest_at_index)\n", + " pqq_b = gpdf(qqspan, 0.0, 1.0, mutest_at_index)\n", "\n", - " pqq_s_tail = gpdf(tail_span_qq,mutest_at_index,1.0,mutest_at_index)\n", - " pqq_b_tail = gpdf(tail_span_qq,0.0,1.0,mutest_at_index)\n", + " pqq_s_tail = gpdf(tail_span_qq, mutest_at_index, 1.0, mutest_at_index)\n", + " pqq_b_tail = gpdf(tail_span_qq, 0.0, 1.0, mutest_at_index)\n", "\n", - " ax.plot(pqq_b,qqspan, c = 'green')\n", - " ax.plot(pqq_s,qqspan, c = 'red')\n", + " ax.plot(pqq_b, qqspan, c='green')\n", + " ax.plot(pqq_s, qqspan, c='red')\n", "\n", - " ax.fill_betweenx(tail_span_qq,pqq_b_tail, facecolor = 'green', alpha = 0.2)\n", - " ax.fill_betweenx(tail_span_qq,pqq_s_tail, facecolor = 'red', alpha = 0.2)\n", + " ax.fill_betweenx(tail_span_qq, pqq_b_tail, facecolor='green', alpha=0.2)\n", + " ax.fill_betweenx(tail_span_qq, pqq_s_tail, facecolor='red', alpha=0.2)\n", " ax.set_xscale('log')\n", "\n", - " ax.hlines(obs_pllr_ts,0.0,10.0, colors = 'black', linestyles='dashed')\n", - " ax.hlines(asimov_pllr_ts,0.0,10.0, colors = 'green', linestyles='dashed')\n", - " ax.set_xlim(1e-3,1e1)\n", - " ax.set_ylim(-1,maxsigma**2)\n", - "\n", - "\n", + " ax.hlines(obs_pllr_ts, 0.0, 10.0, colors='black', linestyles='dashed')\n", + " ax.hlines(asimov_pllr_ts, 0.0, 10.0, colors='green', linestyles='dashed')\n", + " ax.set_xlim(1e-3, 1e1)\n", + " ax.set_ylim(-1, maxsigma ** 2)\n", "\n", " ax = axarr['F']\n", "\n", " ax.set_ylabel(r'$p(\\hat{\\mu}/\\sigma)$')\n", "\n", - " ax.plot(muhatspan,muhat_pdfs[:,0], c = 'green')\n", - " ax.fill_between(tail_span,muhat_pdfs_tail[:,0], facecolor = 'green', alpha = 0.2)\n", + " ax.plot(muhatspan, muhat_pdfs[:, 0], c='green')\n", + " ax.fill_between(tail_span, muhat_pdfs_tail[:, 0], facecolor='green', alpha=0.2)\n", "\n", - " ax.vlines(0,0,.5,colors = 'green', linestyles = 'dashed')\n", - " ax.vlines(means_at_index[1],0,.5,colors = 'red', linestyles = 'dashed')\n", - " ax.vlines(ts_at_index,0,.5,colors = 'black', linestyles = 'dashed')\n", - " ax.vlines(empirial_muhat,0,.5,colors = 'orange', linestyles = 'dashed')\n", - " ax.set_xlim(muhatmin,muhatmax)\n", - " ax.set_ylim(0,.5)\n", + " ax.vlines(0, 0, 0.5, colors='green', linestyles='dashed')\n", + " ax.vlines(means_at_index[1], 0, 0.5, colors='red', linestyles='dashed')\n", + " ax.vlines(ts_at_index, 0, 0.5, colors='black', linestyles='dashed')\n", + " ax.vlines(empirial_muhat, 0, 0.5, colors='orange', linestyles='dashed')\n", + " ax.set_xlim(muhatmin, muhatmax)\n", + " ax.set_ylim(0, 0.5)\n", "\n", " ax = axarr['C']\n", " ax.set_ylabel(r'$p(\\hat{\\mu}/\\sigma)$')\n", "\n", - " ax.plot(muhatspan,muhat_pdfs[:,1], c = 'red')\n", - " ax.fill_between(tail_span,muhat_pdfs_tail[:,1], facecolor = 'red', alpha = 0.2)\n", - " ax.vlines(0,0,.5,colors = 'green', linestyles = 'dashed')\n", - " ax.vlines(means_at_index[1],0,.5,colors = 'red', linestyles = 'dashed')\n", - " ax.vlines(ts_at_index,0,.5,colors = 'black', linestyles = 'dashed')\n", - " ax.vlines(empirial_muhat,0,.5,colors = 'orange', linestyles = 'dashed')\n", - " ax.set_xlim(muhatmin,muhatmax)\n", - " ax.set_ylim(0,.5)\n", + " ax.plot(muhatspan, muhat_pdfs[:, 1], c='red')\n", + " ax.fill_between(tail_span, muhat_pdfs_tail[:, 1], facecolor='red', alpha=0.2)\n", + " ax.vlines(0, 0, 0.5, colors='green', linestyles='dashed')\n", + " ax.vlines(means_at_index[1], 0, 0.5, colors='red', linestyles='dashed')\n", + " ax.vlines(ts_at_index, 0, 0.5, colors='black', linestyles='dashed')\n", + " ax.vlines(empirial_muhat, 0, 0.5, colors='orange', linestyles='dashed')\n", + " ax.set_xlim(muhatmin, muhatmax)\n", + " ax.set_ylim(0, 0.5)\n", "\n", " ax = axarr['D']\n", " ax.set_ylabel(r'$\\mu_\\mathrm{test}$')\n", " ax.set_xlabel(r'$\\hat{\\mu}/\\sigma$')\n", "\n", - " ax.scatter(empirial_muhat,mutest_at_index, c = 'orange')\n", - " ax.plot(-test_stats[:,0],muspan, c = 'red')\n", - " ax.plot(-test_stats[:,1],muspan, c = 'green')\n", - " ax.scatter(means_at_index[1],mutest_at_index, c = 'red')\n", - " ax.vlines(means_at_index[1],mutest_at_index,5,colors = 'red', linestyles = 'dashed')\n", - "\n", - " ax.plot(-test_stats[:,2],muspan, c = 'black')\n", - " ax.scatter(ts_at_index,mutest_at_index, c = 'black')\n", - " ax.vlines(ts_at_index,mutest_at_index,mumax,colors = 'black', linestyles = 'dashed')\n", - " ax.vlines(empirial_muhat,mumin,mumax,colors = 'orange', linestyles = 'dashed')\n", - " ax.hlines(empirial_muhat*sigma,empirial_muhat,muhatmax,colors = 'orange', linestyles = 'dashed')\n", + " ax.scatter(empirial_muhat, mutest_at_index, c='orange')\n", + " ax.plot(-test_stats[:, 0], muspan, c='red')\n", + " ax.plot(-test_stats[:, 1], muspan, c='green')\n", + " ax.scatter(means_at_index[1], mutest_at_index, c='red')\n", + " ax.vlines(means_at_index[1], mutest_at_index, 5, colors='red', linestyles='dashed')\n", + "\n", + " ax.plot(-test_stats[:, 2], muspan, c='black')\n", + " ax.scatter(ts_at_index, mutest_at_index, c='black')\n", + " ax.vlines(ts_at_index, mutest_at_index, mumax, colors='black', linestyles='dashed')\n", + " ax.vlines(empirial_muhat, mumin, mumax, colors='orange', linestyles='dashed')\n", + " ax.hlines(\n", + " empirial_muhat * sigma,\n", + " empirial_muhat,\n", + " muhatmax,\n", + " colors='orange',\n", + " linestyles='dashed',\n", + " )\n", "\n", - " ax.scatter(means_at_index[0],mutest_at_index, c = 'green')\n", + " ax.scatter(means_at_index[0], mutest_at_index, c='green')\n", "\n", - " ax.set_xlim(muhatmin,muhatmax)\n", + " ax.set_xlim(muhatmin, muhatmax)\n", " ax.set_ylim(mumin, mumax)\n", "\n", " ax = axarr['E']\n", " ax.set_xlabel(r'CL')\n", "\n", - " ax.plot(vals[:,0],muspan,c = 'red')\n", - " ax.plot(vals[:,1],muspan,c = 'green')\n", - " ax.plot(vals[:,2],muspan,c = 'grey')\n", - "\n", - "\n", - " ax.scatter(vals_at_index[0],mutest_at_index, c = 'red')\n", - " ax.scatter(vals_at_index[1],mutest_at_index, c = 'green')\n", - " ax.scatter(vals_at_index[2],mutest_at_index, c = 'grey')\n", + " ax.plot(vals[:, 0], muspan, c='red')\n", + " ax.plot(vals[:, 1], muspan, c='green')\n", + " ax.plot(vals[:, 2], muspan, c='grey')\n", "\n", - " ax.hlines(empirial_muhat*sigma,0,1.0,colors = 'orange', linestyles = 'dashed')\n", + " ax.scatter(vals_at_index[0], mutest_at_index, c='red')\n", + " ax.scatter(vals_at_index[1], mutest_at_index, c='green')\n", + " ax.scatter(vals_at_index[2], mutest_at_index, c='grey')\n", "\n", - " ax.set_ylim(mumin,mumax)\n", - " ax.set_xlim(0,1)\n", + " ax.hlines(empirial_muhat * sigma, 0, 1.0, colors='orange', linestyles='dashed')\n", "\n", + " ax.set_ylim(mumin, mumax)\n", + " ax.set_xlim(0, 1)\n", "\n", "\n", "for index in range(len(muspan)):\n", - "# for index in [15]:\n", - " f,axarr = plt.subplot_mosaic(\"\"\"\n", + " # for index in [15]:\n", + " f, axarr = plt.subplot_mosaic(\n", + " \"\"\"\n", " AAABBB\n", " AAABBB\n", " FFF...\n", " CCC...\n", " DDDEEE\n", " DDDEEE\n", - " \"\"\")\n", + " \"\"\"\n", + " )\n", "\n", " plot_explanation(\n", " axarr,\n", @@ -1146,10 +1149,10 @@ " )\n", "\n", " f.set_tight_layout(True)\n", - " f.set_size_inches(10,10)\n", + " f.set_size_inches(10, 10)\n", " f.savefig(f'scan{str(index).zfill(5)}.png')\n", " plt.show()\n", - " plt.clf()\n" + " plt.clf()" ] }, { @@ -1181,4 +1184,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +}