In [None]:
from hipe4ml import plot_utils

In [None]:
import uproot
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
df_sig = uproot.open("SignalTable.root")["SignalTable"].pandas.df()
df_bkg = uproot.open("DataTable_18LS.root")["DataTable"].pandas.df()

In [None]:
training_columns = ["ProngsDCA","He3ProngPvDCA","He3ProngPvDCAXY","PiProngPvDCA","PiProngPvDCAXY","TPCnSigmaHe3","TPCnSigmaPi","NpidClustersHe3","V0CosPA", "pt"]

In [None]:
training_labels = [r"$\mathrm{DCA_{daughters}}$ (cm)", r"$\mathrm{DCA_{PV} \/ ^{3}He} $ (cm)",  r"$\mathrm{DCA_{PV} \/ \pi} $ (cm)", r"$\mathrm{DCA_{PV XY} \/ ^{3}He}$ (cm)",  r"$\mathrm{DCA_{PV XY} \/ \pi}$ (cm)", r"n$\sigma_{\mathrm{TPC}} \/ \mathrm{^{3}He}$",  r"n$\sigma_{\mathrm{TPC}} \/ \mathrm{\pi}$", r"n$_{\mathrm{cluster TPC}} \/ \mathrm{^{3}He}$", r"cos($\theta_{\mathrm{pointing}}$)", r"$p_\mathrm{T}$ (GeV/$c$)"]  

In [None]:
bins= [80, 63, 63, 63, 63, 79, 78, 127,63,63]
log_scale = [True, True, True, True, True, True, True, True, True, True]

In [None]:
len(bins)

In [None]:
fig, axs = plt.subplots(3,4, figsize=(35, 22))
axs = axs.flatten()
for index, variable in enumerate(training_columns, start=0):
    ax = axs[index]
    ax = sns.distplot(df_sig[variable], norm_hist=True, kde=False, bins=bins[index], hist_kws={'log': log_scale[index]}, label='Signal', color='tab:red', ax=ax)
    ax = sns.distplot(df_bkg[variable], norm_hist=True, kde=False, bins=bins[index], hist_kws={'log': log_scale[index]}, label='Background', color='tab:blue', ax=ax)
    ax.set_xlabel(training_labels[index], fontsize=30)
    ax.set_ylabel('counts (arb. units)', fontsize=30)
    ax.set_xlim(df_bkg[variable].min(), df_sig[variable].max())
    ax.tick_params(direction='in')

fig.delaxes(axs[-1])
fig.delaxes(axs[-2])
axs[-4].legend(bbox_to_anchor=(3.9, 0.58),prop={'size': 48}, frameon=False)
plt.text(0.61, 0.31, "ALICE Performance", fontsize=48, transform=plt.gcf().transFigure)
plt.text(0.595, 0.263, "Pb-Pb $\sqrt{s_{\mathrm{NN}}} = $ 5.02TeV ", fontsize=48, transform=plt.gcf().transFigure)

In [None]:
fig.savefig("feature_distribution.png",bbox_inches='tight')
fig.savefig("feature_distribution.pdf",bbox_inches='tight')
fig.set_rasterized(True)
fig.savefig("feature_distribution.eps", format="eps", bbox_inches='tight')