## Track 2 (raw data)

In [None]:
import os
import json
import copy
import pyhrv
import scipy
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from monai.config import KeysCollection
from monai.transforms import MapTransform
from monai.transforms import Compose, ToTensorD
from monai.data import CacheDataset, DataLoader, DistributedSampler

valid_range = {
    "acc_X" : (-19.6, 19.6),
    "acc_Y" : (-19.6, 19.6),
    "acc_Z" : (-19.6, 19.6),
    "gyr_X" : (-573, 573),
    "gyr_Y" : (-573, 573),
    "gyr_Z" : (-573, 573),
    "heartRate" : (0, 255),
    "rRInterval" : (0, 2000),
}



In [None]:
root_dir = Path("../../datasets/SPGC_challenge_track_2_release/")

def get_paths(root_dir, split):
    paths = []
    if split != 'test':
        base_dir = Path('training_data')
        for user in os.listdir(root_dir/base_dir):
            user_dir = base_dir/Path(user)/Path(split)
            for status in os.listdir(root_dir/user_dir):
                status_dir = user_dir/Path(status)
                for sample in os.listdir(root_dir/status_dir):
                    paths.append(status_dir/Path(sample))
    else:
        base_dir = Path('test_data')
        for user in os.listdir(root_dir/base_dir):
            user_dir = base_dir/Path(user)/Path(split)
            for sample in os.listdir(root_dir/user_dir):
                paths.append(user_dir/Path(sample))
    return paths

def parse_path(path):
    path = str(path).split("/")
    user = int(path[1].split("_")[1])
    split = path[2]
    if len(path)==5:
        status = 1 if path[3]=='relapse' else 0
        id = int(path[4])
    else:
        status = -1
        id = int(path[3])
    return user, split, status, id

In [None]:
paths = {split: get_paths(root_dir, split) for split in ['train', 'val', 'test']}

In [None]:
base_dir = Path('training_data')
split = 'val'
for user in os.listdir(root_dir/base_dir):
    user_dir = base_dir/Path(user)/Path(split)
    for status in os.listdir(root_dir/user_dir):
        print(f"User: {user}, Status: {status}, Samples:{len(os.listdir(root_dir/user_dir/Path(status)))}")

User: user_02, Status: non-relapse, Samples:25
User: user_02, Status: relapse, Samples:13
User: user_08, Status: non-relapse, Samples:13
User: user_08, Status: relapse, Samples:3
User: user_00, Status: non-relapse, Samples:31
User: user_00, Status: relapse, Samples:9
User: user_01, Status: non-relapse, Samples:22
User: user_01, Status: relapse, Samples:57
User: user_09, Status: non-relapse, Samples:21
User: user_09, Status: relapse, Samples:73
User: user_05, Status: non-relapse, Samples:27
User: user_05, Status: relapse, Samples:22
User: user_07, Status: non-relapse, Samples:29
User: user_07, Status: relapse, Samples:93
User: user_06, Status: non-relapse, Samples:26
User: user_06, Status: relapse, Samples:4
User: user_04, Status: non-relapse, Samples:22
User: user_04, Status: relapse, Samples:3
User: user_03, Status: non-relapse, Samples:21
User: user_03, Status: relapse, Samples:17


### Create Dataset Dataframes

In [None]:
root_dir = Path("../../datasets/SPGC_challenge_track_2_release/")
splits = ['train', 'val', 'test']
paths = {split: get_paths(root_dir, split) for split in splits}

def validate(window):
    invalid_filter = window.isna().any(axis=1)
    return 1- (len(window[invalid_filter])/len(window)) 

def get_observations(root_dir, path, w_size_h=4, w_stride_h=1, val_percentage=0.25):
    
    data = pd.read_csv(root_dir/path/"data.csv")
    user, split, status, id = parse_path(path)
    w_size = int(w_size_h*12*60)
    w_stride = int(w_stride_h*12*60)
    obs = []
    path = Path(path/"data.csv")
    # Treat short sequences
    if len(data) < w_size:
        if split == 'train':
            return obs
        # Consider short windows in validation and test
        else:
            validity = validate(data)
            return [{
                'data_file' : path,
                'user_id' : user,
                'sample_id' : id,
                'label' : status,
                'valid' : validity >= val_percentage,
                'start_data_row' : 0,
                'end_data_row' : len(data) 
            }]
    
    # Slide windows
    for start in range(0, len(data)-w_size, w_stride):
        stop = start + w_size # excluded
        window = data.loc[start:stop-1] # upperbound is included
        # check validity
        validity = validate(window)
        obs.append({
            'data_file' : path,
            'user_id' : user,
            'sample_id' : id,
            'label' : status,
            'valid' : validity >= val_percentage,
            'start_data_row' : start,
            'end_data_row' :stop
        })

    return obs

def create_dataset_list(root_dir, paths,  w_size_h=4, w_stride_h=1, val_percentage=0.25):
    dataset_list = []
    for sample in paths:
        # open file
        dataset_list.extend(get_observations(root_dir, sample, w_size_h=w_size_h, w_stride_h=w_stride_h, val_percentage=val_percentage))
    return dataset_list

def _create_offsets(x):
    if len(x[x.valid]) == 0:
        return list(zip(x.start_data_row, x.end_data_row))
    return list(zip(x[x.valid].start_data_row, x[x.valid].end_data_row))

def save_dataset(root_dir, output_dir, w_size_h=4, w_stride_h=1, val_percentage={'train': 2.5/3, 'val':1/3, 'test':1/3}):
    for split in splits:
        # create records
        dataset_list = create_dataset_list(root_dir, paths[split], w_size_h=w_size_h, w_stride_h=w_stride_h, val_percentage=val_percentage[split])
        # create dataframe
        dataset = pd.DataFrame(dataset_list)
        if split != 'train':
            # group by sample_id (data_file) and create a list of valid offsets
            records = dataset.groupby('data_file').apply(lambda x: {
                    'data_file' : x.data_file.iloc[0],
                    'user_id' : x.user_id.iloc[0],
                    'sample_id' : x.sample_id.iloc[0],
                    'label' : x.label.iloc[0],
                    'valid' : 1,
                    'offsets' : _create_offsets(x),
                })
            dataset = pd.DataFrame().from_records(records.to_list())
        dataset.to_csv(output_dir/f"{split}_dataset.csv")

In [None]:
root_dir = Path("../../datasets/SPGC_challenge_track_2_release")
output_dir = Path("../data/track2/raw_volund")

w_size_h = 2.8445
w_stride_h = 2.8445
val_percentage = {'train': 2.5/3, 'val':1/3, 'test':1/3}

save_dataset(root_dir, output_dir, w_size_h=w_size_h, w_stride_h=w_stride_h, val_percentage=val_percentage)

### Compute per-subject Statistics 

In [None]:
# Compute Stats
root_dir = Path("../../datasets/SPGC_challenge_track_2_release/training_data")
stats = {}

for user in tqdm(os.listdir(root_dir)):
    user_id = int(user.split("_")[1])
    user_dir = root_dir/Path(f"{user}/train/non-relapse")
    arrays = []
    for sample in tqdm(os.listdir(user_dir)):
        # read data
        df = pd.read_csv(user_dir/Path(sample)/"data.csv")
        df = df.replace([np.inf, -np.inf], np.nan)
        #print(df.columns)
        arrays.append(df.to_numpy())
    total_array = np.concatenate(arrays)
    mean = total_array[:, :8].mean(0)
    std = np.nanstd(total_array[:, :8].astype(float), axis=0)
    columns = list(df.columns)[:8]
    record = {columns[i]: {'mean': mean[i], 'std': std[i]} for i in range(len(columns))}
    #print(record)
    stats[user_id] = record


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/204 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.07740479162575463, 'std': 0.5343033072045137}, 'acc_Y': {'mean': -0.04565059638375774, 'std': 0.4632659545956309}, 'acc_Z': {'mean': 0.02541218994184653, 'std': 1.017124106303795}, 'gyr_X': {'mean': 0.04991267045793881, 'std': 5.508729026718897}, 'gyr_Y': {'mean': 0.19329504866729633, 'std': 5.835455661815294}, 'gyr_Z': {'mean': 0.02172364578913043, 'std': 7.335775080155805}, 'heartRate': {'mean': 76.91158360006465, 'std': 37.287505305403634}, 'rRInterval': {'mean': 719.9911809838902, 'std': 205.81339896547001}}


  0%|          | 0/105 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.0850953928717194, 'std': 0.3949884243991635}, 'acc_Y': {'mean': 0.016234903817932095, 'std': 0.49339803718533104}, 'acc_Z': {'mean': 0.0020008037716371238, 'std': 0.2982161607613874}, 'gyr_X': {'mean': 0.013499659163641242, 'std': 4.992049295122978}, 'gyr_Y': {'mean': 0.05001058881494166, 'std': 4.326428229719058}, 'gyr_Z': {'mean': -0.07173644678538918, 'std': 4.009702150871599}, 'heartRate': {'mean': 51.14093382088202, 'std': 43.983636478762236}, 'rRInterval': {'mean': 641.4522993438326, 'std': 385.5320666573568}}


  0%|          | 0/248 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.048440933939075295, 'std': 9.55237469473406}, 'acc_Y': {'mean': 0.07818054431582433, 'std': 0.4179387651682867}, 'acc_Z': {'mean': -0.05413172496386478, 'std': 0.4755626480131258}, 'gyr_X': {'mean': -0.022104286782690885, 'std': 6.422744574550959}, 'gyr_Y': {'mean': 0.116121176780356, 'std': 5.249311763286557}, 'gyr_Z': {'mean': -0.009221273915754976, 'std': 5.362507778972771}, 'heartRate': {'mean': 80.50294353467253, 'std': 39.6601647168719}, 'rRInterval': {'mean': 687.1113597150423, 'std': 218.96576478907238}}


  0%|          | 0/179 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.10997362242580694, 'std': 0.7058907120698802}, 'acc_Y': {'mean': 0.09147887657910124, 'std': 0.5197327561841306}, 'acc_Z': {'mean': -0.10899774189309082, 'std': 0.4518665285638489}, 'gyr_X': {'mean': 0.08776794049723474, 'std': 7.539833390757674}, 'gyr_Y': {'mean': 0.10356399750790797, 'std': 5.585229424425513}, 'gyr_Z': {'mean': -0.03976771107317502, 'std': 6.321834138993106}, 'heartRate': {'mean': 82.6555533299491, 'std': 30.416062624724187}, 'rRInterval': {'mean': 840.3548580496051, 'std': 219.2047892267108}}


  0%|          | 0/169 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.14136787738247256, 'std': 0.5426113879587267}, 'acc_Y': {'mean': 0.023768131889762057, 'std': 0.4689673054625551}, 'acc_Z': {'mean': -0.012439029132779127, 'std': 0.4406461071283258}, 'gyr_X': {'mean': -0.03419923962560961, 'std': 8.348854037356908}, 'gyr_Y': {'mean': 0.29665035106841764, 'std': 6.749337848374998}, 'gyr_Z': {'mean': 0.04557008292909999, 'std': 6.504126240479325}, 'heartRate': {'mean': 88.81403669694498, 'std': 21.499448769234718}, 'rRInterval': {'mean': 713.7407634794719, 'std': 158.07935502391476}}


  0%|          | 0/217 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.10273361268933424, 'std': 0.6690189783109256}, 'acc_Y': {'mean': -0.004094247613722982, 'std': 0.42740236057568937}, 'acc_Z': {'mean': 0.05233723099241764, 'std': 3.1957801219838227}, 'gyr_X': {'mean': 0.13855119924909468, 'std': 5.481149317239829}, 'gyr_Y': {'mean': -0.1075337870131346, 'std': 3.981962705961871}, 'gyr_Z': {'mean': 0.05803836175868217, 'std': 57.86694218699615}, 'heartRate': {'mean': 73.2514684077721, 'std': 21.377218454198236}, 'rRInterval': {'mean': 907.002603854491, 'std': 186.34437718637585}}


  0%|          | 0/230 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.09319438383430002, 'std': 1.0549163715363512}, 'acc_Y': {'mean': -0.1318285128188079, 'std': 0.7824211155774202}, 'acc_Z': {'mean': 0.0581258162208312, 'std': 2.8133538480466758}, 'gyr_X': {'mean': 2.2043800008858794, 'std': 92.648168681159}, 'gyr_Y': {'mean': -0.47524791633864766, 'std': 12.29713208371014}, 'gyr_Z': {'mean': 0.04163716593449852, 'std': 89.8694566892467}, 'heartRate': {'mean': 72.3729479198667, 'std': 20.268623292508792}, 'rRInterval': {'mean': 861.8180264035229, 'std': 191.6807804814167}}


  0%|          | 0/210 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.02092114265507649, 'std': 0.4154525036337986}, 'acc_Y': {'mean': -0.014012288149161977, 'std': 0.3781170387917316}, 'acc_Z': {'mean': 0.053518889409039976, 'std': 0.30976595106201094}, 'gyr_X': {'mean': 0.06170811909847024, 'std': 4.959383620928154}, 'gyr_Y': {'mean': 0.05460819371298155, 'std': 4.12906968431291}, 'gyr_Z': {'mean': 0.06711260025394976, 'std': 4.114395839546679}, 'heartRate': {'mean': 72.66138217353134, 'std': 29.686907125119014}, 'rRInterval': {'mean': 829.1710687354415, 'std': 268.86139018980003}}


  0%|          | 0/176 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.10977938591748378, 'std': 0.5849905651281385}, 'acc_Y': {'mean': 0.09671424934674658, 'std': 0.42242881138851734}, 'acc_Z': {'mean': 0.01227884443702697, 'std': 0.39596732436439214}, 'gyr_X': {'mean': 0.05048087384967409, 'std': 6.405825547399702}, 'gyr_Y': {'mean': 0.026655853000925794, 'std': 5.303059696431313}, 'gyr_Z': {'mean': -0.016025140677655692, 'std': 5.184809574545775}, 'heartRate': {'mean': 77.04926626391975, 'std': 27.502978136510016}, 'rRInterval': {'mean': 890.2228859012833, 'std': 214.60999275381027}}


  0%|          | 0/168 [00:00<?, ?it/s]

{'acc_X': {'mean': -0.22596949728009436, 'std': 1.4858015668008786}, 'acc_Y': {'mean': -0.06796499415827657, 'std': 4.585748126660567}, 'acc_Z': {'mean': 0.07151734662930911, 'std': 0.3360983510393319}, 'gyr_X': {'mean': 0.03623044481526213, 'std': 5.547699728914393}, 'gyr_Y': {'mean': -0.13753716699566793, 'std': 4.168085396206269}, 'gyr_Z': {'mean': 0.12295767572916634, 'std': 7.9929711279947835}, 'heartRate': {'mean': 85.51817947300428, 'std': 24.432683891364164}, 'rRInterval': {'mean': 784.9783992496891, 'std': 144.53229810610463}}


In [None]:
output_dir = Path("../data/track2/raw")

#with open(output_dir/"subject_stats.json", "w") as f:
#    json.dump(stats, f)

with open(output_dir/"subject_stats.json", "r") as f:
    stats = json.load(f)

user = '9'
print(stats[user])

{'acc_X': {'mean': -0.14136787738247256, 'std': 0.5426113879587267}, 'acc_Y': {'mean': 0.023768131889762057, 'std': 0.4689673054625551}, 'acc_Z': {'mean': -0.012439029132779127, 'std': 0.4406461071283258}, 'gyr_X': {'mean': -0.03419923962560961, 'std': 8.348854037356908}, 'gyr_Y': {'mean': 0.29665035106841764, 'std': 6.749337848374998}, 'gyr_Z': {'mean': 0.04557008292909999, 'std': 6.504126240479325}, 'heartRate': {'mean': 88.81403669694498, 'std': 21.499448769234718}, 'rRInterval': {'mean': 713.7407634794719, 'std': 158.07935502391476}}


### Dataset class and Transforms

In [None]:
class EPreventionDataset(CacheDataset):
    def __init__(self, split_path, split, transforms, max_samples=None, subject=None, cache_num = sys.maxsize, cache_rate=1.0, num_workers=1):    
        
        self.split = split
        self.max_samples = max_samples
        self.subject = subject
        
        data = self._generate_data_list(split_path/f"{split}_dataset.csv")

        super().__init__(data, transforms, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
        
     
    #split data in train, val and test sets in a reproducible way
    def _generate_data_list(self, split_path):

        # open csv with observations
        data_list = pd.read_csv(split_path, index_col=0, nrows=self.max_samples)
        if self.subject is not None:
           # filter subject
            data_list = data_list[data_list['user_id']==self.subject]
        # filter valid
        data_list = data_list[data_list.valid.astype(bool)]
        # save ditribution
        count_distribution = data_list.label.value_counts().sort_index().to_numpy()
        num_samples = len(data_list)
        self.distribution = count_distribution / num_samples

        return data_list.to_dict('records')  
    
    def get_label_proportions(self):

        return self.distribution

In [None]:
ds = EPreventionDataset(split_path=Path("../data/track2/raw"), split='val', subject=0, transforms=None, max_samples=100)
ds[0]

Loading dataset: 100%|██████████| 40/40 [00:00<00:00, 544714.81it/s]


{'data_file': 'training_data/user_00/val/non-relapse/00/data.csv',
 'user_id': 0,
 'sample_id': 0,
 'label': 0,
 'valid': 1,
 'offsets': '[(0, 2160), (2160, 4320), (4320, 6480), (6480, 8640), (8640, 10800), (10800, 12960)]'}

In [None]:
class AppendRootDirD(MapTransform):

    def __init__(self, keys: KeysCollection, root_dir):
        super().__init__(keys)
        self.root_dir = root_dir
    
    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            d[k] = os.path.join(self.root_dir,d[k])
        return d
        
class LoadDataD(MapTransform):
    
    def __init__(self, keys: KeysCollection, split, use_sleeping):
        super().__init__(keys)
        self.split = split
        if use_sleeping:
            self.cols = ['acc_X', 'acc_Y', 'acc_Z', 'gyr_X', 'gyr_Y', 'gyr_Z', 'heartRate', 'rRInterval', 'timecol', 'sleeping']
        else:
            self.cols = ['acc_X', 'acc_Y', 'acc_Z', 'gyr_X', 'gyr_Y', 'gyr_Z', 'heartRate', 'rRInterval', 'timecol']


    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            if self.split == 'train':
                d['data'] = pd.read_csv(d[k],
                    skiprows=lambda x : x in range(1, d['start_data_row']+1),
                    nrows=d['end_data_row']-d['start_data_row'],
                    usecols=self.cols) 
            else:
                d['data'] = pd.read_csv(d[k], usecols=self.cols)
            if self.split == 'test':
                d['sample_id'] = d['data_file'].split("/")[-2]
            del d[k]
        if 'valid' in d.keys(): del d['valid']
        if 'start_data_row' in d.keys(): del d['start_data_row']
        if 'end_data_row' in d.keys(): del d['end_data_row']
        return d

class DeleteTimeD(MapTransform):

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            del d[k]
        return d

class ImputeMedianD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            # impute median
            d[k] = d[k].replace([np.inf, -np.inf], np.nan)
            d[k] = d[k].fillna(d[k].median())
            # check whole nan cols
            user = str(d['user_id'])
            for col in d[k].columns:
                if d[k][col].isna().all():
                    d[k][col] = stats[user][col]['mean']
        return d

class ToNumpyD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            d[k] = d[k].to_numpy()
        return d

class StandardizeD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            user = str(d['user_id'])
            means = torch.tensor([stat['mean'] for _, stat in stats[user].items()])
            stds = torch.tensor([stat['std'] for _, stat in stats[user].items()])
            #means[7:] = 0.
            #stds[7:] = 1.
            #print(means, stds)
            d[k] = (d[k] - means)/stds
        return d

class TransposeD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            d[k] = d[k].t()
        return d

class FlattenD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            if len(d[k].shape) == 2:
                d[k] = d[k].flatten()
            else:
                d[k] = d[k].flatten(start_dim=1)
        return d

class ExtractTimeD(MapTransform):

    def __call__(self, data):
        d = copy.deepcopy(data)
        for k in self.keys:
            d['time'] = d[k].timecol.astype('datetime64[ns]')
            d[k].drop('timecol', inplace=True, axis=1)
        return d

root_dir = Path("../../datasets/SPGC_challenge_track_2_release")

transforms = [
        ToTensorD(['label'],dtype=torch.long),
        AppendRootDirD(['data_file'], root_dir),
        LoadDataD(['data_file'], 'train', use_sleeping=False),
        ExtractTimeD(['data']),
        DeleteTimeD(['time']),
        ImputeMedianD(['data']),
        ToNumpyD(['data']),
        ToTensorD(['data'], dtype=torch.float),
        StandardizeD(['data']),
        TransposeD(['data']),
]

transforms = Compose(transforms)

train_data = EPreventionDataset(Path("../data/track2/raw"), 'train', subject=0, transforms=transforms)

Loading dataset: 100%|██████████| 956/956 [01:03<00:00, 15.04it/s]


In [None]:
train_data[0]

{'user_id': 0,
 'sample_id': 37,
 'label': tensor(0),
 'data': tensor([[ 0.0091,  0.0091,  0.0087,  ...,  0.0066,  0.0064,  0.0064],
         [ 0.1761,  0.1379,  0.2219,  ...,  0.0625,  0.0470,  0.0612],
         [-0.2325, -0.2580, -0.2016,  ..., -0.2353, -0.2301, -0.2318],
         ...,
         [-0.0070, -0.0609,  0.0038,  ..., -0.0098, -0.0121, -0.0106],
         [-0.4928, -0.4141, -0.3778,  ..., -0.3980, -0.4343, -0.5069],
         [ 1.1063,  0.9673,  1.5217,  ...,  1.7636,  1.7084,  1.6666]])}

### Add Transforms for Validation and Test

In [None]:
from torch.nn import ConstantPad1d, ReplicationPad1d

class PadShortSequenceD(MapTransform):
    
    def __init__(self, keys: KeysCollection, output_size, padding, mode):
        super().__init__(keys)
        assert padding in ['replication', 'zero'], "Select Proper Padding Mode: Allowed same and zero"
        assert mode in ['head', 'center', 'tail'], "Select Proper Mode: Allowed head, center and tail"
        self.output_size = output_size
        self.padding = padding
        self.mode = mode
        
    def __call__(self, data):
        d = copy.deepcopy(data)
        w_in = d['data'].shape[-1]
        #print(w_in, self.output_size)
        if w_in >= self.output_size:
            return d
        pad_size = self.output_size - w_in
        if self.mode == 'head':
            padding = (pad_size, 0)
        elif self.mode == 'tail':
            padding = (0, pad_size)
        elif self.mode == 'center' and pad_size%2==0:
            padding = pad_size//2
        elif self.mode == 'center' and pad_size%2==1:
            padding = (pad_size//2, pad_size//2+1)
        pad_fn = self._get_pad_fn(padding)
        for k in self.keys:
            d[k] = pad_fn(d[k])
        return d

    def _get_pad_fn(self, padding):
        return ConstantPad1d(padding, 0) if self.padding == 'zero' else ReplicationPad1d(padding)

class CreateVotingBatchD(MapTransform):
    
    def __init__(self, keys: KeysCollection):
        super().__init__(keys)
        
    def __call__(self, data):
        d = copy.deepcopy(data)
        offsets = eval(d['offsets'])
        for k in self.keys:
            windows = [d[k][:, start:stop].unsqueeze(0) for (start, stop) in offsets]
            d[k] = torch.cat(windows, dim=0)
        if 'offsets' in d.keys():
            del d['offsets']
        return d

eval_transforms = [
        ToTensorD(['label'],dtype=torch.long),
        AppendRootDirD(['data_file'], root_dir),
        LoadDataD(['data_file'], 'val', use_sleeping=False),
        ExtractTimeD(['data']),
        DeleteTimeD(['time']),
        ImputeMedianD(['data']),
        ToNumpyD(['data']),
        ToTensorD(['data'], dtype=torch.float),
        StandardizeD(['data']),
        TransposeD(['data']),
        CreateVotingBatchD(['data']),
        PadShortSequenceD(['data'], output_size=2160, padding='replication', mode='center'),
        #FlattenD(['data'])
]

eval_transforms = Compose(eval_transforms)

val_data = EPreventionDataset(Path("../data/track2/raw"), 'val', subject=0, transforms=eval_transforms)

Loading dataset: 100%|██████████| 40/40 [00:10<00:00,  3.68it/s]


In [None]:
for sample in val_data:
    B, F, T = sample['data'].size()
    if T<2160:
        print(sample['data'].size())