In [1]:
import lmdb
import datetime
import argparse
import pandas as pd
import numpy as np
import random
import re
from collections import defaultdict
from typing import List, Tuple, Union

import scipy.io
import pickle
import numpy as np
import os
import h5py

import torch
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from tqdm import tqdm

In [2]:
def to_tensor(array):
    return torch.from_numpy(array).float()

In [3]:
def random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f'set seed {seed} is done')

In [4]:
KeyT = Union[str, bytes, bytearray]

# two type of TUEG sub id pattern handle both
_KEY_RE = re.compile(
    r"^TUEG-(?:\d+_)?(?P<sub_id>[A-Za-z]+)_s\d+_t\d+_\d+$"
)

def _decode_key(k: KeyT) -> str:
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

def _extract_sub_id(k: KeyT) -> str:
    s = _decode_key(k)
    m = _KEY_RE.match(s)
    if m is None:
        raise ValueError(f"Key does not match expected patterns: {s}")
    return m.group("sub_id")


def train_test_split_by_fold_num(
    fold_num: int,
    lmdb_keys: List[KeyT],
    maxFold: int,
    split_by_sub: bool = True,
    seed: int = 41
) -> Tuple[List[KeyT], List[KeyT]]:
    """
    True k-fold cross-validation split.

    Args:
        fold_num: test fold index (0 <= fold_num < maxFold)
        lmdb_keys: LMDB key list
        maxFold: total number of folds (k)
        split_by_sub: True → subject-wise k-fold, False → key-wise k-fold

    Returns:
        train_key_list, test_key_list
    """
    if maxFold < 2:
        raise ValueError("maxFold must be >= 2.")
    if fold_num < 0 or fold_num >= maxFold:
        raise ValueError(f"fold_num must be in [0, {maxFold-1}]")

    keys = list(lmdb_keys)

    rng = np.random.default_rng(seed)

    if split_by_sub:
        # -------- subject-wise k-fold --------
        sub_to_keys = defaultdict(list)
        invalid = []

        for k in keys:
            try:
                sid = _extract_sub_id(k)
                sub_to_keys[sid].append(k)
            except ValueError:
                invalid.append(_decode_key(k))

        if invalid:
            ex = "\n".join(invalid[:10])
            raise ValueError(
                f"Found {len(invalid)} invalid keys. Examples:\n{ex}"
            )

        subjects = np.array(list(sub_to_keys.keys()), dtype=object)
        rng.shuffle(subjects)

        subj_folds = np.array_split(subjects, maxFold)
        test_subjects = set(subj_folds[fold_num].tolist())

        train_keys, test_keys = [], []
        for sid, ks in sub_to_keys.items():
            (test_keys if sid in test_subjects else train_keys).extend(ks)

        return train_keys, test_keys

    else:
        # -------- key-wise k-fold --------
        idx = np.arange(len(keys))
        rng.shuffle(idx)

        folds = np.array_split(idx, maxFold)
        test_idx = set(folds[fold_num].tolist())

        train_keys = [keys[i] for i in idx if i not in test_idx]
        test_keys  = [keys[i] for i in idx if i in test_idx]

        return train_keys, test_keys

In [None]:
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb
# 아현썜 pscratch의 데이터 경로 당장은 그냥 써도 되지만 추후 내 pscratch나 m4727 등으로 옮겨서 사용할 것

class TUEG_for_SOLID_from_lmdb(Dataset):
    def __init__(
            self,
            lmdb_dir: str,
            maxfold: int,
            targetfold: int,
            seed: int,
            train: bool,
            split_by_sub: bool,
    ):
        random_seed(seed)
        self.seed = seed
        self.lmdb_dir = lmdb_dir
        self.db = lmdb.open(lmdb_dir, readonly=True, lock=False, readahead=True, meminit=False)
        with self.db.begin(write=False) as txn:
            self.lmdb_keys = pickle.loads(txn.get('__keys__'.encode()))

        self.train = train
        self.split_by_sub = split_by_sub

        self.maxfold = maxfold
        self.targetfold = targetfold
        self.data, self.target = self.make_data_and_target_by_fold(self.targetfold, self.lmdb_keys, 
                                                                   self.maxfold, self.split_by_sub, self.seed)

    def make_data_and_target_by_fold(self, fold, lmdb_keys, maxfold, split_by_sub, seed):
        self.record = []

        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}

        train_data_keys_in_lmdb, test_data_keys_in_lmdb = train_test_split_by_fold_num(fold, lmdb_keys, maxfold, split_by_sub, seed)


        if self.train:
            for train_data_key in train_data_keys_in_lmdb:

                # TODO : get proper seg_in and seg_out by input idx
                seg_in, seg_out = self.segmentation_from_idx(train_data_key, self.db)

                train_data['input'] += seg_in
                train_data['target'] += seg_out

                data = train_data['input']
                target = train_data['target']
        
        else:
            for test_data_key in test_data_keys_in_lmdb:

                seg_in, seg_out = self.segmentation_from_idx(test_data_key, self.db)

                test_data['input'] += seg_in
                test_data['target'] += seg_out

                data = test_data['input']
                target = test_data['target']


        return data, target

    def lmdb_get(self, env, key):
        if isinstance(key, str):
            key = key.encode("utf-8")
        with env.begin(write=False) as txn:
            v = txn.get(key)
        if v is None:
            raise KeyError(f"Key not found: {key}")
        return pickle.loads(v)

    def segmentation_from_idx(self, key, lmdb_db):
        sample_for_key = self.lmdb_get(lmdb_db, key)

        channel_name = sample_for_key['data_info']['channel_name'] # (C,) list
        eeg_data = sample_for_key['sample'] # (C, T, Fs)
        eeg_data_ = rearrange(eeg_data, 'c t f -> c (t f)')
        
        seg_in = None
        seg_out = None
        return seg_in, seg_out

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        # TODO : 이거 meta 줄 떄 time은 time이고 spatial을 아예 grid에 맞게 주는게 좋을 듯 / Grid 여기서 받게 하자
        i = None
        o = None
        im = None
        om = None
        return i, o, im, om

In [None]:
TORCHEEG_2DGRID = [
    ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'],
    ['-', '-', '-', '-', 'FP1', 'FPZ', 'FP2', '-', '-', '-', '-'],
    ['-', '-', 'AF7', '-', 'AF3', 'AFZ', 'AF4', '-', 'AF8', '-', '-'],
    ['F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10'],
    ['FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10'], 
    ['T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10'],
    ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10'], 
    ['P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10'],
    ['-', '-', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', '-', '-'],
    ['-', '-', '-', 'CB1', 'O1', 'OZ', 'O2', 'CB2', '-', '-', '-'],
    ['-', '-', '-', '-', '-', 'IZ', '-', '-', '-', '-', '-']
    ]

In [None]:
class EEGToGrid(Dataset):
    def __init__(self, base_dataset,):
        self.base_dataset = base_dataset
        self.mean = float(self.base_dataset.mean)
        self.std = float(self.base_dataset.std)

    def TorchEEG_Grid(self, channel_list, grid_templete=TORCHEEG_2DGRID, H=11, W=11):
        """
        2D Grid based on TorchEEG 2D Grid
        input 10-10 coord channel name index 
        output is grid of channel input
        """
        grid = torch.zeros(H, W, dtype=torch.float32)
        mask = torch.zeros(H, W, dtype=torch.float32)
        return grid, mask

    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        i,o,im,om = self.base[idx]

        target_grid = None
        target_mask = None
        cond = None

        return target_grid, target_mask, cond, self.mean, self.std

In [None]:
# Set some future argparse

# TUEG_1.0 path in lucy's pscratch
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb
LMDB_DIR = "/pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb"
BATCH_SIZE = 64

In [None]:
# Set Train and Test dataset and dataloader

train_eeg = TUEG_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=True,)
test_eeg = TUEG_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=False,
                         )

train_set = EEGToGrid(train_eeg)
test_set = EEGToGrid(test_eeg)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_worker=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_worker=2, pin_memory=True)

In [None]:
LMDB_DIR = "/pscratch/sd/t/tylee/Dataset/1109_Physio_500Hz"


In [None]:
class EEG_from_lmdb(Dataset):
    def __init__(self, data_dir, transform, return_info):
        self.data_dir = data_dir
        self.transform = transform
        self.return_info = return_info

    def lmdb_to_data(self, idx):
        self.db = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=True, meminit=False)
        key = self.keys[idx]
        with self.db.begin(write=False) as txn:
            pair = pickle.loads(txn.get(key.encode()))
        data = pair['sample']
        label = pair['label']
        data_info = pair.get('data_info', {})
        
        data = to_tensor(data)
        if self.transform is not None:
            data = self.transform(data)
        if self.return_info:
            return data/100, label, data_info
        else:
            return data/100, label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om

In [None]:
# Implementing torch.dataset

def xform_day(day):
    arr = [0, 30, 61]
    w = 0 if day <= 30 else 1 if day <= 61 else 2
    mon = ['2020-11-', '2020-12-', '2021-01-'][w]
    date = mon + '{:02d}'.format(day - arr[w])
    return date


def get_suffixes(mode):
    suffixes = []
    if 'C' in mode or 'A' in mode:
        suffixes.append('train')
    if 'D' in mode or 'B' in mode:
        suffixes.append('test')
    return suffixes

def rename_cols(data):
    data.rename(
        columns={'dateTime': 'time', 'lat': 'latitude', 'long': 'longitude', 'pm2_5': 'PM25_Concentration',
                 'pm10': 'PM10_Concentration'}, inplace=True)



def torch1dgrid(num, bot=0, top=1):
    arr = torch.linspace(bot, top, steps=num)
    mesh = torch.stack([arr], dim=1)
    return mesh.squeeze(-1)

import torch
from torch.utils.data import Dataset
from einops import rearrange        
class Delhi(Dataset):
    def __init__(
        self, mode_t, mode_p, canada, train_days, 
        maxFolds = 5, target_fold = 0, temporal_scaling=1, spatiotemporal=1, data_dir='/pscratch/sd/d/dpark1/AirDelhi/delhi/processed', 
        seed=10, nTrainStartDay = 15, nTestStartDay = 75, nTotalDays = 91, train=True):
        
        self.mode_t = mode_t
        self.mode_p = mode_p
        self.train_days = train_days
        self.train = train
        self.maxFolds = maxFolds
        self.target_fold = target_fold
        self.temporal_scaling = temporal_scaling
        self.spatiotemporal = spatiotemporal
        self.data_dir = data_dir
        self.nTestStartDay = nTestStartDay
        self.nTrainStartDay = nTrainStartDay
        self.nTotalDays = nTotalDays
        
        np.random.seed(seed)        
        
        self.train_suffix = get_suffixes(mode_t)
                
        if spatiotemporal < 0 and mode_t == 'AB' and mode_p == 'CD':
            # Forecasting, single fold is enough
            maxFolds = 1
    
        self.folds = [i for i in range(maxFolds)]
        
        self.data, self.target = self.proc_custom(target_fold)
        
        
        
    def get_normalize_params(self, target):
        all_signal = []
        for a in target:
            all_signal += list(a[..., 0])
        self.mean, self.std = np.array(all_signal).mean(), np.array(all_signal).std()
    
    def get_spatial_norm_parameters(self, arr_of_days):
        """"minmax normalization"""
        latmin = 10e10
        latmax = -10e10
        longmin = 10e10
        longmax = -10e10
        for arr in arr_of_days:
            minned = arr.min(0)
            # print(minned[0])
            if minned[2] < latmin:
                latmin = minned[2]
            if minned[3] < longmin:
                longmin = minned[3]
                
            maxed = arr.max(0)
            # print(maxed)
            if maxed[2] > latmax:
                latmax = maxed[2]
            if maxed[3] > longmax:
                longmax = maxed[3]
        self.latmin, self.latmax, self.longmin, self.longmax =latmin, latmax, longmin, longmax
    
    def make_data_by_time(self, arr_of_days, t_in = 9, reverse=False, day = 0):
        seg_by_time = []
        uniq_times = np.unique(arr_of_days[..., 1])
        
        for t in uniq_times:
            idx_ = arr_of_days[..., 1] == t
            seg_by_time.append(arr_of_days[idx_])
        
        in_ = []
        out_ = []
        for i in range(len(seg_by_time) - t_in):
            temp_in = []
            for t_ in range(t_in):
                temp_in.append(seg_by_time[i + t_])
            
            # normalize time to relative scale by the last one of the encoder
            in_cand = np.copy(np.concatenate(temp_in, axis=0))
            out_cand = np.copy(seg_by_time[i + t_in])
            
            last_enc_t = in_cand[..., 1][-1]
            in_cand[..., 1] -= last_enc_t
            out_cand[..., 1] -= last_enc_t
            in_.append(in_cand)
            out_.append(out_cand)
            self.day_record.append(day)
            
            # reverse it
            if reverse:
                out_cand = np.copy(seg_by_time[i])
                temp_in = temp_in[1:]
                temp_in.append(seg_by_time[i + t_in])
                in_cand = np.copy(np.concatenate(temp_in, axis=0))
                last_enc_t = in_cand[..., 1][0]
                in_cand[..., 1] -= last_enc_t
                out_cand[..., 1] -= last_enc_t
                
                in_.append(in_cand)
                out_.append(out_cand)
                self.day_record.append(day)
        
        
        return in_, out_
    
    def proc_custom(self, fold):
        
        self.day_record = []
        
        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}
        
        for day in range(self.nTrainStartDay, self.nTestStartDay):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            train_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            train_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day
            
            seg_in, seg_out = self.make_data_by_time(train_in, day = day)
            
            train_data['input'] += seg_in
            train_data['target'] += seg_out            
        
        
        
        seg_in, seg_out = self.make_data_by_time(train_out)
        train_data['input'] += seg_in
        train_data['target'] += seg_out
            
        
        for day in range(self.nTestStartDay, self.nTotalDays+1):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            test_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            test_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day

            seg_in, seg_out = self.make_data_by_time(test_in, reverse = False)
            
            test_data['input'] += seg_in
            test_data['target'] += seg_out            
            
        seg_in, seg_out = self.make_data_by_time(test_out, reverse = False)
        test_data['input'] += seg_in
        test_data['target'] += seg_out

        self.get_normalize_params(train_data['target']) 
        self.get_spatial_norm_parameters(train_data['target'])
            
        if self.train:
            data = train_data['input']
            target = train_data['target']
            print(len(data), len(target))
            
        else:
            data = test_data['input']
            target = test_data['target']
            print(len(data), len(target))

        return data, target        
    
    

    def process_np(self, fold, date):
        tmStart = datetime.datetime.now()
        train_input,train_output,test_input,test_output = self.return_data_time(fold=fold, data=date, with_scaling=True)
        return train_input,train_output,test_input,test_output
    
    def return_data_time(self, fold, data, with_scaling):
        train_input = None
        if 'A' in self.mode_t or 'B' in self.mode_t:
            for idx,dt in enumerate(data[:-1]):
                for suffix in self.train_suffix:
                    input = pd.read_csv(self.data_dir+'/'+dt+'_f'+str(fold)+'_'+suffix+'.csv')
                    # if self.temporal_scaling:
                    #     input.dateTime += idx * 24 * 60
                    train_input = pd.concat((train_input, input))
                    
        if 'C' in self.mode_t:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            # if self.temporal_scaling:
            #     input.dateTime += (len(data)-1) * 24 * 60
            train_input = pd.concat((train_input, input))

        test_input = pd.read_csv(self.data_dir+'/'+data[-1]+'_f'+str(fold)+'_test.csv')
        
        if 'C' in self.mode_p:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            test_input = pd.concat((input, test_input))
            
        # if self.temporal_scaling:
        #     test_input.dateTime += (len(data)-1) * 24 * 60

        return self.return_data_0(train_input, test_input, with_scaling)

    
    def return_data_0(self, train_input, test_input, with_scaling):
        train_output = np.array(train_input['pm2_5'])
        train_input = train_input[['dateTime','lat','long']]
        test_output = np.array(test_input['pm2_5'])
        test_input = test_input[['dateTime','lat','long']]

        # if with_scaling:
        #     scaler = MinMaxScaler().fit(train_input)
        #     if self.temporal_scaling:
        #         data = scaler.transform(pd.concat((train_input, test_input)))
        #         test_input = data[len(train_input):]
        #         train_input = data[:len(train_input)]
        #     else:
        #         train_input = scaler.transform(train_input)
        #         test_input = scaler.transform(test_input)
        return train_input,train_output,test_input,test_output

    def set_target_fold(self, fold=0):
        self.fold = fold
        print('target fold set to {}'.format(self.fold))
        
    def normalize_z(self, arr):
        return (arr - self.mean) / self.std
    
    def normalize(self, data, min_, max_):
        return (data - min_) / (max_ - min_)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om