In [4]:
import lmdb
import datetime
import argparse
import pandas as pd
import numpy as np
import random

import scipy.io
import pickle
import numpy as np
import os
import h5py

import torch
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from tqdm import tqdm

In [5]:
def to_tensor(array):
    return torch.from_numpy(array).float()

In [6]:
def random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f'set seed {seed} is done')

In [7]:
LMDB = "/pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb"

DB = lmdb.open(LMDB, readonly=True, lock=False, readahead=True, meminit=False)
with DB.begin(write=False) as txn:
    KEYS = pickle.loads(txn.get('__keys__'.encode()))

In [45]:
len(KEYS)

468470

In [8]:
import re
from collections import Counter

def decode_key(k):
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

# 두 형태 모두 허용 + sub_id는 영어 문자열만
KEY_PATTERN = re.compile(
    r"^TUEG-(?:\d+_)?(?P<sub_id>[A-Za-z]+)_s\d+_t\d+_\d+$"
)

invalid_keys = []
sub_ids = []

for k in KEYS:
    k_str = decode_key(k)
    m = KEY_PATTERN.match(k_str)
    if m is None:
        invalid_keys.append(k_str)
    else:
        sub_ids.append(m.group("sub_id"))

# 기본 요약
print("==== KEY VALIDATION ====")
print(f"Total keys   : {len(KEYS)}")
print(f"Valid keys   : {len(sub_ids)}")
print(f"Invalid keys : {len(invalid_keys)}")

if invalid_keys:
    print("\n[Invalid key examples]")
    for x in invalid_keys[:10]:
        print(x)

# subject 수 계산
unique_sub_ids = set(sub_ids)
print("\n==== SUBJECT STATS ====")
print(f"Total subjects : {len(unique_sub_ids)}")

# subject별 key 개수 분포
sub_counter = Counter(sub_ids)

print("\n[Top 10 subjects by #keys]")
for sub, cnt in sub_counter.most_common(10):
    print(f"{sub:12s} : {cnt}")

print("\n[Bottom 10 subjects by #keys]")
for sub, cnt in sub_counter.most_common()[-10:]:
    print(f"{sub:12s} : {cnt}")


==== KEY VALIDATION ====
Total keys   : 468470
Valid keys   : 468470
Invalid keys : 0

==== SUBJECT STATS ====
Total subjects : 5420

[Top 10 subjects by #keys]
aaaaabhz     : 13763
aaaaahwg     : 9103
aaaaacmq     : 8775
aaaaaddm     : 7107
aaaaabfu     : 5896
aaaaahzp     : 5820
aaaaahun     : 4914
aaaaaghb     : 3729
aaaaagxr     : 3321
aaaaahyu     : 3125

[Bottom 10 subjects by #keys]
aaaaagfs     : 2
aaaaaahy     : 1
aaaaablr     : 1
aaaaabmx     : 1
aaaaabps     : 1
aaaaabyl     : 1
aaaaabyz     : 1
aaaaaclx     : 1
aaaaacqu     : 1
aaaaaenz     : 1


In [9]:
import re
import numpy as np
from collections import defaultdict
from typing import List, Tuple, Union

KeyT = Union[str, bytes, bytearray]

random_seed(41)

# 두 key 형태 모두 커버, sub_id는 "중간 영어 문자열"
_KEY_RE = re.compile(
    r"^TUEG-(?:\d+_)?(?P<sub_id>[A-Za-z]+)_s\d+_t\d+_\d+$"
)

def _decode_key(k: KeyT) -> str:
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

def _extract_sub_id(k: KeyT) -> str:
    s = _decode_key(k)
    m = _KEY_RE.match(s)
    if m is None:
        raise ValueError(f"Key does not match expected patterns: {s}")
    return m.group("sub_id")


def train_test_split_by_fold_num(
    fold_num: int,
    lmdb_keys: List[KeyT],
    maxFold: int,
    split_by_sub: bool = True,
    seed: int = 41
) -> Tuple[List[KeyT], List[KeyT]]:
    """
    True k-fold cross-validation split.

    Args:
        fold_num: test fold index (0 <= fold_num < maxFold)
        lmdb_keys: LMDB key list
        maxFold: total number of folds (k)
        split_by_sub: True → subject-wise k-fold, False → key-wise k-fold

    Returns:
        train_key_list, test_key_list
    """
    if maxFold < 2:
        raise ValueError("maxFold must be >= 2.")
    if fold_num < 0 or fold_num >= maxFold:
        raise ValueError(f"fold_num must be in [0, {maxFold-1}]")

    keys = list(lmdb_keys)

    # 고정 seed → 모든 fold에서 assignment가 일관됨
    rng = np.random.default_rng(seed)

    if split_by_sub:
        # -------- subject-wise k-fold --------
        sub_to_keys = defaultdict(list)
        invalid = []

        for k in keys:
            try:
                sid = _extract_sub_id(k)
                sub_to_keys[sid].append(k)
            except ValueError:
                invalid.append(_decode_key(k))

        if invalid:
            ex = "\n".join(invalid[:10])
            raise ValueError(
                f"Found {len(invalid)} invalid keys. Examples:\n{ex}"
            )

        subjects = np.array(list(sub_to_keys.keys()), dtype=object)
        rng.shuffle(subjects)

        # subject를 k개 fold로 분할
        subj_folds = np.array_split(subjects, maxFold)
        test_subjects = set(subj_folds[fold_num].tolist())

        train_keys, test_keys = [], []
        for sid, ks in sub_to_keys.items():
            (test_keys if sid in test_subjects else train_keys).extend(ks)

        return train_keys, test_keys

    else:
        # -------- key-wise k-fold --------
        idx = np.arange(len(keys))
        rng.shuffle(idx)

        folds = np.array_split(idx, maxFold)
        test_idx = set(folds[fold_num].tolist())

        train_keys = [keys[i] for i in idx if i not in test_idx]
        test_keys  = [keys[i] for i in idx if i in test_idx]

        return train_keys, test_keys

set seed 41 is done


In [10]:
train_keys, test_keys = train_test_split_by_fold_num(
    fold_num=0,
    lmdb_keys=KEYS,
    maxFold=5,
    split_by_sub=False,
)

print(f"train: {len(train_keys)}, test: {len(test_keys)}")

train: 374776, test: 93694


In [11]:
def lmdb_get(env, key):
    if isinstance(key, str):
        key = key.encode("utf-8")
    with env.begin(write=False) as txn:
        v = txn.get(key)
    if v is None:
        raise KeyError(f"Key not found: {key}")
    return pickle.loads(v)

In [30]:
import lmdb

db = lmdb.open(LMDB, readonly=True, lock=False, readahead=True, meminit=False)

sample = lmdb_get(db, "TUEG-aaaaaaai_s001_t001_10")

In [31]:
sample['sample'].shape

(17, 30, 500)

In [32]:
sample['data_info']['channel_names']

['F3',
 'F4',
 'C3',
 'C4',
 'P3',
 'P4',
 'O1',
 'O2',
 'F7',
 'F8',
 'T3',
 'T4',
 'T5',
 'T6',
 'Fz',
 'Cz',
 'Pz']

In [60]:
import re
from collections import Counter

KEYS_TO_CHECK = test_keys

def decode_key(k):
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

# 두 형태 모두 허용 + sub_id는 영어 문자열만
KEY_PATTERN = re.compile(
    r"^TUEG-(?:\d+_)?(?P<sub_id>[A-Za-z]+)_s\d+_t\d+_\d+$"
)

invalid_keys = []
sub_ids = []

for k in KEYS_TO_CHECK:
    k_str = decode_key(k)
    m = KEY_PATTERN.match(k_str)
    if m is None:
        invalid_keys.append(k_str)
    else:
        sub_ids.append(m.group("sub_id"))

# 기본 요약
print("==== KEY VALIDATION ====")
print(f"Total keys   : {len(KEYS_TO_CHECK)}")
print(f"Valid keys   : {len(sub_ids)}")
print(f"Invalid keys : {len(invalid_keys)}")

if invalid_keys:
    print("\n[Invalid key examples]")
    for x in invalid_keys[:10]:
        print(x)

# subject 수 계산
unique_sub_ids = set(sub_ids)
print("\n==== SUBJECT STATS ====")
print(f"Total subjects : {len(unique_sub_ids)}")

# subject별 key 개수 분포
sub_counter = Counter(sub_ids)

print("\n[Top 10 subjects by #keys]")
for sub, cnt in sub_counter.most_common(10):
    print(f"{sub:12s} : {cnt}")

print("\n[Bottom 10 subjects by #keys]")
for sub, cnt in sub_counter.most_common()[-10:]:
    print(f"{sub:12s} : {cnt}")

==== KEY VALIDATION ====
Total keys   : 91345
Valid keys   : 91345
Invalid keys : 0

==== SUBJECT STATS ====
Total subjects : 1084

[Top 10 subjects by #keys]
aaaaacmq     : 8775
aaaaabfu     : 5896
aaaaacrt     : 2828
aaaaabbn     : 1862
aaaaaath     : 1342
aaaaadns     : 1279
aaaaahzf     : 1168
aaaaabwi     : 1022
aaaaaacq     : 984
aaaaaeaw     : 535

[Bottom 10 subjects by #keys]
aaaaaaan     : 3
aaaaaait     : 3
aaaaafqc     : 3
aaaaadhl     : 2
aaaaafpg     : 2
aaaaabps     : 1
aaaaabyz     : 1
aaaaaclx     : 1
aaaaacqu     : 1
aaaaaenz     : 1


In [None]:
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb
# 아현썜 pscratch의 데이터 경로 당장은 그냥 써도 되지만 추후 내 pscratch나 m4727 등으로 옮겨서 사용할 것

class TUEG_for_SOLID_from_lmdb(Dataset):
    def __init__(
            self,
            lmdb_dir: str,
            maxfold: int,
            targetfold: int,
            seed: int,
            train: bool,
            split_by_sub: bool,
    ):
        random_seed(seed)
        self.lmdb_dir = lmdb_dir
        self.db = lmdb.open(lmdb_dir, readonly=True, lock=False, readahead=True, meminit=False)
        with self.db.begin(write=False) as txn:
            self.lmdb_keys = pickle.loads(txn.get('__keys__'.encode()))

        self.train = train
        self.split_by_sub = split_by_sub

        self.maxfold = maxfold
        self.targetfold = targetfold
        self.data, self.target = self.make_data_and_target_by_fold(self.targetfold, self.lmdb_keys, 
                                                                   self.maxfold, self.split_by_sub)

    def make_data_and_target_by_fold(self, fold, lmdb_keys, maxfold, split_by_sub):
        self.record = []

        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}

        # TODO : train test split by fold num
        train_data_in_lmdb, test_data_in_lmdb = TRAIN_TEST_SPLIT_BY_FOLD_NUM(fold, lmdb_keys, maxfold, split_by_sub)


        if self.train:
            for train_data_idx in train_data_in_lmdb:

                # TODO : get proper seg_in and seg_out by input idx
                seg_in, seg_out = self.segmentation_from_idx(train_data_idx)

                train_data['input'] += seg_in
                train_data['target'] += seg_out

                data = train_data['input']
                target = train_data['target']
        
        else:
            for test_data_idx in test_data_in_lmdb:

                seg_in, seg_out = self.segmentation_from_idx(test_data_idx)

                test_data['input'] += seg_in
                test_data['target'] += seg_out

                data = test_data['input']
                target = test_data['target']


        return data, target

    def segmentation_from_idx(idx):
        seg_in = None
        seg_out = None
        return seg_in, seg_out

    def train_test_split_by_fold_num(fold_num, lmdb_keys, train_ratio, split_by_sub):
        train_key_list = None
        test_key_list = None
        return train_key_list, test_key_list


    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        # TODO : 이거 meta 줄 떄 time은 time이고 spatial을 아예 grid에 맞게 주는게 좋을 듯 / Grid 여기서 받게 하자
        i = None
        o = None
        im = None
        om = None
        return i, o, im, om

In [None]:
TORCHEEG_2DGRID = [
    ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'],
    ['-', '-', '-', '-', 'FP1', 'FPZ', 'FP2', '-', '-', '-', '-'],
    ['-', '-', 'AF7', '-', 'AF3', 'AFZ', 'AF4', '-', 'AF8', '-', '-'],
    ['F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10'],
    ['FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10'], 
    ['T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10'],
    ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10'], 
    ['P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10'],
    ['-', '-', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', '-', '-'],
    ['-', '-', '-', 'CB1', 'O1', 'OZ', 'O2', 'CB2', '-', '-', '-'],
    ['-', '-', '-', '-', '-', 'IZ', '-', '-', '-', '-', '-']
    ]

In [None]:
class EEGToGrid(Dataset):
    def __init__(self, base_dataset,):
        self.base_dataset = base_dataset
        self.mean = float(self.base_dataset.mean)
        self.std = float(self.base_dataset.std)

    def TorchEEG_Grid(self, channel_list, grid_templete=TORCHEEG_2DGRID, H=11, W=11):
        """
        2D Grid based on TorchEEG 2D Grid
        input 10-10 coord channel name index 
        output is grid of channel input
        """
        grid = torch.zeros(H, W, dtype=torch.float32)
        mask = torch.zeros(H, W, dtype=torch.float32)
        return grid, mask

    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        i,o,im,om = self.base[idx]

        target_grid = None
        target_mask = None
        cond = None

        return target_grid, target_mask, cond, self.mean, self.std

In [None]:
# Set some future argparse

# TUEG_1.0 path in lucy's pscratch
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb
LMDB_DIR = "/pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb"
BATCH_SIZE = 64

In [None]:
# Set Train and Test dataset and dataloader

train_eeg = TUEG_for_SOLID_from_lmdb(mdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=True,)
test_eeg = TUEG_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=False,
                         )

train_set = EEGToGrid(train_eeg)
test_set = EEGToGrid(test_eeg)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_worker=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_worker=2, pin_memory=True)

In [3]:
class EEGDelhiLike(Dataset):
    """
    LMDB에서 EEG raw를 읽어 (i, o, im, om)을 반환.
    - i:  (ctx_len*C_used, 1)   : 과거 ctx_len step의 '신호'
    - o:  (C_used, 1)           : 다음 step의 '신호'
    - im: (ctx_len*C_used, 3)   : (x, y, t_id)  (x,y는 0~1, t_id는 ctx_len개의 유니크 값)
    - om: (C_used, 3)           : (x, y, t_id=0) (t_id는 크게 중요하지 않으면 0으로 둬도 됨)

    전제:
    - sample은 (C, N) 또는 (C, T, Fs) 등인데, 마지막 축들을 펼쳐서 (C, N)으로 사용.
    - channel 좌표는 channel_pos에 dict 형태로 제공: {ch_idx: (x01, y01)} (0~1 정규화)
      -> 좌표 없는 채널은 제외하거나(기본), 혹은 (0,0)으로 처리 가능.
    """

    def __init__(
        self,
        data_dir: str,
        keys: list[str],
        channel_pos: dict,             # {ch: (x01,y01)} 0~1
        ctx_len: int = 9,
        horizon: int = 1,              # 다음 몇 step을 타깃으로 할지 (여기선 1만 쓰는 게 wrapper와 잘 맞음)
        step: int = 50,                # "1 step"이 몇 sample인지 (예: 500Hz에서 50이면 100ms)
        start_offset: int = 0,          # 윈도우 시작 오프셋
        pick_channels: list[int] | None = None,  # 특정 채널만 쓰고 싶으면
        normalize: str = "z_trial",     # "none" | "z_trial" | "z_ch_trial"
        transform=None,
        return_info: bool = False,
        amp_scale: float = 1.0          # 기존 코드의 /100 같은 스케일링
    ):
        self.data_dir = data_dir
        self.keys = keys
        self.channel_pos = channel_pos
        self.ctx_len = ctx_len
        self.horizon = horizon
        self.step = step
        self.start_offset = start_offset
        self.pick_channels = pick_channels
        self.normalize = normalize
        self.transform = transform
        self.return_info = return_info
        self.amp_scale = amp_scale

        # LMDB 핸들
        self.db = lmdb.open(
            self.data_dir, readonly=True, lock=False,
            readahead=True, meminit=False
        )

        # 사용할 채널 목록 확정 (좌표가 있는 채널만 기본 사용)
        if self.pick_channels is None:
            self.used_channels = sorted([ch for ch in self.channel_pos.keys()])
        else:
            self.used_channels = [ch for ch in self.pick_channels if ch in self.channel_pos]

        if len(self.used_channels) == 0:
            raise ValueError("used_channels가 비었습니다. channel_pos / pick_channels를 확인하세요.")

        # (C_used, 2) 좌표 텐서 캐시
        xy = [self.channel_pos[ch] for ch in self.used_channels]
        self.xy = torch.tensor(xy, dtype=torch.float32)  # (C_used,2)

    def __len__(self):
        return len(self.keys)

    def _flatten_to_CN(self, data: torch.Tensor) -> torch.Tensor:
        """
        data를 (C, N)으로 변환.
        예:
          (C, N) -> 그대로
          (C, T, Fs) -> (C, T*Fs)
          (C, ..., ...) -> (C, prod(rest))
        """
        if data.ndim < 2:
            raise ValueError(f"EEG sample ndim이 {data.ndim}입니다. 최소 (C,N) 형태가 필요합니다.")
        C = data.shape[0]
        return data.reshape(C, -1)

    def _apply_normalize(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (C_used, Nwin)
        """
        if self.normalize == "none":
            return x
        if self.normalize == "z_trial":
            mean = x.mean()
            std = x.std().clamp_min(1e-6)
            return (x - mean) / std
        if self.normalize == "z_ch_trial":
            mean = x.mean(dim=1, keepdim=True)
            std = x.std(dim=1, keepdim=True).clamp_min(1e-6)
            return (x - mean) / std
        raise ValueError(f"normalize='{self.normalize}'는 지원하지 않습니다.")

    def __getitem__(self, idx):
        key = self.keys[idx]

        with self.db.begin(write=False) as txn:
            pair = pickle.loads(txn.get(key.encode()))

        data = pair["sample"]          # (C, ...) raw
        label = pair.get("label", None)
        data_info = pair.get("data_info", {})

        data = to_tensor(data).float() / float(self.amp_scale)

        if self.transform is not None:
            data = self.transform(data)

        # (C, N)로 펼치기
        data = self._flatten_to_CN(data)

        # 채널 선택 + 좌표 없는 채널 제거
        data = data[self.used_channels, :]   # (C_used, N)

        # 한 샘플(trial) 내에서 "ctx_len + horizon" 만큼의 step을 만들기 위해 필요한 길이
        need = (self.ctx_len + self.horizon) * self.step
        if data.shape[1] < self.start_offset + need:
            # 너무 짧으면 뒤에서 잘리지 않게 start를 앞으로 당김 (최소 동작)
            start = max(0, data.shape[1] - need)
        else:
            start = self.start_offset

        # (C_used, need_samples) 구간 추출
        seg = data[:, start:start + need]     # (C_used, need)

        # 정규화 (trial 기준 / 채널별)
        seg = self._apply_normalize(seg)

        # step 단위로 reshape: (C_used, ctx_len+horizon, step)
        seg = seg.reshape(seg.shape[0], self.ctx_len + self.horizon, self.step)

        # 여기서 "각 step에서 grid로 뿌릴 스칼라"를 선택해야 함.
        # raw를 최대한 유지하려면 step 내의 대표값을 하나 뽑아야 하는데,
        # 가장 단순히 마지막 샘플(또는 평균)을 사용.
        # - 마지막샘플: seg[..., -1]
        # - 평균: seg.mean(-1)
        per_step_val = seg[..., -1]  # (C_used, ctx_len+horizon)

        # input/target 분리
        x_in = per_step_val[:, :self.ctx_len]                 # (C_used, ctx_len)
        x_out = per_step_val[:, self.ctx_len:self.ctx_len+1]  # (C_used, 1)

        # (ctx_len*C_used, 1)로 펼치기
        # time id는 -8..0 같이 ctx_len개의 유니크 값이 되도록 구성 (wrapper의 unique==9 요구 대응)
        t_ids = torch.arange(-(self.ctx_len - 1), 1, dtype=torch.float32)  # (ctx_len,)
        # 각 time에 대해 채널 C_used개씩 반복되도록 구성
        # im: (ctx_len*C_used, 3) = [x, y, t_id]
        C_used = x_in.shape[0]

        # i: (ctx_len*C_used, 1)
        i = x_in.T.contiguous().reshape(self.ctx_len * C_used, 1)  # time-major -> flatten

        # o: (C_used, 1)
        o = x_out.reshape(C_used, 1)

        # im 만들기: 시간별로 좌표를 반복
        # time 0: all channels, time 1: all channels ... 형태
        xy_rep = self.xy.unsqueeze(0).repeat(self.ctx_len, 1, 1)          # (ctx_len, C, 2)
        t_rep  = t_ids.view(self.ctx_len, 1, 1).repeat(1, C_used, 1)      # (ctx_len, C, 1)
        im = torch.cat([xy_rep, t_rep], dim=-1).reshape(self.ctx_len * C_used, 3)

        # om: (C_used, 3) = [x, y, 0]
        om = torch.cat([self.xy, torch.zeros(C_used, 1)], dim=-1)

        if self.return_info:
            return i, o, im, om, label, data_info
        return i, o, im, om

In [None]:
LMDB_DIR = "/pscratch/sd/t/tylee/Dataset/1109_Physio_500Hz"


In [None]:
class EEG_from_lmdb(Dataset):
    def __init__(self, data_dir, transform, return_info):
        self.data_dir = data_dir
        self.transform = transform
        self.return_info = return_info

    def lmdb_to_data(self, idx):
        self.db = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=True, meminit=False)
        key = self.keys[idx]
        with self.db.begin(write=False) as txn:
            pair = pickle.loads(txn.get(key.encode()))
        data = pair['sample']
        label = pair['label']
        data_info = pair.get('data_info', {})
        
        data = to_tensor(data)
        if self.transform is not None:
            data = self.transform(data)
        if self.return_info:
            return data/100, label, data_info
        else:
            return data/100, label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om

In [None]:
# Implementing torch.dataset

def xform_day(day):
    arr = [0, 30, 61]
    w = 0 if day <= 30 else 1 if day <= 61 else 2
    mon = ['2020-11-', '2020-12-', '2021-01-'][w]
    date = mon + '{:02d}'.format(day - arr[w])
    return date


def get_suffixes(mode):
    suffixes = []
    if 'C' in mode or 'A' in mode:
        suffixes.append('train')
    if 'D' in mode or 'B' in mode:
        suffixes.append('test')
    return suffixes

def rename_cols(data):
    data.rename(
        columns={'dateTime': 'time', 'lat': 'latitude', 'long': 'longitude', 'pm2_5': 'PM25_Concentration',
                 'pm10': 'PM10_Concentration'}, inplace=True)



def torch1dgrid(num, bot=0, top=1):
    arr = torch.linspace(bot, top, steps=num)
    mesh = torch.stack([arr], dim=1)
    return mesh.squeeze(-1)

import torch
from torch.utils.data import Dataset
from einops import rearrange        
class Delhi(Dataset):
    def __init__(
        self, mode_t, mode_p, canada, train_days, 
        maxFolds = 5, target_fold = 0, temporal_scaling=1, spatiotemporal=1, data_dir='/pscratch/sd/d/dpark1/AirDelhi/delhi/processed', 
        seed=10, nTrainStartDay = 15, nTestStartDay = 75, nTotalDays = 91, train=True):
        
        self.mode_t = mode_t
        self.mode_p = mode_p
        self.train_days = train_days
        self.train = train
        self.maxFolds = maxFolds
        self.target_fold = target_fold
        self.temporal_scaling = temporal_scaling
        self.spatiotemporal = spatiotemporal
        self.data_dir = data_dir
        self.nTestStartDay = nTestStartDay
        self.nTrainStartDay = nTrainStartDay
        self.nTotalDays = nTotalDays
        
        np.random.seed(seed)        
        
        self.train_suffix = get_suffixes(mode_t)
                
        if spatiotemporal < 0 and mode_t == 'AB' and mode_p == 'CD':
            # Forecasting, single fold is enough
            maxFolds = 1
    
        self.folds = [i for i in range(maxFolds)]
        
        self.data, self.target = self.proc_custom(target_fold)
        
        
        
    def get_normalize_params(self, target):
        all_signal = []
        for a in target:
            all_signal += list(a[..., 0])
        self.mean, self.std = np.array(all_signal).mean(), np.array(all_signal).std()
    
    def get_spatial_norm_parameters(self, arr_of_days):
        """"minmax normalization"""
        latmin = 10e10
        latmax = -10e10
        longmin = 10e10
        longmax = -10e10
        for arr in arr_of_days:
            minned = arr.min(0)
            # print(minned[0])
            if minned[2] < latmin:
                latmin = minned[2]
            if minned[3] < longmin:
                longmin = minned[3]
                
            maxed = arr.max(0)
            # print(maxed)
            if maxed[2] > latmax:
                latmax = maxed[2]
            if maxed[3] > longmax:
                longmax = maxed[3]
        self.latmin, self.latmax, self.longmin, self.longmax =latmin, latmax, longmin, longmax
    
    def make_data_by_time(self, arr_of_days, t_in = 9, reverse=False, day = 0):
        seg_by_time = []
        uniq_times = np.unique(arr_of_days[..., 1])
        
        for t in uniq_times:
            idx_ = arr_of_days[..., 1] == t
            seg_by_time.append(arr_of_days[idx_])
        
        in_ = []
        out_ = []
        for i in range(len(seg_by_time) - t_in):
            temp_in = []
            for t_ in range(t_in):
                temp_in.append(seg_by_time[i + t_])
            
            # normalize time to relative scale by the last one of the encoder
            in_cand = np.copy(np.concatenate(temp_in, axis=0))
            out_cand = np.copy(seg_by_time[i + t_in])
            
            last_enc_t = in_cand[..., 1][-1]
            in_cand[..., 1] -= last_enc_t
            out_cand[..., 1] -= last_enc_t
            in_.append(in_cand)
            out_.append(out_cand)
            self.day_record.append(day)
            
            # reverse it
            if reverse:
                out_cand = np.copy(seg_by_time[i])
                temp_in = temp_in[1:]
                temp_in.append(seg_by_time[i + t_in])
                in_cand = np.copy(np.concatenate(temp_in, axis=0))
                last_enc_t = in_cand[..., 1][0]
                in_cand[..., 1] -= last_enc_t
                out_cand[..., 1] -= last_enc_t
                
                in_.append(in_cand)
                out_.append(out_cand)
                self.day_record.append(day)
        
        
        return in_, out_
    
    def proc_custom(self, fold):
        
        self.day_record = []
        
        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}
        
        for day in range(self.nTrainStartDay, self.nTestStartDay):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            train_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            train_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day
            
            seg_in, seg_out = self.make_data_by_time(train_in, day = day)
            
            train_data['input'] += seg_in
            train_data['target'] += seg_out            
        
        
        
        seg_in, seg_out = self.make_data_by_time(train_out)
        train_data['input'] += seg_in
        train_data['target'] += seg_out
            
        
        for day in range(self.nTestStartDay, self.nTotalDays+1):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            test_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            test_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day

            seg_in, seg_out = self.make_data_by_time(test_in, reverse = False)
            
            test_data['input'] += seg_in
            test_data['target'] += seg_out            
            
        seg_in, seg_out = self.make_data_by_time(test_out, reverse = False)
        test_data['input'] += seg_in
        test_data['target'] += seg_out

        self.get_normalize_params(train_data['target']) 
        self.get_spatial_norm_parameters(train_data['target'])
            
        if self.train:
            data = train_data['input']
            target = train_data['target']
            print(len(data), len(target))
            
        else:
            data = test_data['input']
            target = test_data['target']
            print(len(data), len(target))

        return data, target        
    
    

    def process_np(self, fold, date):
        tmStart = datetime.datetime.now()
        train_input,train_output,test_input,test_output = self.return_data_time(fold=fold, data=date, with_scaling=True)
        return train_input,train_output,test_input,test_output
    
    def return_data_time(self, fold, data, with_scaling):
        train_input = None
        if 'A' in self.mode_t or 'B' in self.mode_t:
            for idx,dt in enumerate(data[:-1]):
                for suffix in self.train_suffix:
                    input = pd.read_csv(self.data_dir+'/'+dt+'_f'+str(fold)+'_'+suffix+'.csv')
                    # if self.temporal_scaling:
                    #     input.dateTime += idx * 24 * 60
                    train_input = pd.concat((train_input, input))
                    
        if 'C' in self.mode_t:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            # if self.temporal_scaling:
            #     input.dateTime += (len(data)-1) * 24 * 60
            train_input = pd.concat((train_input, input))

        test_input = pd.read_csv(self.data_dir+'/'+data[-1]+'_f'+str(fold)+'_test.csv')
        
        if 'C' in self.mode_p:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            test_input = pd.concat((input, test_input))
            
        # if self.temporal_scaling:
        #     test_input.dateTime += (len(data)-1) * 24 * 60

        return self.return_data_0(train_input, test_input, with_scaling)

    
    def return_data_0(self, train_input, test_input, with_scaling):
        train_output = np.array(train_input['pm2_5'])
        train_input = train_input[['dateTime','lat','long']]
        test_output = np.array(test_input['pm2_5'])
        test_input = test_input[['dateTime','lat','long']]

        # if with_scaling:
        #     scaler = MinMaxScaler().fit(train_input)
        #     if self.temporal_scaling:
        #         data = scaler.transform(pd.concat((train_input, test_input)))
        #         test_input = data[len(train_input):]
        #         train_input = data[:len(train_input)]
        #     else:
        #         train_input = scaler.transform(train_input)
        #         test_input = scaler.transform(test_input)
        return train_input,train_output,test_input,test_output

    def set_target_fold(self, fold=0):
        self.fold = fold
        print('target fold set to {}'.format(self.fold))
        
    def normalize_z(self, arr):
        return (arr - self.mean) / self.std
    
    def normalize(self, data, min_, max_):
        return (data - min_) / (max_ - min_)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om