# Load modules and data

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from copy import deepcopy

import subprocess
import os.path
from time import time

import uproot 

from scipy.optimize import curve_fit

from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))

In [None]:
PERIOD = 'LHC18r'
BRANCHES = ['runNumber', 'evtTimeStamp', 'v0.fEffMass']

TIME_OFFSET = 1262307600 - 7200

In [None]:
os.listdir('data_validation_V0s/')

In [None]:
len([f for f in os.listdir('data_validation_V0s/') if PERIOD in f and f.endswith('.root')])

In [None]:
from time import time

def root2df(fname):
    try:
        df = uproot.open(fname)['v0_mass'].pandas.df()
        return df
    except ValueError as e_msg:
        print(f'ERROR in {fname} : {e_msg}')
        return None

def fname2run(fname):
    return fname.split('_')[-1].split('.')[0]
    
    
def gaus(x, amp, mu, sigma):
    return amp * np.exp(-0.5*((x-mu)/sigma)**2)

def lin(x, a, b):
    return a*x + b

def fit_func(x, amp, mu, sigma, a, b):
    return amp * np.exp(-0.5*((x-mu)/sigma)**2) + a*x + b



def make_fitting(mass_arr, 
                 period_run_chunk,
                 is_bad,
                 binning=(440, 540, 50), 
                 bounds_low=[0, 470, 2, -10000, -1000], 
                 bounds_high=[25000, 530, 10, 10000, 50000], 
                 peak_range=(480, 515),
                 param_names=('amp', 'mu', 'sigma', 'a', 'b'),
                 close_fig=False,
                 verbose=True,
                ):
    period, run, chunk = period_run_chunk
    if verbose:
        print(f'#V0s total = {len(mass_arr)}')
        print(f'#V0s in peak ({peak_range}) = {sum([m > peak_range[0] and m < peak_range[1] for m in mass_arr])}')
    counts, bin_edges = np.histogram(mass_arr, bins=np.linspace(*binning));
    bin_centers = 0.5*bin_edges[:-1] + 0.5*bin_edges[1:]

    try:
        bounds_high[0] = max(counts)*1.5+1
        bounds_high[4] = (max(counts[:5])-min(counts[-5:]))*6*2 + max(counts[-5:])
    except Exception as e:
        print(e)

    try:
        popt, pcov = curve_fit(fit_func, bin_centers, counts, bounds=(bounds_low, bounds_high))
    except RuntimeError as err:
        print(f'ERROR in fitting: \n\t{err}\n\t -> skipping')
        return None
    perr = np.sqrt(np.diag(pcov))
    bounded_params = [abs(pval-bl) < 1e-6 or abs(pval-bh) < 1e-6 for pval, bl, bh in zip(popt, bounds_low, bounds_high)]
    is_bound = any(bounded_params)
    if is_bound:  
        printmd('IS **BOUND**')
    amp, mu, sigma, a, b = popt
    xx = np.linspace(binning[0], binning[1], 10000)
    mean_lo, mean_hi = popt[1]-perr[1]*3, popt[1]+perr[1]*3


    fig,ax = plt.subplots(figsize=(9,6))
    ax.grid(color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.hist(mass_arr, bins=np.linspace(*binning), histtype='step', lw=2, label='data')
    ax.plot(xx, fit_func(xx, *popt), '-', lw=3, c='r', label='full fit');
    ax.plot(xx, gaus(xx, amp, mu, sigma), '--', color='cyan', label='signal fit')
    ax.plot(xx, lin(xx, a, b), 'y--', label='bckg fit')
    xlim = plt.xlim()
    ylim = plt.ylim()
    ax.vlines([mean_lo, mean_hi], 0, 0.05*ylim[1], label='$\mu +/- 3\sigma_{\mu}$', lw=0.5)

    ax.set_xlabel('$\mathrm{K}^{0}_{S}$ mass [MeV]', fontsize=14);
    ax.xaxis.set_tick_params(labelsize=14)
    ax.yaxis.set_tick_params(labelsize=14)


    fit_param_text = f'{period}/{run}/{chunk:03}\n#counts={len(mass_arr)}\n#bins={binning[2]}\n\nFit results:\n'
    for name, val, err, bounded in zip(param_names, popt, perr, bounded_params):
        err_str = f'{err:.2g}'
        precision = len(err_str[err_str.index('.')+1:]) if '.' in err_str else 0
        bounded_suffix = ' <-- boundary' if bounded else '' 
        par_str = f'{name:5s} = {val:6.{precision}f} +/- {err:<6.2g} {bounded_suffix}'
        if verbose: print(par_str)
        fit_param_text += par_str+'\n'
    if is_bound: fit_param_text += '$\\mathbf{FITTED\;VALUE\;HIT\;THE\;PARAM\;BOUNDARY\;!!!}$'
    ax.text(xlim[0]+0.02*(xlim[1]-xlim[0]), 0.98*ylim[1], fit_param_text, 
            fontdict={'family' : 'monospace'},
            horizontalalignment='left', verticalalignment='top',)

    ax.scatter(0, -100, c='r' if is_bad else 'b', s=120, label='globalWarning flag')
    plt.legend()

    ax.set_ylim(*ylim)
    ax.set_xlim(*xlim)
    plt.savefig(f'fits/fit_K0s_{period}_{run}_{chunk:03}_nbins{binning[2]}.png')
    if close_fig: 
        plt.close(fig)
        plt.close('all')

    return {'period':period,
            'run':run,
            'chunk':chunk,
            'nbins':binning[2],
            'counts':len(mass_arr),
            'amp':amp,
            'mu':mu,
            'sigma':sigma,
            'a':a,
            'b':b,
            'amp_err':perr[0],
            'mu_err':perr[1],
            'sigma_err':perr[2],
            'a_err':perr[3],
            'b_err':perr[4],
            'bound':is_bound,
            'bad':is_bad
            }
                                 
                                 

In [None]:
%%time 

matplotlib.interactive(False)


files = [f for f in os.listdir('data_validation_V0s/') if PERIOD in f and f.endswith('.root')]
df_full = pd.read_csv('data/trending_merged_{}_withGraphs.csv'.format(PERIOD))
df_full = df_full[[col for col in df_full.columns if not col.startswith('gr')]]
df_full = df_full[[col for col in df_full.columns if 'alias' not in col or 'global' in col]]

for part, fname in enumerate(files[73:], 73):
    tic = time()
    runx = fname2run(fname) 
#     files = [f for f in os.listdir('data_validation_V0s/') if PERIOD in f and f.endswith('.root')][2:4]
    V0 = root2df('data_validation_V0s/'+fname)

    df = df_full.query('run == @runx')
    print(len(df))
    print(V0.shape)



    result_arr = []
    for i,row in enumerate(df[['chunkStart', 'chunkStop', 'alias_global_Warning', 'period.fString', 'run', 'chunkID']].to_numpy()):
        lo, hi, bad, period, run, chunk = row
        print(i, run, chunk)
        mass_arr = V0.query(f'evtTimeStamp - @TIME_OFFSET > @lo & evtTimeStamp - @TIME_OFFSET < @hi')['v0.fEffMass']*1000
        for nbins in [25,50,100,200,500,1000,2000]:
            res = make_fitting(mass_arr, [period,run,chunk], bad, binning=(440, 540, nbins), close_fig=True, verbose=False)
            if res: result_arr.append(res)


    rdf = pd.DataFrame(result_arr)
    rdf.to_csv(f'fit_results_{period}_part{part}.csv')
    print(f'\t--- exec. time for run={runx}: {time()-tic} sec')

In [None]:
# rdf.to_csv(f'fit_results_{period}.csv')

## Fitting stability tests

In [None]:
%matplotlib inline 

data1 = rdf.query('nbins == 100 & bad > -1')['mu'].to_numpy()
data2 = rdf.query('nbins == 500 & bad > -1')['mu'].to_numpy()
mean_diff = np.mean(abs(data1-data2))
plt.scatter(data1, data2, color='none', edgecolors='k', alpha=0.5)
xlim = plt.xlim()
ylim = plt.ylim()
plt.text(xlim[0]+0.1*(xlim[1] - xlim[0]), ylim[0]+0.9*(ylim[1] - ylim[0]), f'mean abs difference = {mean_diff:.3f}')
plt.plot(xlim, ylim)

In [None]:
nbins=50
plt.scatter(rdf.query('nbins==@nbins')['counts'], rdf.query('nbins==@nbins')['sigma_err'], alpha=0.3, color='none', edgecolors='k', s=15)
plt.ylim(0,0.3)
plt.grid()

## Fitting results analysis

In [None]:
%matplotlib inline
tmp = rdf.query('nbins == 50')
cond_good = 'bad < 1e-6 & counts > 4000 & mu > 495'
cond_bad = 'bad > 1e-6 & counts > 4000 & mu > 495'
plt.grid()
plt.scatter(tmp.query(cond_good)['mu'], tmp.query(cond_good)['sigma'], color='b', alpha=0.3)
plt.scatter(tmp.query(cond_bad)['mu'], tmp.query(cond_bad)['sigma'], color='r', edgecolor='k')


In [None]:
bins = np.linspace(tmp['counts'].min(), tmp['counts'].max(), 40)
plt.hist(tmp.query('bad < 1e-6')['counts'], bins=bins, density=1, histtype='step', color='b');
plt.hist(tmp.query('bad > 1e-6')['counts'], bins=bins, density=1, histtype='step', color='r');