In [2]:
import lmdb
import datetime
import argparse
import pandas as pd
import numpy as np
import random
import re
from collections import defaultdict
from typing import List, Tuple, Union

import scipy.io
import pickle
import numpy as np
import os
import h5py

import torch
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from tqdm import tqdm

In [3]:
def to_tensor(array):
    return torch.from_numpy(array).float()

In [4]:
def random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f'set seed {seed} is done')

In [5]:
KeyT = Union[str, bytes, bytearray]

_KEY_RE = re.compile(
   r"^S(?P<sub_id>\d{3})R\d{2}-\d+$"
)

# S012R04-21

def _decode_key(k: KeyT) -> str:
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

def _extract_sub_id(k: KeyT) -> str:
    s = _decode_key(k)
    m = _KEY_RE.match(s)
    if m is None:
        raise ValueError(f"Key does not match expected patterns: {s}")
    return m.group("sub_id")


def train_test_split_by_fold_num(
    fold_num: int,
    lmdb_keys: List[KeyT],
    maxFold: int,
    split_by_sub: bool = True,
    seed: int = 41
) -> Tuple[List[KeyT], List[KeyT]]:
    """
    True k-fold cross-validation split.

    Args:
        fold_num: test fold index (0 <= fold_num < maxFold)
        lmdb_keys: LMDB key list
        maxFold: total number of folds (k)
        split_by_sub: True → subject-wise k-fold, False → key-wise k-fold

    Returns:
        train_key_list, test_key_list
    """
    if maxFold < 2:
        raise ValueError("maxFold must be >= 2.")
    if fold_num < 0 or fold_num >= maxFold:
        raise ValueError(f"fold_num must be in [0, {maxFold-1}]")

    keys = list(lmdb_keys)

    rng = np.random.default_rng(seed)

    if split_by_sub:
        # -------- subject-wise k-fold --------
        sub_to_keys = defaultdict(list)
        invalid = []

        for k in keys:
            try:
                sid = _extract_sub_id(k)
                sub_to_keys[sid].append(k)
            except ValueError:
                invalid.append(_decode_key(k))

        if invalid:
            ex = "\n".join(invalid[:10])
            raise ValueError(
                f"Found {len(invalid)} invalid keys. Examples:\n{ex}"
            )

        subjects = np.array(list(sub_to_keys.keys()), dtype=object)
        rng.shuffle(subjects)

        subj_folds = np.array_split(subjects, maxFold)
        test_subjects = set(subj_folds[fold_num].tolist())

        train_keys, test_keys = [], []
        for sid, ks in sub_to_keys.items():
            (test_keys if sid in test_subjects else train_keys).extend(ks)

        return train_keys, test_keys

    else:
        # -------- key-wise k-fold --------
        idx = np.arange(len(keys))
        rng.shuffle(idx)

        folds = np.array_split(idx, maxFold)
        test_idx = set(folds[fold_num].tolist())

        train_keys = [keys[i] for i in idx if i not in test_idx]
        test_keys  = [keys[i] for i in idx if i in test_idx]

        return train_keys, test_keys

In [6]:
LMDB = '/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID'

DB = lmdb.open(LMDB, readonly=True, lock=False, readahead=True, meminit=False)
with DB.begin(write=False) as txn:
    KEYS = pickle.loads(txn.get('__keys__'.encode()))

In [19]:
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb
# 아현썜 pscratch의 데이터 경로 당장은 그냥 써도 되지만 추후 내 pscratch나 m4727 등으로 옮겨서 사용할 것

class Physio_for_SOLID_from_lmdb(Dataset):
    def __init__(
            self,
            lmdb_dir: str,
            maxfold: int,
            targetfold: int,
            seed: int,
            train: bool,
            split_by_sub: bool,
    ):
        random_seed(seed)
        self.seed = seed
        self.lmdb_dir = lmdb_dir
        self.db = lmdb.open(lmdb_dir, readonly=True, lock=False, readahead=True, meminit=False)
        with self.db.begin(write=False) as txn:
            self.lmdb_keys = pickle.loads(txn.get('__keys__'.encode()))

        self.train = train
        self.split_by_sub = split_by_sub

        self.maxfold = maxfold
        self.targetfold = targetfold
        self.data, self.target, self.data_meta, self.target_meta = self.make_data_and_target_by_fold(self.targetfold, self.lmdb_keys, 
                                                                   self.maxfold, self.split_by_sub, self.seed)

    def make_data_and_target_by_fold(self, fold, lmdb_keys, maxfold, split_by_sub, seed):
        self.record = []

        train_data = {'input': [], 'target': [], 'input_meta': [], 'target_meta': []}
        test_data  = {'input': [], 'target': [], 'input_meta': [], 'target_meta': []}

        train_data_keys_in_lmdb, test_data_keys_in_lmdb = train_test_split_by_fold_num(fold, lmdb_keys, maxfold, split_by_sub, seed)


        if self.train:
            for train_data_key in train_data_keys_in_lmdb:

                # TODO : get proper seg_in and seg_out by input idx
                seg_in, seg_out, seg_in_meta, seg_out_meta = self.segmentation_from_idx(train_data_key, self.db)

                train_data['input'] += seg_in
                train_data['target'] += seg_out
                train_data['input_meta'] += seg_in_meta
                train_data['target_meta']+= seg_out_meta

            return (train_data['input'], train_data['target'],
                    train_data['input_meta'], train_data['target_meta'])
        
        else:
            for test_data_key in test_data_keys_in_lmdb:

                seg_in, seg_out, seg_in_meta, seg_out_meta = self.segmentation_from_idx(test_data_key, self.db)

                test_data['input'] += seg_in
                test_data['target'] += seg_out
                test_data['input_meta']  += seg_in_meta
                test_data['target_meta'] += seg_out_meta
                
            return (test_data['input'], test_data['target'],
                    test_data['input_meta'], test_data['target_meta'])



    def lmdb_get(self, env, key):
        if isinstance(key, str):
            key = key.encode("utf-8")
        with env.begin(write=False) as txn:
            v = txn.get(key)
        if v is None:
            raise KeyError(f"Key not found: {key}")
        return pickle.loads(v)

    def segmentation_from_idx(self, key, lmdb_db, in_len=99, out_len=1, stride=1):
        """
        eeg_data_ : (C, L)  e.g. (64, 800)

        sliding:
          input  = eeg[:, t:t+99]
          target = eeg[:, t+99:t+100]  (1 step)
        """
        sample_for_key = self.lmdb_get(lmdb_db, key)

        channel_name = sample_for_key['data_info']['channel_names']  # len=C
        eeg_data = sample_for_key['sample']                         # (C, T, Fs)
        eeg_data_ = rearrange(eeg_data, 'c t f -> c (t f)')         # (C, L)
        eeg_data_ = torch.from_numpy(eeg_data_).to(torch.float32)

        C, L = eeg_data_.shape
        total_needed = in_len + out_len

        if L < total_needed:
            return [], [], [], []

        seg_in_list = []
        seg_out_list = []
        seg_in_meta_list = []
        seg_out_meta_list = []

        for t in range(0, L - total_needed + 1, stride):
            x = eeg_data_[:, t : t + in_len]                         # (C, 99)
            y = eeg_data_[:, t + in_len : t + in_len + out_len]      # (C, 1)

            seg_in_list.append(x)
            seg_out_list.append(y)

            seg_in_meta_list.append(channel_name)
            seg_out_meta_list.append(channel_name)

        return seg_in_list, seg_out_list, seg_in_meta_list, seg_out_meta_list

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        i  = self.data[idx]
        o  = self.target[idx]
        im = self.data_meta[idx]
        om = self.target_meta[idx]
        return i, o, im, om

In [8]:
TORCHEEG_2DGRID = [
    ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'],
    ['-', '-', '-', '-', 'FP1', 'FPZ', 'FP2', '-', '-', '-', '-'],
    ['-', '-', 'AF7', '-', 'AF3', 'AFZ', 'AF4', '-', 'AF8', '-', '-'],
    ['F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10'],
    ['FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10'], 
    ['T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10'],
    ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10'], 
    ['P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10'],
    ['-', '-', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', '-', '-'],
    ['-', '-', '-', 'CB1', 'O1', 'OZ', 'O2', 'CB2', '-', '-', '-'],
    ['-', '-', '-', '-', '-', 'IZ', '-', '-', '-', '-', '-']
    ]

In [9]:
def build_channel_to_rc(grid_2d):
    ch2rc = {}
    H = len(grid_2d)
    W = len(grid_2d[0])
    for r in range(H):
        for c in range(W):
            ch = grid_2d[r][c]
            if ch != '-' and ch is not None:
                ch2rc[str(ch).strip().upper()] = (r, c)
    return ch2rc, H, W

CHANNEL_TO_RC, H, W = build_channel_to_rc(TORCHEEG_2DGRID)


In [10]:
def splat_eeg_grid(eeg_cl, channel_names, channel_to_rc=CHANNEL_TO_RC, H=H, W=W):
    """
    eeg_cl: (C, L) torch.Tensor (권장)
    channel_names: list[str] len=C
    returns:
      grid: (L, H, W)
      mask: (H, W)
    """
    if not torch.is_tensor(eeg_cl):
        eeg_cl = torch.as_tensor(eeg_cl)

    assert eeg_cl.dim() == 2, f"Expected (C,L), got {tuple(eeg_cl.shape)}"
    C, L = eeg_cl.shape
    device = eeg_cl.device

    grid = torch.zeros((L, H, W), dtype=eeg_cl.dtype, device=device)
    cnt  = torch.zeros((H, W), dtype=torch.float32, device=device)

    for ci in range(C):
        ch = str(channel_names[ci]).strip().upper()
        if ch in channel_to_rc:
            r, c = channel_to_rc[ch]
            grid[:, r, c] += eeg_cl[ci, :]
            cnt[r, c] += 1.0

    mask = (cnt > 0).float()
    grid = torch.where(cnt > 0, grid / torch.clamp(cnt, min=1.0), grid)
    return grid, mask


In [11]:
class EEGToGridCtx9(Dataset):
    """
    base[idx] -> (i, o, im, om)
      i : (C, 99)   (torch or numpy)
      o : (C, 1)
      im: channel list
      om: channel list

    return:
      x0       : (1,H,W)
      tgt_mask : (1,H,W)
      cond     : (20,H,W) = [lat_map, lon_map, past_grids(9), past_masks(9)]
      mean, std
    """
    def __init__(self, base_dataset, squash_tanh=True, channel_to_rc=CHANNEL_TO_RC):
        self.base = base_dataset
        self.squash = squash_tanh
        self.channel_to_rc = channel_to_rc

        self.mean = float(getattr(self.base, "mean", 0.0))
        self.std  = float(getattr(self.base, "std",  1.0))

        lat = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
        lon = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
        self.lat_map = lat
        self.lon_map = lon

        self.ctx_steps = 9
        self.in_len = 99
        assert self.in_len % self.ctx_steps == 0, "can not dividable"
        self.bin = self.in_len // self.ctx_steps  # 11

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        i, o, im, om = self.base[idx]  # i:(C,99), o:(C,1), im/om: channel list

        if not torch.is_tensor(i): i = torch.as_tensor(i)
        if not torch.is_tensor(o): o = torch.as_tensor(o)

        # ---- build past 9 step grids/masks from encoder ----
        past_grids, past_masks = [], []

        for k in range(self.ctx_steps):
            seg = i[:, k*self.bin:(k+1)*self.bin]      # (C,11)
            vals = seg.mean(dim=1, keepdim=True)       # (C,1)

            grid_k, mask_k = splat_eeg_grid(vals, im, self.channel_to_rc, H, W)  # grid:(1,H,W)
            grid_k = grid_k.squeeze(0)  # (H,W)

            if self.squash:
                grid_k = torch.tanh(grid_k)

            past_grids.append(grid_k)
            past_masks.append(mask_k)

        past_grids = torch.stack(past_grids, 0)  # (9,H,W)
        past_masks = torch.stack(past_masks, 0)  # (9,H,W)

        # ---- target (next slot) grid/mask ----
        tgt_grid, tgt_mask = splat_eeg_grid(o, om, self.channel_to_rc, H, W)  # tgt_grid:(1,H,W)
        tgt_grid = tgt_grid.squeeze(0)    # (H,W)

        x0 = torch.tanh(tgt_grid) if self.squash else tgt_grid

        # ---- cond channels ----
        cond = torch.cat([
            self.lat_map.unsqueeze(0),  # (1,H,W)
            self.lon_map.unsqueeze(0),  # (1,H,W)
            past_grids,                 # (9,H,W)
            past_masks                  # (9,H,W)
        ], dim=0)  # (20,H,W)

        return x0.unsqueeze(0), tgt_mask.unsqueeze(0), cond, self.mean, self.std


In [28]:
sample_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=True,
                                            split_by_sub=True)

set seed 41 is done


In [33]:
grid_ds = EEGToGridCtx9(sample_dataset)

x, mask, cond, mean, std = grid_ds[0]
print(x.shape, mask.shape, cond.shape, mean, std)

torch.Size([1, 11, 11]) torch.Size([1, 11, 11]) torch.Size([20, 11, 11]) 0.0 1.0


In [20]:
sample_train_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=True,
                                            split_by_sub=True)

sample_test_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=False,
                                            split_by_sub=True)

train_grid_dataset = EEGToGridCtx9(sample_train_dataset)
test_grid_dataset = EEGToGridCtx9(sample_test_dataset)

set seed 41 is done
set seed 41 is done


In [42]:
train_loader = DataLoader(train_grid_dataset, batch_size=16, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_grid_dataset, batch_size=16, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

In [43]:
import os, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image, make_grid
from einops import rearrange
from tqdm import tqdm
import matplotlib.pyplot as plt

In [44]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RESULT_DIR = '/pscratch/sd/t/tylee/SOLID_EEG_RESULT'
os.makedirs(RESULT_DIR, exist_ok=True)

BATCH_SIZE = 16
LR = 2e-4
TIME_STEPS = 1000                  # diffusion T
TOTAL_STEPS = 150_000
LOG_EVERY  = 200
EVAL_EVERY = 1000
SAVE_SAMPLES_EVERY = 1000
BG_WEIGHT = 0   

In [47]:
# Model implementation from Kevin
print(len(train_grid_dataset))
print(len(test_grid_dataset))

# ============================================================
# 3) UNet (no attention; rectangular-friendly)
# ============================================================
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        dim_out = dim if dim_out is None else dim_out
        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim else None
        self.norm1 = nn.GroupNorm(groups, dim);     self.conv1 = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm2 = nn.GroupNorm(groups, dim_out); self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding=1)
        self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
        self.act = nn.SiLU()
        self.res = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    def forward(self, x, t_emb=None):
        h = self.conv1(self.act(self.norm1(x)))
        if self.mlp is not None and t_emb is not None:
            h = h + self.mlp(t_emb)[..., None, None]
        h = self.conv2(self.dropout(self.act(self.norm2(h))))
        return h + self.res(x)


class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, base_dim):
        super().__init__()
        self.out_dim = base_dim
    def forward(self, t):  # t: (B,)
        # classic transformer-style PE on scalar t
        half = self.out_dim // 2
        device = t.device
        freqs = torch.exp(torch.arange(half, device=device).float()
                          * -(math.log(10000.0) / max(1, half-1)))
        ang = t.float().unsqueeze(1) * freqs.unsqueeze(0)  # (B, half)
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=1)  # (B, 2*half)
        if emb.shape[1] < self.out_dim:
            emb = F.pad(emb, (0, self.out_dim - emb.shape[1]))
        return emb

class TimeMLP(nn.Module):
    def __init__(self, base_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, base_dim*4), nn.SiLU(),
            nn.Linear(base_dim*4, base_dim*4)
        )
    def forward(self, t):  # (B,)
        return self.net(t[:,None].float())

class UNet(nn.Module):
    def __init__(self, base_dim=128, dim_mults=(1,2,4),
                 in_channels=1+20, image_size=(H,W), dropout=0.0, groups=32):
        super().__init__()
        self.image_h, self.image_w = image_size
        self.time_dim = base_dim * 4

        # self.time_pe  = SinusoidalTimeEmbedding(base_dim)
        # self.time_mlp = nn.Sequential(
        #     nn.Linear(base_dim, self.time_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.time_dim, self.time_dim)
        # )
        self.time_mlp = TimeMLP(base_dim)
        self.init = nn.Conv2d(in_channels, base_dim, 3, padding=1)

        self.downs = nn.ModuleList()
        in_ch = base_dim
        skip_channels = []
        for li, m in enumerate(dim_mults):
            out_ch = base_dim * m
            rb1 = ResnetBlock(in_ch, out_ch, self.time_dim, dropout, groups); self.downs.append(rb1); in_ch = out_ch; skip_channels.append(in_ch)
            rb2 = ResnetBlock(in_ch, out_ch, self.time_dim, dropout, groups); self.downs.append(rb2); in_ch = out_ch; skip_channels.append(in_ch)
            if li != len(dim_mults) - 1:
                self.downs.append(nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1))

        self.mid1 = ResnetBlock(in_ch, in_ch, self.time_dim, dropout, groups)
        self.mid2 = ResnetBlock(in_ch, in_ch, self.time_dim, dropout, groups)

        self.ups, self.kinds = nn.ModuleList(), []
        sc = skip_channels.copy()
        for li, m in enumerate(reversed(dim_mults)):
            out_ch = base_dim * m
            for _ in range(2):
                skip_ch = sc.pop()
                self.ups.append(ResnetBlock(in_ch + skip_ch, out_ch, self.time_dim, dropout, groups)); self.kinds.append('res')
                in_ch = out_ch
            if li != len(dim_mults) - 1:
                self.ups.append(nn.Upsample(scale_factor=2, mode='nearest')); self.kinds.append('up')
                self.ups.append(nn.Conv2d(in_ch, in_ch, 3, padding=1));       self.kinds.append('conv')

        self.final = nn.Sequential(nn.GroupNorm(groups, in_ch), nn.SiLU(), nn.Conv2d(in_ch, 1, 3, padding=1))

    def forward(self, x_cat, t):
        # t_emb = self.time_mlp(self.time_pe(t))
        t_emb = self.time_mlp(t)
        skips, h = [], self.init(x_cat)
        for layer in self.downs:
            if isinstance(layer, ResnetBlock):
                h = layer(h, t_emb); skips.append(h)
            else:
                h = layer(h)
        h = self.mid1(h, t_emb); h = self.mid2(h, t_emb)
        for kind, layer in zip(self.kinds, self.ups):
            if kind == 'res':
                s = skips.pop()
                if s.shape[-2:] != h.shape[-2:]:
                    s = F.interpolate(s, size=h.shape[-2:], mode='nearest')
                h = layer(torch.cat([h, s], dim=1), t_emb)
            elif kind == 'up':
                h = layer(h)
            else:
                h = layer(h)
        if h.shape[-2:] != (self.image_h, self.image_w):
            h = F.interpolate(h, size=(self.image_h, self.image_w), mode='nearest')
        return self.final(h)

# ============================================================
# 4) Diffusion core — noise only target channel; cond is clean
# ============================================================
class GaussianDiffusion(nn.Module):
    def __init__(self, unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2'):
        super().__init__()
        self.unet = unet
        self.H, self.W = image_size
        self.T = time_steps
        self.loss_type = loss_type

        beta  = self.linear_beta_schedule(time_steps)
        alpha = 1. - beta
        abar  = torch.cumprod(alpha, dim=0)
        abar_prev = F.pad(abar[:-1], (1,0), value=1.)

        self.register_buffer('beta', beta)
        self.register_buffer('alpha', alpha)
        self.register_buffer('alpha_bar', abar)
        self.register_buffer('alpha_bar_prev', abar_prev)
        self.register_buffer('sqrt_alpha_bar', torch.sqrt(abar))
        self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - abar))
        self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / abar))
        self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / abar - 1))
        self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha))
        self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - abar))

    def linear_beta_schedule(self, T):
        scale = 1000 / T
        return torch.linspace(scale*1e-4, scale*2e-2, T, dtype=torch.float32)

    def q_sample(self, x0, t, noise):
        return self.sqrt_alpha_bar[t][:,None,None,None] * x0 + \
               self.sqrt_one_minus_alpha_bar[t][:,None,None,None] * noise

    def forward(self, x0, mask, cond):
        """
        x0:   (B,1,H,W) in tanh(z) space
        mask: (B,1,H,W)  (1=observed bin in target; 0=unobserved)
        cond: (B,20,H,W) = [lat, lon, past_grids(9), past_masks(9)]
        """
        b = x0.size(0)
        t = torch.randint(0, self.T, (b,), device=x0.device).long()

        noise = torch.randn_like(x0)
        x_t   = self.q_sample(x0, t, noise)
        x_cat = torch.cat([x_t, cond], dim=1)  # noised target + clean cond

        pred = self.unet(x_cat, t)  # predict noise on target channel

        if self.loss_type == 'l1':
            raw = F.l1_loss(noise, pred, reduction='none')
        elif self.loss_type == 'l2':
            raw = F.mse_loss(noise, pred, reduction='none')
        else:
            raw = F.smooth_l1_loss(noise, pred, reduction='none')

        w = mask + BG_WEIGHT  # supervise observed bins + tiny everywhere
        return (raw * w).sum() / (w.sum() + 1e-8)

    @torch.inference_mode()
    def p_sample(self, xt, cond, t, clip=True):
        bt = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
        x_cat = torch.cat([xt, cond], dim=1)
        pred_noise = self.unet(x_cat, bt)

        def bcast(x): return x.view(-1,1,1,1)
        if clip:
            x0 = bcast(self.sqrt_recip_alpha_bar[bt]) * xt - bcast(self.sqrt_recip_alpha_bar_min_1[bt]) * pred_noise
            x0 = x0.clamp(-1., 1.)
            c1 = self.beta[bt] * torch.sqrt(self.alpha_bar_prev[bt]) / (1. - self.alpha_bar[bt])
            c2 = torch.sqrt(self.alpha[bt]) * (1. - self.alpha_bar_prev[bt]) / (1. - self.alpha[bt])
            mean = bcast(c1) * x0 + bcast(c2) * xt
        else:
            mean = bcast(self.sqrt_recip_alpha[bt]) * (xt - bcast(self.beta_over_sqrt_one_minus_alpha_bar[bt]) * pred_noise)
        var = self.beta[bt] * ((1. - self.alpha_bar_prev[bt]) / (1. - self.alpha_bar[bt]))
        noise = torch.randn_like(xt) if t > 0 else 0.
        return mean + torch.sqrt(bcast(var)) * noise

    @torch.inference_mode()
    def sample(self, cond, clip=False):  # clip=False often gives crisper fields
        b = cond.size(0)
        x = torch.randn((b,1,self.H,self.W), device=cond.device)
        for t in reversed(range(self.T)):
            x = self.p_sample(x, cond, t, clip=clip)
        return x.clamp(-1, 1)

5490933
1404804


In [None]:
# ============================================================
# 6) Build model + diffusion + optimizer
# ============================================================
IN_CHANNELS = 1 + 2 + 9 + 9   # target(noised) + lat/lon + 9 past grids + 9 past masks = 21
unet = UNet(base_dim=128, dim_mults=(1,2,4), in_channels=IN_CHANNELS, image_size=(H,W)).to(DEVICE)
diffusion = GaussianDiffusion(unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2').to(DEVICE)

# ---- CosineAnnealingWarmupRestarts setup ----
from torch.optim import AdamW
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

max_lr = 4e-4
min_lr = 8e-6
TOTAL_ITERS = TOTAL_STEPS          # keep these tied
warmup_steps = max(1, int(0.1 * TOTAL_ITERS))
weight_decay = 1e-4

opt = AdamW(diffusion.parameters(), lr=max_lr, betas=(0.9, 0.999), weight_decay=weight_decay)

sched = CosineAnnealingWarmupRestarts(
    optimizer=opt,
    first_cycle_steps=TOTAL_ITERS,  # single full-length cosine cycle
    max_lr=max_lr,
    min_lr=min_lr,
    warmup_steps=warmup_steps,
    gamma=1.0                       # no decay across cycles since we use one cycle
)

# ============================================================
# 7) Utilities: invert tanh->raw and metrics
# ============================================================
def inv_tanh_to_raw(x_tanh, mean, std):
    z = torch.atanh(x_tanh.clamp(-0.999, 0.999))
    return z * std + mean

@torch.no_grad()
def eval_rmse(diffusion, loader): # add mini batch to check fast
    mse_sum_raw = 0.0; w_sum = 0.0
    mse_sum_norm = 0.0
    for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)   # tanh(z)
        m_te   = m_te.to(DEVICE)
        c_te   = c_te.to(DEVICE)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        xhat = diffusion.sample(c_te, clip=False)             # tanh(z)
        # raw µg/m³
        raw_hat = inv_tanh_to_raw(xhat,  mu_te, std_te)
        raw_gt  = inv_tanh_to_raw(x0_te, mu_te, std_te)

        mse_sum_raw  += ((raw_hat - raw_gt)**2 * m_te).sum().item()
        w_sum        += m_te.sum().item()
        # normalized (z) space RMSE (mask)
        zhat = torch.atanh(xhat.clamp(-0.999, 0.999))
        zgt  = torch.atanh(x0_te.clamp(-0.999, 0.999))
        mse_sum_norm += ((zhat - zgt)**2 * m_te).sum().item()

    rmse_raw  = math.sqrt(mse_sum_raw / max(w_sum, 1e-8))
    rmse_norm = math.sqrt(mse_sum_norm / max(w_sum, 1e-8))
    return rmse_raw, rmse_norm, int(w_sum)

@torch.no_grad()
def eval_rmse_minibatch(diffusion, loader, max_batches=None): # add mini batch to check fast
    mse_sum_raw = 0.0; w_sum = 0.0
    mse_sum_norm = 0.0
    for i, batch in enumerate(loader):
        if (max_batches is not None) and (i >= max_batches):
            break
        x0_te, m_te, c_te, mu_te, std_te = batch
    # for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)   # tanh(z)
        m_te   = m_te.to(DEVICE)
        c_te   = c_te.to(DEVICE)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        xhat = diffusion.sample(c_te, clip=False)             # tanh(z)
        # raw µg/m³
        raw_hat = inv_tanh_to_raw(xhat,  mu_te, std_te)
        raw_gt  = inv_tanh_to_raw(x0_te, mu_te, std_te)

        mse_sum_raw  += ((raw_hat - raw_gt)**2 * m_te).sum().item()
        w_sum        += m_te.sum().item()
        # normalized (z) space RMSE (mask)
        zhat = torch.atanh(xhat.clamp(-0.999, 0.999))
        zgt  = torch.atanh(x0_te.clamp(-0.999, 0.999))
        mse_sum_norm += ((zhat - zgt)**2 * m_te).sum().item()

    rmse_raw  = math.sqrt(mse_sum_raw / max(w_sum, 1e-8))
    rmse_norm = math.sqrt(mse_sum_norm / max(w_sum, 1e-8))
    return rmse_raw, rmse_norm, int(w_sum)

import math

@torch.no_grad()
def _crps_from_ensemble(y_flat, samples_flat):
    """
    y_flat:        (N,) ground-truth vector (masked entries only later)
    samples_flat:  (K,N) ensemble samples
    returns:       (N,) CRPS per entry
    """
    K = samples_flat.shape[0]
    # term1 = E|X - y|  ≈ (1/K) Σ_i |x_i - y|
    term1 = (samples_flat - y_flat.unsqueeze(0)).abs().mean(dim=0)  # (N,)
    # term2 = 0.5 * E|X - X'| ≈ 0.5 * (1/K^2) Σ_ij |x_i - x_j|
    diffs = samples_flat.unsqueeze(0) - samples_flat.unsqueeze(1)   # (K,K,N)
    term2 = 0.5 * diffs.abs().mean(dim=(0,1))                      # (N,)
    return term1 - term2                                            # (N,)

@torch.no_grad()
def eval_crps_and_points(diffusion, loader, K=10, clip=False):
    """
    Returns masked dataset-averaged:
      CRPS_raw, CRPS_norm, MAE_raw, RMSE_raw, MAE_norm, RMSE_norm, n_obs_bins
    """
    crps_raw_sum = 0.0
    crps_norm_sum = 0.0
    mae_raw_sum = 0.0
    rmse_raw_sum = 0.0
    mae_norm_sum = 0.0
    rmse_norm_sum = 0.0
    w_sum = 0.0

    for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)          # (B,1,H,W), tanh(z)
        m_te   = m_te.to(DEVICE)           # (B,1,H,W) mask
        c_te   = c_te.to(DEVICE)           # (B,20,H,W)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        B, _, H, W = x0_te.shape
        N = B*H*W
        mask_flat = m_te.view(N).bool()

        # K samples
        samples = []
        for _ in range(K):
            xhat = diffusion.sample(c_te, clip=clip)             # (B,1,H,W) tanh(z)
            samples.append(xhat)
        S = torch.stack(samples, dim=0)                          # (K,B,1,H,W)

        # normalized (z)
        z_gt  = torch.atanh(x0_te.clamp(-0.999, 0.999))          # (B,1,H,W)
        z_smp = torch.atanh(S.clamp(-0.999, 0.999))              # (K,B,1,H,W)

        # raw μg/m³
        raw_gt  = z_gt * std_te + mu_te                          # (B,1,H,W)
        raw_smp = z_smp * std_te + mu_te                         # (K,B,1,H,W)

        # flatten
        z_gt_f   = z_gt.view(N)
        raw_gt_f = raw_gt.view(N)
        z_smp_f  = z_smp.view(K, N)
        raw_smp_f= raw_smp.view(K, N)

        # CRPS (masked mean)
        crps_norm = _crps_from_ensemble(z_gt_f,   z_smp_f)[mask_flat].mean()
        crps_raw  = _crps_from_ensemble(raw_gt_f, raw_smp_f)[mask_flat].mean()
        crps_norm_sum += crps_norm.item() * mask_flat.sum().item()
        crps_raw_sum  += crps_raw.item()  * mask_flat.sum().item()

        # point forecast = ensemble mean
        z_mean   = z_smp.mean(dim=0).view(N)
        raw_mean = raw_smp.mean(dim=0).view(N)

        # MAE/RMSE (masked)
        mae_norm_sum  += (z_mean - z_gt_f).abs()[mask_flat].sum().item()
        rmse_norm_sum += ((z_mean - z_gt_f)**2)[mask_flat].sum().item()
        mae_raw_sum   += (raw_mean - raw_gt_f).abs()[mask_flat].sum().item()
        rmse_raw_sum  += ((raw_mean - raw_gt_f)**2)[mask_flat].sum().item()

        w_sum += mask_flat.sum().item()

    CRPS_raw  = crps_raw_sum  / max(w_sum, 1e-8)
    CRPS_norm = crps_norm_sum / max(w_sum, 1e-8)
    MAE_raw   = mae_raw_sum   / max(w_sum, 1e-8)
    MAE_norm  = mae_norm_sum  / max(w_sum, 1e-8)
    RMSE_raw  = math.sqrt(rmse_raw_sum  / max(w_sum, 1e-8))
    RMSE_norm = math.sqrt(rmse_norm_sum / max(w_sum, 1e-8))

    return dict(
        CRPS_raw=CRPS_raw, CRPS_norm=CRPS_norm,
        MAE_raw=MAE_raw, RMSE_raw=RMSE_raw,
        MAE_norm=MAE_norm, RMSE_norm=RMSE_norm,
        K=K, n_obs_bins=int(w_sum),
    )



# ============================================================
# 8) Training loop with periodic eval + checkpoint
# ============================================================
best_rmse = float('inf')
ckpt = os.path.join(RESULT_DIR, 'best_ctx9.pt')

pbar = tqdm(range(TOTAL_STEPS), desc="Train(ctx9)")
run_loss = 0.0
it = iter(train_loader)

for step in pbar:
    try:
        x0, msk, cond, mu, std = next(it)
    except StopIteration:
        it = iter(train_loader)
        x0, msk, cond, mu, std = next(it)

    x0   = x0.to(DEVICE, non_blocking=True)
    msk  = msk.to(DEVICE, non_blocking=True)
    cond = cond.to(DEVICE, non_blocking=True)

    opt.zero_grad(set_to_none=True)
    loss = diffusion(x0, msk, cond)
    loss.backward()
    nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)
    opt.step()
    sched.step()

    run_loss += loss.item()
    if (step+1) % LOG_EVERY == 0:
        avg = run_loss / LOG_EVERY
        run_loss = 0.0
        # grab LR robustly
        curr_lr = opt.param_groups[0]["lr"]
        pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{curr_lr:.2e}")


    if (step+1) % EVAL_EVERY == 0:
        diffusion.eval()
        with torch.inference_mode():
            # rmse_raw, rmse_norm, nobs = eval_rmse(diffusion, test_loader) # FIXME : original eval code
            rmse_raw, rmse_norm, nobs = eval_rmse_minibatch(diffusion, test_loader, max_batches=200)
            # m = eval_crps_and_points(diffusion, test_loader, K=10, clip=False)  # <-- CRPS (and friends)
        print(f"\n[Step {step+1}] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} (over {nobs} bins)")

        
        # print(f"\n[Step {step+1}] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} "
        #       f"(over {nobs} bins) | CRPS_raw={m['CRPS_raw']:.3f} µg/m³ (K={m['K']})")
        if rmse_raw < best_rmse:
            best_rmse = rmse_raw
            torch.save({'unet': unet.state_dict(),
                        'diff': diffusion.state_dict(),
                        'H': H, 'W': W}, ckpt)
            print(f"  >> Saved best checkpoint @ {ckpt} (RMSE_raw={best_rmse:.3f})")
        diffusion.train()

    if (step+1) % SAVE_SAMPLES_EVERY == 0:
        diffusion.eval()
        with torch.inference_mode():
            # grab a test batch, sample, and save plain grids (tanh -> [0,1] for viewing)
            xb, mb, cb, mub, stdb = next(iter(test_loader))
            xhat = diffusion.sample(cb.to(DEVICE)).cpu()
            vis = (xhat + 1.0) * 0.5
            save_image(vis, os.path.join(RESULT_DIR, f"samples_step{step+1}.png"), nrow=4)
        diffusion.train()

print("Done. Best Test RMSE_raw:", best_rmse)

# ============================================================
# 9) Inference + overlay: GT vs Pred (only observed bins)
# ============================================================
def overlay_panel(test_loader, model, save_path=os.path.join(RESULT_DIR, 'ctx9_overlay.png'),
                  back_img='./airdelhi_background.png'):
    try:
        back = plt.imread(back_img)
    except:
        back = None

    import matplotlib.colors as colors
    from matplotlib.colors import LinearSegmentedColormap
    cmap0 = LinearSegmentedColormap.from_list('', ['white', 'orange', 'red'])

    with torch.inference_mode():
        xb, mb, cb, mub, stdb = next(iter(test_loader))
        xhat = model.sample(cb.to(DEVICE), clip=False).cpu()  # tanh(z)
        raw_hat = inv_tanh_to_raw(xhat, mub[:,None,None,None], stdb[:,None,None,None]).squeeze(1)
        raw_gt  = inv_tanh_to_raw(xb,   mub[:,None,None,None], stdb[:,None,None,None]).squeeze(1)
        mb      = mb.squeeze(1)

    vmax = float(mub.mean() + 3*stdb.mean())
    lon_edges = np.linspace(0,1,W+1); lat_edges = np.linspace(0,1,H+1)

    B = min(8, raw_hat.size(0))
    fig, axes = plt.subplots(2, B, figsize=(3.4*B, 6.8))
    for i in range(B):
        for row, arr in enumerate([raw_gt[i].numpy(), raw_hat[i].numpy()]):
            ax = axes[row, i]
            if back is not None: ax.imshow(back, extent=[0,1,0,1], alpha=0.6)
            arr_plot = arr.copy()
            arr_plot[mb[i].numpy() == 0] = np.nan  # show only observed bins
            pm = ax.pcolormesh(lon_edges, lat_edges, arr_plot, cmap=cmap0,
                               norm=colors.Normalize(vmin=0, vmax=vmax))
            # draw grid
            for y in lat_edges: ax.plot([0,1],[y,y], c='k', lw=0.1)
            for x in lon_edges: ax.plot([x,x],[0,1], c='k', lw=0.1)
            ax.set_axis_off()
        axes[0, i].set_title("GT (obs bins)", fontsize=11)
        axes[1, i].set_title("Pred (obs bins)", fontsize=11)

    cbar = fig.colorbar(pm, ax=axes.ravel().tolist(), fraction=0.03, pad=0.02)
    cbar.ax.tick_params(labelsize=10)
    plt.tight_layout(); plt.savefig(save_path, dpi=140, bbox_inches='tight'); plt.close(fig)
    print("Saved overlays to:", save_path)

# ---- Run an overlay on current (or best-loaded) model ----
overlay_panel(test_loader, diffusion)

# ============================================================
# 10) Load best checkpoint & run full test-day eval again
# ============================================================
def load_and_eval(ckpt_path, test_loader):
    ck = torch.load(ckpt_path, map_location=DEVICE)
    unet = UNet(base_dim=128, dim_mults=(1,2,4), in_channels=IN_CHANNELS, image_size=(H,W)).to(DEVICE)
    unet.load_state_dict(ck['unet'])
    diff = GaussianDiffusion(unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2').to(DEVICE)
    diff.load_state_dict(ck['diff'])
    diff.eval()

    # rmse_raw, rmse_norm, nobs = eval_rmse(diff, test_loader) # FIXME : original eval code
    rmse_raw, rmse_norm, nobs = eval_rmse_minibatch(diff, test_loader, max_batches=200)
    print(f"[BEST] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} (over {nobs} bins)")
    overlay_panel(test_loader, diff, save_path=os.path.join(RESULT_DIR, 'ctx9_overlay_best.png'))
    return diff

# Example (after training): 
diff_best = load_and_eval(ckpt,test_loader)

Train(ctx9):   0%|          | 185/150000 [00:23<5:15:35,  7.91it/s]


KeyboardInterrupt: 

In [None]:
class EEGToGrid(Dataset):
    def __init__(self, base_dataset,):
        self.base_dataset = base_dataset
        self.mean = float(self.base_dataset.mean)
        self.std = float(self.base_dataset.std)

    def TorchEEG_Grid(self, channel_list, grid_templete=TORCHEEG_2DGRID, H=11, W=11):
        """
        2D Grid based on TorchEEG 2D Grid
        input 10-10 coord channel name index 
        output is grid of channel input
        """
        grid = torch.zeros(H, W, dtype=torch.float32)
        mask = torch.zeros(H, W, dtype=torch.float32)
        return grid, mask

    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        i,o,im,om = self.base[idx]

        target_grid = None
        target_mask = None
        cond = None

        return target_grid, target_mask, cond, self.mean, self.std

In [None]:
# Set some future argparse

# TUEG_1.0 path in lucy's pscratch
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb
LMDB_DIR = "/pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb"
BATCH_SIZE = 64

In [None]:
# Set Train and Test dataset and dataloader

train_eeg = Physio_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=True,)
test_eeg = Physio_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=False,
                         )

train_set = EEGToGrid(train_eeg)
test_set = EEGToGrid(test_eeg)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_worker=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_worker=2, pin_memory=True)

In [None]:
LMDB_DIR = "/pscratch/sd/t/tylee/Dataset/1109_Physio_500Hz"


In [None]:
class EEG_from_lmdb(Dataset):
    def __init__(self, data_dir, transform, return_info):
        self.data_dir = data_dir
        self.transform = transform
        self.return_info = return_info

    def lmdb_to_data(self, idx):
        self.db = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=True, meminit=False)
        key = self.keys[idx]
        with self.db.begin(write=False) as txn:
            pair = pickle.loads(txn.get(key.encode()))
        data = pair['sample']
        label = pair['label']
        data_info = pair.get('data_info', {})
        
        data = to_tensor(data)
        if self.transform is not None:
            data = self.transform(data)
        if self.return_info:
            return data/100, label, data_info
        else:
            return data/100, label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om

In [None]:
# Implementing torch.dataset

def xform_day(day):
    arr = [0, 30, 61]
    w = 0 if day <= 30 else 1 if day <= 61 else 2
    mon = ['2020-11-', '2020-12-', '2021-01-'][w]
    date = mon + '{:02d}'.format(day - arr[w])
    return date


def get_suffixes(mode):
    suffixes = []
    if 'C' in mode or 'A' in mode:
        suffixes.append('train')
    if 'D' in mode or 'B' in mode:
        suffixes.append('test')
    return suffixes

def rename_cols(data):
    data.rename(
        columns={'dateTime': 'time', 'lat': 'latitude', 'long': 'longitude', 'pm2_5': 'PM25_Concentration',
                 'pm10': 'PM10_Concentration'}, inplace=True)



def torch1dgrid(num, bot=0, top=1):
    arr = torch.linspace(bot, top, steps=num)
    mesh = torch.stack([arr], dim=1)
    return mesh.squeeze(-1)

import torch
from torch.utils.data import Dataset
from einops import rearrange        
class Delhi(Dataset):
    def __init__(
        self, mode_t, mode_p, canada, train_days, 
        maxFolds = 5, target_fold = 0, temporal_scaling=1, spatiotemporal=1, data_dir='/pscratch/sd/d/dpark1/AirDelhi/delhi/processed', 
        seed=10, nTrainStartDay = 15, nTestStartDay = 75, nTotalDays = 91, train=True):
        
        self.mode_t = mode_t
        self.mode_p = mode_p
        self.train_days = train_days
        self.train = train
        self.maxFolds = maxFolds
        self.target_fold = target_fold
        self.temporal_scaling = temporal_scaling
        self.spatiotemporal = spatiotemporal
        self.data_dir = data_dir
        self.nTestStartDay = nTestStartDay
        self.nTrainStartDay = nTrainStartDay
        self.nTotalDays = nTotalDays
        
        np.random.seed(seed)        
        
        self.train_suffix = get_suffixes(mode_t)
                
        if spatiotemporal < 0 and mode_t == 'AB' and mode_p == 'CD':
            # Forecasting, single fold is enough
            maxFolds = 1
    
        self.folds = [i for i in range(maxFolds)]
        
        self.data, self.target = self.proc_custom(target_fold)
        
        
        
    def get_normalize_params(self, target):
        all_signal = []
        for a in target:
            all_signal += list(a[..., 0])
        self.mean, self.std = np.array(all_signal).mean(), np.array(all_signal).std()
    
    def get_spatial_norm_parameters(self, arr_of_days):
        """"minmax normalization"""
        latmin = 10e10
        latmax = -10e10
        longmin = 10e10
        longmax = -10e10
        for arr in arr_of_days:
            minned = arr.min(0)
            # print(minned[0])
            if minned[2] < latmin:
                latmin = minned[2]
            if minned[3] < longmin:
                longmin = minned[3]
                
            maxed = arr.max(0)
            # print(maxed)
            if maxed[2] > latmax:
                latmax = maxed[2]
            if maxed[3] > longmax:
                longmax = maxed[3]
        self.latmin, self.latmax, self.longmin, self.longmax =latmin, latmax, longmin, longmax
    
    def make_data_by_time(self, arr_of_days, t_in = 9, reverse=False, day = 0):
        seg_by_time = []
        uniq_times = np.unique(arr_of_days[..., 1])
        
        for t in uniq_times:
            idx_ = arr_of_days[..., 1] == t
            seg_by_time.append(arr_of_days[idx_])
        
        in_ = []
        out_ = []
        for i in range(len(seg_by_time) - t_in):
            temp_in = []
            for t_ in range(t_in):
                temp_in.append(seg_by_time[i + t_])
            
            # normalize time to relative scale by the last one of the encoder
            in_cand = np.copy(np.concatenate(temp_in, axis=0))
            out_cand = np.copy(seg_by_time[i + t_in])
            
            last_enc_t = in_cand[..., 1][-1]
            in_cand[..., 1] -= last_enc_t
            out_cand[..., 1] -= last_enc_t
            in_.append(in_cand)
            out_.append(out_cand)
            self.day_record.append(day)
            
            # reverse it
            if reverse:
                out_cand = np.copy(seg_by_time[i])
                temp_in = temp_in[1:]
                temp_in.append(seg_by_time[i + t_in])
                in_cand = np.copy(np.concatenate(temp_in, axis=0))
                last_enc_t = in_cand[..., 1][0]
                in_cand[..., 1] -= last_enc_t
                out_cand[..., 1] -= last_enc_t
                
                in_.append(in_cand)
                out_.append(out_cand)
                self.day_record.append(day)
        
        
        return in_, out_
    
    def proc_custom(self, fold):
        
        self.day_record = []
        
        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}
        
        for day in range(self.nTrainStartDay, self.nTestStartDay):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            train_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            train_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day
            
            seg_in, seg_out = self.make_data_by_time(train_in, day = day)
            
            train_data['input'] += seg_in
            train_data['target'] += seg_out            
        
        
        
        seg_in, seg_out = self.make_data_by_time(train_out)
        train_data['input'] += seg_in
        train_data['target'] += seg_out
            
        
        for day in range(self.nTestStartDay, self.nTotalDays+1):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            test_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            test_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day

            seg_in, seg_out = self.make_data_by_time(test_in, reverse = False)
            
            test_data['input'] += seg_in
            test_data['target'] += seg_out            
            
        seg_in, seg_out = self.make_data_by_time(test_out, reverse = False)
        test_data['input'] += seg_in
        test_data['target'] += seg_out

        self.get_normalize_params(train_data['target']) 
        self.get_spatial_norm_parameters(train_data['target'])
            
        if self.train:
            data = train_data['input']
            target = train_data['target']
            print(len(data), len(target))
            
        else:
            data = test_data['input']
            target = test_data['target']
            print(len(data), len(target))

        return data, target        
    
    

    def process_np(self, fold, date):
        tmStart = datetime.datetime.now()
        train_input,train_output,test_input,test_output = self.return_data_time(fold=fold, data=date, with_scaling=True)
        return train_input,train_output,test_input,test_output
    
    def return_data_time(self, fold, data, with_scaling):
        train_input = None
        if 'A' in self.mode_t or 'B' in self.mode_t:
            for idx,dt in enumerate(data[:-1]):
                for suffix in self.train_suffix:
                    input = pd.read_csv(self.data_dir+'/'+dt+'_f'+str(fold)+'_'+suffix+'.csv')
                    # if self.temporal_scaling:
                    #     input.dateTime += idx * 24 * 60
                    train_input = pd.concat((train_input, input))
                    
        if 'C' in self.mode_t:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            # if self.temporal_scaling:
            #     input.dateTime += (len(data)-1) * 24 * 60
            train_input = pd.concat((train_input, input))

        test_input = pd.read_csv(self.data_dir+'/'+data[-1]+'_f'+str(fold)+'_test.csv')
        
        if 'C' in self.mode_p:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            test_input = pd.concat((input, test_input))
            
        # if self.temporal_scaling:
        #     test_input.dateTime += (len(data)-1) * 24 * 60

        return self.return_data_0(train_input, test_input, with_scaling)

    
    def return_data_0(self, train_input, test_input, with_scaling):
        train_output = np.array(train_input['pm2_5'])
        train_input = train_input[['dateTime','lat','long']]
        test_output = np.array(test_input['pm2_5'])
        test_input = test_input[['dateTime','lat','long']]

        # if with_scaling:
        #     scaler = MinMaxScaler().fit(train_input)
        #     if self.temporal_scaling:
        #         data = scaler.transform(pd.concat((train_input, test_input)))
        #         test_input = data[len(train_input):]
        #         train_input = data[:len(train_input)]
        #     else:
        #         train_input = scaler.transform(train_input)
        #         test_input = scaler.transform(test_input)
        return train_input,train_output,test_input,test_output

    def set_target_fold(self, fold=0):
        self.fold = fold
        print('target fold set to {}'.format(self.fold))
        
    def normalize_z(self, arr):
        return (arr - self.mean) / self.std
    
    def normalize(self, data, min_, max_):
        return (data - min_) / (max_ - min_)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om