In [None]:
%run Topo_treatment.ipynb
import glob
import struct
from pathlib import Path

class FileIO(object):
    
    @staticmethod
    def read_binary(path, shape):
        file_in  = open(path, "rb")
        data_arr = struct.unpack(str(np.prod(shape))+"d", file_in.read(8 * np.prod(shape)))
        data_arr = np.reshape(data_arr, shape)
        file_in.close()
        return data_arr
    
    @staticmethod
    def write_binary(path, data, shape):
        file_out = open(path, "wb")
        file_out.write(struct.pack(str(np.prod(shape))+'d', *(data)))
        file_out.close()
        
    @staticmethod
    def read_matrix_binary(path):
        file_in = open(path, "rb")
        dims    = struct.unpack('I', file_in.read(4))[0]
        if dims == 0:
            return None
        shape   = struct.unpack(str(dims)+'I', file_in.read(4 * dims))
        mat     = struct.unpack(str(np.prod(shape))+'d', file_in.read(8*np.prod(shape)))
        mat     = np.reshape(mat, shape)
        return mat
        
    @staticmethod
    def write_matrix_binary(path, mat):
        dims  = len(mat.shape)
        shape = mat.shape
        file_out = open(path, "wb")
        file_out.write(struct.pack('I', dims))
        file_out.write(struct.pack(str(dims)+'I', *(shape)))
        file_out.write(struct.pack(str(np.prod(shape))+'d', *(mat.flatten())))
        file_out.close()
    
    @staticmethod
    def compute_pim_save(dir_in, dir_out, ext_in, params, reg=0, writeout=True, give_status=False):
        address_book_in  = []
        address_book_out = []
        os.chdir(dir_in)
        for file in glob.glob("*."+ext_in):
            address_book_in.append(os.path.join(dir_in, file))
            address_book_out.append(os.path.join(dir_out, file+".dat"))
        if writeout:
            Path(dir_out).mkdir(parents=True, exist_ok=True)
        
        max_rec = []
        min_rec = []
        et      = Edges_(params, False)
        num_    = len(address_book_in)
        for i in range(num_):
            img = np.expand_dims(skio.imread(address_book_in[i]), 0)
            pim = np.squeeze(et.persimg_batch(img, binarize=True))
            pim = (pim - reg) / reg
            if writeout:
                FileIO.write_binary(address_book_out[i], pim.flatten(), pim.shape)
            if i % 49 == 0:
                print("%d/%d" %(i+1, num_))
            if give_status:
                max_rec.append(np.amax(pim))
                min_rec.append(np.amin(pim))
        
        if give_status:
            print(np.amax(max_rec), np.amin(min_rec))
            
    @staticmethod
    def compute_pd_save(dir_in, dir_out, ext_in, params, dim):
        address_book_in  = []
        address_book_out = []
        os.chdir(dir_in)
        for file in glob.glob("*."+ext_in):
            address_book_in.append(os.path.join(dir_in, file))
            address_book_out.append(os.path.join(dir_out, file+".dat"))
        Path(dir_out).mkdir(parents=True, exist_ok=True)
        
        et   = Edges_(params, False)
        num_ = len(address_book_in)
        for i in range(num_):
            img = np.expand_dims(skio.imread(address_book_in[i]), 0)
            _, _, _, _, pd = et.pd_batch(img, dim, debug=False, old_form=True, binarize=True)
            pd = np.squeeze(np.asarray(pd))
            FileIO.write_matrix_binary(address_book_out[i], pd)
            if i % 49 == 0:
                print("%d/%d" %(i+1, num_))
                
    @staticmethod
    def read_pd_subset(dir_in, ext_in, percentage, shuffle=True):
        address_book_in = []
        filename_book   = []
        os.chdir(dir_in)
        for file in glob.glob("*."+ext_in):
            address_book_in.append(os.path.join(dir_in, file))
            filename_book.append(file[:-(len(ext_in)+1)])
        file_num = len(address_book_in)
        read_num = int(np.floor(file_num * percentage))
        if shuffle:
            ind_list = random.sample(range(file_num), read_num)
        else:
            ind_list = np.arange(read_num)
        
        pd_set = []
        f_set  = []
        for i in range(read_num):
            pd_ = FileIO.read_matrix_binary(address_book_in[ind_list[i]])
            if pd_.any() != None:
                if len(pd_.shape) == 1:
                    pd_ = np.expand_dims(pd_, axis=0)
                pd_set.append(pd_)
                f_set.append(filename_book[ind_list[i]])
        return pd_set, f_set
    
    @staticmethod
    def read_pd_pathlist(pd_pathlist):
        '''
        Read persistence diagrams from the given path list.
        Each path should be a full path.
        '''
        file_num = len(pd_pathlist)
        pd_set   = [None] * file_num
        for i in range(file_num):
            pd_set[i] = FileIO.read_matrix_binary(pd_pathlist[i])
        return pd_set
    
    @staticmethod
    def read_image_subset(dir_in, ext_in, percentage, shuffle=True):
        '''
        Read in images in grayscale.
        '''
        address_book_in = []
        os.chdir(dir_in)
        for file in glob.glob("*."+ext_in):
            address_book_in.append(os.path.join(dir_in, file))
        file_num = len(address_book_in)
        read_num = int(np.floor(file_num * percentage))
        if shuffle:
            ind_list = random.sample(range(file_num), read_num)
        else:
            ind_list = np.arange(read_num)
            
        res = [None] * read_num
        for i in range(read_num):
            res[i] = cv2.imread(address_book_in[ind_list[i]], cv2.IMREAD_GRAYSCALE)
        print("Data value range: %f - %f" %(np.amax(res[0]), np.amin(res[0])))
        return res