In [68]:
import os
import numpy as np
from PIL import Image
import pandas as pd 

import torch
from torch.utils.data import Dataset
import glob
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import platform
import pickle
import json

In [69]:
class Discretizer:
    def __init__(self, timestep=0.8, store_masks=True, impute_strategy='zero', start_time='zero',
                 config_path= '/scratch/se1525/MedFuse/ehr_utils/resources/discretizer_config.json'):

        with open(config_path) as f:
            config = json.load(f)
            self._id_to_channel = config['id_to_channel']
            self._channel_to_id = dict(zip(self._id_to_channel, range(len(self._id_to_channel))))
            self._is_categorical_channel = config['is_categorical_channel']
            self._possible_values = config['possible_values']
            self._normal_values = config['normal_values']

        self._header = ["Hours"] + self._id_to_channel
        self._timestep = timestep
        self._store_masks = store_masks
        self._start_time = start_time
        self._impute_strategy = impute_strategy

        # for statistics
        self._done_count = 0
        self._empty_bins_sum = 0
        self._unused_data_sum = 0

    def transform(self, X, header=None, end=None):
        if header is None:
            header = self._header
        assert header[0] == "Hours"
        eps = 1e-6

        N_channels = len(self._id_to_channel)
        ts = [float(row[0]) for row in X]
        for i in range(len(ts) - 1):
            assert ts[i] < ts[i+1] + eps

        if self._start_time == 'relative':
            first_time = ts[0]
        elif self._start_time == 'zero':
            first_time = 0
        else:
            raise ValueError("start_time is invalid")

        if end is None:
            max_hours = max(ts) - first_time
        else:
            max_hours = end - first_time

        N_bins = int(max_hours / self._timestep + 1.0 - eps)

        cur_len = 0
        begin_pos = [0 for i in range(N_channels)]
        end_pos = [0 for i in range(N_channels)]
        for i in range(N_channels):
            channel = self._id_to_channel[i]
            begin_pos[i] = cur_len
            if self._is_categorical_channel[channel]:
                end_pos[i] = begin_pos[i] + len(self._possible_values[channel])
            else:
                end_pos[i] = begin_pos[i] + 1
            cur_len = end_pos[i]

        data = np.zeros(shape=(N_bins, cur_len), dtype=float)
        mask = np.zeros(shape=(N_bins, N_channels), dtype=int)
        original_value = [["" for j in range(N_channels)] for i in range(N_bins)]
        total_data = 0
        unused_data = 0

        def write(data, bin_id, channel, value, begin_pos):
            channel_id = self._channel_to_id[channel]
            if self._is_categorical_channel[channel]:
                category_id = self._possible_values[channel].index(value)
                N_values = len(self._possible_values[channel])
                one_hot = np.zeros((N_values,))
                one_hot[category_id] = 1
                for pos in range(N_values):
                    data[bin_id, begin_pos[channel_id] + pos] = one_hot[pos]
            else:
                data[bin_id, begin_pos[channel_id]] = float(value)

        for row in X:
            t = float(row[0]) - first_time
            if t > max_hours + eps:
                continue
            bin_id = int(t / self._timestep - eps)
            assert 0 <= bin_id < N_bins

            for j in range(1, len(row)):
                if row[j] == "":
                    continue
                channel = header[j]
                channel_id = self._channel_to_id[channel]

                total_data += 1
                if mask[bin_id][channel_id] == 1:
                    unused_data += 1
                mask[bin_id][channel_id] = 1

                write(data, bin_id, channel, row[j], begin_pos)
                original_value[bin_id][channel_id] = row[j]

        # impute missing values

        if self._impute_strategy not in ['zero', 'normal_value', 'previous', 'next']:
            raise ValueError("impute strategy is invalid")

        if self._impute_strategy in ['normal_value', 'previous']:
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if self._impute_strategy == 'normal_value':
                        imputed_value = self._normal_values[channel]
                    if self._impute_strategy == 'previous':
                        if len(prev_values[channel_id]) == 0:
                            imputed_value = self._normal_values[channel]
                        else:
                            imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        if self._impute_strategy == 'next':
            prev_values = [[] for i in range(len(self._id_to_channel))]
            for bin_id in range(N_bins-1, -1, -1):
                for channel in self._id_to_channel:
                    channel_id = self._channel_to_id[channel]
                    if mask[bin_id][channel_id] == 1:
                        prev_values[channel_id].append(original_value[bin_id][channel_id])
                        continue
                    if len(prev_values[channel_id]) == 0:
                        imputed_value = self._normal_values[channel]
                    else:
                        imputed_value = prev_values[channel_id][-1]
                    write(data, bin_id, channel, imputed_value, begin_pos)

        empty_bins = np.sum([1 - min(1, np.sum(mask[i, :])) for i in range(N_bins)])
        self._done_count += 1
        self._empty_bins_sum += empty_bins / (N_bins + eps)
        self._unused_data_sum += unused_data / (total_data + eps)

        if self._store_masks:
            data = np.hstack([data, mask.astype(np.float32)])

        # create new header
        new_header = []
        for channel in self._id_to_channel:
            if self._is_categorical_channel[channel]:
                values = self._possible_values[channel]
                for value in values:
                    new_header.append(channel + "->" + value)
            else:
                new_header.append(channel)

        if self._store_masks:
            for i in range(len(self._id_to_channel)):
                channel = self._id_to_channel[i]
                new_header.append("mask->" + channel)

        new_header = ",".join(new_header)

        return (data, new_header)

    def print_statistics(self):
        print("statistics of discretizer:")
        print("\tconverted {} examples".format(self._done_count))
        print("\taverage unused data = {:.2f} percent".format(100.0 * self._unused_data_sum / self._done_count))
        print("\taverage empty  bins = {:.2f} percent".format(100.0 * self._empty_bins_sum / self._done_count))

In [70]:
class Normalizer:
    def __init__(self, fields=None):
        self._means = None
        self._stds = None
        self._fields = None
        if fields is not None:
            self._fields = [col for col in fields]

        self._sum_x = None
        self._sum_sq_x = None
        self._count = 0

    def _feed_data(self, x):
        x = np.array(x)
        self._count += x.shape[0]
        if self._sum_x is None:
            self._sum_x = np.sum(x, axis=0)
            self._sum_sq_x = np.sum(x**2, axis=0)
        else:
            self._sum_x += np.sum(x, axis=0)
            self._sum_sq_x += np.sum(x**2, axis=0)

    def _save_params(self, save_file_path):
        eps = 1e-7
        with open(save_file_path, "wb") as save_file:
            N = self._count
            self._means = 1.0 / N * self._sum_x
            self._stds = np.sqrt(1.0/(N - 1) * (self._sum_sq_x - 2.0 * self._sum_x * self._means + N * self._means**2))
            self._stds[self._stds < eps] = eps
            pickle.dump(obj={'means': self._means,
                             'stds': self._stds},
                        file=save_file,
                        protocol=2)

    def load_params(self, load_file_path):
        with open(load_file_path, "rb") as load_file:
            if platform.python_version()[0] == '2':
                dct = pickle.load(load_file)
            else:
                dct = pickle.load(load_file, encoding='latin1')
            self._means = dct['means']
            self._stds = dct['stds']

    def transform(self, X):
        if self._fields is None:
            fields = range(X.shape[1])
        else:
            fields = self._fields
        ret = 1.0 * X
        for col in fields:
            ret[:, col] = (X[:, col] - self._means[col]) / self._stds[col]
        return ret

In [71]:
with open('/scratch/fs999/shamoutlab/data/mimic-iv-extracted/decompensation/test_listfile.csv','r') as file:
    data = file.readlines()

In [72]:
header = data[0]
# print(header)
classes = header.strip().split(',')[3:]
# print(classes)
data1 = data[1:]
# print(data1)

In [73]:
data1 = [line.split(',') for line in data1]
# data1

In [74]:
data_map = {mas[0]: {'labels': list(map(float, mas[3:])),
                     'stay_id': float(mas[2]),
                     'time': float(mas[1])}
                           for mas in data1
                }

In [75]:
list(data_map.items())[:10]

[('10001884_episode1_timeseries.csv',
  {'labels': [1.0], 'stay_id': 37510196.0, 'time': 216.0}),
 ('10002155_episode1_timeseries.csv',
  {'labels': [0.0], 'stay_id': 33685454.0, 'time': 148.0}),
 ('10002155_episode2_timeseries.csv',
  {'labels': [0.0], 'stay_id': 31090461.0, 'time': 93.0}),
 ('10002155_episode3_timeseries.csv',
  {'labels': [1.0], 'stay_id': 32358465.0, 'time': 20.0}),
 ('10002348_episode1_timeseries.csv',
  {'labels': [0.0], 'stay_id': 32610785.0, 'time': 235.0}),
 ('10002428_episode1_timeseries.csv',
  {'labels': [0.0], 'stay_id': 34807493.0, 'time': 48.0}),
 ('10002930_episode1_timeseries.csv',
  {'labels': [0.0], 'stay_id': 37049133.0, 'time': 27.0}),
 ('10002930_episode2_timeseries.csv',
  {'labels': [0.0], 'stay_id': 35629889.0, 'time': 16.0}),
 ('10003400_episode1_timeseries.csv',
  {'labels': [0.0], 'stay_id': 32128372.0, 'time': 309.0}),
 ('10003400_episode2_timeseries.csv',
  {'labels': [0.0], 'stay_id': 34577403.0, 'time': 70.0})]

In [76]:
def read_timeseries(ts_file, time_bound=48):
        
    ret = []
    with open(ts_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        assert header[0] == "Hours"
        for line in 000tsfile:
            mas = line.strip().split(',')
            if time_bound is not None:
                t = float(mas[0])
                if t > time_bound + 1e-6:
                    break
            ret.append(np.array(mas))
    return (np.stack(ret), header)

In [77]:
read_timeseries('/scratch/fs999/shamoutlab/data/mimic-iv-extracted/phenotyping/test/14851532_episode3_timeseries.csv')

(array([['0.11666666666666667', '', '', '', '', '', '', '', '', '109', '',
         '', '', '30', '', '', '', ''],
        ['0.16666666666666666', '', '61.0', '', '', '', '', '', '', '109',
         '', '64', '97.0', '29', '74.0', '', '', ''],
        ['0.6666666666666666', '', '', '', '', '', '', '', '167.0', '',
         '', '', '', '', '', '', '', ''],
        ['0.9333333333333333', '', '', '', '', '', '', '', '', '', '', '',
         '', '', '', '39.05555555555556', '', ''],
        ['1.0833333333333333', '', '', '', 'Spontaneously',
         'Obeys Commands', '', 'Oriented', '', '', '', '', '', '', '', '',
         '', ''],
        ['1.1666666666666667', '', '48.0', '', '', '', '', '', '', '93',
         '', '60', '99.0', '26', '93.0', '', '', ''],
        ['2.1666666666666665', '', '45.0', '', '', '', '', '', '115.0',
         '81', '', '54', '98.0', '22', '83.0', '', '', ''],
        ['3.1666666666666665', '', '52.0', '', '', '', '', '', '', '80',
         '', '63', '100.0', '22

In [78]:
def read_by_file_name(index, time_bound=5):
    t = data_map[index.split('/')[-1]]['time'] if time_bound is None else time_bound
    y = data_map[index.split('/')[-1]]['labels']
    stay_id = data_map[index.split('/')[-1]]['stay_id']
    (X, header) = read_timeseries(index, time_bound=time_bound)

    return {"X": X,
            "t": t,
            "y": y,
            'stay_id': stay_id,
            "header": header,
            "name": index}

In [79]:
read_by_file_name('/scratch/fs999/shamoutlab/data/mimic-iv-extracted/phenotyping/test/14851532_episode3_timeseries.csv')

{'X': array([['0.11666666666666667', '', '', '', '', '', '', '', '', '109', '',
         '', '', '30', '', '', '', ''],
        ['0.16666666666666666', '', '61.0', '', '', '', '', '', '', '109',
         '', '64', '97.0', '29', '74.0', '', '', ''],
        ['0.6666666666666666', '', '', '', '', '', '', '', '167.0', '',
         '', '', '', '', '', '', '', ''],
        ['0.9333333333333333', '', '', '', '', '', '', '', '', '', '', '',
         '', '', '', '39.05555555555556', '', ''],
        ['1.0833333333333333', '', '', '', 'Spontaneously',
         'Obeys Commands', '', 'Oriented', '', '', '', '', '', '', '', '',
         '', ''],
        ['1.1666666666666667', '', '48.0', '', '', '', '', '', '', '93',
         '', '60', '99.0', '26', '93.0', '', '', ''],
        ['2.1666666666666665', '', '45.0', '', '', '', '', '', '115.0',
         '81', '', '54', '98.0', '22', '83.0', '', '', ''],
        ['3.1666666666666665', '', '52.0', '', '', '', '', '', '', '80',
         '', '63', '100.0'

In [80]:
class EHRdataset(Dataset):
    def __init__(self, discretizer, normalizer, listfile, dataset_dir, return_names=True, period_length=24.0):
        self.return_names = return_names
        self.discretizer = discretizer
        self.normalizer = normalizer
        self._period_length = period_length

        self._dataset_dir = dataset_dir
        listfile_path = listfile
        with open(listfile_path, "r") as lfile:
            self._data = lfile.readlines()
        # column names
        self._listfile_header = self._data[0]
        # label columns
        self.CLASSES = self._listfile_header.strip().split(',')[3:]
        # files list with lables 
        self._data = self._data[1:]
        
        ### pay attention
        self._data = [line.split(',') for line in self._data]
        self.data_map = {
            mas[0]: {
                'labels': list(map(float, mas[3:])),
                'stay_id': float(mas[2]),
                'time': float(mas[1]),
                }
                for mas in self._data
        }
        self.names = list(self.data_map.keys())
        ### pay attention
    
    def _read_timeseries(self, ts_filename, time_bound=None):
        
        ret = []
        with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
            header = tsfile.readline().strip().split(',')
            assert header[0] == "Hours"
            for line in tsfile:
                mas = line.strip().split(',')
                if time_bound is not None:
                    t = float(mas[0])
                    if t > time_bound + 1e-6:
                        break
                ret.append(np.array(mas))
        return (np.stack(ret), header)
    
    def read_by_file_name(self, index, time_bound=None):
        t = self.data_map[index]['time'] if time_bound is None else time_bound
        y = self.data_map[index]['labels']
        stay_id = self.data_map[index]['stay_id']
        (X, header) = self._read_timeseries(index, time_bound=time_bound)

        return {"X": X,
                "t": t,
                "y": y,
                'stay_id': stay_id,
                "header": header,
                "name": index}

    def get_decomp_los(self, index, time_bound=None):
        # name = self._data[index][0]
        # time_bound = self._data[index][1]
        # ys = self._data[index][3]

        # (data, header) = self._read_timeseries(index, time_bound=time_bound)
        # data = self.discretizer.transform(data, end=time_bound)[0] 
        # if (self.normalizer is not None):
        #     data = self.normalizer.transform(data)
        # ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
        # return data, ys

        # data, ys = 
        return self.__getitem__(index, time_bound)


    def __getitem__(self, index, time_bound=None):
        if isinstance(index, int):
            index = self.names[index]
        ret = self.read_by_file_name(index, time_bound)
        data = ret["X"]
        ts = ret["t"] if ret['t'] > 0.0 else self._period_length
        ys = ret["y"]
        names = ret["name"]
        data = self.discretizer.transform(X=data, header=None, end=ts)[0] 
        if (self.normalizer is not None):
            data = self.normalizer.transform(X=data)
        ys = np.array(ys, dtype=np.int32) if len(ys) > 1 else np.array(ys, dtype=np.int32)[0]
        return data, ys

    
    def __len__(self):
        return len(self.names)

In [81]:
def my_collate(batch):
    x = [item[0] for item in batch]
    x, seq_length = pad_zeros(x)
    targets = np.array([item[1] for item in batch])
    return [x, targets, seq_length]

def pad_zeros(arr, min_length=None):
    dtype = arr[0].dtype
    seq_length = [x.shape[0] for x in arr]
    max_len = max(seq_length)
    ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
           for x in arr]
    if (min_length is not None) and ret[0].shape[0] < min_length:
        ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
               for x in ret]
    return np.array(ret), seq_length

In [82]:
listfile1 = '/scratch/fs999/shamoutlab/data/mimic-iv-extracted/in-hospital-mortality/train_listfile.csv'
data_dir1 = '/scratch/fs999/shamoutlab/data/mimic-iv-extracted/in-hospital-mortality/train/'

In [83]:
discretizer = Discretizer()
normalizer = Normalizer()
normalizer.load_params('/scratch/se1525/mml-ssl/ph_ts0.8.input_str:previous.start_time:zero.normalizer')

In [84]:
test_set = EHRdataset(discretizer,normalizer,listfile1, data_dir1)

In [85]:
 test_dataloader = DataLoader(test_set, 32, shuffle=True, collate_fn=my_collate, pin_memory=True, num_workers=16)

In [86]:
next(iter(test_dataloader))[1].shape

(32,)

In [87]:
import threading
import os
import numpy as np
import random

In [88]:
class Reader(object):
    def __init__(self, dataset_dir, listfile=None):
        self._dataset_dir = dataset_dir
        self._current_index = 0
        if listfile is None:
            listfile_path = os.path.join(dataset_dir, "listfile.csv")
        else:
            listfile_path = listfile
        with open(listfile_path, "r") as lfile:
            self._data = lfile.readlines()
        self._listfile_header = self._data[0]
        self._data = self._data[1:]

    def get_number_of_examples(self):
        return len(self._data)

    def random_shuffle(self, seed=None):
        if seed is not None:
            random.seed(seed)
        random.shuffle(self._data)

    def read_example(self, index):
        raise NotImplementedError()

    def read_next(self):
        to_read_index = self._current_index
        self._current_index += 1
        if self._current_index == self.get_number_of_examples():
            self._current_index = 0
        return self.read_example(to_read_index)



In [91]:
Reader(dataset_dir=data_dir1,listfile=listfile1).get_number_of_examples()

18845

In [92]:
class DecompensationReader(Reader):
    def __init__(self, dataset_dir, listfile=None):
        """ Reader for decompensation prediction task.
        :param dataset_dir: Directory where timeseries files are stored.
        :param listfile:    Path to a listfile. If this parameter is left `None` then
                            `dataset_dir/listfile.csv` will be used.
        """
        Reader.__init__(self, dataset_dir, listfile)
        self._data = [line.split(',') for line in self._data]


        self._data = [(x, float(t), int(id), int(y)) for (x, t, id, y) in self._data]

    def _read_timeseries(self, ts_filename, time_bound):
        ret = []
        with open(os.path.join(self._dataset_dir, ts_filename), "r") as tsfile:
            header = tsfile.readline().strip().split(',')
            assert header[0] == "Hours"
            for line in tsfile:
                mas = line.strip().split(',')
                t = float(mas[0])
                if t > time_bound + 1e-6:
                    break
                ret.append(np.array(mas))
        return (np.stack(ret), header)

    def read_example(self, index):
        """ Read the example with given index.

        :param index: Index of the line of the listfile to read (counting starts from 0).
        :return: Directory with the following keys:
            X : np.array
                2D array containing all events. Each row corresponds to a moment.
                First column is the time and other columns correspond to different
                variables.
            t : float
                Length of the data in hours. Note, in general, it is not equal to the
                timestamp of last event.
            y : int (0 or 1)
                Mortality within next 24 hours.
            header : array of strings
                Names of the columns. The ordering of the columns is always the same.
            name: Name of the sample.
        """
        if index < 0 or index >= len(self._data):
            raise ValueError("Index must be from 0 (inclusive) to number of examples (exclusive).")

        name = self._data[index][0]
        t = self._data[index][1]
        id = self._data[index][2]
        y = self._data[index][3]
        (X, header) = self._read_timeseries(name, t)

        return {"X": X,
                "t": t,
                "id":id,
                "y": y,
                "header": header,
                "name": name}

In [93]:
listfile = '/scratch/fs999/shamoutlab/data/mimic-iv-extracted/decompensation/test_listfile.csv'
datadir = '/scratch/fs999/shamoutlab/data/mimic-iv-extracted/decompensation/test'

In [94]:
reader = DecompensationReader(datadir,listfile)

In [95]:
reader.read_example(811430)['X'].shape

(239, 18)

In [96]:
def preprocess_chunk(data, ts, discretizer, normalizer=None):
    data = [discretizer.transform(X, end=t)[0] for (X, t) in zip(data, ts)]
    if normalizer is not None:
        data = [normalizer.transform(X) for X in data]
    return data

In [97]:
def read_chunk(reader, chunk_size):
    data = {}
    for i in range(chunk_size):
        ret = reader.read_next()
        for k, v in ret.items():
            if k not in data:
                data[k] = []
            data[k].append(v)
    data["header"] = data["header"][0]
    return data

In [98]:
def sort_and_shuffle(data, batch_size):
    """ Sort data by the length and then make batches and shuffle them.
        data is tuple (X1, X2, ..., Xn) all of them have the same length.
        Usually data = (X, y).
    """
    assert len(data) >= 2
    data = list(zip(*data))

    random.shuffle(data)

    old_size = len(data)
    rem = old_size % batch_size
    head = data[:old_size - rem]
    tail = data[old_size - rem:]
    data = []

    head.sort(key=(lambda x: x[0].shape[0]))

    mas = [head[i: i+batch_size] for i in range(0, len(head), batch_size)]
    random.shuffle(mas)

    for x in mas:
        data += x
    data += tail

    data = list(zip(*data))
    return data

In [99]:
def pad_zeros(arr, min_length=None):
    """
    `arr` is an array of `np.array`s

    The function appends zeros to every `np.array` in `arr`
    to equalize their first axis lenghts.
    """
    dtype = arr[0].dtype
    max_len = max([x.shape[0] for x in arr])
    ret = [np.concatenate([x, np.zeros((max_len - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
           for x in arr]
    if (min_length is not None) and ret[0].shape[0] < min_length:
        ret = [np.concatenate([x, np.zeros((min_length - x.shape[0],) + x.shape[1:], dtype=dtype)], axis=0)
               for x in ret]
    return np.array(ret)

In [100]:
class BatchGen(object):

    def __init__(self, reader, discretizer, normalizer,
                 batch_size, steps, shuffle, return_names=False):
        self.reader = reader
        self.discretizer = discretizer
        self.normalizer = normalizer
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.return_names = return_names

        if steps is None:
            self.n_examples = reader.get_number_of_examples()
            self.steps = (self.n_examples + batch_size - 1) // batch_size
        else:
            self.n_examples = steps * batch_size
            self.steps = steps

        self.chunk_size = min(1024, self.steps) * batch_size
        self.lock = threading.Lock()
        self.generator = self._generator()

    def _generator(self):
        B = self.batch_size
        while True:
            if self.shuffle:
                self.reader.random_shuffle()
            remaining = self.n_examples
            while remaining > 0:
                current_size = min(self.chunk_size, remaining)
                remaining -= current_size

                ret = read_chunk(self.reader, current_size)
                Xs = ret["X"]
                ts = ret["t"]
                ys = ret["y"]
                names = ret["name"]
                
                #Xs = preprocess_chunk(Xs, ts, self.discretizer, self.normalizer)
                (Xs, ys, ts, names) = sort_and_shuffle([Xs, ys, ts, names], B)

                for i in range(0, current_size, B):
                    X = pad_zeros(Xs[i:i + B])
                    y = np.array(ys[i:i + B])
                    batch_names = names[i:i+B]
                    batch_ts = ts[i:i+B]
                    batch_data = (X, y)
                    if not self.return_names:
                        yield batch_data
                    else:
                        yield {"data": batch_data, "names": batch_names, "ts": batch_ts}

    def __iter__(self):
        return self.generator

    def next(self):
        with self.lock:
            return next(self.generator)

    def __next__(self):
        return self.next()

In [101]:
(811440+2-1)//2

405720

In [102]:
batch_loader = BatchGen(reader,discretizer,normalizer,batch_size=1,steps=None,shuffle=True,return_names=True)

In [103]:
data = next(iter(batch_loader))

In [104]:
data

{'data': (array([[['0.3', '', '', ..., '', '', ''],
          ['0.48333333333333334', '', '85.0', ..., '', '', ''],
          ['0.5', '', '', ..., '', '', ''],
          ...,
          ['27.683333333333334', '', '', ..., '', '', ''],
          ['28.466666666666665', '', '', ..., '', '', ''],
          ['28.483333333333334', '', '60.0', ..., '', '', '']]],
        dtype='<U19'),
  array([0])),
 'names': ('15048951_episode3_timeseries.csv',),
 'ts': (29.0,)}

In [105]:
data['data'][0].shape

(1, 69, 18)