In [None]:
import os
import h5py
import numpy as np
import torch
import torch.utils.data.Dataset as Dataset

In [None]:
data_dir = '../data/'
train_scan_shot = 'training/scan_shot'
train_eigen = 'training/scan_lb/'
test_scan_shot = 'test/scan_shot/'
test_eigen = 'test/scan_lb/'
inter = 'inter_challenge.txt'
intra = 'intra_challenge.txt'

In [None]:
class FAUSTDataset(Dataset):
    def __init__(self, train, eignum=120):
        self.train = train
        self.shot_des = []
        self.eigen_des = []
        self.eignum = eignum
        
        if self.train:
            streig = "%d" % (eignum)
            for i in range(100):
                strnum = '%03d' % (i)
                fn = ''.join([data_dir, train_scan_shot, 'tr_scan_d_res_', strnum, '.txt'])
                self.shot_des.append(self.load_shot(fn))
                fn_eig = ''.join([data_dir, train_eigen, 'tr_scan_', streig, '_', strnum, '.h5'])
                self.eigen_des.append(self.load_eig(fn_eig))
        else:
            pairs = self.load_pairs()
            for i in range(100):  
                fn1 = ''.join([data_dir, test_scan_shot, 'test_scan_d_res_', pairs[i][0], '.txt'])
                fn2 = ''.join([data_dir, test_scan_shot, 'test_scan_d_res_', pairs[i][1], '.txt'])
                self.shot_des.append(self.load_shot(fn1))
                self.shot_des.append(self.load_shot(fn2))
                fn_eig1 = ''.join([data_dir, test_eigen, 'test_scan_', streig, '_', pairs[i][0], '.h5'])
                fn_eig2 = ''.join([data_dir, test_eigen, 'test_scan_', streig, '_', pairs[i][1], '.h5'])
                self.eigen_des.append(self.load_eig(fn_eig1))
                self.eigen_des.append(self.load_eig(fn_eig2))
            
    def __len__(self):
        if self.train:
            return len(self.shot_des)
        else:
            return len(self.shot_dex)//2
        
    def __getitem__(self, idx):
        if self.train:
            s = self.shot_des[idx]
            e = self.eigen_des[idx]
            return s, e
        else:
            src_idx = idx*2-1
            tar_idx = idx*2
            s_src = self.shot_des[src_idx]
            s_tar = self.shot_des[tar_idx]
            e_src = self.eigen_des[src_idx]
            e_tar = self.eigen_des[tar_idx]
            return s_src, s_tar, e_src, e_tar
        
    def load_pairs(self):
        intraFname = ''.join([data_dir, intra])
        interFname = ''.join([data_dir, inter])
        pairs = []
        with open(intraFname) as f:
            for line in f:
                array = line.split('_')
                array[1] = array[1][:-1]
                pairs.append(array)
        with open(intraFname) as f:
            for line in f:
                array = line.split('_')
                pairs.append(array)
        return pairs
        
    def load_shot(self, fname):
        i = 0
        shot = []
        
        with open() as fname:
            for line in fname:
                array = line.split()
                if len(array) > 0:
                    np_array = np.array(array)
                    f_array = np_array.astype(float)
                    t_array = torch.from_numpy(f_array)
                    shot.append(f_array)
                i = i + 1
        shot_np = np.array(shot)
        shot_np = shot_np[shot_np[:,0].argsort()]
        shot_ret = shot_np[:, 4:]
        shot_ret = shot_np
        t_shot = torch.from_numpy(shot_ret) #shot_ret
#         t_shot = torch.unsqueeze(t_shot, 0).float().to(cuda_device)
        f.close()
        return t_shot
    
    def load_eig(self, fname):
        file = h5py.File(fname)
        dkey = list(file.keys())[0]
        dset = file[dkey]
        nData = np.array(dset)
        nData = np.transpose(nData)
        numVert = int(nData[0][0])
        numEig = int(nData[0][1])
        phi = nData[1]
        
        vertice = nData[2:numVert+2]
        nd_vert = np.array(vertice)
        nd_vert = nd_vert.astype(float)
        t_vert = torch.from_numpy(nd_vert)
        
        return t_vert