In [19]:
from b2heavy.ThreePointFunctions.corr3pts import BINSIZE

from b2heavy.TwoPointFunctions.utils     import Tmax
from b2heavy.TwoPointFunctions.types2pts import CorrelatorIO, Correlator
from b2heavy.TwoPointFunctions.fitter    import StagFitter, p_value, standard_p

from b2heavy.FnalHISQMetadata import params

import lsqfit
import copy
import h5py

import numpy             as np
import gvar              as gv
import matplotlib.pyplot as plt
import pandas            as pd
import corrfitter        as cf

from tqdm  import tqdm
from scipy import linalg as la

import jax 
import jax.numpy         as jnp
jax.config.update("jax_enable_x64", True)


In [20]:
def fexp(Nt):
    return lambda t,E,Z: Z * ( jnp.exp(-E*t) + jnp.exp(-E*(Nt-t)) ) 

In [21]:
ensemble = 'Fine-Phys'
mom      = '400'
mes      = 'Dst'
data_dir = '/Users/pietro/code/data_analysis/BtoD/Alex/'

mdata = params(ensemble)

cov_specs = dict(scale=True, shrink=True, cutsvd=1e-12)
jkb       = BINSIZE[ensemble]

In [22]:
io   = CorrelatorIO(ensemble,'Dst',mom,PathToDataDir=data_dir)
stag = StagFitter(
    io       = io,
    jkBin    = BINSIZE[ensemble],
    smearing = ['1S-1S','d-d','d-1S']
)

In [23]:
nconf = io.ReadCorrelator(jkBin=1).shape[-2]

In [24]:
teff     = (17,28)
tmin     = int(0.65//mdata['aSpc'].mean)
# tmin     = int(1.//mdata['aSpc'].mean)
tmax_all = 28
Nstates = 3
errmax  = 0.25

In [25]:
effm,effa = stag.meff(teff,**cov_specs)

  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )


In [26]:
def chi2red(fit):
    chi2red = fit.chi2
    for k,pr in fit.prior.items():
        for i,p in enumerate(pr):
            chi2red -= ((fit.pmean[k][i]-p.mean)/p.sdev)**2
    return chi2red

def ndof(fit):
    return len(fit.y) - sum([len(pr) for k,pr in fit.prior.items()]) 

In [27]:
xdata, ydata, yfull = stag.format(alljk=True,**cov_specs)

xfit, yfit, yjk = {},{},{}
for smr,pol in stag.keys:
    tmax = Tmax(ydata[smr,pol],errmax=errmax)
    # tmax = tmax_all
    
    xfit[smr,pol] = xdata[tmin:tmax+1]
    yfit[smr,pol] = ydata[smr,pol][tmin:tmax+1]
    yjk [smr,pol] = yfull[smr,pol][:,tmin:tmax+1]

yflat   = np.concatenate([yfit[k] for k in stag.keys])
yflatjk = np.hstack([yjk[k] for k in stag.keys])


def fitfcn(func):
    def _fitfcn(xd,p):
        tmp = []
        for sm,pl in stag.keys:
            fcn = func(Nstates,stag.Nt,sm,pl)
            tmp.append(fcn(xfit[sm,pl],p))
        return np.concatenate(tmp)
    return _fitfcn 

In [28]:
def custom(Nstates,Nt,sm,pol):
    sm1,sm2 = sm.split('-')
    mix = sm1!=sm2

    def aux(t,p):
        E0, E1 = p['E'][0], p['E'][0] + jnp.exp(p['E'][1])
        Z0 = jnp.exp(p[f'Z_{sm1}_{pol}'][0]) * jnp.exp(p[f'Z_{sm2}_{pol}'][0])
        Z1 = jnp.exp(p[f'Z_{sm1}_{pol}'][1]) * jnp.exp(p[f'Z_{sm2}_{pol}'][1])
        ans = fexp(Nt)(t,E0,Z0) + fexp(Nt)(t,E1,Z1) * (-1)**(t+1)

        Es = [E0,E1]
        for i in range(2,2*Nstates):
            Ei = Es[i-2] + jnp.exp(p['E'][i])
            Z = p[f'Z_{sm if mix else sm1}_{pol}'][i-2 if mix else i]**2

            ans += fexp(Nt)(t,Ei,Z) * (-1)**(i*(t+1))

            Es.append(Ei)
        return ans

    return aux

cfcn = fitfcn(custom)

In [37]:
def corrnew(Nstates,Nt,sm,pol):
    sm1,sm2 = sm.split('-')
    mix = sm1!=sm2

    def aux(t,p):
        erg    = jnp.exp(p[f'dE'  ])
        ergo   = jnp.exp(p[f'dE.o'])
        # erg [0] = p[f'dE'][0]
        # ergo[0] = erg[0] + jnp.exp(p[f'dE.o'][0])

        c2 = 0.
        for n in range(Nstates):
            Z0 = jnp.exp(p[f'Z.{sm1}.{pol}'  ][n]) * jnp.exp(p[f'Z.{sm2}.{pol}'  ][n])
            Z1 = jnp.exp(p[f'Z.{sm1}.{pol}.o'][n]) * jnp.exp(p[f'Z.{sm2}.{pol}.o'][n])

            if n>0:
                Z0 = p[f'Z.{sm if mix else sm1}.{pol}'  ][n-1 if mix else n]**2
                Z1 = p[f'Z.{sm if mix else sm1}.{pol}.o'][n-1 if mix else n]**2

            Ephy = sum(erg[ :n+1])
            Eosc = sum(ergo[:n+1])

            c2 += fexp(Nt)(t,Ephy,Z0) + fexp(Nt)(t,Eosc,Z1) * (-1)**((t+1))

        return c2
    
    return aux

newf = fitfcn(corrnew)

In [38]:
def to_new(p,fcn=lambda x: x):
    pr = {}
    pr['dE']    = fcn(p['E'][::2] ) 
    pr['dE.o']  = fcn(p['E'][1::2])
    pr['dE'][0] = p['E'][0]

    for sm in ['1S-1S','d-1S','d-d']:
        for pol in ['Bot','Par']:
            sm1,sm2 = sm.split('-')
            smr = sm1 if sm1==sm2 else sm

            pr[f'Z.{smr}.{pol}']   = fcn(p[f'Z_{smr}_{pol}'][::2] )
            pr[f'Z.{smr}.{pol}.o'] = fcn(p[f'Z_{smr}_{pol}'][1::2])
    
    return pr

In [39]:
def show_par(fitp,fcn=lambda x:x):
    aux,idx = [],[]
    for k in fitp:
        if k.endswith('o'):
            continue

        tmp = {}
        for n in range(len(fitp[k])):
            ise0 = (n==0 and k.startswith('dE') and not k.endswith('o'))

            mix = '-' in k

            tmp[f'{n+1 if mix else n}']   = fcn(fitp[k][n])
            tmp[f'{n+1 if mix else n}.o'] = fcn(fitp[f'{k}.o'][n])


        aux.append(tmp)
        idx.append(k)

    return pd.DataFrame(aux,index=idx).transpose()

In [40]:
# def chi2exp(fcn):
    

# Custom from `b2heavy`

In [41]:
pr = stag.priors(Nstates,Meff=effm,Aeff=effa)
cfit = stag.fit(Nstates,(tmin,tmax_all),priors=pr,**cov_specs)

In [42]:
cpar = cfit.pmean
fc = custom(3,stag.Nt,'1S-1S','Bot')

npar = to_new(cpar)
fn = corrnew(3,stag.Nt,'1S-1S','Bot')

In [43]:
aux,idx = [],[]
for k in cfit.p:
    if k.endswith('o'):
        continue

    tmp = {}
    for n in range(len(cfit.p[k])):
        ise0 = n==0 and k.startswith('E')
        mix = '-' in k

        # fcn = lambda x: (x if ise0 else np.exp(x)**2)
        fcn = lambda x: x

        tmp[f'{n+2 if mix else n}']   = fcn(cfit.p[k][n])


    aux.append(tmp)
    idx.append(k)

cpar = pd.DataFrame(aux,index=idx).transpose()

cpar

Unnamed: 0,E,Z_1S_Bot,Z_1S_Par,Z_d-1S_Bot,Z_d-1S_Par,Z_d_Bot,Z_d_Par
0,0.9024(67),-0.148(65),-0.002(68),,,-2.551(71),-2.442(71)
1,-3.05(29),-0.25(14),-0.47(11),,,-3.06(15),-6.9(1.7)
2,-1.76(17),0.982(99),1.139(88),-0.269(35),0.318(35),0.090(14),0.101(15)
3,-2.28(17),0.60(21),0.04(1.44),0.202(37),0.249(28),0.065(10),0.0763(93)
4,-0.82(24),1.16(69),0.2(2.8),0.50(12),0.48(11),0.304(85),0.318(87)
5,-1.00(22),1.87(52),2.14(53),0.52(11),0.410(91),0.208(46),0.175(39)


# same results ? (yes!)

In [44]:
fit = lsqfit.nonlinear_fit(
    data  = (xfit,yflat),
    fcn   = newf,
    prior = to_new(pr)
)

show_par(fit.p)

TypeError: Dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

# costrained

In [None]:
def corr2(Nstates,Nt,sm,pol):
    sm1,sm2 = sm.split('-')
    mix = sm1!=sm2

    def aux(t,p):
        erg    = jnp.exp(p[f'dE'  ])
        ergo   = jnp.exp(p[f'dE.o'])
        erg [0] = p[f'dE'][0]
        ergo[0] = erg[0] + jnp.exp(p[f'dE.o'][0])

        c2 = 0.
        for n in range(Nstates):
            Z0 = jnp.exp(p[f'Z.{sm1}.{pol}'  ][n]) * jnp.exp(p[f'Z.{sm2}.{pol}'  ][n])
            Z1 = jnp.exp(p[f'Z.{sm1}.{pol}.o'][n]) * jnp.exp(p[f'Z.{sm2}.{pol}.o'][n])

            # if n>0:
            #     Z0 = p[f'Z.{sm if mix else sm1}.{pol}'  ][n-1 if mix else n]**2
            #     Z1 = p[f'Z.{sm if mix else sm1}.{pol}.o'][n-1 if mix else n]**2

            Ephy = sum(erg[ :n+1])
            Eosc = sum(ergo[:n+1])

            c2 += fexp(Nt)(t,Ephy,Z0) + fexp(Nt)(t,Eosc,Z1) * (-1)**((t+1))

        return c2
    
    return aux

In [None]:
nfit = lsqfit.nonlinear_fit(
    data  = (xfit,yflat),
    fcn   = fitfcn(corr2),
    prior = to_new(pr)
)

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

# free $E_0^{(osc)}$

In [None]:
def corr3(Nstates,Nt,sm,pol):
    sm1,sm2 = sm.split('-')
    mix = sm1!=sm2

    def aux(t,p):
        erg    = jnp.exp(p[f'dE'  ])
        ergo   = jnp.exp(p[f'dE.o'])
        erg [0] = p[f'dE'][0]
        ergo[0] = p[f'dE.o'][0]

        c2 = 0.
        for n in range(Nstates):
            Z0 = jnp.exp(p[f'Z.{sm1}.{pol}'  ][n]) * jnp.exp(p[f'Z.{sm2}.{pol}'  ][n])
            Z1 = jnp.exp(p[f'Z.{sm1}.{pol}.o'][n]) * jnp.exp(p[f'Z.{sm2}.{pol}.o'][n])

            if n>0:
                Z0 = p[f'Z.{sm if mix else sm1}.{pol}'  ][n-1 if mix else n]**2
                Z1 = p[f'Z.{sm if mix else sm1}.{pol}.o'][n-1 if mix else n]**2

            Ephy = sum(erg[ :n+1])
            Eosc = sum(ergo[:n+1])

            c2 += fexp(Nt)(t,Ephy,Z0) + fexp(Nt)(t,Eosc,Z1) * (-1)**((t+1))

        return c2
    
    return aux

In [None]:
_prior = to_new(pr)
_prior['dE.o'][0] = _prior['dE'][0] + np.exp(_prior['dE.o'][0])
_prior

{'dE': [0.907(36), -1.36(76), -2.6(2.5)],
 'dE.o': [1.031(81), -2.6(2.5), -2.6(2.5)],
 'Z.1S.Bot': [0.09(1.00), 0.5(1.5), 0.5(3.0)],
 'Z.1S.Bot.o': [-1.2(1.2), 0.5(1.5), 0.5(3.0)],
 'Z.1S.Par': [0.2(1.0), 0.5(1.5), 0.5(3.0)],
 'Z.1S.Par.o': [-1.2(1.2), 0.5(1.5), 0.5(3.0)],
 'Z.d-1S.Bot': [0.5(1.7), 0.5(1.7)],
 'Z.d-1S.Bot.o': [0.5(1.7), 0.5(1.7)],
 'Z.d-1S.Par': [0.5(1.7), 0.5(1.7)],
 'Z.d-1S.Par.o': [0.5(1.7), 0.5(1.7)],
 'Z.d.Bot': [-2.3(1.0), 0.5(1.5), 0.5(3.0)],
 'Z.d.Bot.o': [-3.0(1.5), 0.5(1.5), 0.5(3.0)],
 'Z.d.Par': [-2.2(1.0), 0.5(1.5), 0.5(3.0)],
 'Z.d.Par.o': [-5.5(2.0), 0.5(1.5), 0.5(3.0)]}

In [None]:
nfit3 = lsqfit.nonlinear_fit(
    data  = (xfit,yflat),
    fcn   = fitfcn(corr3),
    prior = _prior
)

print(nfit3)

Least Square Fit:
  chi2/dof [dof] = 0.55 [112]    Q = 1    logGBF = 2063.2

Parameters:
           dE 0   0.9040 (64)     [  0.907 (36) ]  
              1    -1.72 (16)     [  -1.36 (76) ]  
              2    -0.76 (25)     [  -2.6 (2.5) ]  
         dE.o 0    0.940 (23)     [  1.031 (81) ]  *
              1    -2.22 (29)     [  -2.6 (2.5) ]  
              2    -1.09 (25)     [  -2.6 (2.5) ]  
     Z.1S.Bot 0   -0.132 (61)     [ 0.09 (1.00) ]  
              1    0.986 (96)     [   0.5 (1.5) ]  
              2     1.22 (78)     [   0.5 (3.0) ]  
   Z.1S.Bot.o 0    -0.35 (26)     [  -1.2 (1.2) ]  
              1     0.69 (33)     [   0.5 (1.5) ]  
              2     1.69 (46)     [   0.5 (3.0) ]  
     Z.1S.Par 0    0.015 (63)     [   0.2 (1.0) ]  
              1    1.139 (80)     [   0.5 (1.5) ]  
              2     0.1 (2.9)     [   0.5 (3.0) ]  
   Z.1S.Par.o 0    -0.54 (25)     [  -1.2 (1.2) ]  
              1     0.2 (1.1)     [   0.5 (1.5) ]  
              2     1.99 (

# free $E_0^{(osc)}$ +  constrained coeff

In [None]:
def corr4(Nstates,Nt,sm,pol):
    sm1,sm2 = sm.split('-')
    mix = sm1!=sm2

    def aux(t,p):
        erg    = jnp.exp(p[f'dE'  ])
        ergo   = jnp.exp(p[f'dE.o'])
        erg [0] = p[f'dE'][0]
        ergo[0] = p[f'dE.o'][0]

        c2 = 0.
        for n in range(Nstates):
            Z0 = jnp.exp(p[f'Z.{sm1}.{pol}'  ][n]) * jnp.exp(p[f'Z.{sm2}.{pol}'  ][n])
            Z1 = jnp.exp(p[f'Z.{sm1}.{pol}.o'][n]) * jnp.exp(p[f'Z.{sm2}.{pol}.o'][n])

            Ephy = sum(erg[ :n+1])
            Eosc = sum(ergo[:n+1])

            c2 += fexp(Nt)(t,Ephy,Z0) + fexp(Nt)(t,Eosc,Z1) * (-1)**((t+1))

        return c2
    
    return aux

In [None]:
_prior = to_new(pr)
_prior['dE.o'][0] = _prior['dE'][0] + np.exp(_prior['dE.o'][0])

In [None]:
nfit4 = lsqfit.nonlinear_fit(
    data  = (xfit,yflat),
    fcn   = fitfcn(corr4),
    prior = _prior
)

print(nfit4)

Least Square Fit:
  chi2/dof [dof] = 1.5 [112]    Q = 0.00097    logGBF = 2066.8

Parameters:
           dE 0   0.9058 (24)     [  0.907 (36) ]  
              1   -1.200 (40)     [  -1.36 (76) ]  
              2    -2.78 (43)     [  -2.6 (2.5) ]  
         dE.o 0   0.9177 (87)     [  1.031 (81) ]  *
              1   -1.929 (65)     [  -2.6 (2.5) ]  
              2    -0.12 (14)     [  -2.6 (2.5) ]  *
     Z.1S.Bot 0   -0.074 (17)     [ 0.09 (1.00) ]  
              1    -2.0 (1.3)     [   0.5 (1.5) ]  *
              2    0.538 (68)     [   0.5 (3.0) ]  
   Z.1S.Bot.o 0   -0.588 (79)     [  -1.2 (1.2) ]  
              1    0.047 (26)     [   0.5 (1.5) ]  
              2    -1.7 (2.4)     [   0.5 (3.0) ]  
     Z.1S.Par 0    0.073 (17)     [   0.2 (1.0) ]  
              1     0.10 (18)     [   0.5 (1.5) ]  
              2     0.28 (21)     [   0.5 (3.0) ]  
   Z.1S.Par.o 0   -0.979 (79)     [  -1.2 (1.2) ]  
              1   -0.119 (28)     [   0.5 (1.5) ]  
              2    

# Results

In [None]:
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)


print(f'============ {ensemble} ========== {mom} ============= {mes} ============')

print('-------------------------------------- Old function ---------------------------------------')
print(show_par(fit.p))
print(f'{p_value(chi2red(fit  ),nconf,ndof(fit  )) = }')
print(f'{fit  .time = }')
print(f'{chi2red(fit  )/ndof(fit  ) = }')


print('------------------------------------ free dE.o[0] --------------------------------------')
print(show_par(nfit3.p))
print(f'{p_value(chi2red(nfit3),nconf,ndof(nfit3)) = }')
print(f'{nfit3.time = }')
print(f'{chi2red(nfit3)/ndof(nfit3) = }')

print('------------------------------------ constrained coeffs --------------------------------------')
print(show_par(nfit.p))
print(f'{p_value(chi2red(nfit ),nconf,ndof(nfit )) = }')
print(f'{nfit .time = }')
print(f'{chi2red(nfit )/ndof(nfit ) = }')

print('------------------------------------ free dE.o[0] + constr. coeff --------------------------------------')
print(show_par(nfit4.p))
print(f'{p_value(chi2red(nfit4),nconf,ndof(nfit4)) = }')
print(f'{nfit4.time = }')
print(f'{chi2red(nfit4)/ndof(nfit4) = }')

-------------------------------------- Old function ---------------------------------------
             dE    Z.1S.Bot      Z.1S.Par  Z.d-1S.Bot Z.d-1S.Par     Z.d.Bot     Z.d.Par
0    0.9024(66)  -0.147(65)  -0.0010(670)         NaN        NaN  -2.552(70)  -2.443(71)
0.o   -3.06(30)   -0.25(15)     -0.47(11)         NaN        NaN   -3.06(15)   -6.8(1.7)
1     -1.74(16)   0.995(96)     1.150(84)  -0.274(34)  0.322(34)   0.093(14)   0.103(14)
1.o   -2.26(17)    0.59(22)    0.04(1.44)   0.202(39)  0.252(29)   0.065(11)  0.0773(98)
2     -0.78(25)    1.18(77)      0.2(2.9)    0.52(14)   0.50(13)    0.32(10)    0.33(10)
2.o   -1.03(23)    1.84(53)      2.09(54)    0.51(11)  0.389(95)   0.203(47)   0.169(40)
p_value(chi2red(fit  ),nconf,ndof(fit  )) = 0.9605203983351072
fit  .time = 0.5349334159400314
chi2red(fit  )/ndof(fit  ) = 0.7631913457128525
------------------------------------ free dE.o[0] --------------------------------------
             dE    Z.1S.Bot   Z.1S.Par  Z.d-1S.Bot Z.