In [1]:
#
# Data I/O routines
# 
import numpy as np
import pandas as pds
import sys
import re
import os
from itertools import product, compress
import datetime

# Database & I/O routines for ML fits 
#   dpar: dictionary of parameters
#   key:  Output directory & prefix
class Data_IO:
    def __init__(self, dpar, key):
        if key is not None:
            self.ID = str(key)+'.'
        else:
            self.ID = ''
        self.ANY = 'ANY'
        # data file format: 'pdf' for Chroma ascii files; 'raw' for 1-d list of X & Y
        self.flist = [ 'pdf', 'raw', 'NA']
        try:
            self.form = dpar['format']
        except:
            self.form = 'NA'
        assert(self.form in self.flist)
        try:
            self.dbin = dpar['binsize']
        except:
            self.dbin = 0
        try:
            self.emax = dpar['emax']
        except:
            self.emax = None
        try:
            self.nsrc = dpar['nsrc']
        except:
            self.nsrc = 1
        try:
            self.reverse = dpar['reverse']
        except:
            self.reverse = False
        try:
            self.NA = str(dpar['spectrum'])
            if len(list(self.NA)) == 0:
                self.NA = str(None)
        except:
            self.NA = str(None)
        try:
            self.crrt = dpar['is3pt']
        except:
            self.crrt = '3pt'
        try:
            self.mp2pt = dpar['neg_p_2pt']
        except:
            self.mp2pt = False
        #print("calling make_data\n")
        self.make_data(dpar)
        try:
            self.odir = dpar['odir']
            self.osave = True
            self.dsave = True
        except:
            self.osave = False
            self.dsave = False
            self.odir = None
        self.setup_output()
        return
    
    # clear database, free memory
    def finalize(self):
        self.clear_data()
        self.__init__(None, None)
        return
    
    # Read in data files and generate database
    def make_data(self, dpar):
        self.ndata = {}
        try:
            self.ddir = dpar['ddir']
            try:
                self.xkey = dpar['dfile.x']
                if len(list(self.xkey)) == 0:
                    self.xkey = None
                elif isinstance(self.xkey, str):
                    self.xkey = [self.xkey]
            except:
                self.xkey = None
            try:
                self.ykey = dpar['dfile.y']
                if len(list(self.ykey)) == 0:
                    self.ykey = None
                elif isinstance(self.ykey, str):
                    self.ykey = [self.ykey]
            except:
                self.ykey = None
            if isinstance(self.ddir, str):
                self.ddir = [self.ddir]
            self.xfile = []
            for ddir in self.ddir:
                for d1, d2, f in os.walk(ddir): 
                    if (self.xkey is None) or (self.ykey is None):
                        self.xfile.extend([ddir+'/'+ff for ff in f])
                    else:
                        for ff in f:
                            #if np.array(self.xkey).any(self.xkey in ff) or np.array(self.ykey).any(self.ykey in ff): #if (self.xkey in ff) or (self.ykey in ff):
                            flag = False
                            for k in self.xkey:
                                kl = k
                                if isinstance(k, str):
                                    kl = [k]
                                if all([kk in ddir+'/'+ff for kk in kl]): #kl[*] in ff:
                                    self.xfile.append(ddir+'/'+ff)
                                    flag = True
                                    break
                            if flag is False:
                                for k in self.ykey:
                                    kl = k
                                    if isinstance(k, str):
                                        kl = [k]
                                    if all([kk in ddir+'/'+ff for kk in kl]):
                                        self.xfile.append(ddir+'/'+ff)
                                        #print("adding file {:}".format(ddir+'/'+ff))
                                        break
        except:
            self.ddir = None
            try:
                self.xfile = dpar['dfile.x']
                self.yfile = dpar['dfile.y']
            except:
                return
            if self.form == 'pdf':
                self.xfile.extend(self.yfile)                   
        if self.form == 'pdf':
            if self.xkey is not None:
                if self.ykey is not None:
                    self.xkey.extend(self.ykey)
            else:
                if self.ykey is not None:
                    self.xkey = self.ykey
            #print("keys are {:}".format(self.xkey))
            self.NT = {}
            self.tsp = {}
            self.p = {}
            self.z = {}
            self.nx = 0
            self.ny = 0
            self.pztag = {}
            self.pdict = {}
            self.zdict = {}
            self.tdict = {}
            self.odict = {}
            self.tspdict = {}
            self.tdict['X'] = None
            self.tdict['Y'] = None
            self.pdict['X'] = None
            self.pdict['Y'] = None
            self.zdict['X'] = None
            self.zdict['Y'] = None
            self.odict['X'] = None
            self.odict['Y'] = None
            dataframe = {}
            #print("xfile list {:}".format(self.xfile))
            for dfile in self.xfile:
                key = self.ANY
                if self.xkey is not None:
                    for k in self.xkey:
                        kl = k
                        if isinstance(k, str):
                            kl = [k]
                        if all([kk in dfile for kk in kl]): #k in dfile:
                            key = '.'.join(kl)
                            break
                self.p[key] = []
                self.z[key] = []
                self.pztag[key] = []
                # read in momentum
                mode = re.search('((?<=pz)|(?<=p))((\d+)|(-\d+))', dfile)
                try:
                    ps = str(mode.group(0))
                    #print(ps)
                    if ps not in self.p[key]:
                        self.p[key].append(ps)
                except:
                    print("Unknown PDF data filename convention for momentum p/pz in {:}\n".format(dfile))
                    sys.exit(1)
                # read in Wilson link length (z)
                mode = re.search('(?<=z)(?<![a-zA-Z]z)((\d+)|(-\d+))', dfile)
                if mode is None:
                    #print("searching using second convention")
                    mode = re.search('(?<=_l)(\d+)((((_2)|(_6))(?=_))|((_g)(?=\d+)))', dfile)
                    if mode is None: 
                        #print("Not matched")
                        zv = self.NA # spectrum data: z = None (2pt)
                    else:
                        #print("z = {:}".format(str(mode.group(0))))
                        z = str(mode.group(0)).split('_')
                        zv = z[0]
                        if z[1] == '6':
                            zv = str(-int(zv))
                else:
                    zv = str(mode.group(0))
                if zv not in self.z[key]:
                    self.z[key].append(zv)
                r = 1
                i = 2
                if self.reverse:
                    r = 2
                    i = 1 
                self.pztag[key].append(ps+'+'+zv)
                tagX = key+'+'+self.pztag[key][-1]
                is3pt = False
                # non-spectrum data
                if zv != self.NA:
                    mode = re.search('(?<=t)\d+', dfile)
                    # 3pt data
                    if mode is not None:
                        try:
                            self.tsp[tagX].append(str(mode.group(0)))
                        except:
                            self.tsp[tagX] = [str(mode.group(0))]
                        tagX = tagX+'+T+'+str(mode.group(0))
                    # 2pt data
                    else:
                        try:
                            self.tsp[tagX].append(self.NA)
                        except:
                            self.tsp[tagX] = [self.NA]
                else:
                    try:
                        self.tsp[tagX].append(self.NA)
                    except:
                        self.tsp[tagX] = [self.NA]
                # loop over t & configs.
                #print("{:} tag = {:}".format(dfile, tagX))
                with open(dfile, 'r') as df:
                    n = 0
                    if self.dbin > 1:
                        bkid = 0
                        datblk = {}
                    for line in df:
                        tl = line.split()
                        if n == 0:
                            self.ndata[tagX] = int(tl[0])
                            self.NT[tagX] = int(tl[1])
                            self.nx += 2*self.NT[tagX]
                            # generate columns : p+z+t+{R,I}
                            for tt in range(self.NT[tagX]):
                                dataframe[tagX+'+t+'+str(tt)+'+R'] = []
                                dataframe[tagX+'+t+'+str(tt)+'+I'] = []
                                if self.dbin > 1:
                                    datblk[str(tt)] = [0., 0.]
                            n+=1
                            continue
                        if self.dbin == 1:
                            dataframe[tagX+'+t+'+tl[0]+'+R'].append(float(tl[r]))
                            dataframe[tagX+'+t+'+tl[0]+'+I'].append(float(tl[i]))
                            continue
                        datblk[tl[0]][0] += float(tl[r])
                        datblk[tl[0]][1] += float(tl[i])
                        if bkid == self.dbin-1:
                            dataframe[tagX+'+t+'+tl[0]+'+R'].append(datblk[tl[0]][0]/self.dbin)
                            dataframe[tagX+'+t+'+tl[0]+'+I'].append(datblk[tl[0]][1]/self.dbin)
                            datblk[tl[0]] = [0., 0.]
                        if int(tl[0]) == self.NT[tagX]-1:
                            bkid = (bkid+1) % self.dbin
                    #print("{:} -- {:} : {:d}".format(dfile, tagX+'+t+0+R', len(dataframe[tagX+'+t+0+R'])))
            self.data = pds.DataFrame(data=dict([(k,pds.Series(v)) for k,v in dataframe.items()]))
            #print("Data is {:}".format(self.data))
            del dataframe
        elif self.form == 'raw':
            pdfile = open(self.xfile, 'r')
            data = {}
            data['X'] = np.fromfile(pdfile, dtype=float, sep=" ")
            pdfile.close()
            pdfile = open(self.yfile, 'r')
            data['Y'] = np.fromfile(pdfile, dtype=float, sep=" ")
            pdfile.close()
            self.ndata['Y'] = data['Y'].shape[0]
            self.ndata['X'] = self.ndata['Y']
            self.nx = int(data['X'].shape[0]/self.ndata['X'])
            data['X'].shape = (self.ndata['X'], self.nx)
            if self.dbin > 1:
                data['Y'] = np.array([data['Y'][i:i+self.dbin].mean() for i in range(0,self.ndata['X'],self.dbin)])
                data['X'] = [[data['X'][i:i+self.dbin,j].mean() for j in range(self.nx)]
                                           for i in range(0,self.ndata['X'],self.dbin)]#.reshape(int(self.ndata['X']/self.dbin)*self.nx)
                self.ndata['Y'] = data['Y'].shape[0]
                self.ndata['X'] = self.ndata['Y']
            #print("shape of X is {:} Y {:}\n".format(data['X'].shape, data['Y'].shape))
            self.data = pds.DataFrame(data=data, columns=['X', 'Y'], dtype=float)
            del data
        #print("self.tdict is {:}\n".format(self.tdict))
        return
       
    # Clear data
    def clear_data(self):
        del self.data
        del self.ndata
        del self.nx
        if self.form == 'pdf':
            del self.NT
            del self.p 
            del self.z  
            del self.ny
            del self.pztag
            del self.pdict
            del self.zdict
            del self.tdict
            del self.odict
            del self.tsp
            del self.tspdict
        return
    
    # Select data according to list of p, z, t, dt, return normalized data
    #   pzt: dictionary or list-like data chacacters of: 
    #        'opt' (operator or other key tags, if None or not specified use default list of keys in stored data)
    #        'p' (momentum), 'z' (z-link), 'T' (t_sep, optional), t' (Y time slices), 'dt' (X & Y time differences)
    #   tag: 'X' or 'Y'
    #   incld2pt: for ratio plots; True to select spectrum w. same 'pzt' excluding z & T 
    def select_data(self, pzt, tag = None, incld2pt = False):
        if self.form == 'raw':
            if tag == 'X':
                print("X data is {:}".format(self.data.loc[:,tag].values))
                sdata = self.data.loc[:,tag].values.reshape(self.ndata['X'],self.nx)[:,pzt[1]::pzt[0]]
            else:
                sdata = self.data.loc[:,tag].values
            smean = sdata.mean()
            sstd = sdata.std(ddof=1)
            sdata /= smean
            return None, None, sdata, smean, sstd
        # data form is 'pdf'
        if self.tdict[tag] is not None:
            del self.tdict[tag]
            self.tdict[tag] = None
        if self.pdict[tag] is not None:
            del self.pdict[tag]
        if self.zdict[tag] is not None:
            del self.zdict[tag]
        if self.odict[tag] is not None:
            del self.odict[tag]
        ol = None
        if isinstance(pzt, dict):
            try:
                ol = pzt['opt']
            except:
                pass
            pl = pzt['p']
            zl = pzt['z']
            tl = pzt['t']
            dtl = pzt['dt']
            try:
                tspl = pzt['T']
            except:
                tspl = [self.NA]
        else:
            assert(len(pzt)>=4)
            ofs = 0
            if len(pzt)>4:
                ol = pzt[0]
                ofs = 1
            print("ofs = {:}".format(ofs))
            pl = pzt[ofs]
            zl = pzt[ofs+1]
            tl = pzt[-2]
            dtl = pzt[-1]
            if len(pzt)==6:
                tspl = pzt[ofs+2]
            else:
                tspl = [self.NA]
        opfix = True
        if ol is None:
            if self.xkey is not None:
                ol = list(self.pztag.keys())
            else:
                ol = [self.ANY]
                #opfix = False
        elif isinstance(ol, str):
            ol = [ol]
        olkeys = {}
        for o in ol:
            try:
                olkeys[o] = {list(compress(list(self.pztag.keys()), [o in k for k in self.pztag ]))[0]: list(compress(list(self.pztag.keys()), [ol[o] in k for k in self.pztag ]))}
            except:
                olkeys[o] = list(compress(list(self.pztag.keys()), [o in k for k in self.pztag ]))
        print("list of olkeys is    {:}\n\n".format(olkeys))
        #flag = [True]
        #flag.extend([tmpo[k]==tmpo[k-1] for k in range(1, len(tmpo))])
        #olkeys = list(compress(tmpo, flag))
        tl0 = tl
        if pl is None: 
            pl = self.p
        if zl is None:
            zl = self.z
        if tspl is None:
            tspl = self.tsp
        #print("list of momentum {:}".format(pl))
        self.tdict[tag] = []
        self.pdict[tag] = pl
        self.zdict[tag] = zl
        self.odict[tag] = ol
        self.tspdict[tag] = tspl
        # X data with dt range
        rttag = None
        if dtl is not None:
            # maximum dt 
            if isinstance(dtl, int):
                rttag = (-dtl, dtl, 1)
                rdt = list(range(-dtl, dtl+1))
            # range or list of dt 
            elif isinstance(dtl, (tuple, list)):
                if isinstance(dtl, tuple):
                    try:
                        dt = dtl[2]
                    except:
                        dt = 1
                    rdt = list(range(dtl[0], dtl[1]+1))
                    rttag = (dtl[0], dtl[1], dt)
                else:
                    rdt = np.array(dtl).sort().tolist()
                    rttag = rdt
            else:
                print("Error: Unknown dt range dts_X !\n")
                sys.exit(1)
            self.tdict[tag] = list(rdt)
        nsrc = self.nsrc
        # choose entire NT (or T) range
        if tl is None:
            nsrc = 1
            tl = {}
            self.countXT = 0
            for kpz in self.NT:
                tl[kpz] = list(range(0,self.NT[kpz]))
        elif dtl is not None:
            ttl = []
            print("tY is {:} tX is {:}".format(tl, rdt))
            for t, dt in product(tl, rdt):
                ttl.append(t+dt)
            self.countXT = list(tl)
            tl = ttl
        lst = []
        strT = {}
        # include spectrum w. p 
        mpflag = 0
        excldspec = False
        if incld2pt:
            # include spectrum w. p, and -p iff. self.mp2pt is True
            mpflag = int(self.mp2pt)
            if self.NA not in zl:
                excldspec = True
                zl = zl.copy()
                zl.append(self.NA)
                # include spectrum w. -p / p if. self.mp2pt is True / False
                mpflag = -int(self.mp2pt)
            # include key list
            if False:#if isinstance(ol, dict):
                sign = [ 1 ]
                if mpflag == 1:
                    sign = [-1, 1]
                elif mpflag == -1:
                    sign = [ -1 ]
                for k in ol: 
                    if self.NA in k:
                        continue
                    p = k.split('+')[0]
                    for s in sign:
                        ps = str(s*int(p))
                        if ps not in ol:
                            ol[ps] = [ol[k]]
                        else:
                            ol[ps].append(ol[k])
        for o in ol:
            if isinstance(pl, dict):
                pll = pl[o]
                if isinstance(pll, str):
                    pll = [pll]
            else:
                try:
                    pll = list(pl)
                except:
                    pll = [pl]
            if isinstance(zl, dict):
                zll = zl[o]
                if isinstance(zll, str):
                    zll = [zll]
                #print("list of zll is {:}".format(zll))
            else:
                if isinstance(zl, str):
                    zll = [zl]
                else:
                    zll = zl.copy()
            print("pll is {:} zll is {:}".format(pll, zll))
            for p in pll:
                if isinstance(zll, dict):
                    zlp = zll[p]
                else:
                    zlp = zll
                for z in zlp:
                    sign = [ 1 ]
                    if (z == self.NA) and (self.mp2pt is True):
                        sign = [ -1 ]
                    for s in sign:
                        ps = str(s*int(p))       
                        for oky in olkeys[o]:
                            if z == self.NA:
                                try:
                                    kpz = olkeys[o][oky][0]+'+'+ps+'+'+str(z) # for kk in olkeys[o][oky]]
                                    ikpz = olkeys[o][oky][0]+'+'+ps+'+'+str(z) #[kk+'+'+ps+'+'+str(z) for kk in olkeys[o][oky]]
                                except:
                                    kpz = oky+'+'+ps+'+'+str(z)
                                    ikpz = oky+'+'+ps+'+'+str(z)
                            else:
                                kpz = oky+'+'+ps+'+'+str(z)
                                ikpz = oky+'+'+ps+'+'+str(z)
                            if opfix is False:
                                ikpz = ps+'+'+str(z)
                            strT[kpz] = []
                            #print('key {:}'.format(kpz))
                            if isinstance(tspl, dict):
                                tspn = list(tspl[ikpz])
                                self.tspdict[tag][kpz] = tspn.copy()
                                del self.tspdict[tag][ikpz]
                            else:
                                tspn = list(tspl)
                            for tsp in tspn:
                                if (str(tsp) != self.NA) and (self.NA not in kpz):
                                    lst.append(self.ndata[kpz+'+T+'+str(tsp)]) 
                                    strT[kpz].append('+T+'+str(tsp))
                                else:
                                    lst.append(self.ndata[kpz])
                                    strT[kpz].append(None)
        print("Number of data are {:}".format(lst))
        ndata = np.array(lst).min()
        print("Number of data is {:}".format(ndata))
        del lst
        sdata = []
        ttag = []
        xdata = []
        xtag = []
        #print(ndata)
        print(pl)
        print("list of z {:}".format(zl))
        checkerr = False
        #print("dtl is {:} tl0 is {:} self.emax is {:}".format(dtl, tl0, self.emax))
        if (dtl is None) and (tl0 is None) and (self.emax is not None):
            checkerr = True
            rttag = {}
            self.tdict[tag] = {}
        flag = False
        if self.tdict[tag] is None:
            flag = True
        for o in ol:
            if isinstance(pl, dict):
                pll = pl[o]
                if isinstance(pll, str):
                    pll = [pll]
            else:
                try:
                    pll = list(pl)
                except:
                    pll = [pl]
            if isinstance(zl, dict):
                zll = zl[o]
                if isinstance(zll, str):
                    zll = [zll]
            else:
                try:
                    zll = list(zl)
                except:
                    zll = [zl]
            for pp in pll:
                if isinstance(zll, dict):
                    zlp = zll[pp]
                else:
                    zlp = zll
                for z in zlp:#, z in product(pll, zll):
                    sign = [ 1 ]
                    if (z == self.NA) and (self.mp2pt is True):
                        sign = [ -1 ]
                    for s in sign:
                        p = s*int(pp)
                        for oky in olkeys[o]:
                            if z == self.NA:
                                try:
                                    kpz = olkeys[o][oky][0]+'+'+str(p)+'+'+str(z) # for kk in olkeys[o][oky]]
                                    ikpz = olkeys[o][oky][0]+'+'+str(p)+'+'+str(z) #[kk+'+'+ps+'+'+str(z) for kk in olkeys[o][oky]]
                                except:
                                    kpz = oky+'+'+str(p)+'+'+str(z)
                                    ikpz = oky+'+'+str(p)+'+'+str(z)
                                kpz3pt = oky+'+'+str(p)+'+'+str(z)
                            else:
                                kpz = oky+'+'+str(p)+'+'+str(z)
                                ikpz = oky+'+'+str(p)+'+'+str(z)
                            #print('key {:}'.format(kpz))
                            if opfix is False:
                                ikpz = str(p)+'+'+str(z)
                            for sT in strT[kpz]:
                                itg = ikpz
                                if sT is not None:
                                    tg = kpz+sT
                                    itg = ikpz+sT
                                    T = int(tg.split('+T+')[1])+1
                                else:
                                    tg = kpz
                                    itg = ikpz
                                    T = int(self.NT[tg]/self.nsrc)
                                if checkerr:
                                    self.tdict[tag][tg] = []
                                if isinstance(tl, dict):
                                    tt = tl[itg]
                                else: 
                                    tt = tl
                                if checkerr:
                                    rttag[tg] = []
                                for n, tmp in product(range(nsrc), tt):
                                    t = (tmp+n*int(self.NT[tg]/nsrc)+self.NT[tg]) % self.NT[tg]
                                    if t%int(self.NT[tg]/self.nsrc) >= T:
                                        continue
                                    if excldspec and (z == self.NA): #and (s == -1):
                                        if kpz3pt+'+t+'+str(t)+'+R' in xtag:
                                            continue
                                    else:
                                        if tg+'+t+'+str(t)+'+R' in ttag:
                                            continue
                                    datatmp = self.data.loc[:ndata-1,[tg+'+t+'+str(t)+'+R',tg+'+t+'+str(t)+'+I']].T.values
                                    if checkerr: 
                                        if excldspec and (z == self.NA):
                                            pass                                
                                        else:
                                            arrtmp = np.array(datatmp)
                                            #print("shape of arrtmp {:}".format(arrtmp.shape))
                                            if abs(arrtmp[0].std(ddof=1)/arrtmp[0].mean()) > self.emax and (abs(arrtmp[1].std(ddof=1)/arrtmp[1].mean()) > self.emax):
                                                del arrtmp
                                                del datatmp
                                                if t not in rttag[tg]:
                                                    rttag[tg].append(t)
                                                continue
                                        if t not in self.tdict[tag][tg]:
                                            self.tdict[tag][tg].append(t)
                                    elif t not in self.tdict[tag]:
                                        self.tdict[tag].append(t)
                                    if excldspec and (z == self.NA): #and (s == -1):
                                        xdata.extend(datatmp)
                                        xtag.extend([kpz3pt+'+t+'+str(t)+'+R', kpz3pt+'+t+'+str(t)+'+I'])
                                    else:
                                        #print("datatmp at {:} ({:}): {:}".format(tg+'+t+'+str(t), np.array(datatmp).shape, datatmp))
                                        sdata.extend(datatmp)
                                        ttag.extend([tg+'+t+'+str(t)+'+R', tg+'+t+'+str(t)+'+I'])
                                    del datatmp
                                flag = False
        #sdata.extend(xdata)
        arr = np.array(sdata)
        del sdata
        del strT
        if incld2pt:
            xarr = np.array(xdata)
        del xdata
        sh = list(arr.shape)
        tsh = 1
        for t in sh:
            tsh *= t
        smean = [ arr[i].mean() for i in range(len(arr)) ]
        sstd = [ arr[i].std(ddof=1) for i in range(len(arr)) ]
        for i in range(len(arr)):
            arr[i] /= smean[i]
       # return tuple(tgl), arr.reshape(sh).T, np.array(smean), np.array(sstd)
        if incld2pt:
            return rttag, tuple(ttag), tuple(xtag), arr.T, xarr.T, np.array(smean), np.array(sstd)
        else:
            return rttag, tuple(ttag), arr.T, np.array(smean), np.array(sstd)
                         

    # Select X data with specified time range for a fit
    #   tagX: list of X id's
    #   dtX: range (type: tuple) or list (type: list) of time slices
    #   tagY: the Y id
    #   X : list of X samples with len(tagX) variables
    def select_T(self, tagX, dtX, tagY, X = None):
        #print("nsrc = {:}".format(self.nsrc))
        if X is None:
            aX = np.array(self.data.loc[:,list(tagX)].values)#.T
        else:
            aX = np.array(X).T
        assert(aX.shape[0] == len(tagX))
        nd = aX.shape[1]
        lX = []
        tgl = tagY.split('+')
        ty = int(tgl[-2]) 
        if isinstance(dtX, tuple):
            tlist0 = list(range(dtX[0], dtX[1]+1, dtX[2]))
        else:
            tlist0 = dtX
        tgXl = []
        tlist = {}
        for tag in tagX:
            tg = tag.split('+t+')[0]
            if tg in tgXl:
                continue
            tgXl.append(tg)
            tlist[tg] = tlist0.copy()
            NTx = self.NT[tg]
            #print("X {:} NT = {:}".format(tg,NTx))
            nnt = int(NTx/self.nsrc)
            try:
                Tx = NTx + int(tg.split('+T+')[1])+1 - nnt
            except:
                Tx = NTx
            for i in range(1, self.nsrc):
                tlist[tg].extend([t+i*nnt for t in tlist0])
            for i in range(len(tlist[tg])):
                tlist[tg][i] = (tlist[tg][i]+ty+Tx)%Tx
        for i in range(len(tagX)):
            tg = tagX[i].split('+t+')[0]
            #print("list of t for tag {:}: {:}".format(tg, tlist[tg]))
            if int(tagX[i].split('+')[-2]) in tlist[tg]:
                #print("Data X at {:} : {:}".format(tagX[i], aX[i]))
                lX.append(aX[i])
        del tgXl
        del tlist
        del tlist0
        arr = np.array(lX).T
        del lX
        del aX
        return arr
                         
    
    # Select X data with specified time range for a fit
    #   tagX: list of X id's including full time range
    #   dtX: range (type: tuple) of time slices
    #   tagY: the Y id
    #   X : (optional) list of X (complex) samples with len(tagX) variables; by default use the data stored
    def select_T_fast(self, tagX, dtX, tagY, X=None):
        # range of contiguous time slices
        assert(isinstance(dtX, tuple) and (len(dtX)==2))
        if X is None:
            aX = np.array(self.data.loc[:,list(tagX)].values)#.T
        else:
            aX = np.array(X).T
        assert(aX.shape[0] == len(tagX))
        nd = aX.shape[1]
        lX = []
        tgl = tagY.split('+')
        ty = tgl[-2]    
        pzX = []
        x = 0
        while x < len(tagX):
            pz = tagX[x].split('+t+')[0]
            assert(pz not in pzX)
            pzX.append(pz)
            nnt = int(self.NT[pz]/self.nsrc)
            for i in range(self.nsrc):
                dt = [(tx+i*nnt)%self.NT[pz] for tx in dtX]
                lX.extend(aX[2*(x+dt[0]):2*(x+dt[1])+1])
            #assert(tagX[])
            x += self.NT[pz]*2
        arr = np.array(lX).T
        del lX
        del aX
        return arr
        
        
    # Generate output directories and filenames
    def setup_output(self):
        self.date = datetime.datetime.today().strftime('%m%d%Y')
        if self.osave or self.dsave:
            odir = self.odir+'/ML.'+self.date
            sfix = '.000'
            n = 1
            flag = 0
            while True:
                try:
                    os.mkdir(odir+sfix)
                    flag = 1
                except:
                    sfix = ".{:03d}".format(n)
                    flag = 0
                    n += 1
                    #print("Warming: directory"+self.odir+'/ML.'+self.date+" already exist")
                if flag == 1:
                    break
            self.odir = odir+sfix
            if self.osave:
                try:
                    os.mkdir(self.odir+'/plots')
                except:
                    print("Warming: directory"+self.odir+'/plots'+" already exist")
            if self.dsave:
                try:
                    os.mkdir(self.odir+'/preds')
                except:
                    print("Warming: directory"+self.odir+'/preds'+" already exist")
            self.oheader_tag = None
        else:
            self.ofile = None
            self.pltfile = None
        return
    
    # Make the output file header lines
    #   indx: filename surfix (digital)
    def dfile_mkheader(self, indx):
        if self.dsave is False:
            return
        if self.form == 'raw':
            return
        if self.form == 'pdf':
            self.ofile = self.odir+'/'+self.ID+str(indx)
            self.pltfile = self.odir+'/plots/'+self.ID+str(indx)
            pf = open(self.ofile, 'a+')
            self.dfile_mkheader_pdf(pf)
            pf.close()
        return
    
    def dfile_mkheader_pdf(self, pf):
        if True: 
            pf.write("PDF data analysis with ML: \n Target data: ")
            n = 1
            for tgp, tgz in product(self.pdict['Y'], self.zdict['Y']):
                if isinstance(self.odict['Y'], dict):
                    try:
                        odict = self.odict['Y'][str(tgp)+'+'+str(tgz)]
                    except:
                        odict = self.odict['Y']#[str(tgp)]
                else:
                    odict = self.odict['Y']
                for tgo in odict:
                    if self.NA == tgz: 
                        pf.write("Y{:d}( spectrum ({:}): p= {:} ),  ".format(n, tgo, tgp))
                    else:
                        if isinstance(self.tspdict['Y'], dict):
                            tspl = self.tspdict['Y'][tgo+'+'+str(tgp)+'+'+str(tgz)]
                        else:
                            tspl = self.tspdict['Y']
                        if self.NA in tspl:
                            pf.write("Y{:d}( DA ({:}): p= {:} ,z= {:} ),  ".format(n, tgo, tgp, tgz))
                            tspl.remove(self.NA)
                        if len(tspl) > 0:
                            pf.write("Y{:d}( PDF ({:}): p= {:} ,z= {:}, T= {:} ),  ".format(n, tgo, tgp, tgz, tspl))
                    n+=1
            pf.write("\n Training data: p= {:} ,z= {:} ,operators= {:}\n".format(self.pdict['X'], self.zdict['X'], self.odict['X']))
            pf.write("\n Target data at t= {:} \n Training data at dt= {:} \n\n".format(self.tdict['Y'], 
                                                                                        self.tdict['X']))
        else:
            pf.write("\n Target data at t= {:} \n Training data at dt= {:} \n\n".format(self.tdict['Y'], 
                                                                                        self.tdict['X']))
        return
        

SyntaxError: invalid syntax (<ipython-input-1-cecf936ad60a>, line 364)

In [4]:
a=['a','b','c']
a.remove('a')
print(a)
b=a.copy()
print(b)

['b', 'c']
['b', 'c']


In [12]:
a=['a', 'b', 'c']
print(all([aa in 'abdef' for aa in a]))
print(str(a))
print('-'.join(['a']))
print(str(False))
b='-0.5'
print(b.split('-')[0])

False
['a', 'b', 'c']
a
False



In [1]:
import numpy as np
a = np.array([[1,2],[3,4],[5,6]])
print(a[:2,:])
a=np.array([(1,2),[3,4],[5,6]])
print(a.reshape(3,2))

[[1 2]
 [3 4]]
[[1 2]
 [3 4]
 [5 6]]


In [17]:
a=1
print("{:03d}".format(a))
from itertools import compress
tmpo = [ 1, 1, 2, 3, 4, 4]


001
[True, <generator object <genexpr> at 0x112477de0>]
[1, 1]
