In [1]:
import lmdb
import datetime
import argparse
import pandas as pd
import numpy as np
import random
import re
from collections import defaultdict
from typing import List, Tuple, Union

import scipy.io
import pickle
import numpy as np
import os
import h5py

import torch
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
from tqdm import tqdm

In [2]:
def to_tensor(array):
    return torch.from_numpy(array).float()

In [3]:
def random_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f'set seed {seed} is done')

In [4]:
KeyT = Union[str, bytes, bytearray]

_KEY_RE = re.compile(
   r"^S(?P<sub_id>\d{3})R\d{2}-\d+$"
)

# S012R04-21

def _decode_key(k: KeyT) -> str:
    if isinstance(k, (bytes, bytearray)):
        return k.decode("utf-8", errors="ignore")
    return k

def _extract_sub_id(k: KeyT) -> str:
    s = _decode_key(k)
    m = _KEY_RE.match(s)
    if m is None:
        raise ValueError(f"Key does not match expected patterns: {s}")
    return m.group("sub_id")


def train_test_split_by_fold_num(
    fold_num: int,
    lmdb_keys: List[KeyT],
    maxFold: int,
    split_by_sub: bool = True,
    seed: int = 41
) -> Tuple[List[KeyT], List[KeyT]]:
    """
    True k-fold cross-validation split.

    Args:
        fold_num: test fold index (0 <= fold_num < maxFold)
        lmdb_keys: LMDB key list
        maxFold: total number of folds (k)
        split_by_sub: True → subject-wise k-fold, False → key-wise k-fold

    Returns:
        train_key_list, test_key_list
    """
    if maxFold < 2:
        raise ValueError("maxFold must be >= 2.")
    if fold_num < 0 or fold_num >= maxFold:
        raise ValueError(f"fold_num must be in [0, {maxFold-1}]")

    keys = list(lmdb_keys)

    rng = np.random.default_rng(seed)

    if split_by_sub:
        # -------- subject-wise k-fold --------
        sub_to_keys = defaultdict(list)
        invalid = []

        for k in keys:
            try:
                sid = _extract_sub_id(k)
                sub_to_keys[sid].append(k)
            except ValueError:
                invalid.append(_decode_key(k))

        if invalid:
            ex = "\n".join(invalid[:10])
            raise ValueError(
                f"Found {len(invalid)} invalid keys. Examples:\n{ex}"
            )

        subjects = np.array(list(sub_to_keys.keys()), dtype=object)
        rng.shuffle(subjects)

        subj_folds = np.array_split(subjects, maxFold)
        test_subjects = set(subj_folds[fold_num].tolist())

        train_keys, test_keys = [], []
        for sid, ks in sub_to_keys.items():
            (test_keys if sid in test_subjects else train_keys).extend(ks)

        return train_keys, test_keys

    else:
        # -------- key-wise k-fold --------
        idx = np.arange(len(keys))
        rng.shuffle(idx)

        folds = np.array_split(idx, maxFold)
        test_idx = set(folds[fold_num].tolist())

        train_keys = [keys[i] for i in idx if i not in test_idx]
        test_keys  = [keys[i] for i in idx if i in test_idx]

        return train_keys, test_keys

In [5]:
LMDB = '/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID'

DB = lmdb.open(LMDB, readonly=True, lock=False, readahead=True, meminit=False)
with DB.begin(write=False) as txn:
    KEYS = pickle.loads(txn.get('__keys__'.encode()))

In [6]:
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb
# 아현썜 pscratch의 데이터 경로 당장은 그냥 써도 되지만 추후 내 pscratch나 m4727 등으로 옮겨서 사용할 것

class Physio_for_SOLID_from_lmdb(Dataset):
    def __init__(
            self,
            lmdb_dir: str,
            maxfold: int,
            targetfold: int,
            seed: int,
            train: bool,
            split_by_sub: bool,
    ):
        random_seed(seed)
        self.seed = seed
        self.lmdb_dir = lmdb_dir
        self.db = lmdb.open(lmdb_dir, readonly=True, lock=False, readahead=True, meminit=False)
        with self.db.begin(write=False) as txn:
            self.lmdb_keys = pickle.loads(txn.get('__keys__'.encode()))

        self.train = train
        self.split_by_sub = split_by_sub

        self.maxfold = maxfold
        self.targetfold = targetfold
        self.data, self.target, self.data_meta, self.target_meta = self.make_data_and_target_by_fold(self.targetfold, self.lmdb_keys, 
                                                                   self.maxfold, self.split_by_sub, self.seed)

    def make_data_and_target_by_fold(self, fold, lmdb_keys, maxfold, split_by_sub, seed):
        self.record = []

        train_data = {'input': [], 'target': [], 'input_meta': [], 'target_meta': []}
        test_data  = {'input': [], 'target': [], 'input_meta': [], 'target_meta': []}

        train_data_keys_in_lmdb, test_data_keys_in_lmdb = train_test_split_by_fold_num(fold, lmdb_keys, maxfold, split_by_sub, seed)


        if self.train:
            for train_data_key in train_data_keys_in_lmdb:

                # TODO : get proper seg_in and seg_out by input idx
                seg_in, seg_out, seg_in_meta, seg_out_meta = self.segmentation_from_idx(train_data_key, self.db)

                train_data['input'] += seg_in
                train_data['target'] += seg_out
                train_data['input_meta'] += seg_in_meta
                train_data['target_meta']+= seg_out_meta

            return (train_data['input'], train_data['target'],
                    train_data['input_meta'], train_data['target_meta'])
        
        else:
            for test_data_key in test_data_keys_in_lmdb:

                seg_in, seg_out, seg_in_meta, seg_out_meta = self.segmentation_from_idx(test_data_key, self.db)

                test_data['input'] += seg_in
                test_data['target'] += seg_out
                test_data['input_meta']  += seg_in_meta
                test_data['target_meta'] += seg_out_meta
                
            return (test_data['input'], test_data['target'],
                    test_data['input_meta'], test_data['target_meta'])



    def lmdb_get(self, env, key):
        if isinstance(key, str):
            key = key.encode("utf-8")
        with env.begin(write=False) as txn:
            v = txn.get(key)
        if v is None:
            raise KeyError(f"Key not found: {key}")
        return pickle.loads(v)

    def segmentation_from_idx(self, key, lmdb_db, in_len=99, out_len=1, stride=1):
        """
        eeg_data_ : (C, L)  e.g. (64, 800)

        sliding:
          input  = eeg[:, t:t+99]
          target = eeg[:, t+99:t+100]  (1 step)
        """
        sample_for_key = self.lmdb_get(lmdb_db, key)

        channel_name = sample_for_key['data_info']['channel_names']  # len=C
        eeg_data = sample_for_key['sample']                         # (C, T, Fs)
        eeg_data_ = rearrange(eeg_data, 'c t f -> c (t f)')         # (C, L)
        eeg_data_ = torch.from_numpy(eeg_data_).to(torch.float32)

        C, L = eeg_data_.shape
        total_needed = in_len + out_len

        if L < total_needed:
            return [], [], [], []

        seg_in_list = []
        seg_out_list = []
        seg_in_meta_list = []
        seg_out_meta_list = []

        for t in range(0, L - total_needed + 1, stride):
            x = eeg_data_[:, t : t + in_len]                         # (C, 99)
            y = eeg_data_[:, t + in_len : t + in_len + out_len]      # (C, 1)

            seg_in_list.append(x)
            seg_out_list.append(y)

            seg_in_meta_list.append(channel_name)
            seg_out_meta_list.append(channel_name)

        return seg_in_list, seg_out_list, seg_in_meta_list, seg_out_meta_list

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        i  = self.data[idx]
        o  = self.target[idx]
        im = self.data_meta[idx]
        om = self.target_meta[idx]
        return i, o, im, om

In [7]:
TORCHEEG_2DGRID = [
    ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'],
    ['-', '-', '-', '-', 'FP1', 'FPZ', 'FP2', '-', '-', '-', '-'],
    ['-', '-', 'AF7', '-', 'AF3', 'AFZ', 'AF4', '-', 'AF8', '-', '-'],
    ['F9', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'F10'],
    ['FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10'], 
    ['T9', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'T10'],
    ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10'], 
    ['P9', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'P10'],
    ['-', '-', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', '-', '-'],
    ['-', '-', '-', 'CB1', 'O1', 'OZ', 'O2', 'CB2', '-', '-', '-'],
    ['-', '-', '-', '-', '-', 'IZ', '-', '-', '-', '-', '-']
    ]

In [8]:
def build_channel_to_rc(grid_2d):
    ch2rc = {}
    H = len(grid_2d)
    W = len(grid_2d[0])
    for r in range(H):
        for c in range(W):
            ch = grid_2d[r][c]
            if ch != '-' and ch is not None:
                ch2rc[str(ch).strip().upper()] = (r, c)
    return ch2rc, H, W

CHANNEL_TO_RC, H, W = build_channel_to_rc(TORCHEEG_2DGRID)


In [9]:
def splat_eeg_grid(eeg_cl, channel_names, channel_to_rc=CHANNEL_TO_RC, H=H, W=W):
    """
    eeg_cl: (C, L) torch.Tensor (권장)
    channel_names: list[str] len=C
    returns:
      grid: (L, H, W)
      mask: (H, W)
    """
    if not torch.is_tensor(eeg_cl):
        eeg_cl = torch.as_tensor(eeg_cl)

    assert eeg_cl.dim() == 2, f"Expected (C,L), got {tuple(eeg_cl.shape)}"
    C, L = eeg_cl.shape
    device = eeg_cl.device

    grid = torch.zeros((L, H, W), dtype=eeg_cl.dtype, device=device)
    cnt  = torch.zeros((H, W), dtype=torch.float32, device=device)

    for ci in range(C):
        ch = str(channel_names[ci]).strip().upper()
        if ch in channel_to_rc:
            r, c = channel_to_rc[ch]
            grid[:, r, c] += eeg_cl[ci, :]
            cnt[r, c] += 1.0

    mask = (cnt > 0).float()
    grid = torch.where(cnt > 0, grid / torch.clamp(cnt, min=1.0), grid)
    return grid, mask


In [10]:
class EEGToGridCtx9(Dataset):
    """
    base[idx] -> (i, o, im, om)
      i : (C, 99)   (torch or numpy)
      o : (C, 1)
      im: channel list
      om: channel list

    return:
      x0       : (1,H,W)
      tgt_mask : (1,H,W)
      cond     : (20,H,W) = [lat_map, lon_map, past_grids(9), past_masks(9)]
      mean, std
    """
    def __init__(self, base_dataset, squash_tanh=True, channel_to_rc=CHANNEL_TO_RC):
        self.base = base_dataset
        self.squash = squash_tanh
        self.channel_to_rc = channel_to_rc

        self.mean = float(getattr(self.base, "mean", 0.0))
        self.std  = float(getattr(self.base, "std",  1.0))

        lat = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
        lon = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
        self.lat_map = lat
        self.lon_map = lon

        self.ctx_steps = 9
        self.in_len = 99
        assert self.in_len % self.ctx_steps == 0, "can not dividable"
        self.bin = self.in_len // self.ctx_steps  # 11

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        i, o, im, om = self.base[idx]  # i:(C,99), o:(C,1), im/om: channel list

        if not torch.is_tensor(i): i = torch.as_tensor(i)
        if not torch.is_tensor(o): o = torch.as_tensor(o)

        # ---- build past 9 step grids/masks from encoder ----
        past_grids, past_masks = [], []

        for k in range(self.ctx_steps):
            seg = i[:, k*self.bin:(k+1)*self.bin]      # (C,11)
            vals = seg.mean(dim=1, keepdim=True)       # (C,1)

            grid_k, mask_k = splat_eeg_grid(vals, im, self.channel_to_rc, H, W)  # grid:(1,H,W)
            grid_k = grid_k.squeeze(0)  # (H,W)

            if self.squash:
                grid_k = torch.tanh(grid_k)

            past_grids.append(grid_k)
            past_masks.append(mask_k)

        past_grids = torch.stack(past_grids, 0)  # (9,H,W)
        past_masks = torch.stack(past_masks, 0)  # (9,H,W)

        # ---- target (next slot) grid/mask ----
        tgt_grid, tgt_mask = splat_eeg_grid(o, om, self.channel_to_rc, H, W)  # tgt_grid:(1,H,W)
        tgt_grid = tgt_grid.squeeze(0)    # (H,W)

        x0 = torch.tanh(tgt_grid) if self.squash else tgt_grid

        # ---- cond channels ----
        cond = torch.cat([
            self.lat_map.unsqueeze(0),  # (1,H,W)
            self.lon_map.unsqueeze(0),  # (1,H,W)
            past_grids,                 # (9,H,W)
            past_masks                  # (9,H,W)
        ], dim=0)  # (20,H,W)

        return x0.unsqueeze(0), tgt_mask.unsqueeze(0), cond, self.mean, self.std


In [28]:
sample_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=True,
                                            split_by_sub=True)

set seed 41 is done


In [33]:
grid_ds = EEGToGridCtx9(sample_dataset)

x, mask, cond, mean, std = grid_ds[0]
print(x.shape, mask.shape, cond.shape, mean, std)

torch.Size([1, 11, 11]) torch.Size([1, 11, 11]) torch.Size([20, 11, 11]) 0.0 1.0


In [12]:
sample_train_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=True,
                                            split_by_sub=True)

sample_test_dataset = Physio_for_SOLID_from_lmdb(lmdb_dir='/pscratch/sd/t/tylee/Dataset/PhysioNet_200Hz_for_SOLID',
                                            maxfold = 5,
                                            targetfold=0,
                                            seed=41,
                                            train=False,
                                            split_by_sub=True)

train_grid_dataset = EEGToGridCtx9(sample_train_dataset)
test_grid_dataset = EEGToGridCtx9(sample_test_dataset)

set seed 41 is done
set seed 41 is done


In [13]:
train_loader = DataLoader(train_grid_dataset, batch_size=16, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_grid_dataset, batch_size=16, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

In [14]:
import os, math, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image, make_grid
from einops import rearrange
from tqdm import tqdm
import matplotlib.pyplot as plt

In [18]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RESULT_DIR = '/pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220'
os.makedirs(RESULT_DIR, exist_ok=True)

BATCH_SIZE = 16
LR = 2e-4
TIME_STEPS = 1000                  # diffusion T
TOTAL_STEPS = 150_000
LOG_EVERY  = 200
EVAL_EVERY = 1000
SAVE_SAMPLES_EVERY = 1000
BG_WEIGHT = 0   

In [19]:
# Model implementation from Kevin
print(len(train_grid_dataset))
print(len(test_grid_dataset))

# ============================================================
# 3) UNet (no attention; rectangular-friendly)
# ============================================================
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=0.0, groups=32):
        super().__init__()
        dim_out = dim if dim_out is None else dim_out
        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim else None
        self.norm1 = nn.GroupNorm(groups, dim);     self.conv1 = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm2 = nn.GroupNorm(groups, dim_out); self.conv2 = nn.Conv2d(dim_out, dim_out, 3, padding=1)
        self.dropout = nn.Dropout(dropout) if dropout else nn.Identity()
        self.act = nn.SiLU()
        self.res = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    def forward(self, x, t_emb=None):
        h = self.conv1(self.act(self.norm1(x)))
        if self.mlp is not None and t_emb is not None:
            h = h + self.mlp(t_emb)[..., None, None]
        h = self.conv2(self.dropout(self.act(self.norm2(h))))
        return h + self.res(x)


class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, base_dim):
        super().__init__()
        self.out_dim = base_dim
    def forward(self, t):  # t: (B,)
        # classic transformer-style PE on scalar t
        half = self.out_dim // 2
        device = t.device
        freqs = torch.exp(torch.arange(half, device=device).float()
                          * -(math.log(10000.0) / max(1, half-1)))
        ang = t.float().unsqueeze(1) * freqs.unsqueeze(0)  # (B, half)
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=1)  # (B, 2*half)
        if emb.shape[1] < self.out_dim:
            emb = F.pad(emb, (0, self.out_dim - emb.shape[1]))
        return emb

class TimeMLP(nn.Module):
    def __init__(self, base_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, base_dim*4), nn.SiLU(),
            nn.Linear(base_dim*4, base_dim*4)
        )
    def forward(self, t):  # (B,)
        return self.net(t[:,None].float())

class UNet(nn.Module):
    def __init__(self, base_dim=128, dim_mults=(1,2,4),
                 in_channels=1+20, image_size=(H,W), dropout=0.0, groups=32):
        super().__init__()
        self.image_h, self.image_w = image_size
        self.time_dim = base_dim * 4

        # self.time_pe  = SinusoidalTimeEmbedding(base_dim)
        # self.time_mlp = nn.Sequential(
        #     nn.Linear(base_dim, self.time_dim),
        #     nn.SiLU(),
        #     nn.Linear(self.time_dim, self.time_dim)
        # )
        self.time_mlp = TimeMLP(base_dim)
        self.init = nn.Conv2d(in_channels, base_dim, 3, padding=1)

        self.downs = nn.ModuleList()
        in_ch = base_dim
        skip_channels = []
        for li, m in enumerate(dim_mults):
            out_ch = base_dim * m
            rb1 = ResnetBlock(in_ch, out_ch, self.time_dim, dropout, groups); self.downs.append(rb1); in_ch = out_ch; skip_channels.append(in_ch)
            rb2 = ResnetBlock(in_ch, out_ch, self.time_dim, dropout, groups); self.downs.append(rb2); in_ch = out_ch; skip_channels.append(in_ch)
            if li != len(dim_mults) - 1:
                self.downs.append(nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1))

        self.mid1 = ResnetBlock(in_ch, in_ch, self.time_dim, dropout, groups)
        self.mid2 = ResnetBlock(in_ch, in_ch, self.time_dim, dropout, groups)

        self.ups, self.kinds = nn.ModuleList(), []
        sc = skip_channels.copy()
        for li, m in enumerate(reversed(dim_mults)):
            out_ch = base_dim * m
            for _ in range(2):
                skip_ch = sc.pop()
                self.ups.append(ResnetBlock(in_ch + skip_ch, out_ch, self.time_dim, dropout, groups)); self.kinds.append('res')
                in_ch = out_ch
            if li != len(dim_mults) - 1:
                self.ups.append(nn.Upsample(scale_factor=2, mode='nearest')); self.kinds.append('up')
                self.ups.append(nn.Conv2d(in_ch, in_ch, 3, padding=1));       self.kinds.append('conv')

        self.final = nn.Sequential(nn.GroupNorm(groups, in_ch), nn.SiLU(), nn.Conv2d(in_ch, 1, 3, padding=1))

    def forward(self, x_cat, t):
        # t_emb = self.time_mlp(self.time_pe(t))
        t_emb = self.time_mlp(t)
        skips, h = [], self.init(x_cat)
        for layer in self.downs:
            if isinstance(layer, ResnetBlock):
                h = layer(h, t_emb); skips.append(h)
            else:
                h = layer(h)
        h = self.mid1(h, t_emb); h = self.mid2(h, t_emb)
        for kind, layer in zip(self.kinds, self.ups):
            if kind == 'res':
                s = skips.pop()
                if s.shape[-2:] != h.shape[-2:]:
                    s = F.interpolate(s, size=h.shape[-2:], mode='nearest')
                h = layer(torch.cat([h, s], dim=1), t_emb)
            elif kind == 'up':
                h = layer(h)
            else:
                h = layer(h)
        if h.shape[-2:] != (self.image_h, self.image_w):
            h = F.interpolate(h, size=(self.image_h, self.image_w), mode='nearest')
        return self.final(h)

# ============================================================
# 4) Diffusion core — noise only target channel; cond is clean
# ============================================================
class GaussianDiffusion(nn.Module):
    def __init__(self, unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2'):
        super().__init__()
        self.unet = unet
        self.H, self.W = image_size
        self.T = time_steps
        self.loss_type = loss_type

        beta  = self.linear_beta_schedule(time_steps)
        alpha = 1. - beta
        abar  = torch.cumprod(alpha, dim=0)
        abar_prev = F.pad(abar[:-1], (1,0), value=1.)

        self.register_buffer('beta', beta)
        self.register_buffer('alpha', alpha)
        self.register_buffer('alpha_bar', abar)
        self.register_buffer('alpha_bar_prev', abar_prev)
        self.register_buffer('sqrt_alpha_bar', torch.sqrt(abar))
        self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - abar))
        self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / abar))
        self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / abar - 1))
        self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha))
        self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - abar))

    def linear_beta_schedule(self, T):
        scale = 1000 / T
        return torch.linspace(scale*1e-4, scale*2e-2, T, dtype=torch.float32)

    def q_sample(self, x0, t, noise):
        return self.sqrt_alpha_bar[t][:,None,None,None] * x0 + \
               self.sqrt_one_minus_alpha_bar[t][:,None,None,None] * noise

    def forward(self, x0, mask, cond):
        """
        x0:   (B,1,H,W) in tanh(z) space
        mask: (B,1,H,W)  (1=observed bin in target; 0=unobserved)
        cond: (B,20,H,W) = [lat, lon, past_grids(9), past_masks(9)]
        """
        b = x0.size(0)
        t = torch.randint(0, self.T, (b,), device=x0.device).long()

        noise = torch.randn_like(x0)
        x_t   = self.q_sample(x0, t, noise)
        x_cat = torch.cat([x_t, cond], dim=1)  # noised target + clean cond

        pred = self.unet(x_cat, t)  # predict noise on target channel

        if self.loss_type == 'l1':
            raw = F.l1_loss(noise, pred, reduction='none')
        elif self.loss_type == 'l2':
            raw = F.mse_loss(noise, pred, reduction='none')
        else:
            raw = F.smooth_l1_loss(noise, pred, reduction='none')

        w = mask + BG_WEIGHT  # supervise observed bins + tiny everywhere
        return (raw * w).sum() / (w.sum() + 1e-8)

    @torch.inference_mode()
    def p_sample(self, xt, cond, t, clip=True):
        bt = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
        x_cat = torch.cat([xt, cond], dim=1)
        pred_noise = self.unet(x_cat, bt)

        def bcast(x): return x.view(-1,1,1,1)
        if clip:
            x0 = bcast(self.sqrt_recip_alpha_bar[bt]) * xt - bcast(self.sqrt_recip_alpha_bar_min_1[bt]) * pred_noise
            x0 = x0.clamp(-1., 1.)
            c1 = self.beta[bt] * torch.sqrt(self.alpha_bar_prev[bt]) / (1. - self.alpha_bar[bt])
            c2 = torch.sqrt(self.alpha[bt]) * (1. - self.alpha_bar_prev[bt]) / (1. - self.alpha[bt])
            mean = bcast(c1) * x0 + bcast(c2) * xt
        else:
            mean = bcast(self.sqrt_recip_alpha[bt]) * (xt - bcast(self.beta_over_sqrt_one_minus_alpha_bar[bt]) * pred_noise)
        var = self.beta[bt] * ((1. - self.alpha_bar_prev[bt]) / (1. - self.alpha_bar[bt]))
        noise = torch.randn_like(xt) if t > 0 else 0.
        return mean + torch.sqrt(bcast(var)) * noise

    @torch.inference_mode()
    def sample(self, cond, clip=False):  # clip=False often gives crisper fields
        b = cond.size(0)
        x = torch.randn((b,1,self.H,self.W), device=cond.device)
        for t in reversed(range(self.T)):
            x = self.p_sample(x, cond, t, clip=clip)
        return x.clamp(-1, 1)

5490933
1404804


In [20]:
# ============================================================
# 6) Build model + diffusion + optimizer
# ============================================================
IN_CHANNELS = 1 + 2 + 9 + 9   # target(noised) + lat/lon + 9 past grids + 9 past masks = 21
unet = UNet(base_dim=128, dim_mults=(1,2,4), in_channels=IN_CHANNELS, image_size=(H,W)).to(DEVICE)
diffusion = GaussianDiffusion(unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2').to(DEVICE)

# ---- CosineAnnealingWarmupRestarts setup ----
from torch.optim import AdamW
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts

max_lr = 4e-4
min_lr = 8e-6
TOTAL_ITERS = TOTAL_STEPS          # keep these tied
warmup_steps = max(1, int(0.1 * TOTAL_ITERS))
weight_decay = 1e-4

opt = AdamW(diffusion.parameters(), lr=max_lr, betas=(0.9, 0.999), weight_decay=weight_decay)

sched = CosineAnnealingWarmupRestarts(
    optimizer=opt,
    first_cycle_steps=TOTAL_ITERS,  # single full-length cosine cycle
    max_lr=max_lr,
    min_lr=min_lr,
    warmup_steps=warmup_steps,
    gamma=1.0                       # no decay across cycles since we use one cycle
)

# ============================================================
# 7) Utilities: invert tanh->raw and metrics
# ============================================================
def inv_tanh_to_raw(x_tanh, mean, std):
    z = torch.atanh(x_tanh.clamp(-0.999, 0.999))
    return z * std + mean

@torch.no_grad()
def eval_rmse(diffusion, loader): # add mini batch to check fast
    mse_sum_raw = 0.0; w_sum = 0.0
    mse_sum_norm = 0.0
    for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)   # tanh(z)
        m_te   = m_te.to(DEVICE)
        c_te   = c_te.to(DEVICE)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        xhat = diffusion.sample(c_te, clip=False)             # tanh(z)
        # raw µg/m³
        raw_hat = inv_tanh_to_raw(xhat,  mu_te, std_te)
        raw_gt  = inv_tanh_to_raw(x0_te, mu_te, std_te)

        mse_sum_raw  += ((raw_hat - raw_gt)**2 * m_te).sum().item()
        w_sum        += m_te.sum().item()
        # normalized (z) space RMSE (mask)
        zhat = torch.atanh(xhat.clamp(-0.999, 0.999))
        zgt  = torch.atanh(x0_te.clamp(-0.999, 0.999))
        mse_sum_norm += ((zhat - zgt)**2 * m_te).sum().item()

    rmse_raw  = math.sqrt(mse_sum_raw / max(w_sum, 1e-8))
    rmse_norm = math.sqrt(mse_sum_norm / max(w_sum, 1e-8))
    return rmse_raw, rmse_norm, int(w_sum)

@torch.no_grad()
def eval_rmse_minibatch(diffusion, loader, max_batches=None): # add mini batch to check fast
    mse_sum_raw = 0.0; w_sum = 0.0
    mse_sum_norm = 0.0
    for i, batch in enumerate(loader):
        if (max_batches is not None) and (i >= max_batches):
            break
        x0_te, m_te, c_te, mu_te, std_te = batch
    # for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)   # tanh(z)
        m_te   = m_te.to(DEVICE)
        c_te   = c_te.to(DEVICE)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        xhat = diffusion.sample(c_te, clip=False)             # tanh(z)
        # raw µg/m³
        raw_hat = inv_tanh_to_raw(xhat,  mu_te, std_te)
        raw_gt  = inv_tanh_to_raw(x0_te, mu_te, std_te)

        mse_sum_raw  += ((raw_hat - raw_gt)**2 * m_te).sum().item()
        w_sum        += m_te.sum().item()
        # normalized (z) space RMSE (mask)
        zhat = torch.atanh(xhat.clamp(-0.999, 0.999))
        zgt  = torch.atanh(x0_te.clamp(-0.999, 0.999))
        mse_sum_norm += ((zhat - zgt)**2 * m_te).sum().item()

    rmse_raw  = math.sqrt(mse_sum_raw / max(w_sum, 1e-8))
    rmse_norm = math.sqrt(mse_sum_norm / max(w_sum, 1e-8))
    return rmse_raw, rmse_norm, int(w_sum)

import math

@torch.no_grad()
def _crps_from_ensemble(y_flat, samples_flat):
    """
    y_flat:        (N,) ground-truth vector (masked entries only later)
    samples_flat:  (K,N) ensemble samples
    returns:       (N,) CRPS per entry
    """
    K = samples_flat.shape[0]
    # term1 = E|X - y|  ≈ (1/K) Σ_i |x_i - y|
    term1 = (samples_flat - y_flat.unsqueeze(0)).abs().mean(dim=0)  # (N,)
    # term2 = 0.5 * E|X - X'| ≈ 0.5 * (1/K^2) Σ_ij |x_i - x_j|
    diffs = samples_flat.unsqueeze(0) - samples_flat.unsqueeze(1)   # (K,K,N)
    term2 = 0.5 * diffs.abs().mean(dim=(0,1))                      # (N,)
    return term1 - term2                                            # (N,)

@torch.no_grad()
def eval_crps_and_points(diffusion, loader, K=10, clip=False):
    """
    Returns masked dataset-averaged:
      CRPS_raw, CRPS_norm, MAE_raw, RMSE_raw, MAE_norm, RMSE_norm, n_obs_bins
    """
    crps_raw_sum = 0.0
    crps_norm_sum = 0.0
    mae_raw_sum = 0.0
    rmse_raw_sum = 0.0
    mae_norm_sum = 0.0
    rmse_norm_sum = 0.0
    w_sum = 0.0

    for x0_te, m_te, c_te, mu_te, std_te in loader:
        x0_te  = x0_te.to(DEVICE)          # (B,1,H,W), tanh(z)
        m_te   = m_te.to(DEVICE)           # (B,1,H,W) mask
        c_te   = c_te.to(DEVICE)           # (B,20,H,W)
        mu_te  = mu_te.to(DEVICE)[:,None,None,None]
        std_te = std_te.to(DEVICE)[:,None,None,None]

        B, _, H, W = x0_te.shape
        N = B*H*W
        mask_flat = m_te.view(N).bool()

        # K samples
        samples = []
        for _ in range(K):
            xhat = diffusion.sample(c_te, clip=clip)             # (B,1,H,W) tanh(z)
            samples.append(xhat)
        S = torch.stack(samples, dim=0)                          # (K,B,1,H,W)

        # normalized (z)
        z_gt  = torch.atanh(x0_te.clamp(-0.999, 0.999))          # (B,1,H,W)
        z_smp = torch.atanh(S.clamp(-0.999, 0.999))              # (K,B,1,H,W)

        # raw μg/m³
        raw_gt  = z_gt * std_te + mu_te                          # (B,1,H,W)
        raw_smp = z_smp * std_te + mu_te                         # (K,B,1,H,W)

        # flatten
        z_gt_f   = z_gt.view(N)
        raw_gt_f = raw_gt.view(N)
        z_smp_f  = z_smp.view(K, N)
        raw_smp_f= raw_smp.view(K, N)

        # CRPS (masked mean)
        crps_norm = _crps_from_ensemble(z_gt_f,   z_smp_f)[mask_flat].mean()
        crps_raw  = _crps_from_ensemble(raw_gt_f, raw_smp_f)[mask_flat].mean()
        crps_norm_sum += crps_norm.item() * mask_flat.sum().item()
        crps_raw_sum  += crps_raw.item()  * mask_flat.sum().item()

        # point forecast = ensemble mean
        z_mean   = z_smp.mean(dim=0).view(N)
        raw_mean = raw_smp.mean(dim=0).view(N)

        # MAE/RMSE (masked)
        mae_norm_sum  += (z_mean - z_gt_f).abs()[mask_flat].sum().item()
        rmse_norm_sum += ((z_mean - z_gt_f)**2)[mask_flat].sum().item()
        mae_raw_sum   += (raw_mean - raw_gt_f).abs()[mask_flat].sum().item()
        rmse_raw_sum  += ((raw_mean - raw_gt_f)**2)[mask_flat].sum().item()

        w_sum += mask_flat.sum().item()

    CRPS_raw  = crps_raw_sum  / max(w_sum, 1e-8)
    CRPS_norm = crps_norm_sum / max(w_sum, 1e-8)
    MAE_raw   = mae_raw_sum   / max(w_sum, 1e-8)
    MAE_norm  = mae_norm_sum  / max(w_sum, 1e-8)
    RMSE_raw  = math.sqrt(rmse_raw_sum  / max(w_sum, 1e-8))
    RMSE_norm = math.sqrt(rmse_norm_sum / max(w_sum, 1e-8))

    return dict(
        CRPS_raw=CRPS_raw, CRPS_norm=CRPS_norm,
        MAE_raw=MAE_raw, RMSE_raw=RMSE_raw,
        MAE_norm=MAE_norm, RMSE_norm=RMSE_norm,
        K=K, n_obs_bins=int(w_sum),
    )



# ============================================================
# 8) Training loop with periodic eval + checkpoint
# ============================================================
best_rmse = float('inf')
ckpt = os.path.join(RESULT_DIR, 'best_ctx9.pt')

pbar = tqdm(range(TOTAL_STEPS), desc="Train(ctx9)")
run_loss = 0.0
it = iter(train_loader)

for step in pbar:
    try:
        x0, msk, cond, mu, std = next(it)
    except StopIteration:
        it = iter(train_loader)
        x0, msk, cond, mu, std = next(it)

    x0   = x0.to(DEVICE, non_blocking=True)
    msk  = msk.to(DEVICE, non_blocking=True)
    cond = cond.to(DEVICE, non_blocking=True)

    opt.zero_grad(set_to_none=True)
    loss = diffusion(x0, msk, cond)
    loss.backward()
    nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)
    opt.step()
    sched.step()

    run_loss += loss.item()
    if (step+1) % LOG_EVERY == 0:
        avg = run_loss / LOG_EVERY
        run_loss = 0.0
        # grab LR robustly
        curr_lr = opt.param_groups[0]["lr"]
        pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{curr_lr:.2e}")


    if (step+1) % EVAL_EVERY == 0:
        diffusion.eval()
        with torch.inference_mode():
            # rmse_raw, rmse_norm, nobs = eval_rmse(diffusion, test_loader) # FIXME : original eval code
            rmse_raw, rmse_norm, nobs = eval_rmse_minibatch(diffusion, test_loader, max_batches=2)
            # m = eval_crps_and_points(diffusion, test_loader, K=10, clip=False)  # <-- CRPS (and friends)
        # print(f"\n[Step {step+1}] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} (over {nobs} bins)")
        print(f"\n[Step {step+1}] Test RMSE_raw={rmse_raw:.3f} µV (over {nobs} bins)")
        
        # print(f"\n[Step {step+1}] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} "
        #       f"(over {nobs} bins) | CRPS_raw={m['CRPS_raw']:.3f} µg/m³ (K={m['K']})")
        if rmse_raw < best_rmse:
            best_rmse = rmse_raw
            torch.save({'unet': unet.state_dict(),
                        'diff': diffusion.state_dict(),
                        'H': H, 'W': W}, ckpt)
            print(f"  >> Saved best checkpoint @ {ckpt} (RMSE_raw={best_rmse:.3f})")
        diffusion.train()

    if (step+1) % SAVE_SAMPLES_EVERY == 0:
        diffusion.eval()
        with torch.inference_mode():
            # grab a test batch, sample, and save plain grids (tanh -> [0,1] for viewing)
            xb, mb, cb, mub, stdb = next(iter(test_loader))
            xhat = diffusion.sample(cb.to(DEVICE)).cpu()
            vis = (xhat + 1.0) * 0.5
            save_image(vis, os.path.join(RESULT_DIR, f"samples_step{step+1}.png"), nrow=4)
        diffusion.train()

print("Done. Best Test RMSE_raw:", best_rmse)

# ============================================================
# 9) Inference + overlay: GT vs Pred (only observed bins)
# ============================================================
def overlay_panel(test_loader, model, save_path=os.path.join(RESULT_DIR, 'ctx9_overlay.png'),
                  back_img='./airdelhi_background.png'):
    try:
        back = plt.imread(back_img)
    except:
        back = None

    import matplotlib.colors as colors
    from matplotlib.colors import LinearSegmentedColormap
    cmap0 = LinearSegmentedColormap.from_list('', ['white', 'orange', 'red'])

    with torch.inference_mode():
        xb, mb, cb, mub, stdb = next(iter(test_loader))
        xhat = model.sample(cb.to(DEVICE), clip=False).cpu()  # tanh(z)
        raw_hat = inv_tanh_to_raw(xhat, mub[:,None,None,None], stdb[:,None,None,None]).squeeze(1)
        raw_gt  = inv_tanh_to_raw(xb,   mub[:,None,None,None], stdb[:,None,None,None]).squeeze(1)
        mb      = mb.squeeze(1)

    vmax = float(mub.mean() + 3*stdb.mean())
    lon_edges = np.linspace(0,1,W+1); lat_edges = np.linspace(0,1,H+1)

    B = min(8, raw_hat.size(0))
    fig, axes = plt.subplots(2, B, figsize=(3.4*B, 6.8))
    for i in range(B):
        for row, arr in enumerate([raw_gt[i].numpy(), raw_hat[i].numpy()]):
            ax = axes[row, i]
            if back is not None: ax.imshow(back, extent=[0,1,0,1], alpha=0.6)
            arr_plot = arr.copy()
            arr_plot[mb[i].numpy() == 0] = np.nan  # show only observed bins
            pm = ax.pcolormesh(lon_edges, lat_edges, arr_plot, cmap=cmap0,
                               norm=colors.Normalize(vmin=0, vmax=vmax))
            # draw grid
            for y in lat_edges: ax.plot([0,1],[y,y], c='k', lw=0.1)
            for x in lon_edges: ax.plot([x,x],[0,1], c='k', lw=0.1)
            ax.set_axis_off()
        axes[0, i].set_title("GT (obs bins)", fontsize=11)
        axes[1, i].set_title("Pred (obs bins)", fontsize=11)

    cbar = fig.colorbar(pm, ax=axes.ravel().tolist(), fraction=0.03, pad=0.02)
    cbar.ax.tick_params(labelsize=10)
    plt.tight_layout(); plt.savefig(save_path, dpi=140, bbox_inches='tight'); plt.close(fig)
    print("Saved overlays to:", save_path)

# ---- Run an overlay on current (or best-loaded) model ----
overlay_panel(test_loader, diffusion)

# ============================================================
# 10) Load best checkpoint & run full test-day eval again
# ============================================================
def load_and_eval(ckpt_path, test_loader):
    ck = torch.load(ckpt_path, map_location=DEVICE)
    unet = UNet(base_dim=128, dim_mults=(1,2,4), in_channels=IN_CHANNELS, image_size=(H,W)).to(DEVICE)
    unet.load_state_dict(ck['unet'])
    diff = GaussianDiffusion(unet, image_size=(H,W), time_steps=TIME_STEPS, loss_type='l2').to(DEVICE)
    diff.load_state_dict(ck['diff'])
    diff.eval()

    # rmse_raw, rmse_norm, nobs = eval_rmse(diff, test_loader) # FIXME : original eval code
    rmse_raw, rmse_norm, nobs = eval_rmse_minibatch(diff, test_loader, max_batches=2)
    # print(f"[BEST] Test RMSE_raw={rmse_raw:.3f} µg/m³ | RMSE_norm={rmse_norm:.3f} (over {nobs} bins)")
    print(f"[BEST] Test RMSE_raw={rmse_raw:.3f} µV (over {nobs} bins)")
    overlay_panel(test_loader, diff, save_path=os.path.join(RESULT_DIR, 'ctx9_overlay_best.png'))
    return diff

# Example (after training): 
diff_best = load_and_eval(ckpt,test_loader)

Train(ctx9):   1%|          | 998/150000 [01:47<4:31:11,  9.16it/s, loss=0.3803, lr=3.41e-05]


[Step 1000] Test RMSE_raw=5.135 µV (over 2048 bins)
  >> Saved best checkpoint @ /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/best_ctx9.pt (RMSE_raw=5.135)


Train(ctx9):   1%|▏         | 1999/150000 [03:56<4:13:21,  9.74it/s, loss=0.2231, lr=6.03e-05]  


[Step 2000] Test RMSE_raw=4.596 µV (over 2048 bins)
  >> Saved best checkpoint @ /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/best_ctx9.pt (RMSE_raw=4.596)


Train(ctx9):   2%|▏         | 2998/150000 [06:03<4:21:21,  9.37it/s, loss=0.1899, lr=8.64e-05]  


[Step 3000] Test RMSE_raw=4.689 µV (over 2048 bins)


Train(ctx9):   3%|▎         | 3999/150000 [08:12<4:20:47,  9.33it/s, loss=0.1496, lr=1.13e-04]  


[Step 4000] Test RMSE_raw=4.706 µV (over 2048 bins)


Train(ctx9):   3%|▎         | 4999/150000 [10:19<4:28:00,  9.02it/s, loss=0.1342, lr=1.39e-04]  


[Step 5000] Test RMSE_raw=4.567 µV (over 2048 bins)
  >> Saved best checkpoint @ /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/best_ctx9.pt (RMSE_raw=4.567)


Train(ctx9):   4%|▍         | 5999/150000 [12:26<4:01:46,  9.93it/s, loss=0.1223, lr=1.65e-04]  


[Step 6000] Test RMSE_raw=4.595 µV (over 2048 bins)


Train(ctx9):   5%|▍         | 6999/150000 [14:33<4:11:07,  9.49it/s, loss=0.1225, lr=1.91e-04]  


[Step 7000] Test RMSE_raw=4.642 µV (over 2048 bins)


Train(ctx9):   5%|▌         | 7999/150000 [16:41<4:13:15,  9.34it/s, loss=0.1177, lr=2.17e-04]  


[Step 8000] Test RMSE_raw=4.725 µV (over 2048 bins)


Train(ctx9):   6%|▌         | 8998/150000 [18:45<4:03:51,  9.64it/s, loss=0.1154, lr=2.43e-04]  


[Step 9000] Test RMSE_raw=4.768 µV (over 2048 bins)


Train(ctx9):   7%|▋         | 9999/150000 [20:54<4:14:31,  9.17it/s, loss=0.1141, lr=2.69e-04]  


[Step 10000] Test RMSE_raw=4.682 µV (over 2048 bins)


Train(ctx9):   7%|▋         | 10999/150000 [23:03<4:04:26,  9.48it/s, loss=0.1149, lr=2.95e-04]  


[Step 11000] Test RMSE_raw=4.667 µV (over 2048 bins)


Train(ctx9):   8%|▊         | 11999/150000 [25:07<3:57:11,  9.70it/s, loss=0.1109, lr=3.22e-04]  


[Step 12000] Test RMSE_raw=4.496 µV (over 2048 bins)
  >> Saved best checkpoint @ /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/best_ctx9.pt (RMSE_raw=4.496)


Train(ctx9):   9%|▊         | 12999/150000 [27:13<4:33:42,  8.34it/s, loss=0.1144, lr=3.48e-04]  


[Step 13000] Test RMSE_raw=4.705 µV (over 2048 bins)


Train(ctx9):   9%|▉         | 13999/150000 [29:18<3:53:02,  9.73it/s, loss=0.1069, lr=3.74e-04]  


[Step 14000] Test RMSE_raw=4.646 µV (over 2048 bins)


Train(ctx9):  10%|▉         | 14999/150000 [31:21<4:06:09,  9.14it/s, loss=0.1105, lr=4.00e-04]  


[Step 15000] Test RMSE_raw=4.555 µV (over 2048 bins)


Train(ctx9):  11%|█         | 15999/150000 [33:28<3:50:59,  9.67it/s, loss=0.1036, lr=4.00e-04]  


[Step 16000] Test RMSE_raw=4.682 µV (over 2048 bins)


Train(ctx9):  11%|█▏        | 16999/150000 [35:33<3:49:22,  9.66it/s, loss=0.1085, lr=4.00e-04]  


[Step 17000] Test RMSE_raw=4.607 µV (over 2048 bins)


Train(ctx9):  12%|█▏        | 17999/150000 [37:39<4:15:38,  8.61it/s, loss=0.1077, lr=4.00e-04]  


[Step 18000] Test RMSE_raw=4.728 µV (over 2048 bins)


Train(ctx9):  13%|█▎        | 18999/150000 [39:50<4:15:01,  8.56it/s, loss=0.1026, lr=3.99e-04]  


[Step 19000] Test RMSE_raw=4.664 µV (over 2048 bins)


Train(ctx9):  13%|█▎        | 19999/150000 [41:59<3:41:58,  9.76it/s, loss=0.1018, lr=3.99e-04]  


[Step 20000] Test RMSE_raw=4.800 µV (over 2048 bins)


Train(ctx9):  14%|█▍        | 20999/150000 [44:10<4:03:51,  8.82it/s, loss=0.1025, lr=3.98e-04]  


[Step 21000] Test RMSE_raw=4.665 µV (over 2048 bins)


Train(ctx9):  15%|█▍        | 21999/150000 [46:17<4:10:56,  8.50it/s, loss=0.0988, lr=3.97e-04]  


[Step 22000] Test RMSE_raw=4.732 µV (over 2048 bins)


Train(ctx9):  15%|█▌        | 22999/150000 [48:24<3:35:42,  9.81it/s, loss=0.1032, lr=3.97e-04]  


[Step 23000] Test RMSE_raw=4.664 µV (over 2048 bins)


Train(ctx9):  16%|█▌        | 23998/150000 [50:32<3:35:11,  9.76it/s, loss=0.0997, lr=3.96e-04]  


[Step 24000] Test RMSE_raw=4.831 µV (over 2048 bins)


Train(ctx9):  17%|█▋        | 24998/150000 [52:42<3:49:23,  9.08it/s, loss=0.1021, lr=3.95e-04]  


[Step 25000] Test RMSE_raw=4.735 µV (over 2048 bins)


Train(ctx9):  17%|█▋        | 25998/150000 [54:54<3:42:02,  9.31it/s, loss=0.1023, lr=3.94e-04]  


[Step 26000] Test RMSE_raw=4.824 µV (over 2048 bins)


Train(ctx9):  18%|█▊        | 26998/150000 [57:02<3:34:03,  9.58it/s, loss=0.0948, lr=3.92e-04]  


[Step 27000] Test RMSE_raw=4.742 µV (over 2048 bins)


Train(ctx9):  19%|█▊        | 27998/150000 [59:11<3:38:17,  9.31it/s, loss=0.0994, lr=3.91e-04]  


[Step 28000] Test RMSE_raw=4.863 µV (over 2048 bins)


Train(ctx9):  19%|█▉        | 28998/150000 [1:01:20<3:45:31,  8.94it/s, loss=0.1051, lr=3.90e-04]


[Step 29000] Test RMSE_raw=4.660 µV (over 2048 bins)


Train(ctx9):  20%|█▉        | 29999/150000 [1:03:26<3:34:36,  9.32it/s, loss=0.0974, lr=3.88e-04]  


[Step 30000] Test RMSE_raw=4.900 µV (over 2048 bins)


Train(ctx9):  21%|██        | 30999/150000 [1:05:34<3:31:41,  9.37it/s, loss=0.0982, lr=3.87e-04]  


[Step 31000] Test RMSE_raw=4.699 µV (over 2048 bins)


Train(ctx9):  21%|██▏       | 31998/150000 [1:07:42<3:27:32,  9.48it/s, loss=0.0969, lr=3.85e-04]  


[Step 32000] Test RMSE_raw=4.840 µV (over 2048 bins)


Train(ctx9):  22%|██▏       | 32998/150000 [1:09:47<3:20:06,  9.74it/s, loss=0.0972, lr=3.83e-04]  


[Step 33000] Test RMSE_raw=4.672 µV (over 2048 bins)


Train(ctx9):  23%|██▎       | 33999/150000 [1:11:55<3:30:09,  9.20it/s, loss=0.1015, lr=3.81e-04]  


[Step 34000] Test RMSE_raw=4.689 µV (over 2048 bins)


Train(ctx9):  23%|██▎       | 34998/150000 [1:14:02<3:19:42,  9.60it/s, loss=0.0997, lr=3.79e-04]  


[Step 35000] Test RMSE_raw=4.742 µV (over 2048 bins)


Train(ctx9):  24%|██▍       | 35998/150000 [1:16:07<3:13:16,  9.83it/s, loss=0.0967, lr=3.77e-04]  


[Step 36000] Test RMSE_raw=4.812 µV (over 2048 bins)


Train(ctx9):  25%|██▍       | 36998/150000 [1:18:16<3:21:10,  9.36it/s, loss=0.0949, lr=3.75e-04]  


[Step 37000] Test RMSE_raw=4.850 µV (over 2048 bins)


Train(ctx9):  25%|██▌       | 37999/150000 [1:20:24<3:16:28,  9.50it/s, loss=0.0990, lr=3.73e-04]  


[Step 38000] Test RMSE_raw=4.729 µV (over 2048 bins)


Train(ctx9):  26%|██▌       | 38998/150000 [1:22:30<3:36:32,  8.54it/s, loss=0.0959, lr=3.70e-04]  


[Step 39000] Test RMSE_raw=4.818 µV (over 2048 bins)


Train(ctx9):  27%|██▋       | 39998/150000 [1:24:40<3:16:08,  9.35it/s, loss=0.0980, lr=3.68e-04]  


[Step 40000] Test RMSE_raw=4.804 µV (over 2048 bins)


Train(ctx9):  27%|██▋       | 40998/150000 [1:26:47<3:19:46,  9.09it/s, loss=0.0968, lr=3.65e-04] 


[Step 41000] Test RMSE_raw=4.766 µV (over 2048 bins)


Train(ctx9):  28%|██▊       | 41999/150000 [1:28:56<3:31:41,  8.50it/s, loss=0.0958, lr=3.63e-04] 


[Step 42000] Test RMSE_raw=4.737 µV (over 2048 bins)


Train(ctx9):  29%|██▊       | 42998/150000 [1:31:03<3:04:45,  9.65it/s, loss=0.0931, lr=3.60e-04]  


[Step 43000] Test RMSE_raw=4.718 µV (over 2048 bins)


Train(ctx9):  29%|██▉       | 43998/150000 [1:33:09<3:02:15,  9.69it/s, loss=0.0925, lr=3.57e-04] 


[Step 44000] Test RMSE_raw=4.616 µV (over 2048 bins)


Train(ctx9):  30%|██▉       | 44998/150000 [1:35:16<3:02:56,  9.57it/s, loss=0.0913, lr=3.54e-04] 


[Step 45000] Test RMSE_raw=4.675 µV (over 2048 bins)


Train(ctx9):  31%|███       | 45998/150000 [1:37:21<3:13:22,  8.96it/s, loss=0.0980, lr=3.51e-04] 


[Step 46000] Test RMSE_raw=4.563 µV (over 2048 bins)


Train(ctx9):  31%|███▏      | 46998/150000 [1:39:26<3:16:15,  8.75it/s, loss=0.0977, lr=3.48e-04] 


[Step 47000] Test RMSE_raw=4.677 µV (over 2048 bins)


Train(ctx9):  32%|███▏      | 47998/150000 [1:41:32<2:53:35,  9.79it/s, loss=0.0962, lr=3.45e-04] 


[Step 48000] Test RMSE_raw=4.636 µV (over 2048 bins)


Train(ctx9):  33%|███▎      | 48999/150000 [1:43:40<3:04:44,  9.11it/s, loss=0.0938, lr=3.42e-04] 


[Step 49000] Test RMSE_raw=4.588 µV (over 2048 bins)


Train(ctx9):  33%|███▎      | 49999/150000 [1:45:45<2:54:26,  9.55it/s, loss=0.0915, lr=3.39e-04]  


[Step 50000] Test RMSE_raw=4.681 µV (over 2048 bins)


Train(ctx9):  34%|███▍      | 50999/150000 [1:47:51<2:49:59,  9.71it/s, loss=0.0980, lr=3.35e-04]  


[Step 51000] Test RMSE_raw=4.717 µV (over 2048 bins)


Train(ctx9):  35%|███▍      | 51999/150000 [1:49:56<2:56:36,  9.25it/s, loss=0.0944, lr=3.32e-04] 


[Step 52000] Test RMSE_raw=4.648 µV (over 2048 bins)


Train(ctx9):  35%|███▌      | 52998/150000 [1:52:00<2:44:38,  9.82it/s, loss=0.0943, lr=3.28e-04] 


[Step 53000] Test RMSE_raw=4.882 µV (over 2048 bins)


Train(ctx9):  36%|███▌      | 53998/150000 [1:54:06<2:49:17,  9.45it/s, loss=0.0900, lr=3.25e-04] 


[Step 54000] Test RMSE_raw=4.745 µV (over 2048 bins)


Train(ctx9):  37%|███▋      | 54998/150000 [1:56:12<2:50:41,  9.28it/s, loss=0.0930, lr=3.21e-04] 


[Step 55000] Test RMSE_raw=4.698 µV (over 2048 bins)


Train(ctx9):  37%|███▋      | 55998/150000 [1:58:17<2:55:27,  8.93it/s, loss=0.0944, lr=3.17e-04] 


[Step 56000] Test RMSE_raw=4.813 µV (over 2048 bins)


Train(ctx9):  38%|███▊      | 56998/150000 [2:00:23<2:49:29,  9.15it/s, loss=0.0932, lr=3.14e-04] 


[Step 57000] Test RMSE_raw=4.724 µV (over 2048 bins)


Train(ctx9):  39%|███▊      | 57998/150000 [2:02:29<2:37:14,  9.75it/s, loss=0.0932, lr=3.10e-04] 


[Step 58000] Test RMSE_raw=4.823 µV (over 2048 bins)


Train(ctx9):  39%|███▉      | 58998/150000 [2:04:45<2:51:23,  8.85it/s, loss=0.0956, lr=3.06e-04] 


[Step 59000] Test RMSE_raw=4.779 µV (over 2048 bins)


Train(ctx9):  40%|███▉      | 59999/150000 [2:07:00<2:37:08,  9.55it/s, loss=0.0909, lr=3.02e-04] 


[Step 60000] Test RMSE_raw=4.649 µV (over 2048 bins)


Train(ctx9):  41%|████      | 60999/150000 [2:09:06<2:50:20,  8.71it/s, loss=0.0935, lr=2.98e-04] 


[Step 61000] Test RMSE_raw=4.542 µV (over 2048 bins)


Train(ctx9):  41%|████▏     | 61999/150000 [2:11:18<2:34:10,  9.51it/s, loss=0.0915, lr=2.94e-04] 


[Step 62000] Test RMSE_raw=4.776 µV (over 2048 bins)


Train(ctx9):  42%|████▏     | 62999/150000 [2:13:22<2:38:15,  9.16it/s, loss=0.0904, lr=2.90e-04] 


[Step 63000] Test RMSE_raw=4.767 µV (over 2048 bins)


Train(ctx9):  43%|████▎     | 63999/150000 [2:15:27<2:34:56,  9.25it/s, loss=0.0896, lr=2.86e-04] 


[Step 64000] Test RMSE_raw=4.534 µV (over 2048 bins)


Train(ctx9):  43%|████▎     | 64999/150000 [2:17:32<2:26:07,  9.69it/s, loss=0.0966, lr=2.82e-04] 


[Step 65000] Test RMSE_raw=4.664 µV (over 2048 bins)


Train(ctx9):  44%|████▍     | 65999/150000 [2:19:37<2:31:55,  9.22it/s, loss=0.0893, lr=2.77e-04] 


[Step 66000] Test RMSE_raw=4.716 µV (over 2048 bins)


Train(ctx9):  45%|████▍     | 66999/150000 [2:21:41<2:21:08,  9.80it/s, loss=0.0906, lr=2.73e-04] 


[Step 67000] Test RMSE_raw=4.569 µV (over 2048 bins)


Train(ctx9):  45%|████▌     | 67999/150000 [2:23:44<2:20:02,  9.76it/s, loss=0.0865, lr=2.69e-04] 


[Step 68000] Test RMSE_raw=4.758 µV (over 2048 bins)


Train(ctx9):  46%|████▌     | 68999/150000 [2:25:48<2:18:54,  9.72it/s, loss=0.0929, lr=2.65e-04] 


[Step 69000] Test RMSE_raw=4.683 µV (over 2048 bins)


Train(ctx9):  47%|████▋     | 69999/150000 [2:27:52<2:14:33,  9.91it/s, loss=0.0873, lr=2.60e-04] 


[Step 70000] Test RMSE_raw=4.743 µV (over 2048 bins)


Train(ctx9):  47%|████▋     | 70999/150000 [2:29:59<2:21:23,  9.31it/s, loss=0.0903, lr=2.56e-04] 


[Step 71000] Test RMSE_raw=4.635 µV (over 2048 bins)


Train(ctx9):  48%|████▊     | 71999/150000 [2:32:04<2:12:56,  9.78it/s, loss=0.0889, lr=2.51e-04] 


[Step 72000] Test RMSE_raw=4.702 µV (over 2048 bins)


Train(ctx9):  49%|████▊     | 72999/150000 [2:34:08<2:10:31,  9.83it/s, loss=0.0922, lr=2.47e-04] 


[Step 73000] Test RMSE_raw=4.793 µV (over 2048 bins)


Train(ctx9):  49%|████▉     | 73998/150000 [2:36:13<2:08:22,  9.87it/s, loss=0.0916, lr=2.43e-04] 


[Step 74000] Test RMSE_raw=4.693 µV (over 2048 bins)


Train(ctx9):  50%|████▉     | 74999/150000 [2:38:18<2:10:00,  9.62it/s, loss=0.0924, lr=2.38e-04] 


[Step 75000] Test RMSE_raw=4.744 µV (over 2048 bins)


Train(ctx9):  51%|█████     | 75999/150000 [2:40:21<2:15:44,  9.09it/s, loss=0.0861, lr=2.34e-04] 


[Step 76000] Test RMSE_raw=4.728 µV (over 2048 bins)


Train(ctx9):  51%|█████▏    | 76998/150000 [2:42:26<2:06:02,  9.65it/s, loss=0.0929, lr=2.29e-04] 


[Step 77000] Test RMSE_raw=4.648 µV (over 2048 bins)


Train(ctx9):  52%|█████▏    | 77998/150000 [2:44:30<2:06:31,  9.48it/s, loss=0.0920, lr=2.24e-04] 


[Step 78000] Test RMSE_raw=4.694 µV (over 2048 bins)


Train(ctx9):  53%|█████▎    | 78999/150000 [2:46:34<2:06:29,  9.36it/s, loss=0.0890, lr=2.20e-04] 


[Step 79000] Test RMSE_raw=4.692 µV (over 2048 bins)


Train(ctx9):  53%|█████▎    | 79999/150000 [2:48:37<1:58:10,  9.87it/s, loss=0.0869, lr=2.15e-04] 


[Step 80000] Test RMSE_raw=4.627 µV (over 2048 bins)


Train(ctx9):  54%|█████▍    | 80999/150000 [2:50:40<1:55:50,  9.93it/s, loss=0.0931, lr=2.11e-04] 


[Step 81000] Test RMSE_raw=4.778 µV (over 2048 bins)


Train(ctx9):  55%|█████▍    | 81999/150000 [2:52:44<1:53:44,  9.96it/s, loss=0.0873, lr=2.06e-04] 


[Step 82000] Test RMSE_raw=4.692 µV (over 2048 bins)


Train(ctx9):  55%|█████▌    | 82999/150000 [2:54:46<1:56:49,  9.56it/s, loss=0.0930, lr=2.02e-04] 


[Step 83000] Test RMSE_raw=4.665 µV (over 2048 bins)


Train(ctx9):  56%|█████▌    | 83999/150000 [2:56:49<1:53:28,  9.69it/s, loss=0.0851, lr=1.97e-04] 


[Step 84000] Test RMSE_raw=4.649 µV (over 2048 bins)


Train(ctx9):  57%|█████▋    | 84999/150000 [2:58:52<1:50:19,  9.82it/s, loss=0.0884, lr=1.93e-04] 


[Step 85000] Test RMSE_raw=4.540 µV (over 2048 bins)


Train(ctx9):  57%|█████▋    | 85999/150000 [3:00:54<1:48:15,  9.85it/s, loss=0.0883, lr=1.88e-04] 


[Step 86000] Test RMSE_raw=4.555 µV (over 2048 bins)


Train(ctx9):  58%|█████▊    | 86999/150000 [3:02:58<1:47:35,  9.76it/s, loss=0.0936, lr=1.84e-04] 


[Step 87000] Test RMSE_raw=4.693 µV (over 2048 bins)


Train(ctx9):  59%|█████▊    | 87999/150000 [3:05:02<1:49:57,  9.40it/s, loss=0.0906, lr=1.79e-04] 


[Step 88000] Test RMSE_raw=4.702 µV (over 2048 bins)


Train(ctx9):  59%|█████▉    | 88999/150000 [3:07:04<1:42:51,  9.88it/s, loss=0.0905, lr=1.74e-04] 


[Step 89000] Test RMSE_raw=4.680 µV (over 2048 bins)


Train(ctx9):  60%|█████▉    | 89999/150000 [3:09:08<1:42:57,  9.71it/s, loss=0.0917, lr=1.70e-04] 


[Step 90000] Test RMSE_raw=4.631 µV (over 2048 bins)


Train(ctx9):  61%|██████    | 90999/150000 [3:11:11<1:44:05,  9.45it/s, loss=0.0870, lr=1.65e-04] 


[Step 91000] Test RMSE_raw=4.667 µV (over 2048 bins)


Train(ctx9):  61%|██████▏   | 91999/150000 [3:13:14<1:38:58,  9.77it/s, loss=0.0897, lr=1.61e-04] 


[Step 92000] Test RMSE_raw=4.689 µV (over 2048 bins)


Train(ctx9):  62%|██████▏   | 92999/150000 [3:15:17<1:40:01,  9.50it/s, loss=0.0907, lr=1.57e-04] 


[Step 93000] Test RMSE_raw=4.645 µV (over 2048 bins)


Train(ctx9):  63%|██████▎   | 93999/150000 [3:17:21<1:35:14,  9.80it/s, loss=0.0886, lr=1.52e-04] 


[Step 94000] Test RMSE_raw=4.776 µV (over 2048 bins)


Train(ctx9):  63%|██████▎   | 94999/150000 [3:19:24<1:38:47,  9.28it/s, loss=0.0898, lr=1.48e-04] 


[Step 95000] Test RMSE_raw=4.702 µV (over 2048 bins)


Train(ctx9):  64%|██████▍   | 95999/150000 [3:21:27<1:33:51,  9.59it/s, loss=0.0869, lr=1.43e-04] 


[Step 96000] Test RMSE_raw=4.690 µV (over 2048 bins)


Train(ctx9):  65%|██████▍   | 96999/150000 [3:23:31<1:29:46,  9.84it/s, loss=0.0908, lr=1.39e-04] 


[Step 97000] Test RMSE_raw=4.592 µV (over 2048 bins)


Train(ctx9):  65%|██████▌   | 97999/150000 [3:25:34<1:32:55,  9.33it/s, loss=0.0890, lr=1.35e-04] 


[Step 98000] Test RMSE_raw=4.693 µV (over 2048 bins)


Train(ctx9):  66%|██████▌   | 98999/150000 [3:27:37<1:26:16,  9.85it/s, loss=0.0898, lr=1.31e-04] 


[Step 99000] Test RMSE_raw=4.734 µV (over 2048 bins)


Train(ctx9):  67%|██████▋   | 99999/150000 [3:29:41<1:24:42,  9.84it/s, loss=0.0914, lr=1.26e-04] 


[Step 100000] Test RMSE_raw=4.729 µV (over 2048 bins)


Train(ctx9):  67%|██████▋   | 100999/150000 [3:31:45<1:25:47,  9.52it/s, loss=0.0885, lr=1.22e-04] 


[Step 101000] Test RMSE_raw=4.678 µV (over 2048 bins)


Train(ctx9):  68%|██████▊   | 101999/150000 [3:33:47<1:25:46,  9.33it/s, loss=0.0876, lr=1.18e-04] 


[Step 102000] Test RMSE_raw=4.716 µV (over 2048 bins)


Train(ctx9):  69%|██████▊   | 102999/150000 [3:35:52<1:19:48,  9.81it/s, loss=0.0912, lr=1.14e-04] 


[Step 103000] Test RMSE_raw=4.578 µV (over 2048 bins)


Train(ctx9):  69%|██████▉   | 103999/150000 [3:37:56<1:24:11,  9.11it/s, loss=0.0852, lr=1.10e-04] 


[Step 104000] Test RMSE_raw=4.613 µV (over 2048 bins)


Train(ctx9):  70%|██████▉   | 104999/150000 [3:40:00<1:16:27,  9.81it/s, loss=0.0888, lr=1.06e-04] 


[Step 105000] Test RMSE_raw=4.580 µV (over 2048 bins)


Train(ctx9):  71%|███████   | 105998/150000 [3:42:04<1:17:01,  9.52it/s, loss=0.0875, lr=1.02e-04] 


[Step 106000] Test RMSE_raw=4.607 µV (over 2048 bins)


Train(ctx9):  71%|███████▏  | 106999/150000 [3:44:08<1:16:02,  9.43it/s, loss=0.0842, lr=9.82e-05] 


[Step 107000] Test RMSE_raw=4.572 µV (over 2048 bins)


Train(ctx9):  72%|███████▏  | 107999/150000 [3:46:11<1:11:58,  9.73it/s, loss=0.0868, lr=9.44e-05] 


[Step 108000] Test RMSE_raw=4.723 µV (over 2048 bins)


Train(ctx9):  73%|███████▎  | 108998/150000 [3:48:15<1:10:23,  9.71it/s, loss=0.0870, lr=9.06e-05] 


[Step 109000] Test RMSE_raw=4.752 µV (over 2048 bins)


Train(ctx9):  73%|███████▎  | 109999/150000 [3:50:18<1:07:27,  9.88it/s, loss=0.0894, lr=8.70e-05] 


[Step 110000] Test RMSE_raw=4.591 µV (over 2048 bins)


Train(ctx9):  74%|███████▍  | 110999/150000 [3:52:20<1:10:40,  9.20it/s, loss=0.0861, lr=8.33e-05] 


[Step 111000] Test RMSE_raw=4.707 µV (over 2048 bins)


Train(ctx9):  75%|███████▍  | 111999/150000 [3:54:23<1:07:56,  9.32it/s, loss=0.0845, lr=7.98e-05] 


[Step 112000] Test RMSE_raw=4.734 µV (over 2048 bins)


Train(ctx9):  75%|███████▌  | 112999/150000 [3:56:27<1:04:22,  9.58it/s, loss=0.0858, lr=7.63e-05] 


[Step 113000] Test RMSE_raw=4.575 µV (over 2048 bins)


Train(ctx9):  76%|███████▌  | 113999/150000 [3:58:30<1:04:31,  9.30it/s, loss=0.0866, lr=7.29e-05] 


[Step 114000] Test RMSE_raw=4.689 µV (over 2048 bins)


Train(ctx9):  77%|███████▋  | 114999/150000 [4:00:33<59:24,  9.82it/s, loss=0.0846, lr=6.95e-05]   


[Step 115000] Test RMSE_raw=4.642 µV (over 2048 bins)


Train(ctx9):  77%|███████▋  | 115999/150000 [4:02:37<56:46,  9.98it/s, loss=0.0871, lr=6.62e-05]   


[Step 116000] Test RMSE_raw=4.607 µV (over 2048 bins)


Train(ctx9):  78%|███████▊  | 116999/150000 [4:04:40<56:16,  9.78it/s, loss=0.0912, lr=6.30e-05]   


[Step 117000] Test RMSE_raw=4.655 µV (over 2048 bins)


Train(ctx9):  79%|███████▊  | 117998/150000 [4:06:43<54:09,  9.85it/s, loss=0.0850, lr=5.99e-05]   


[Step 118000] Test RMSE_raw=4.673 µV (over 2048 bins)


Train(ctx9):  79%|███████▉  | 118999/150000 [4:08:46<52:19,  9.87it/s, loss=0.0863, lr=5.68e-05]   


[Step 119000] Test RMSE_raw=4.766 µV (over 2048 bins)


Train(ctx9):  80%|███████▉  | 119999/150000 [4:10:51<54:17,  9.21it/s, loss=0.0856, lr=5.39e-05]   


[Step 120000] Test RMSE_raw=4.584 µV (over 2048 bins)


Train(ctx9):  81%|████████  | 120999/150000 [4:12:53<49:51,  9.69it/s, loss=0.0826, lr=5.10e-05]   


[Step 121000] Test RMSE_raw=4.599 µV (over 2048 bins)


Train(ctx9):  81%|████████▏ | 121999/150000 [4:14:56<49:14,  9.48it/s, loss=0.0863, lr=4.82e-05]   


[Step 122000] Test RMSE_raw=4.688 µV (over 2048 bins)


Train(ctx9):  82%|████████▏ | 122999/150000 [4:16:59<45:49,  9.82it/s, loss=0.0845, lr=4.54e-05]   


[Step 123000] Test RMSE_raw=4.695 µV (over 2048 bins)


Train(ctx9):  83%|████████▎ | 123998/150000 [4:19:01<44:22,  9.77it/s, loss=0.0895, lr=4.28e-05]   


[Step 124000] Test RMSE_raw=4.586 µV (over 2048 bins)


Train(ctx9):  83%|████████▎ | 124998/150000 [4:21:05<42:09,  9.88it/s, loss=0.0862, lr=4.02e-05]   


[Step 125000] Test RMSE_raw=4.710 µV (over 2048 bins)


Train(ctx9):  84%|████████▍ | 125999/150000 [4:23:09<41:12,  9.71it/s, loss=0.0853, lr=3.78e-05]   


[Step 126000] Test RMSE_raw=4.648 µV (over 2048 bins)


Train(ctx9):  85%|████████▍ | 126999/150000 [4:25:11<39:25,  9.72it/s, loss=0.0886, lr=3.54e-05]   


[Step 127000] Test RMSE_raw=4.555 µV (over 2048 bins)


Train(ctx9):  85%|████████▌ | 127998/150000 [4:27:15<37:30,  9.78it/s, loss=0.0889, lr=3.31e-05]   


[Step 128000] Test RMSE_raw=4.847 µV (over 2048 bins)


Train(ctx9):  86%|████████▌ | 128998/150000 [4:29:19<35:24,  9.89it/s, loss=0.0843, lr=3.09e-05]   


[Step 129000] Test RMSE_raw=4.563 µV (over 2048 bins)


Train(ctx9):  87%|████████▋ | 129999/150000 [4:31:21<35:35,  9.37it/s, loss=0.0827, lr=2.88e-05]   


[Step 130000] Test RMSE_raw=4.765 µV (over 2048 bins)


Train(ctx9):  87%|████████▋ | 130999/150000 [4:33:25<32:23,  9.77it/s, loss=0.0814, lr=2.68e-05]   


[Step 131000] Test RMSE_raw=4.475 µV (over 2048 bins)
  >> Saved best checkpoint @ /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/best_ctx9.pt (RMSE_raw=4.475)


Train(ctx9):  88%|████████▊ | 131999/150000 [4:35:28<30:16,  9.91it/s, loss=0.0873, lr=2.49e-05]   


[Step 132000] Test RMSE_raw=4.671 µV (over 2048 bins)


Train(ctx9):  89%|████████▊ | 132999/150000 [4:37:31<30:15,  9.37it/s, loss=0.0835, lr=2.31e-05]   


[Step 133000] Test RMSE_raw=4.645 µV (over 2048 bins)


Train(ctx9):  89%|████████▉ | 133999/150000 [4:39:34<27:39,  9.64it/s, loss=0.0897, lr=2.14e-05]   


[Step 134000] Test RMSE_raw=4.657 µV (over 2048 bins)


Train(ctx9):  90%|████████▉ | 134999/150000 [4:41:38<27:18,  9.16it/s, loss=0.0822, lr=1.98e-05]   


[Step 135000] Test RMSE_raw=4.664 µV (over 2048 bins)


Train(ctx9):  91%|█████████ | 135999/150000 [4:43:42<23:47,  9.81it/s, loss=0.0856, lr=1.83e-05]   


[Step 136000] Test RMSE_raw=4.515 µV (over 2048 bins)


Train(ctx9):  91%|█████████▏| 136998/150000 [4:45:45<22:03,  9.83it/s, loss=0.0854, lr=1.69e-05]   


[Step 137000] Test RMSE_raw=4.709 µV (over 2048 bins)


Train(ctx9):  92%|█████████▏| 137999/150000 [4:47:48<20:23,  9.81it/s, loss=0.0856, lr=1.56e-05]   


[Step 138000] Test RMSE_raw=4.770 µV (over 2048 bins)


Train(ctx9):  93%|█████████▎| 138998/150000 [4:49:51<18:41,  9.81it/s, loss=0.0862, lr=1.44e-05]   


[Step 139000] Test RMSE_raw=4.683 µV (over 2048 bins)


Train(ctx9):  93%|█████████▎| 139999/150000 [4:51:54<17:55,  9.30it/s, loss=0.0876, lr=1.33e-05]  


[Step 140000] Test RMSE_raw=4.656 µV (over 2048 bins)


Train(ctx9):  94%|█████████▍| 140999/150000 [4:53:58<15:31,  9.67it/s, loss=0.0841, lr=1.23e-05]   


[Step 141000] Test RMSE_raw=4.551 µV (over 2048 bins)


Train(ctx9):  95%|█████████▍| 141999/150000 [4:56:02<13:38,  9.77it/s, loss=0.0880, lr=1.14e-05]  


[Step 142000] Test RMSE_raw=4.621 µV (over 2048 bins)


Train(ctx9):  95%|█████████▌| 142999/150000 [4:58:06<12:01,  9.71it/s, loss=0.0866, lr=1.06e-05]  


[Step 143000] Test RMSE_raw=4.697 µV (over 2048 bins)


Train(ctx9):  96%|█████████▌| 143999/150000 [5:00:09<10:11,  9.81it/s, loss=0.0856, lr=9.91e-06]  


[Step 144000] Test RMSE_raw=4.598 µV (over 2048 bins)


Train(ctx9):  97%|█████████▋| 144999/150000 [5:02:11<08:34,  9.71it/s, loss=0.0852, lr=9.33e-06]  


[Step 145000] Test RMSE_raw=4.708 µV (over 2048 bins)


Train(ctx9):  97%|█████████▋| 145999/150000 [5:04:13<06:48,  9.80it/s, loss=0.0823, lr=8.85e-06]  


[Step 146000] Test RMSE_raw=4.680 µV (over 2048 bins)


Train(ctx9):  98%|█████████▊| 146999/150000 [5:06:16<05:10,  9.67it/s, loss=0.0832, lr=8.48e-06]  


[Step 147000] Test RMSE_raw=4.513 µV (over 2048 bins)


Train(ctx9):  99%|█████████▊| 147999/150000 [5:08:18<03:22,  9.87it/s, loss=0.0869, lr=8.21e-06]  


[Step 148000] Test RMSE_raw=4.736 µV (over 2048 bins)


Train(ctx9):  99%|█████████▉| 148999/150000 [5:10:20<01:48,  9.20it/s, loss=0.0834, lr=8.05e-06]  


[Step 149000] Test RMSE_raw=4.576 µV (over 2048 bins)


Train(ctx9): 100%|█████████▉| 149999/150000 [5:12:23<00:00,  9.91it/s, loss=0.0853, lr=8.00e-06]  


[Step 150000] Test RMSE_raw=4.718 µV (over 2048 bins)


Train(ctx9): 100%|██████████| 150000/150000 [5:12:42<00:00,  7.99it/s, loss=0.0853, lr=8.00e-06]

Done. Best Test RMSE_raw: 4.47469478962677



  plt.tight_layout(); plt.savefig(save_path, dpi=140, bbox_inches='tight'); plt.close(fig)


Saved overlays to: /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/ctx9_overlay.png
[BEST] Test RMSE_raw=4.672 µV (over 2048 bins)
Saved overlays to: /pscratch/sd/t/tylee/SOLID_EEG_RESULT/small_check_1220/ctx9_overlay_best.png


In [21]:
class EEGToGrid(Dataset):
    def __init__(self, base_dataset,):
        self.base_dataset = base_dataset
        self.mean = float(self.base_dataset.mean)
        self.std = float(self.base_dataset.std)

    def TorchEEG_Grid(self, channel_list, grid_templete=TORCHEEG_2DGRID, H=11, W=11):
        """
        2D Grid based on TorchEEG 2D Grid
        input 10-10 coord channel name index 
        output is grid of channel input
        """
        grid = torch.zeros(H, W, dtype=torch.float32)
        mask = torch.zeros(H, W, dtype=torch.float32)
        return grid, mask

    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        i,o,im,om = self.base[idx]

        target_grid = None
        target_mask = None
        cond = None

        return target_grid, target_mask, cond, self.mean, self.std

In [22]:
# Set some future argparse

# TUEG_1.0 path in lucy's pscratch
# /pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb
LMDB_DIR = "/pscratch/sd/a/ahhyun/EcoGFound/DATA/scaling_data_V2_Sep_2025/striped_EEG_lmdb/TUEG_1.0/1.0_TUEG/all_resample-500_highpass-0.3_lowpass-None.lmdb"
BATCH_SIZE = 64

In [23]:
# Set Train and Test dataset and dataloader

train_eeg = Physio_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=True,)
test_eeg = Physio_for_SOLID_from_lmdb(lmdb_dir=LMDB_DIR,
                         maxFolds=5,
                         seed=41,
                         train=False,
                         )

train_set = EEGToGrid(train_eeg)
test_set = EEGToGrid(test_eeg)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_worker=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_worker=2, pin_memory=True)

TypeError: Physio_for_SOLID_from_lmdb.__init__() got an unexpected keyword argument 'maxFolds'

In [None]:
LMDB_DIR = "/pscratch/sd/t/tylee/Dataset/1109_Physio_500Hz"


In [None]:
class EEG_from_lmdb(Dataset):
    def __init__(self, data_dir, transform, return_info):
        self.data_dir = data_dir
        self.transform = transform
        self.return_info = return_info

    def lmdb_to_data(self, idx):
        self.db = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=True, meminit=False)
        key = self.keys[idx]
        with self.db.begin(write=False) as txn:
            pair = pickle.loads(txn.get(key.encode()))
        data = pair['sample']
        label = pair['label']
        data_info = pair.get('data_info', {})
        
        data = to_tensor(data)
        if self.transform is not None:
            data = self.transform(data)
        if self.return_info:
            return data/100, label, data_info
        else:
            return data/100, label
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om

In [None]:
# Implementing torch.dataset

def xform_day(day):
    arr = [0, 30, 61]
    w = 0 if day <= 30 else 1 if day <= 61 else 2
    mon = ['2020-11-', '2020-12-', '2021-01-'][w]
    date = mon + '{:02d}'.format(day - arr[w])
    return date


def get_suffixes(mode):
    suffixes = []
    if 'C' in mode or 'A' in mode:
        suffixes.append('train')
    if 'D' in mode or 'B' in mode:
        suffixes.append('test')
    return suffixes

def rename_cols(data):
    data.rename(
        columns={'dateTime': 'time', 'lat': 'latitude', 'long': 'longitude', 'pm2_5': 'PM25_Concentration',
                 'pm10': 'PM10_Concentration'}, inplace=True)



def torch1dgrid(num, bot=0, top=1):
    arr = torch.linspace(bot, top, steps=num)
    mesh = torch.stack([arr], dim=1)
    return mesh.squeeze(-1)

import torch
from torch.utils.data import Dataset
from einops import rearrange        
class Delhi(Dataset):
    def __init__(
        self, mode_t, mode_p, canada, train_days, 
        maxFolds = 5, target_fold = 0, temporal_scaling=1, spatiotemporal=1, data_dir='/pscratch/sd/d/dpark1/AirDelhi/delhi/processed', 
        seed=10, nTrainStartDay = 15, nTestStartDay = 75, nTotalDays = 91, train=True):
        
        self.mode_t = mode_t
        self.mode_p = mode_p
        self.train_days = train_days
        self.train = train
        self.maxFolds = maxFolds
        self.target_fold = target_fold
        self.temporal_scaling = temporal_scaling
        self.spatiotemporal = spatiotemporal
        self.data_dir = data_dir
        self.nTestStartDay = nTestStartDay
        self.nTrainStartDay = nTrainStartDay
        self.nTotalDays = nTotalDays
        
        np.random.seed(seed)        
        
        self.train_suffix = get_suffixes(mode_t)
                
        if spatiotemporal < 0 and mode_t == 'AB' and mode_p == 'CD':
            # Forecasting, single fold is enough
            maxFolds = 1
    
        self.folds = [i for i in range(maxFolds)]
        
        self.data, self.target = self.proc_custom(target_fold)
        
        
        
    def get_normalize_params(self, target):
        all_signal = []
        for a in target:
            all_signal += list(a[..., 0])
        self.mean, self.std = np.array(all_signal).mean(), np.array(all_signal).std()
    
    def get_spatial_norm_parameters(self, arr_of_days):
        """"minmax normalization"""
        latmin = 10e10
        latmax = -10e10
        longmin = 10e10
        longmax = -10e10
        for arr in arr_of_days:
            minned = arr.min(0)
            # print(minned[0])
            if minned[2] < latmin:
                latmin = minned[2]
            if minned[3] < longmin:
                longmin = minned[3]
                
            maxed = arr.max(0)
            # print(maxed)
            if maxed[2] > latmax:
                latmax = maxed[2]
            if maxed[3] > longmax:
                longmax = maxed[3]
        self.latmin, self.latmax, self.longmin, self.longmax =latmin, latmax, longmin, longmax
    
    def make_data_by_time(self, arr_of_days, t_in = 9, reverse=False, day = 0):
        seg_by_time = []
        uniq_times = np.unique(arr_of_days[..., 1])
        
        for t in uniq_times:
            idx_ = arr_of_days[..., 1] == t
            seg_by_time.append(arr_of_days[idx_])
        
        in_ = []
        out_ = []
        for i in range(len(seg_by_time) - t_in):
            temp_in = []
            for t_ in range(t_in):
                temp_in.append(seg_by_time[i + t_])
            
            # normalize time to relative scale by the last one of the encoder
            in_cand = np.copy(np.concatenate(temp_in, axis=0))
            out_cand = np.copy(seg_by_time[i + t_in])
            
            last_enc_t = in_cand[..., 1][-1]
            in_cand[..., 1] -= last_enc_t
            out_cand[..., 1] -= last_enc_t
            in_.append(in_cand)
            out_.append(out_cand)
            self.day_record.append(day)
            
            # reverse it
            if reverse:
                out_cand = np.copy(seg_by_time[i])
                temp_in = temp_in[1:]
                temp_in.append(seg_by_time[i + t_in])
                in_cand = np.copy(np.concatenate(temp_in, axis=0))
                last_enc_t = in_cand[..., 1][0]
                in_cand[..., 1] -= last_enc_t
                out_cand[..., 1] -= last_enc_t
                
                in_.append(in_cand)
                out_.append(out_cand)
                self.day_record.append(day)
        
        
        return in_, out_
    
    def proc_custom(self, fold):
        
        self.day_record = []
        
        train_data = {'input':[], 'target':[]}
        test_data = {'input':[], 'target':[]}
        
        for day in range(self.nTrainStartDay, self.nTestStartDay):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            train_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            train_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day
            
            seg_in, seg_out = self.make_data_by_time(train_in, day = day)
            
            train_data['input'] += seg_in
            train_data['target'] += seg_out            
        
        
        
        seg_in, seg_out = self.make_data_by_time(train_out)
        train_data['input'] += seg_in
        train_data['target'] += seg_out
            
        
        for day in range(self.nTestStartDay, self.nTotalDays+1):
            date = []
            for i in range(self.train_days,-1,-1):
                date.append(xform_day(day-i))

            train_input,train_output,test_input,test_output = self.process_np(fold, date)
            test_in = np.concatenate([train_output[..., np.newaxis], train_input], axis=1) # 1 days
            test_out = np.concatenate([test_output[..., np.newaxis], test_input], axis=1) # 1 day

            seg_in, seg_out = self.make_data_by_time(test_in, reverse = False)
            
            test_data['input'] += seg_in
            test_data['target'] += seg_out            
            
        seg_in, seg_out = self.make_data_by_time(test_out, reverse = False)
        test_data['input'] += seg_in
        test_data['target'] += seg_out

        self.get_normalize_params(train_data['target']) 
        self.get_spatial_norm_parameters(train_data['target'])
            
        if self.train:
            data = train_data['input']
            target = train_data['target']
            print(len(data), len(target))
            
        else:
            data = test_data['input']
            target = test_data['target']
            print(len(data), len(target))

        return data, target        
    
    

    def process_np(self, fold, date):
        tmStart = datetime.datetime.now()
        train_input,train_output,test_input,test_output = self.return_data_time(fold=fold, data=date, with_scaling=True)
        return train_input,train_output,test_input,test_output
    
    def return_data_time(self, fold, data, with_scaling):
        train_input = None
        if 'A' in self.mode_t or 'B' in self.mode_t:
            for idx,dt in enumerate(data[:-1]):
                for suffix in self.train_suffix:
                    input = pd.read_csv(self.data_dir+'/'+dt+'_f'+str(fold)+'_'+suffix+'.csv')
                    # if self.temporal_scaling:
                    #     input.dateTime += idx * 24 * 60
                    train_input = pd.concat((train_input, input))
                    
        if 'C' in self.mode_t:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            # if self.temporal_scaling:
            #     input.dateTime += (len(data)-1) * 24 * 60
            train_input = pd.concat((train_input, input))

        test_input = pd.read_csv(self.data_dir+'/'+data[-1]+'_f'+str(fold)+'_test.csv')
        
        if 'C' in self.mode_p:
            input = pd.read_csv(self.data_dir + '/' + data[-1] + '_f' + str(fold) + '_train.csv')
            test_input = pd.concat((input, test_input))
            
        # if self.temporal_scaling:
        #     test_input.dateTime += (len(data)-1) * 24 * 60

        return self.return_data_0(train_input, test_input, with_scaling)

    
    def return_data_0(self, train_input, test_input, with_scaling):
        train_output = np.array(train_input['pm2_5'])
        train_input = train_input[['dateTime','lat','long']]
        test_output = np.array(test_input['pm2_5'])
        test_input = test_input[['dateTime','lat','long']]

        # if with_scaling:
        #     scaler = MinMaxScaler().fit(train_input)
        #     if self.temporal_scaling:
        #         data = scaler.transform(pd.concat((train_input, test_input)))
        #         test_input = data[len(train_input):]
        #         train_input = data[:len(train_input)]
        #     else:
        #         train_input = scaler.transform(train_input)
        #         test_input = scaler.transform(test_input)
        return train_input,train_output,test_input,test_output

    def set_target_fold(self, fold=0):
        self.fold = fold
        print('target fold set to {}'.format(self.fold))
        
    def normalize_z(self, arr):
        return (arr - self.mean) / self.std
    
    def normalize(self, data, min_, max_):
        return (data - min_) / (max_ - min_)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_, target_ = self.data[idx], self.target[idx]
        in_ = torch.from_numpy(input_.astype(np.float32))
        out_ = torch.from_numpy(target_.astype(np.float32))
        # print(in_)
        
        # Normalize pm2.5 values
        in_[..., 0] = self.normalize_z(in_[..., 0])
        out_[..., 0] = self.normalize_z(out_[..., 0])
        
        in_[..., 1] = in_[..., 1] / 1440
        out_[..., 1] = out_[..., 1] / 1440
        
        timegap = out_[..., 1:2][0] # get the gap between t+1 and t (ignoring t-1, and t-2)
        
        in_ = torch.cat([in_[..., 0:1], in_[..., 2:], in_[..., 1:2], ], dim=-1)
        out_ = torch.cat([out_[..., 0:1], out_[..., 2:], out_[..., 1:2], ], dim=-1)
        
        in_[..., 1] = self.normalize(in_[..., 1], self.latmin, self.latmax)
        out_[..., 1] = self.normalize(out_[..., 1], self.latmin, self.latmax)
        in_[..., 2] = self.normalize(in_[..., 2], self.longmin, self.longmax)
        out_[..., 2] = self.normalize(out_[..., 2], self.longmin, self.longmax)
        
        i = in_[..., 0:1]
        im = in_[..., 1:]
        o = out_[..., 0:1]
        om = out_[..., 1:]
        
        return i, o, im, om