In [None]:
#
# Data I/O routines
# 
import numpy as np
import pandas as pds
import sys
import re
import os
from itertools import product

# Database & I/O routines for ML fits 
#   dpar: dictionary of parameters
#   key:  Output directory & prefix
class Data_IO:
    def __init__(self, dpar, key):
        if key is not None:
            self.ID = str(key)+'.'
        else:
            self.ID = ''
        # data file format: 'pdf' for Chroma ascii files; 'raw' for 1-d list of X & Y
        self.flist = [ 'pdf', 'raw', 'NA']
        try:
            self.form = dpar['format']
        except:
            self.form = 'NA'
        assert(self.form in self.flist)
        try:
            self.dbin = dpar['binsize']
        except:
            self.dbin = 0
        try:
            self.emax = dpar['emax']
        except:
            self.emax = None
        try:
            self.nsrc = dpar['nsrc']
        except:
            self.nsrc = 1
        try:
            self.reverse = dpar['reverse']
        except:
            self.reverse = False
        try:
            self.NA = str(dpar['spectrum'])
            if len(list(self.NA)) == 0:
                self.NA = str(None)
        except:
            self.NA = str(None)
        try:
            self.crrt = dpar['is3pt']
        except:
            self.crrt = '3pt'
        self.make_data(dpar)
        try:
            self.odir = dpar['odir']
            self.osave = True
            self.dsave = True
        except:
            self.osave = False
            self.dsave = False
            self.odir = None
        self.setup_output()
        return
    
    # clear database, free memory
    def finalize(self):
        self.clear_data()
        self.__init__(None, None)
        return
    
    # Read in data files and generate database
    def make_data(self, dpar):
        self.ndata = {}
        try:
            self.ddir = dpar['ddir']
            try:
                self.xkey = dpar['dfile.x']
            except:
                self.xkey = None
            try:
                self.ykey = dpar['dfile.y']
            except:
                self.ykey = None
            if isinstance(self.ddir, str):
                self.ddir = [self.ddir]
            self.xfile = []
            for ddir in self.ddir:
                for d1, d2, f in os.walk(ddir): 
                    if (self.xkey is None) or (self.ykey is None):
                        self.xfile.extend([ddir+'/'+ff for ff in f])
                    else:
                        for ff in f:
                            if (self.xkey in ff) or (self.ykey in ff):
                                self.xfile.append(ddir+'/'+ff)
        except:
            self.ddir = None
            try:
                self.xfile = dpar['dfile.x']
                self.yfile = dpar['dfile.y']
            except:
                return
            if self.form == 'pdf':
                self.xfile.extend(self.yfile)
        if self.form == 'pdf':
            self.NT = {}
            self.tsp = {}
            self.p = []
            self.z = []
            self.nx = 0
            self.ny = 0
            self.pztag = []
            self.pdict = {}
            self.zdict = {}
            self.tdict = {}
            self.tspdict = {}
            self.tdict['X'] = None
            self.tdict['Y'] = None
            self.pdict['X'] = None
            self.pdict['Y'] = None
            self.zdict['X'] = None
            self.zdict['Y'] = None
            dataframe = {}
            for dfile in self.xfile:
                mode = re.search('((?<=pz)|(?<=p))((\d+)|(-\d+))', dfile)
                try:
                    self.p.append(str(mode.group(0)))
                except:
                    print("Unknown PDF data filename convention for momentum p/pz\n")
                    sys.exit(1)
                mode = re.search('(?<=z)(?<!pz)((\d+)|(-\d+))', dfile)
                try:
                    self.z.append(str(mode.group(0)))
                except:
                    # spectrum data: z = None (2pt)
                    self.z.append(self.NA)
                r = 1
                i = 2
                if self.reverse:
                    r = 2
                    i = 1
                # loop over t & configs.
                self.pztag.append(self.p[-1]+'+'+self.z[-1])
                tagX = self.pztag[-1]
                is3pt = False
                if self.z[-1] is not self.NA:
                    mode = re.search('(?<=t)\d+', dfile)
                    if mode is not None:
                        try:
                            self.tsp[tagX].append(str(mode.group(0)))
                        except:
                            self.tsp[tagX] = [str(mode.group(0))]
                        tagX = tagX+'+T+'+str(mode.group(0))
                    else:
                        try:
                            self.tsp[tagX].append(self.NA)
                        except:
                            self.tsp[tagX] = [self.NA]
                else:
                    try:
                        self.tsp[tagX].append(self.NA)
                    except:
                        self.tsp[tagX] = [self.NA]
                with open(dfile, 'r') as df:
                    n = 0
                    if self.dbin > 1:
                        bkid = 0
                        datblk = {}
                    for line in df:
                        tl = line.split()
                        if n == 0:
                            self.ndata[tagX] = int(tl[0])
                            self.NT[tagX] = int(tl[1])
                            self.nx += 2*self.NT[tagX]
                            # generate columns : p+z+t+{R,I}
                            for tt in range(self.NT[tagX]):
                                dataframe[tagX+'+t+'+str(tt)+'+R'] = []
                                dataframe[tagX+'+t+'+str(tt)+'+I'] = []
                                if self.dbin > 1:
                                    datblk[str(tt)] = [0., 0.]
                            n+=1
                            continue
                        if self.dbin == 1:
                            dataframe[tagX+'+t+'+tl[0]+'+R'].append(float(tl[r]))
                            dataframe[tagX+'+t+'+tl[0]+'+I'].append(float(tl[i]))
                            continue
                        datblk[tl[0]][0] += float(tl[r])
                        datblk[tl[0]][1] += float(tl[i])
                        if bkid == self.dbin-1:
                            dataframe[tagX+'+t+'+tl[0]+'+R'].append(datblk[tl[0]][0]/self.dbin)
                            dataframe[tagX+'+t+'+tl[0]+'+I'].append(datblk[tl[0]][1]/self.dbin)
                            datblk[tl[0]] = [0., 0.]
                        if int(tl[0]) == self.NT[tagX]-1:
                            bkid = (bkid+1) % self.dbin
            self.data = pds.DataFrame(data=dataframe)
            del dataframe
        elif self.form == 'raw':
            pdfile = open(self.xfile, 'r')
            data = {}
            data['X'] = np.fromfile(pdfile, dtype=float, sep=" ")
            pdfile.close()
            pdfile = open(self.yfile, 'r')
            data['Y'] = np.fromfile(pdfile, dtype=float, sep=" ")
            pdfile.close()
            self.ndata['Y'] = data['Y'].shape[0]
            self.ndata['X'] = self.ndata['Y']
            self.nx = int(data['X'].shape[0]/self.ndata['X'])
            data['X'].shape = (self.ndata['X'], self.nx)
            if self.dbin > 1:
                data['Y'] = np.array([data['Y'][i:i+self.dbin].mean() for i in range(0,self.ndata['X'],self.dbin)])
                data['X'] = np.array([[data['X'][i:i+self.dbin,j].mean() for j in range(self.nx)] 
                                           for i in range(0,self.ndata['X'],self.dbin)]).reshape(int(self.ndata['X']/self.dbin)*self.nx)
                self.ndata['Y'] = data['Y'].shape[0]
                self.ndata['X'] = self.ndata['Y']
            self.data = pds.DataFrame(data=data)
            del data
        return
       
    # Clear data
    def clear_data(self):
        del self.data
        del self.ndata
        del self.nx
        if self.form == 'pdf':
            del self.NT
            del self.p 
            del self.z  
            del self.ny
            del self.pztag
            del self.pdict
            del self.zdict
            del self.tdict
            del self.tsp
            del self.tspdict
        return
    
    # Select data according to list of p, z, t, dt, return normalized data
    #   pzt: dictionary or list-like data chacacters of: 
    #        'p' (momentum), 'z' (z-link), 'T' (t_sep, optional), t' (Y time slices), 'dt' (X & Y time differences)
    #   tag: 'X' or 'Y'
    def select_data(self, pzt, tag = None):
        if self.form == 'raw':
            sdata = self.data.loc[:,tag].reshape(self.ndata['X'],self.nx)[:,pzt[1]::pzt[0]]
            smean = sdata.mean()
            sstd = sdata.std(ddof=1)
            sdata /= smean
            return None, None, sdata, smean, sstd
        # data form is 'pdf'
        if self.tdict[tag] is not None:
            del self.tdict[tag]
            self.tdict[tag] = None
        if self.pdict[tag] is not None:
            del self.pdict[tag]
        if self.zdict[tag] is not None:
            del self.zdict[tag]
        if isinstance(pzt, dict):
            pl = pzt['p']
            zl = pzt['z']
            tl = pzt['t']
            dtl = pzt['dt']
            try:
                tspl = pzt['T']
            except:
                tspl = [self.NA]
        else:
            pl = pzt[0]
            zl = pzt[1]
            tl = pzt[-2]
            dtl = pzt[-1]
            if len(pzt)==5:
                tspl = pzt[2]
            else:
                tspl = [self.NA]
        tl0 = tl
        if pl is None: 
            pl = self.p
        if zl is None:
            zl = self.z
        if tspl is None:
            tspl = self.tsp
        self.tdict[tag] = []
        self.pdict[tag] = list(pl)
        self.zdict[tag] = list(zl)
        self.tspdict[tag] = tspl
        # X data with dt range
        if dtl is not None:
            # maximum dt 
            if isinstance(dtl, int):
                rttag = (-dtl, dtl, 1)
                rdt = list(range(-dtl, dtl+1))
            # range or list of dt 
            elif isinstance(dtl, (tuple, list)):
                if isinstance(dtl, tuple):
                    try:
                        dt = dtl[2]
                    except:
                        dt = 1
                    rdt = list(range(dtl[0], dtl[1]+1))
                    rttag = (dtl[0], dtl[1], dt)
                else:
                    rdt = np.array(dtl).sort().tolist()
                    rttag = rdt
            else:
                print("Error: Unknown dt range dts_X !\n")
                sys.exit(1)
            self.tdict[tag] = list(rdt)
        nsrc = self.nsrc
        # choose entire NT (or T) range
        if tl is None:
            nsrc = 1
            tl = {}
            self.countXT = 0
            for k in self.NT:
                tl[k] = list(range(0,self.NT[k]))
        elif dtl is not None:
            ttl = []
            for t, dt in product(tl, dtl):
                ttl.append(t+dt)
            self.countXT = list(tl)
            tl = ttl
        lst = []
        strT = {}
        for p, z in product(pl, zl):
            pz = str(p)+'+'+str(z)
            strT[pz] = []
            if isinstance(tspl, tuple):
                tspn = list(tspl[pz])
            else:
                tspn = list(tspl)
            for tsp in tspn:
                if (str(tsp) != self.NA) and (self.NA not in pz):
                    lst.append(self.ndata[pz+'+T+'+str(tsp)])
                    strT[pz].append('+T+'+str(tsp))
                else:
                    lst.append(self.ndata[pz])
                    strT[pz].append(None)
        ndata = np.array(lst).min()
        del lst
        sdata = []
        ttag = []
        print(ndata)
        print(pl)
        print(zl)
        checkerr = False
        if (dtl is None) and (tl0 is None) and (self.emax is not None):
            checkerr = True
            rttag = {}
            self.tdict[tag] = {}
        flag = False
        if self.tdict[tag] is None:
            flag = True
        for p,z in product(pl, zl):
            pz = str(p)+'+'+str(z)
            for sT in strT[pz]:
                if sT is not None:
                    tg = pz+sT
                    T = int(tg.split('+T+')[1])+1
                else:
                    tg = pz
                    T = int(self.NT[tg]/self.nsrc)
                if checkerr:
                    self.tdict[tag][tg] = []
                if isinstance(tl, dict):
                    tt = tl[tg]
                else: 
                    tt = tl
                if checkerr:
                    rttag[tg] = []
                for n, tmp in product(range(nsrc), tt):
                    t = (tmp+n*int(self.NT[tg]/nsrc)+self.NT[tg]) % self.NT[tg]
                    if t%int(self.NT[tg]/self.nsrc) >= T:
                        continue
                    if tg+'+t+'+str(t)+'+R' in ttag:
                        continue
                    datatmp = self.data.loc[:,[tg+'+t+'+str(t)+'+R',tg+'+t+'+str(t)+'+I']].T.values
                    if checkerr:
                        arrtmp = np.array(datatmp)
                        if abs(arrtmp[0].std(ddof=1)/arrtmp[0].mean()) > self.emax and (abs(arrtmp[1].std(ddof=1)/arrtmp[1].mean()) > self.emax):
                            del arrtmp
                            del datatmp
                            rttag[tg].append(t)
                            continue
                        else:
                            self.tdict[tag][tg].append(t)
                    elif flag:
                        self.tdict[tag].append(t)
                    sdata.extend(datatmp)
                    ttag.extend([tg+'+t+'+str(t)+'+R', tg+'+t+'+str(t)+'+I'])
                    del datatmp
                flag = False
        arr = np.array(sdata)
        del sdata
        del strT
        sh = list(arr.shape)
        tsh = 1
        for t in sh:
            tsh *= t
        smean = [ arr[i].mean() for i in range(len(arr)) ]
        sstd = [ arr[i].std(ddof=1) for i in range(len(arr)) ]
        for i in range(len(arr)):
            arr[i] /= smean[i]
       # return tuple(tgl), arr.reshape(sh).T, np.array(smean), np.array(sstd)
        return rttag, tuple(ttag), arr.T, np.array(smean), np.array(sstd)

    # Select X data with specified time range for a fit
    #   tagX: list of X id's
    #   dtX: range (type: tuple) or list (type: list) of time slices
    #   tagY: the Y id
    #   X : list of X samples with len(tagX) variables
    def select_T(self, tagX, dtX, tagY, X = None):
        if X is None:
            aX = np.array(self.data.loc[:,list(tagX)].values)#.T
        else:
            aX = np.array(X).T
        assert(aX.shape[0] == len(tagX))
        nd = aX.shape[1]
        lX = []
        tgl = tagY.split('+')
        ty = int(tgl[-2]) 
        if isinstance(dtX, tuple):
            tlist0 = list(range(dtX[0], dtX[1]+1, dtX[2]))
        else:
            tlist0 = dtX
        tgXl = []
        tlist = {}
        for tag in tagX:
            tg = tag.split('+t+')[0]
            if tg in tgXl:
                continue
            tgXl.append(tg)
            tlist[tg] = tlist0.copy()
            NTx = self.NT[tg]
            nnt = int(NTx/self.nsrc)
            try:
                Tx = NTx + int(tg.split('+T+')[1])+1 - nnt
            except:
                Tx = NTx
            for i in range(1, self.nsrc):
                tlist[tg].extend([t+i*nnt for t in tlist0])
            for i in range(len(tlist[tg])):
                tlist[tg][i] = (tlist[tg][i]+ty+Tx)%Tx
        for i in range(len(tagX)):
            tg = tagX[i].split('+t+')[0]
            if int(tagX[i].split('+')[-2]) in tlist[tg]:
                lX.append(aX[i])
        del tgXl
        del tlist
        del tlist0
        arr = np.array(lX).T
        del lX
        del aX
        return arr
    
    # Select X data with specified time range for a fit
    #   tagX: list of X id's including full time range
    #   dtX: range (type: tuple) of time slices
    #   tagY: the Y id
    #   X : (optional) list of X (complex) samples with len(tagX) variables; by default use the data stored
    def select_T_fast(self, tagX, dtX, tagY, X=None):
        # range of contiguous time slices
        assert(isinstance(dtX, tuple) and (len(dtX)==2))
        if X is None:
            aX = np.array(self.data.loc[:,list(tagX)].values)#.T
        else:
            aX = np.array(X).T
        assert(aX.shape[0] == len(tagX))
        nd = aX.shape[1]
        lX = []
        tgl = tagY.split('+')
        ty = tgl[-2]    
        pzX = []
        x = 0
        while x < len(tagX):
            pz = tagX[x].split('+t+')[0]
            assert(pz not in pzX)
            pzX.append(pz)
            nnt = int(self.NT[pz]/self.nsrc)
            for i in range(self.nsrc):
                dt = [(tx+i*nnt)%self.NT[pz] for tx in dtX]
                lX.extend(aX[2*(x+dt[0]):2*(x+dt[1])+1])
            #assert(tagX[])
            x += self.NT[pz]*2
        arr = np.array(lX).T
        del lX
        del aX
        return arr
        
        
    # Generate output directories and filenames
    def setup_output(self):
        self.date = datetime.datetime.today().strftime('%m%d%Y')
        if self.osave or self.dsave:
            odir = self.odir+'/ML.'+self.date
            sfix = ''
            n = 1
            flag = 0
            while True:
                try:
                    os.mkdir(odir+sfix)
                    flag = 1
                except:
                    sfix = '.'+str(n)
                    flag = 0
                    n += 1
                    #print("Warming: directory"+self.odir+'/ML.'+self.date+" already exist")
                if flag == 1:
                    break
            self.odir = odir+sfix
            if self.osave:
                try:
                    os.mkdir(self.odir+'/plots')
                except:
                    print("Warming: directory"+self.odir+'/plots'+" already exist")
            if self.dsave:
                try:
                    os.mkdir(self.odir+'/preds')
                except:
                    print("Warming: directory"+self.odir+'/preds'+" already exist")
            self.oheader_tag = None
        else:
            self.ofile = None
            self.pltfile = None
        return
    
    # Make the output file header lines
    #   indx: filename surfix (digital)
    def dfile_mkheader(self, indx):
        if self.dsave is False:
            return
        if self.form == 'raw':
            return
        if self.form == 'pdf':
            self.ofile = self.odir+'/'+self.ID+str(indx)
            self.pltfile = self.odir+'/plots/'+self.ID+str(indx)
            pf = open(self.ofile, 'a+')
            self.dfile_mkheader_pdf(pf)
            pf.close()
        return
    
    def dfile_mkheader_pdf(self, pf):
        if True: 
            pf.write("PDF data analysis with ML: \n Target data: ")
            n = 1
            for tgp, tgz, in product(self.pdict['Y'], self.zdict['Y']):
                if self.NA == tgz: 
                    pf.write("Y{:d}( spectrum: p= {:} ),  ".format(n, tgp))
                else:
                    if isinstance(self.tspdict['Y'], dict):
                        tspl = self.tspdict['Y'][str(tgp)+'+'+str(tgz)].copy()
                    else:
                        tspl = self.tspdict['Y']
                    if self.NA in tspl:
                        pf.write("Y{:d}( DA: p= {:} ,z= {:} ),  ".format(n, tgp, tgz))
                        tspl.remove(self.NA)
                    if len(tspl) > 0:
                        pf.write("Y{:d}( PDF: p= {:} ,z= {:}, T= {:} ),  ".format(n, tgp, tgz, tspl))
                n+=1
            pl = []
            zl = []
            pf.write("\n Training data: p= {:} ,z= {:} \n".format(self.pdict['X'], self.zdict['X']))
            pf.write("\n Target data at t= {:} \n Training data at dt= {:} \n\n".format(self.tdict['Y'], 
                                                                                        self.tdict['X']))
        else:
            pf.write("\n Target data at t= {:} \n Training data at dt= {:} \n\n".format(self.tdict['Y'], 
                                                                                        self.tdict['X']))
        return
        

In [4]:
a=['a','b','c']
a.remove('a')
print(a)
b=a.copy()
print(b)

['b', 'c']
['b', 'c']
