In [1]:
import pandas as pd
import numpy as np
import torch as pt
import multiprocessing

from bps import bps
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import h5py

In [2]:
MAIN_PATH = os.path.join('aptbps-code')
data_path = os.path.join(MAIN_PATH, 'data')
train_path = os.path.join(data_path, 'train')
hdf5_path = os.path.join(data_path, 'hdf5')
no_unlabeled_h5_path = os.path.join(data_path, 'no_unlabeled_hdf5')
hdf5_train_path = os.path.join(hdf5_path, 'train')
encoded_hdf5_path = os.path.join(data_path, 'encoded_hdf5')

# All the clouds in the training dataset
train_files = [
    "bildstein_station1_xyz_intensity_rgb",
    "bildstein_station3_xyz_intensity_rgb",
    "bildstein_station5_xyz_intensity_rgb",
    "domfountain_station1_xyz_intensity_rgb",
    "domfountain_station2_xyz_intensity_rgb",
    "domfountain_station3_xyz_intensity_rgb",
    "neugasse_station1_xyz_intensity_rgb",
    "sg27_station1_intensity_rgb",
    "sg27_station2_intensity_rgb",
    "sg27_station4_intensity_rgb",
    "sg27_station5_intensity_rgb",
    "sg27_station9_intensity_rgb",
    "sg28_station4_intensity_rgb",
    "untermaederbrunnen_station1_xyz_intensity_rgb",
    "untermaederbrunnen_station3_xyz_intensity_rgb",
]

# Clouds used for training
# train_files = [
#     "bildstein_station1_xyz_intensity_rgb", 
#     "bildstein_station5_xyz_intensity_rgb",
#     "domfountain_station1_xyz_intensity_rgb",
#     "domfountain_station3_xyz_intensity_rgb",
#     "neugasse_station1_xyz_intensity_rgb",
#     "sg27_station1_intensity_rgb",
#     "sg27_station2_intensity_rgb",
#     "sg27_station4_intensity_rgb",
#     "sg27_station5_intensity_rgb",
#     "sg27_station9_intensity_rgb",
#     "sg28_station4_intensity_rgb",
#     "untermaederbrunnen_station1_xyz_intensity_rgb",
#     "untermaederbrunnen_station3_xyz_intensity_rgb",
# ]
# 
# # Clouds used for testing
# test_files = [
#     "bildstein_station3_xyz_intensity_rgb",
#     "domfountain_station2_xyz_intensity_rgb",
#     "sg27_station4_intensity_rgb",
#     "untermaederbrunnen_station1_xyz_intensity_rgb",
# ]

In [4]:
# contextual statement that closes store automatically afterwards
file_path = os.path.join(data_path, 'hdf5/train/sg28_station4_intensity_rgb.h5')
with pd.HDFStore(file_path, compression='lz4', mode='r') as store:
    print(store.get('0'))

                   x           y      z     i    r    g    b  label
0          20.764000  -17.844000 -1.985 -1615  239  209  221      1
1          -9.202000   -0.023000 -0.816 -1132  114   89   68      4
2          -9.198000   -0.025000 -0.815 -1183  102   77   56      4
3          -9.200000   -0.026000 -0.817 -1275   98   75   54      4
4          -9.200000   -0.025000 -0.820 -1173  104   82   63      4
...              ...         ...    ...   ...  ...  ...  ...    ...
258720943  54.626999  108.795998  0.145 -1555   41   39   43      0
258720944  54.665001  108.781998  0.126 -1555   43   39   43      0
258720945  54.696999  108.757004  0.184 -1404   39   34   37      0
258720946  54.701000  108.763000  0.146 -1442   47   41   45      0
258720947  54.698002  108.759003  0.165 -1442   39   34   37      0

[258720948 rows x 8 columns]


In [None]:
# Remove all unlabeled points from .h5 files and store them in new .h5 files

for f in tqdm(train_files):
    f_ext = f + '.h5'
    f_path = os.path.join(hdf5_train_path, f_ext)
    with pd.HDFStore(f_path, compression='lz4', mode='r') as store:
        cloud = store.get('0')
        print(cloud.info())
        no_unlabeled_cloud = cloud.drop(cloud[cloud.label == 0].index)
        print(no_unlabeled_cloud.info())
        
        no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
        
        with pd.HDFStore(no_unlabeled_f_path, compression='lz4', mode='w') as no_unlabeled_store:
            no_unlabeled_store['0'] = no_unlabeled_cloud

In [16]:
# contextual statement that closes store automatically afterwards
with pd.HDFStore(no_unlabeled_f_path, compression='lz4', mode='r') as store2:
    print(store2.get('0'))

                  x          y      z     i    r    g    b  label
2         20.360001  40.375999 -2.402 -1083  139  151  165      6
7         20.358999  40.374001 -2.404 -1086  139  151  165      6
8         20.354000  40.375000 -2.404 -1106  139  148  163      6
9         20.356001  40.374001 -2.404 -1059  139  151  165      6
25        20.361000  40.375000 -2.404 -1116  139  151  165      6
...             ...        ...    ...   ...  ...  ...  ...    ...
29684536  33.122002  68.304001  7.580 -1555   50   45   51      5
29684539  33.020000  68.355003  7.579 -1476   43   37   39      5
29684542  33.242001  68.287003  7.597 -1540   61   51   52      5
29684545  33.162998  68.301003  7.620 -1573   67   55   67      5
29684548  33.320000  68.275002  7.616 -1670   73   57   67      5

[9476296 rows x 8 columns]


In [4]:
from bps import bps
import h5py
NUM_POINTS = 2048

In [27]:
# batch of 100 point clouds to convert
x = np.random.normal(size=[100, 2048, 3])

print(x[0])

x_norm = bps.normalize(x)
print('b')
print(x_norm[0])

[[ 1.46084379  0.53316554 -1.13999441]
 [-1.61499181 -1.03422883  1.04827314]
 [ 0.15953528  0.63335278  0.24866621]
 ...
 [-0.0931188   0.11329203 -1.07427838]
 [ 1.53487635  2.07523421 -0.2291684 ]
 [ 0.25565463 -1.02660701  0.59164991]]
b
[[ 0.33417768  0.12928757 -0.27112901]
 [-0.37652509 -0.23287467  0.23449222]
 [ 0.03349725  0.15243684  0.049735  ]
 ...
 [-0.02488102  0.03227157 -0.25594466]
 [ 0.35128364  0.48559806 -0.0606735 ]
 [ 0.05570659 -0.23111357  0.12898483]]


# *Normalize* and *encode*

In [49]:
from sembps.bps import bps

BPS_POINTS = 512

for f in tqdm(train_files):
    f_ext = f + '.h5'
    no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
    encoded_f_path = os.path.join(encoded_hdf5_path, f_ext)
    with pd.HDFStore(no_unlabeled_f_path, compression='lz4', mode='r') as store2:
        
        # Selecting columns (like you would do with the command below) doesn't work.
        # This is because the HDFStore was created with the 'fixed' format, which doesn't allow select but has faster read/write performance.
        # 'table' format would allow selection. Source: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_hdf.html
        #test = pd.DataFrame(store2.select(key="0", columns=["x", "y", "z"]))
        
        # Due to this, we must first retrieve the store in its entirety.
        df = store2.get('0')
    
        # Select the xyz columns only for normalization
        df2 = df[['x','y','z']]
        
        # irgb columns for extra features
        x_features = df[['i','r','g','b']]
        
        labels = df[['label']].to_numpy()
        
        # Store indexes to preserve them
        # (As we removed unlabeled points, and we didn't reset the indexes)
        indexes = df2.index.to_numpy()
        indexes = indexes[:, np.newaxis]
        
        # Convert to np array
        df_np = df2.to_numpy()
        
        # Add extra dimension to match the requirements of the bps library
        df_np = df_np[np.newaxis, :]
        
        # There's no need to reshape again, as I'm adding the newaxis before the other dims (so it becomes 1, n_points, 3)
        # However, just to make sure, I reshape it
        df_np = df_np.reshape(1, -1, 3)
        
        # Normalize cloud
        # Normalization must be done on the full cloud.
        df_np = bps.normalize(df_np)
        
        # -- DIVIDING ARRAY INTO TWO ARRAYS OF SIZE ~2048
        # To speed up BPS encoding, we can merge this list into two np arrays.
        # To obtain even data, we split each cloud into subclouds of ~2048;
        # Each 2048 subcloud is then bps-encoded to 512 points.
        
        n_points = df_np.shape[1]
        cloud_points = 2048
        
        # Remove extra dimension
        df_np = np.squeeze(df_np)
        
        # Floor division
        n_divisions = n_points//cloud_points
        
        # np.array_split returns l % n sub-arrays of size l//n + 1 and the rest of size l//n.
        # This means that the split index to achieve this is split_idx must be equal to (l%n * l//n+1)
        split_idx = (n_points % n_divisions) * (cloud_points + 1)
        
        arr1, arr2 = np.split(df_np, [split_idx])
        lab1, lab2 = np.split(labels, [split_idx])
        idx1, idx2 = np.split(indexes, [split_idx])
    
        # Add extra dimension to hold the n. of clouds
        arr1 = arr1[np.newaxis, :]
        arr2 = arr2[np.newaxis, :]
        
        lab1 = lab1[np.newaxis, :]
        lab2 = lab2[np.newaxis, :]
        
        idx1 = idx1[np.newaxis, :]
        idx2 = idx2[np.newaxis, :]
        
        # Reshape to have n. of clouds as first dimension
        arr1 = arr1.reshape(-1, cloud_points+1, 3)
        arr2 = arr2.reshape(-1, cloud_points, 3)
        lab1 = lab1.reshape(-1, cloud_points+1, 1)
        lab2 = lab2.reshape(-1, cloud_points, 1)
        idx1 = idx1.reshape(-1, cloud_points+1, 1)
        idx2 = idx2.reshape(-1, cloud_points, 1)
        
            
        # bps_idx1 and bps_idx2 contain the indexes of the points sampled by the BPS encoding.
        arr1, bps_idx1 = bps.encode(arr1, n_bps_points=BPS_POINTS, radius=1.7, verbose=False, return_idx=True)
        arr2, bps_idx2 = bps.encode(arr2, n_bps_points=BPS_POINTS, radius=1.7, verbose=False, return_idx=True)
        
        print("BPS encoding complete!")
        
        
        # Now we need to filter our starting labels and indexes;
        # We want to remove all the points that have not been sampled in the bps.encode process
        
        filtered_lab1 = np.empty([lab1.shape[0],BPS_POINTS])
        filtered_lab2 = np.empty([lab2.shape[0],BPS_POINTS])
        filtered_idx1 = np.empty([idx1.shape[0],BPS_POINTS])
        filtered_idx2 = np.empty([idx2.shape[0],BPS_POINTS])
        
        # Filter indexes and labels for first array
        for i, (labc, idxc) in enumerate(zip(lab1, idx1)):
            bps_idx1[i] = bps_idx1[i].tolist() # Convert current filtered BPS indexes to list
            filtered_lab1[i] = np.take(labc, bps_idx1[i]) # Take only the filtered BPS indexes
            filtered_idx1[i] = np.take(idxc, bps_idx1[i]) # Take only the filtered BPS indexes
    
        # Filter indexes and labels for second array
        for j, (labc, idxc) in enumerate(zip(lab2, idx2)):
            bps_idx2[i] = bps_idx2[i].tolist() # Convert current filtered BPS indexes to list
            filtered_lab2[j] = np.take(labc, bps_idx2[j]) # Take only the filtered BPS indexes
            filtered_idx2[j] = np.take(idxc, bps_idx2[j]) # Take only the filtered BPS indexes
        
        # Concatenate the two arrays, as they now have the same n. of points per cloud (BPS_POINTS)
        df_np = np.concatenate((arr1, arr2))
        indexes = np.concatenate((filtered_idx1, filtered_idx2))
        labels = np.concatenate((filtered_lab1, filtered_lab2))
    
        ## Convert to pandas dataframe
        #df2 = pd.DataFrame(data=df_np, index=df_np_indexes, columns=['dist'])
        
        # Remove xyz columns from original dataframe
        #df = df.drop(['x','y','z'], axis=1)
        #
        ## Join normalized xyz dataframe and original dataframe with xyz columns removed
        #df2 = df2.join(df)
        #print(df2)
        
        # Write array to new file
        with h5py.File(encoded_f_path, 'w') as f_to_w:
            for i in range(df_np.shape[0]):
                grp = f_to_w.create_group(str(i)) # e.g. group '0'
                data_dset = grp.create_dataset("data", data=df_np[i], dtype='float32', chunks=True)
                label_dset = grp.create_dataset("label", data=labels[i], dtype='uint8', chunks=True)
                idx_dset = grp.create_dataset("indexes", data=indexes[i], dtype='uint32', chunks=True)


  0%|          | 0/15 [00:00<?, ?it/s][A

BPS encoding complete!



  7%|▋         | 1/15 [00:08<01:57,  8.42s/it][A

BPS encoding complete!



 13%|█▎        | 2/15 [00:15<01:44,  8.02s/it][A

BPS encoding complete!



 20%|██        | 3/15 [00:22<01:33,  7.78s/it][A

BPS encoding complete!



 27%|██▋       | 4/15 [00:31<01:28,  8.04s/it][A

BPS encoding complete!



 33%|███▎      | 5/15 [00:41<01:27,  8.79s/it][A

BPS encoding complete!



 40%|████      | 6/15 [00:52<01:22,  9.19s/it][A

BPS encoding complete!



 47%|████▋     | 7/15 [01:46<03:02, 22.85s/it][A

BPS encoding complete!



 53%|█████▎    | 8/15 [07:36<14:06, 120.99s/it][A

BPS encoding complete!



 60%|██████    | 9/15 [15:37<22:53, 228.92s/it][A

BPS encoding complete!



 67%|██████▋   | 10/15 [20:49<21:09, 253.97s/it][A

BPS encoding complete!



 73%|███████▎  | 11/15 [24:49<16:38, 249.61s/it][A

BPS encoding complete!



 80%|████████  | 12/15 [28:55<12:25, 248.64s/it][A

BPS encoding complete!



 87%|████████▋ | 13/15 [33:41<08:39, 259.65s/it][A

BPS encoding complete!



 93%|█████████▎| 14/15 [34:14<03:11, 191.76s/it][A

BPS encoding complete!



100%|██████████| 15/15 [34:45<00:00, 139.04s/it]


In [None]:
for f in tqdm(train_files):
    f_ext = f + '.h5'
    no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
    encoded_f_path = os.path.join(encoded_hdf5_path, f_ext)
    
    with h5py.File(encoded_f_path, 'r') as f_to_r:
        #print(f_to_r['0']['label'][:])
        #print(len(f_to_r.keys()))
        print(f_to_r['0']['data'][:])
        break

In [None]:
# Concatenate .h5 files into a single sliceable virtual dataset
def concatenate(file_names_to_concatenate):
    entry_key = 'data'  # where the data is inside of the source files.
    sh = h5py.File(file_names_to_concatenate[0], 'r')[entry_key].shape  # get the first ones shape.
    layout = h5py.VirtualLayout(shape=(len(file_names_to_concatenate),) + sh,
                                dtype=np.float)
    with h5py.File("VDS.h5", 'w', libver='latest') as f:
        for i, filename in enumerate(file_names_to_concatenate):
            vsource = h5py.VirtualSource(filename, entry_key, shape=sh)
            layout[i, :, :, :] = vsource

        f.create_virtual_dataset(entry_key, layout, fillvalue=0)

In [None]:
for each file:
    for entries in file:
        concatenate all the entries
    concatenate all files
    
for f in tqdm(train_files):
    f_ext = f + '.h5'
    no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
    encoded_f_path = os.path.join(encoded_hdf5_path, f_ext)
    
    with h5py.File(encoded_f_path, 'r') as f_to_r:
        print(f_to_r['0']['label'][:])
        print(len(f_to_r.keys()))

In [4]:
# CREATE VIRTUAL DATASET

BPS_POINTS = 512

VDS_path = os.path.join(encoded_hdf5_path, 'VDS.h5')

# Initialize list for VirtualSources
data_vsources = []
label_vsources = []

for f in tqdm(train_files):
    f_ext = f + '.h5'
    no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
    encoded_f_path = os.path.join(encoded_hdf5_path, f_ext)

    with h5py.File(encoded_f_path, 'r') as f_to_r:
            # Iterate over all groups within the file
            for group in f_to_r.keys():
                # Specify paths to data and label datasets within group
                data_s = group + '/data'
                label_s = group + '/label'
                # Create VirtualSources for both datasets and append them to respective lists
                data_vsource = h5py.VirtualSource(encoded_f_path, data_s, shape=(1, 512))
                label_vsource = h5py.VirtualSource(encoded_f_path, label_s, shape=(1, 512))
                data_vsources.append(data_vsource)
                label_vsources.append(label_vsource)
            
            #print((f_to_r[str(len(data_vsources)-1)]['label'][:]))
                
    
data_vlayout = h5py.VirtualLayout(shape=(len(data_vsources), 512))
label_vlayout = h5py.VirtualLayout(shape=(len(label_vsources), 512))
    
# Populate layouts
for i, (data_vsource, label_vsource) in enumerate(zip(data_vsources, label_vsources)):
    data_vlayout[i] = data_vsource
    label_vlayout[i] = label_vsource
        
# Add virtual dataset to output file
with h5py.File(VDS_path, "w", libver="latest") as f:
    f.create_virtual_dataset("vdata", data_vlayout, fillvalue=-5)
    f.create_virtual_dataset("vlabels", label_vlayout, fillvalue=-5)
    
# read data back
# virtual dataset is transparent for reader!
with h5py.File(VDS_path, "r") as f:
    print("Virtual datasets:")
    print(f["vdata"].shape)
    #print(f["vdata"][:, :10])
    print(f["vlabels"].shape)
    #print(f["vlabels"][:, :10])

100%|██████████| 15/15 [02:18<00:00,  9.24s/it]


Virtual datasets:
(952084, 512)
(952084, 512)


In [None]:
VDS_path = os.path.join(encoded_hdf5_path, 'VDS.h5')
print('we')
# read data back
# virtual dataset is transparent for reader!
with h5py.File(VDS_path, "r") as f:
    print("Virtual datasets:")
    dset = f["vdata"][:]
    #print(f["vdata"].shape)
    #print(f["vdata"][:, :10])
    #print(f["vlabels"].shape)
    #print(f["vlabels"][:, :10])

In [22]:
no_unlabeled_f_path = os.path.join(no_unlabeled_h5_path, f_ext)
encoded_f_path = os.path.join(encoded_hdf5_path, f_ext)

with pd.HDFStore(no_unlabeled_f_path, compression='lz4', mode='r') as store2:
        
        # Selecting columns (like you would do with the command below) doesn't work.
        # This is because the HDFStore was created with the 'fixed' format, which doesn't allow select but has faster read/write performance.
        # 'table' format would allow selection. Source: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_hdf.html
        #test = pd.DataFrame(store2.select(key="0", columns=["x", "y", "z"]))
        
        # Due to this, we must first retrieve the store in its entirety.
        df = store2.get('0')
        
        labels = df[['label']].to_numpy()
        unique, counts = np.unique(labels, return_counts=True)

        print(np.asarray((unique, counts)).T)

[[      1 1294930]
 [      2 3508530]
 [      3 2538281]
 [      4   49687]
 [      5 1242205]
 [      6  763114]
 [      7   13899]
 [      8   65650]]


In [None]:
%who

In [None]:
# Prototype dataset class
class HDF5Set(Dataset):
    def __init__(self, h5filelist):
        self.h5filelist = h5filelist
    
    def __get__item(self, index):
        # you have n of basis points and number of keys for each file;
        # based on that, from the index you can get which file you need to open
        # to obtain the point
        
    def __len__(self):
        x = BASIS_POINTS # multiplied by the total number of entries 
        return(x)
        

In [44]:
f = h5py.File('ply_data_test0.h5', 'r')
data = f["data"][:]
label = f["label"][:]
print(data.shape)
print(label.shape)

(2048, 2048, 3)
(2048, 1)


In [None]:
# https://discuss.pytorch.org/t/use-of-dataset-class/1620/3
#class Sem3d(Dataset):
#  def __init__(self, hdf5_list):
#        self.datasets = []
#        self.total_count = 0
#        for f in hdf5_list:
#            h5_file = h5py.File(f, 'r')
#            dataset = h5_file['YOUR DATASET NAME']
#            self.datasets.append(dataset)
#            self.total_count += len(dataset)
#
#  def __getitem__(self, index):
#     
#      dataset_index=-1
#      for i in xrange(len(self.limits)-1,-1,-1):
#        #print 'i ',i
#        if index>=self.limits[i]:
#          dataset_index=i
#          break
#      #print 'dataset_index ',dataset_index
#      assert dataset_index>=0, 'negative chunk'
#
#      in_dataset_index = index-self.limits[dataset_index]
#
#      return self.datasets[dataset_index][in_dataset_index], self.datasets_gt[dataset_index][in_dataset_index]
#
#  def __len__(self):
#      return self.total_count 

In [None]:
import h5py
import helpers
import numpy as np
from pathlib import Path
import torch
from torch.utils import data

class HDF5Dataset(data.Dataset):
    """Represents an abstract HDF5 dataset.
    
    Input params:
        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
        recursive: If True, searches for h5 files in subdirectories.
        load_data: If True, loads all the data immediately into RAM. Use this if
            the dataset is fits into memory. Otherwise, leave this at false and 
            the data will load lazily.
        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
        transform: PyTorch transform to apply to every data instance (default=None).
    """
    def __init__(self, file_path, recursive=False, load_data=False, data_cache_size=3, transform=None):
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size
        self.transform = transform

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())
        if recursive:
            files = sorted(p.glob('**/*.h5'))
        else:
            files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()), load_data)
            
    def __getitem__(self, index):
        # get data
        x = self.get_data("data", index)
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)

        # get label
        y = self.get_data("label", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('data'))
    
    def _add_data_infos(self, file_path, load_data):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    if load_data:
                        # add data to the data cache
                        idx = self._add_to_cache(ds.value, file_path)
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    idx = self._add_to_cache(ds.value, file_path)

                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]