In [1]:
import copy
import lsqfit
import h5py
import numpy             as np
import gvar              as gv
import matplotlib.pyplot as plt
import corrfitter        as cf
from scipy import linalg as la
from tqdm  import tqdm

import jax 
jax.config.update("jax_enable_x64", True)

In [42]:
from b2heavy.TwoPointFunctions.utils     import compute_covariance, Tmax, covariance_shrinking
from b2heavy.TwoPointFunctions.types2pts import CorrelatorIO
from b2heavy.TwoPointFunctions.fitter    import StagFitter

from b2heavy.ThreePointFunctions.corr3pts import Correlator3, BINSIZE

from b2heavy.FnalHISQMetadata import params

In [56]:
ENSEMBLE = 'Coarse-1'
MOMENTUM = '300'
DATA_DIR = '/Users/pietro/code/data_analysis/BtoD/Alex/'

cov_specs = dict(scale=True, shrink=True, cutsvd=1e-12)
jkb       = BINSIZE[ENSEMBLE]

tmin2 = {
    'Dst': 0.631, 
    'B'  : 0.9
} 
tmin3 = 0.3

In [57]:
mdata = params(ENSEMBLE)

Tmin2 = {k: int(tmin2[k]//mdata['aSpc'].mean) for k in tmin2}
Tmin3 = int(tmin3//mdata['aSpc'].mean)

Tmin2,Tmin3

({'Dst': 5, 'B': 7}, 2)

# Data gathering

In [58]:
corr2 = {}
corr2['Dst'] = CorrelatorIO(ENSEMBLE,'Dst',MOMENTUM,PathToDataDir=DATA_DIR)
corr2['B']   = CorrelatorIO(ENSEMBLE,'B','000',PathToDataDir=DATA_DIR)

corr2['Dst'].collect(jkBin=jkb)
corr2['B'].collect(jkBin=jkb)

In [105]:
corr3 = Correlator3( ENSEMBLE,PathToDataDir=DATA_DIR)

not_founds = []
for cur in ['A1','A2','A3','V1','V2','V3']:
    for snk in ['A1','A2','A3','V1','V2','V3']:
        for smr in ['1S','RW']:
            if cur[-1]!=snk[-1]:
                continue
            try:
                corr3.read('p5',cur,snk,smr,MOMENTUM)
            except KeyError:
                not_founds.append(f'P5_{cur}_{snk}')

In [106]:
# merge
data = {'binned':{}, 'full':{}}

data['binned'].update(corr2['Dst'].data)
data['binned'].update(corr2['B'  ].data)
data['binned'].update(corr3.data)

data['full'].update(corr2['Dst'].data_full)
data['full'].update(corr2['B'  ].data_full)
data['full'].update(corr3.data_full)

# Compute covariance + data formatting

In [107]:
ylist, treestruct = jax.tree_util.tree_flatten(data['binned'])
ylist_full, _     = jax.tree_util.tree_flatten(data['full'])

# rescaling
ysample      = np.hstack(ylist)
ysample_full = np.hstack(ylist_full)

cov_jk   = np.cov(ysample     ,rowvar=False) * (ysample.shape[0]-1)
cov_full = np.cov(ysample_full,rowvar=False) * (ysample.shape[0]-1)
scale    = np.sqrt(np.diag(cov_jk)/np.diag(cov_full))
cov      = cov_full * np.outer(scale,scale)

avg = ysample_full.mean(axis=0)

yout = gv.gvar(avg,cov)

In [108]:
yout_list = []
I = 0
for y in ylist:
    nt = y.shape[-1]
    yout_list.append(yout[I:I+nt])
    I=I+nt

ydata = jax.tree_util.build_tree(treestruct,yout_list)

In [109]:
[k for k in ydata]

['[B]->[B].000.1S.1S.Unpol',
 '[B]->[B].000.d.1S.Unpol',
 '[B]->[B].000.d.d.Unpol',
 '[B]P5->A1->[Dst]V1.12.300.1S.Par',
 '[B]P5->A1->[Dst]V1.12.300.d.Par',
 '[B]P5->A1->[Dst]V1.13.300.1S.Par',
 '[B]P5->A1->[Dst]V1.13.300.d.Par',
 '[B]P5->A2->[Dst]V2.12.300.1S.Bot',
 '[B]P5->A2->[Dst]V2.12.300.d.Bot',
 '[B]P5->A2->[Dst]V2.13.300.1S.Bot',
 '[B]P5->A2->[Dst]V2.13.300.d.Bot',
 '[B]P5->A3->[Dst]V3.12.300.1S.Bot',
 '[B]P5->A3->[Dst]V3.12.300.d.Bot',
 '[B]P5->A3->[Dst]V3.13.300.1S.Bot',
 '[B]P5->A3->[Dst]V3.13.300.d.Bot',
 '[Dst]->[Dst].300.1S.1S.Bot',
 '[Dst]->[Dst].300.1S.1S.Par',
 '[Dst]->[Dst].300.d.1S.Bot',
 '[Dst]->[Dst].300.d.1S.Par',
 '[Dst]->[Dst].300.d.d.Bot',
 '[Dst]->[Dst].300.d.d.Par']

# Model building

In [110]:
def unpack(tag):
    if tag.count('->')==1:
        corr,mom,sm1,sm2,pol = tag.split('.')
        mes = corr.split('[')[1].split(']')[0]
        return mes,mom,sm1,sm2,pol

    elif tag.count('->')==2:
        corr,tt,mom,smr,pol = tag.split('.')

        src,cur,snk = corr.split('->')

        mes1,g1 = src.split('[')[1].split(']')      
        mes2,g2 = snk.split('[')[1].split(']')   

        return mes1,g1,cur,mes2,g2,tt,mom,smr,pol

In [111]:
models = []
for tag in ydata:
    is3 = tag.count('->')==2

    if is3:
        mes1,g1,cur,mes2,g2,tt,mom,smr,pol = unpack(tag)

        if smr=='d':
            continue
        if pol=='Bot' and cur[-1]=='3':
            continue

        print(tag)
        fcn = cf.Corr3(
            datatag = tag,
            T       = int(tt),
            tmin    = Tmin3,
            dEa     = (f'dE.{mes1}',f'dE.{mes1}.o'),
            dEb     = (f'dE.{mes2}',f'dE.{mes2}.o'),
            a       = (f'Z.{mes1}.{smr}.Unpol',f'Z.{mes1}.{smr}.Unpol.o'),
            b       = (f'Z.{mes2}.{smr}.{pol}',f'Z.{mes2}.{smr}.{pol}.o'),
            sa      = (1.,-1.),
            sb      = (1.,-1.),
            Vnn     = f'Vnn.{cur}.{smr}.{pol}',
            Von     = f'Von.{cur}.{smr}.{pol}',
            Vno     = f'Vno.{cur}.{smr}.{pol}',
            Voo     = f'Voo.{cur}.{smr}.{pol}',
        )

    else:
        mes,mom,sm1,sm2,pol = unpack(tag)

        fcn = cf.Corr2(
            datatag = tag,
            tp      = corr2[mes].Nt,
            s       = (1.,-1.),
            tdata   = np.arange(corr2[mes].Nt//2),
            tmin    = Tmin2[mes1],
            tmax    = Tmax(ydata[tag]),
            dE      = (f'dE.{mes}',f'dE.{mes}.o'),
            a       = (f'Z.{mes}.{sm1}.{pol}',f'Z.{mes}.{sm1}.{pol}.o'),
            b       = (f'Z.{mes}.{sm2}.{pol}',f'Z.{mes}.{sm2}.{pol}.o')
        )

    models.append(fcn)

[B]P5->A1->[Dst]V1.12.300.1S.Par
[B]P5->A1->[Dst]V1.13.300.1S.Par
[B]P5->A2->[Dst]V2.12.300.1S.Bot
[B]P5->A2->[Dst]V2.13.300.1S.Bot


# Priors

In [112]:
Nstates = 3
meff = {}
aeff = {}

In [113]:
dst = StagFitter(corr2['Dst'],jkBin=jkb,smearing=['d-d','1S-1S','d-1S'])
m,a = dst.meff((10,18),**cov_specs)
meff['Dst']=m
aeff['Dst']=a

  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )


In [114]:
b = StagFitter(corr2['B'],jkBin=jkb,smearing=['d-d','1S-1S','d-1S'])
m,a = b.meff((10,18),**cov_specs)
meff['B']=m
aeff['B']=a

  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )
  m = np.arccosh( (y[(it+1)%len(y)]+y[(it-1)%len(y)])/y[it]/2 )


In [115]:
pr = {}
ampl = {
    '1S': '1(1)',
    'd' : '0.01(50)'
}
for mes in ['Dst','B']:
    for smr,pol in aeff[mes]:
        s1,s2 = smr.split('-')
        if s1==s2:
            aux = aeff[mes][smr,pol] # * 2 * effm

            z0 = f'{aux.mean:.6f}({"1.0" if s1=="1S" else "0.5"})'

            pr[f'Z.{mes}.{s1}.{pol}']    = ['1(1)' if s1=='1S' else '0.01(50)' for n in range(Nstates)]
            pr[f'Z.{mes}.{s1}.{pol}'][0] = z0 
            pr[f'Z.{mes}.{s1}.{pol}.o']  = ['1(1)' if s1=='1S' else '0.01(50)' for n in range(Nstates)]

    pr[f'dE.{mes}']   = ['0.5(5)' for n in range(Nstates)]; pr[f'dE.{mes}'][0] = f'{meff[mes].mean:.2f}(0.5)'
    pr[f'dE.{mes}.o'] = ['0.5(5)' for n in range(Nstates)]

In [137]:
for k in ydata:
    if k.count('->')==2:
        mes1,g1,cur,mes2,g2,tt,mom,smr,pol = unpack(k)

        tag = f'{cur}.{smr}.{pol}'
        pr[f'Vnn.{tag}'] = np.reshape(['0.5(5)' for _ in range(Nstates**2)],(Nstates,Nstates))
        pr[f'Von.{tag}'] = np.reshape(['0.5(5)' for _ in range(Nstates**2)],(Nstates,Nstates))
        pr[f'Vno.{tag}'] = np.reshape(['0.5(5)' for _ in range(Nstates**2)],(Nstates,Nstates))
        pr[f'Voo.{tag}'] = np.reshape(['0.5(5)' for _ in range(Nstates**2)],(Nstates,Nstates))


In [138]:
prior = gv.BufferDict()
for k,p in pr.items():
    prior[k] = gv.gvar(p)

# אֵלִ֣י אֵ֖לִי לָמָ֣ה עֲזַבְתָּ֑נִי

In [139]:
fitter = cf.CorrFitter(models=models)

In [140]:
fit = fitter.lsqfit(data=ydata,prior=prior)

In [141]:
print(fit)

Least Square Fit:
  chi2/dof [dof] = 7.7 [159]    Q = 1.6e-163    logGBF = 2885.2

Parameters:
   Z.B.1S.Unpol 0     1.3907 (65)      [  1.7 (1.0) ]  
                1      -0.52 (15)      [  1.0 (1.0) ]  *
                2   -0.00013 (44)      [  1.0 (1.0) ]  *
           dE.B 0     1.9200 (12)      [  1.91 (50) ]  
                1      0.748 (27)      [  0.50 (50) ]  
                2      -1.52 (28)      [  0.50 (50) ]  ****
 Z.B.1S.Unpol.o 0   -0.5(3.1)e-07      [  1.0 (1.0) ]  *
                1     0.0028 (59)      [  1.0 (1.0) ]  
                2      0.928 (44)      [  1.0 (1.0) ]  
         dE.B.o 0       0.56 (38)      [  0.50 (50) ]  
                1       0.90 (37)      [  0.50 (50) ]  
                2       0.67 (24)      [  0.50 (50) ]  
    Z.B.d.Unpol 0     0.1892 (11)      [  0.03 (50) ]  
                1       1.17 (11)      [  0.01 (50) ]  **
                2    0.9(2.6)e-05      [  0.01 (50) ]  
  Z.B.d.Unpol.o 0   -1.8(8.6)e-08      [  0.01 (50) ]  
