In [1]:
from b2heavy.ThreePointFunctions.corr3pts import BINSIZE

from b2heavy.TwoPointFunctions.utils     import Tmax
from b2heavy.TwoPointFunctions.types2pts import CorrelatorIO, Correlator
from b2heavy.TwoPointFunctions.fitter    import StagFitter, p_value, standard_p

from b2heavy.FnalHISQMetadata import params

import lsqfit
import copy
import h5py

import numpy             as np
import gvar              as gv
import matplotlib.pyplot as plt
import pandas            as pd
import corrfitter        as cf

from tqdm  import tqdm
from scipy import linalg as la

import jax 
import jax.numpy         as jnp
jax.config.update("jax_enable_x64", True)


In [9]:
ens      = 'Coarse-Phys'
mes      = 'Dst'
mom      = '100'
binsize  = 19
data_dir = '/Users/pietro/code/data_analysis/BtoD/Alex/'
smlist   = ['1S-1S','d-d','d-1S'] 

mdata = params(ens)

cov_specs = dict(scale=True, shrink=True, cutsvd=1e-12)
jkb       = BINSIZE[ens]

In [25]:
io   = CorrelatorIO(ens,mes,mom,PathToDataDir=data_dir)
self = Correlator(
    io       = io,
    jkBin    = binsize,
    smearing = smlist
)

In [36]:
teff     = (12,25)
tmin     = int(0.65//mdata['aSpc'].mean)
# tmin     = int(1.//mdata['aSpc'].mean)
tmax_all = 28
Nstates = 3
errmax  = 0.25

effm,effa = stag.meff(teff,**cov_specs)



  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )


In [61]:
def priors(self, Nstates, meff=None, aeff=None):
    tmp = {}

    tmp['dE']    = ['-0.7(1.)'] * Nstates
    tmp['dE.o']  = ['-0.7(1.)'] * Nstates
    tmp['dE'][0] = '1(1)' if meff is None else gv.gvar(meff.mean,0.5) 

    for sm,pl in self.keys:
        s1,s2 = sm.split('-')
        mix = s1!=s2

        if s1==s2:
            a = aeff[sm,pl].mean
            tmp[f'Z.{s1}.{pl}'  ]    = ['0.2(1.5)'] * Nstates
            tmp[f'Z.{s1}.{pl}.o']    = ['0.2(1.5)'] * Nstates
            tmp[f'Z.{s1}.{pl}'][0] = gv.gvar(np.log(a)/2,1.)
        else:
            tmp[f'Z.{sm}.{pl}'  ] = ['0.2(1.5)'] * (Nstates-1)
            tmp[f'Z.{sm}.{pl}.o'] = ['0.2(1.5)'] * (Nstates-1)

    return gv.gvar(tmp)

In [62]:
def fexp(Nt):
    return lambda t,E,Z: Z * ( np.exp(-E*t) + np.exp(-E*(Nt-t)) )

In [74]:
def corr2(Nstates,Nt,smr,pol):
    s1,s2 = smr.split('-')
    mix = s1!=s2

    def _fcn(t,p):
        erg    = np.exp(p[f'dE'  ])
        ergo   = np.exp(p[f'dE.o'])
        erg [0] = p[f'dE'][0]
        ergo[0] = erg[0] + ergo[0]

        c2 = 0.
        for n in range(Nstates):
            Z0 = np.exp(p[f'Z.{s1}.{pol}'  ][n]) * np.exp(p[f'Z.{s2}.{pol}'  ][n])
            Z1 = np.exp(p[f'Z.{s1}.{pol}.o'][n]) * np.exp(p[f'Z.{s2}.{pol}.o'][n])

            if n>0:
                Z0 = p[f'Z.{smr if mix else s1}.{pol}'  ][n-1 if mix else n]**2
                Z1 = p[f'Z.{smr if mix else s1}.{pol}.o'][n-1 if mix else n]**2

            Ephy = sum(erg[ :n+1])
            Eosc = sum(ergo[:n+1])

            c2 += fexp(Nt)(t,Ephy,Z0) + fexp(Nt)(t,Eosc,Z1) * (-1)**((t+1))

        return c2

    return _fcn


In [75]:
trange = (tmin,errmax)
Nstates = 3

In [76]:
tmin,tmax = trange
xdata, ydata, yfull = self.format(alljk=True,**cov_specs)

xfit, yfit, yjk = {},{},{}
for smr,pol in self.keys:
    maxt = tmax if isinstance(tmax,int) else \
        Tmax(ydata[smr,pol],errmax=tmax)

    xfit[smr,pol] = xdata[tmin:(maxt+1)]
    yfit[smr,pol] = ydata[smr,pol][tmin:(maxt+1)]
    yjk [smr,pol] = yfull[smr,pol][:,tmin:(maxt+1)]

yflat   = np.concatenate([yfit[k] for k in self.keys])
yflatjk = np.hstack(      [yjk[k] for k in self.keys])

In [77]:
def fitfcn(xd,p):
    tmp = []
    for sm,pl in self.keys:
        fcn = corr2(Nstates,self.Nt,sm,pl)
        tmp.append(fcn(xd[sm,pl], p))
    return np.concatenate(tmp)

In [78]:
pr = priors(self,Nstates,meff=effm,aeff=effa)

In [82]:
fit = lsqfit.nonlinear_fit(
    data = (xfit, yflat),
    fcn  = fitfcn,
    prior=pr
)

In [83]:
print(fit)

Least Square Fit:
  chi2/dof [dof] = 0.89 [113]    Q = 0.8    logGBF = 2098.7

Parameters:
           dE 0   1.0868 (26)     [  1.09 (50) ]  
              1    -1.05 (18)     [ -0.7 (1.0) ]  
              2    -0.44 (19)     [ -0.7 (1.0) ]  
         dE.o 0    -1.79 (13)     [ -0.7 (1.0) ]  *
              1    -1.91 (70)     [ -0.7 (1.0) ]  *
              2    -0.61 (38)     [ -0.7 (1.0) ]  
     Z.1S.Bot 0    0.384 (16)     [  0.4 (1.0) ]  
              1    0.883 (94)     [  0.2 (1.5) ]  
              2     0.7 (1.4)     [  0.2 (1.5) ]  
   Z.1S.Bot.o 0     0.02 (14)     [  0.2 (1.5) ]  
              1     0.2 (1.4)     [  0.2 (1.5) ]  
              2     1.97 (83)     [  0.2 (1.5) ]  *
     Z.1S.Par 0    0.396 (16)     [  0.4 (1.0) ]  
              1     0.83 (11)     [  0.2 (1.5) ]  
              2     1.9 (1.1)     [  0.2 (1.5) ]  *
   Z.1S.Par.o 0    -0.16 (20)     [  0.2 (1.5) ]  
              1     0.85 (18)     [  0.2 (1.5) ]  
              2     0.2 (1.5)     [  0