In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import glob
import json
import random
import re
from typing import List, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# =========================
# 0. 기본 설정
# =========================
BASE_DIR = "/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints"

CKPT_DIR  = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/1D"   # weights-only 저장
CACHE_DIR = "/content/drive/MyDrive/cv-medislr/data/preprocessed/tensors/1D"     # 전처리 캐시(X,y,split,meta) 저장

os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

TARGET_LEN = 16
BATCH_SIZE = 32
EPOCHS = 40
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-2
LABEL_SMOOTHING = 0.1
RANDOM_SEED = 42

TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(RANDOM_SEED)

# =========================
# 1. 유틸 함수들
# =========================
EXPECTED_HAND = 21
SEQ_GROUP_RE = re.compile(r"(.*)_s(\d+)_hands$")

def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def get_landmark_array(lst: List[Dict[str, float]], expected_len: int) -> np.ndarray:
    arr = np.zeros((expected_len, 4), dtype=np.float32)
    for i, lm in enumerate(lst[:expected_len]):
        arr[i, 0] = lm.get("x", 0.0)
        arr[i, 1] = lm.get("y", 0.0)
        arr[i, 2] = lm.get("z", 0.0)
        arr[i, 3] = lm.get("v", 0.0)
    return np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)

def resample_sequence(seq: np.ndarray, target_len: int) -> np.ndarray:
    seq = np.nan_to_num(seq, nan=0.0, posinf=0.0, neginf=0.0)
    L = seq.shape[0]
    if L == target_len:
        return seq
    if L <= 1:
        return np.repeat(seq, target_len, axis=0)
    idxs = np.linspace(0, L - 1, target_len)
    idxs = np.round(idxs).astype(np.int32)
    idxs = np.clip(idxs, 0, L - 1)
    out = seq[idxs]
    return np.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

def add_velocity_feature(seq: np.ndarray) -> np.ndarray:
    seq = np.nan_to_num(seq, nan=0.0, posinf=0.0, neginf=0.0)
    vel = np.diff(seq, axis=0, prepend=seq[0:1])
    vel = np.nan_to_num(vel, nan=0.0, posinf=0.0, neginf=0.0)
    out = np.concatenate([seq, vel], axis=-1)  # (T, 2D)
    return np.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

# =========================
# 2. 데이터 로딩 & 전처리 (HANDS ONLY)
# =========================
def collect_sequences(base_dir: str):
    """
    base_dir 아래:
        person_x / *_s00_hands.json ...
    파일 그룹핑: (base_id, seg_idx) 로 묶고
    seg_idx별 frame 평균 후 (L,42,4) 시퀀스 생성 (lh+rh)
    """
    all_seq_arrays = []
    all_labels = []

    person_dirs = sorted([d for d in glob.glob(os.path.join(base_dir, "*")) if os.path.isdir(d)])
    print(f"Found person dirs: {person_dirs}")

    for p_dir in person_dirs:
        person_name = os.path.basename(p_dir)
        json_files = glob.glob(os.path.join(p_dir, "*_hands.json"))

        groups = {}  # base_id -> seg_idx -> [paths]
        for jf in json_files:
            stem = os.path.splitext(os.path.basename(jf))[0]
            m = SEQ_GROUP_RE.match(stem)
            if not m:
                continue
            base_id, seg_str = m.groups()
            seg_idx = int(seg_str)
            groups.setdefault(base_id, {}).setdefault(seg_idx, []).append(jf)

        print(f"[Person {person_name}] #base sequences: {len(groups)}")

        for base_id, seg_dict in groups.items():
            segment_features = []
            label_word = None

            for seg_idx in sorted(seg_dict.keys()):
                seg_jsons = sorted(seg_dict[seg_idx])
                frame_kpts_list = []

                for jf in seg_jsons:
                    data = load_json(jf)
                    if label_word is None:
                        label_word = data.get("word_folder", "UNKNOWN")

                    lh = get_landmark_array(data.get("left_hand", []), EXPECTED_HAND)
                    rh = get_landmark_array(data.get("right_hand", []), EXPECTED_HAND)

                    frame_kpts = np.concatenate([lh, rh], axis=0)  # (42,4)  ✅ hands only
                    frame_kpts = np.nan_to_num(frame_kpts, nan=0.0, posinf=0.0, neginf=0.0)
                    frame_kpts_list.append(frame_kpts)

                if not frame_kpts_list:
                    continue

                # seg_idx 내 frame 평균 -> (42,4)
                seg_arr = np.stack(frame_kpts_list, axis=0).mean(axis=0)
                seg_arr = np.nan_to_num(seg_arr, nan=0.0, posinf=0.0, neginf=0.0)
                segment_features.append(seg_arr)

            if not segment_features:
                continue

            # (L,42,4)
            seq_arr = np.stack(segment_features, axis=0)
            seq_arr = np.nan_to_num(seq_arr, nan=0.0, posinf=0.0, neginf=0.0)
            all_seq_arrays.append(seq_arr)
            all_labels.append(label_word if label_word is not None else "UNKNOWN")

    print(f"Total sequences (base_id level): {len(all_seq_arrays)}")
    return all_seq_arrays, all_labels

def build_label_mapping(labels: List[str]) -> Dict[str, int]:
    uniq = sorted(list(set(labels)))
    label2idx = {lab: i for i, lab in enumerate(uniq)}
    print("Label mapping:")
    for k, v in label2idx.items():
        print(f"  {k} -> {v}")
    return label2idx

def prepare_dataset(base_dir: str, target_len: int):
    """
    hands-only:
    (L,42,4) -> flatten (L,168) -> resample (T,168) -> +velocity (T,336)
    """
    raw_seqs, raw_labels = collect_sequences(base_dir)
    label2idx = build_label_mapping(raw_labels)

    X_list = []
    y_list = []

    for seq_arr, lab in zip(raw_seqs, raw_labels):
        L, J, C = seq_arr.shape      # (L,42,4)
        seq_flat = seq_arr.reshape(L, J * C)  # (L,168)
        seq_flat = np.nan_to_num(seq_flat, nan=0.0, posinf=0.0, neginf=0.0)

        seq_resampled = resample_sequence(seq_flat, target_len)  # (T,168)
        seq_with_vel = add_velocity_feature(seq_resampled)       # (T,336)
        seq_with_vel = np.nan_to_num(seq_with_vel, nan=0.0, posinf=0.0, neginf=0.0)

        X_list.append(seq_with_vel.astype(np.float32))
        y_list.append(label2idx[lab])

    X = np.stack(X_list, axis=0).astype(np.float32)  # (N,T,D)
    y = np.array(y_list, dtype=np.int64)             # (N,)

    input_dim = int(X.shape[-1])
    num_classes = int(len(label2idx))
    meta = {
        "label2idx": label2idx,
        "num_classes": num_classes,
        "input_dim": input_dim,
        "target_len": target_len,
        "seed": RANDOM_SEED,
        "ratios": {"train": TRAIN_RATIO, "val": VAL_RATIO, "test": TEST_RATIO},
        "mode": "hands_only",
        "landmarks": {"left_hand": 21, "right_hand": 21},
        "per_landmark_dim": 4
    }
    print(f"Final input dim D = {input_dim}, num_classes = {num_classes}")
    return X, y, meta

def compute_global_stats_from_array(X: np.ndarray, train_idx: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    X: (N,T,D)
    train_idx로만 mean/std 계산 (D,)
    """
    train_data = X[train_idx]                   # (Ntr,T,D)
    flat = train_data.reshape(-1, X.shape[-1])  # (Ntr*T, D)
    flat = np.nan_to_num(flat, nan=0.0, posinf=0.0, neginf=0.0)
    mean = flat.mean(axis=0).astype(np.float32)
    std = (flat.std(axis=0) + 1e-6).astype(np.float32)
    mean = np.nan_to_num(mean, nan=0.0, posinf=0.0, neginf=0.0)
    std = np.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1.0)
    return mean, std

# =========================
# 2.5 전처리 캐시 저장/로드
# =========================
def cache_paths():
    data_npz = os.path.join(CACHE_DIR, f"data_HANDS_T{TARGET_LEN}_seed{RANDOM_SEED}.npz")
    meta_json = os.path.join(CACHE_DIR, f"meta_HANDS_T{TARGET_LEN}_seed{RANDOM_SEED}.json")
    split_npz = os.path.join(CACHE_DIR, f"split_HANDS_T{TARGET_LEN}_seed{RANDOM_SEED}.npz")
    return data_npz, meta_json, split_npz

def build_or_load_cache():
    data_npz, meta_json, split_npz = cache_paths()

    # 1) 캐시 존재 시 로드
    if os.path.exists(data_npz) and os.path.exists(meta_json) and os.path.exists(split_npz):
        print(f"✅ Load cached preprocessing: {data_npz}")
        pack = np.load(data_npz)
        X = pack["X"].astype(np.float32)
        y = pack["y"].astype(np.int64)

        with open(meta_json, "r", encoding="utf-8") as f:
            meta = json.load(f)

        split = np.load(split_npz)
        train_idx = split["train_idx"].astype(np.int64)
        val_idx   = split["val_idx"].astype(np.int64)
        test_idx  = split["test_idx"].astype(np.int64)
        return X, y, meta, train_idx, val_idx, test_idx

    # 2) 없으면 전처리 -> 저장
    print("⏳ Cache not found. Running preprocessing once and saving cache...")
    X, y, meta = prepare_dataset(BASE_DIR, TARGET_LEN)

    num_samples = len(X)
    rng = random.Random(RANDOM_SEED)
    indices = list(range(num_samples))
    rng.shuffle(indices)

    n_train = int(num_samples * TRAIN_RATIO)
    n_val   = int(num_samples * VAL_RATIO)

    train_idx = np.array(indices[:n_train], dtype=np.int64)
    val_idx   = np.array(indices[n_train:n_train + n_val], dtype=np.int64)
    test_idx  = np.array(indices[n_train + n_val:], dtype=np.int64)

    np.savez_compressed(data_npz, X=X, y=y)
    with open(meta_json, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)
    np.savez_compressed(split_npz, train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

    print(f"✅ Saved cache: {data_npz}")
    print(f"✅ Saved meta : {meta_json}")
    print(f"✅ Saved split: {split_npz}")
    return X, y, meta, train_idx, val_idx, test_idx

# =========================
# 3. Dataset & Augmentation (캐시 기반)
# =========================
def augment_seq_tensor(seq: torch.Tensor) -> torch.Tensor:
    T, D = seq.shape
    if random.random() < 0.5:
        scale = random.uniform(0.95, 1.05)
        seq = seq * scale
    if random.random() < 0.5 and T > 8:
        cut = random.randint(0, 2)
        if cut > 0:
            seq_short = seq[cut:]
            pad = seq_short[-1:].repeat(cut, 1)
            seq = torch.cat([seq_short, pad], dim=0)
    return seq

class CachedSequenceDataset(Dataset):
    """
    X: (N,T,D) numpy
    y: (N,) numpy
    indices: (K,) numpy
    mean/std: (D,) numpy
    """
    def __init__(self, X: np.ndarray, y: np.ndarray, indices: np.ndarray,
                 mean: np.ndarray, std: np.ndarray, augment: bool):
        self.X = X
        self.y = y
        self.indices = indices
        self.augment = augment
        self.mean = torch.from_numpy(mean).float()
        self.std  = torch.from_numpy(std).float()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = int(self.indices[i])
        seq = torch.from_numpy(self.X[idx]).float()  # (T,D)
        label = int(self.y[idx])

        seq = (seq - self.mean) / self.std

        if self.augment:
            if random.random() < 0.7:
                seq = seq + torch.randn_like(seq) * 0.01
            if random.random() < 0.5:
                shift = random.randint(-3, 3)
                if shift != 0:
                    seq = torch.roll(seq, shifts=shift, dims=0)
            seq = augment_seq_tensor(seq)

        return seq, label

# =========================
# 4. 모델들
# =========================
class GRUClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256, num_layers=2, bidirectional=True, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        out_dim = hidden_dim * (2 if bidirectional else 1)
        self.attn_fc = nn.Linear(out_dim, 1)
        self.head = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(0.2),
            nn.Linear(out_dim, num_classes),
        )

    def forward(self, x):
        out, _ = self.gru(x)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        w = torch.softmax(self.attn_fc(out), dim=1)
        w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0)
        feat = (w * out).sum(dim=1)
        return self.head(feat)

class TemporalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.2):
        super().__init__()
        padding = ((kernel_size - 1) * dilation) // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None

    def forward(self, x):
        out = self.dropout(self.relu(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        return self.relu(out + x)

class AttnPool1d(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.attn = nn.Linear(in_channels, 1)

    def forward(self, x):  # (B,C,T)
        x_perm = x.transpose(1, 2)  # (B,T,C)
        scores = self.attn(x_perm).squeeze(-1)     # (B,T)
        weights = torch.softmax(scores, dim=-1)    # (B,T)
        pooled = torch.bmm(weights.unsqueeze(1), x_perm)  # (B,1,C)
        return pooled.squeeze(1)  # (B,C)

class TCNClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_channels=256):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_channels)
        self.tcn = nn.Sequential(
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=1),
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=2),
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=4),
        )
        self.pool = AttnPool1d(hidden_channels)
        self.fc = nn.Linear(hidden_channels, num_classes)

    def forward(self, x):  # (B,T,D)
        x = self.input_proj(x)  # (B,T,C)
        x = x.transpose(1, 2)   # (B,C,T)
        x = self.tcn(x)         # (B,C,T)
        x = self.pool(x)        # (B,C)
        return self.fc(x)       # (B,num_classes)

# =========================
# 5. 학습/평가
# =========================
def run_epoch(model, loader, optimizer=None, criterion=None):
    if optimizer is None:
        model.eval()
    else:
        model.train()

    total_loss, total_correct, total_count = 0.0, 0, 0

    for seq, label in loader:
        seq = seq.to(DEVICE)
        label = label.to(DEVICE)

        if optimizer is not None:
            optimizer.zero_grad()

        logits = model(seq)
        loss = criterion(logits, label)

        if optimizer is not None:
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * seq.size(0)
        pred = torch.argmax(logits, dim=1)
        total_correct += (pred == label).sum().item()
        total_count += seq.size(0)

    return total_loss / total_count, total_correct / total_count

# =========================
# 6. 저장/로드 (weights-only)
# =========================
def model_artifact_paths(model_name: str):
    w_path    = os.path.join(CKPT_DIR, f"{model_name}_best.pt")          # ✅ weights-only
    norm_path = os.path.join(CKPT_DIR, f"{model_name}_norm.npz")         # ✅ mean/std
    meta_path = os.path.join(CKPT_DIR, f"{model_name}_meta.json")        # ✅ label2idx + cfg
    return w_path, norm_path, meta_path

def save_artifacts(model_name: str, model: nn.Module,
                   train_mean: np.ndarray, train_std: np.ndarray,
                   meta: dict, model_cfg: dict):
    w_path, norm_path, meta_path = model_artifact_paths(model_name)

    # 1) weights-only
    torch.save(model.state_dict(), w_path)

    # 2) norm
    np.savez_compressed(norm_path, mean=train_mean.astype(np.float32), std=train_std.astype(np.float32))

    # 3) meta
    meta_out = {
        "label2idx": meta["label2idx"],
        "num_classes": int(meta["num_classes"]),
        "input_dim": int(meta["input_dim"]),
        "target_len": int(meta["target_len"]),
        "seed": int(meta["seed"]),
        "model_name": model_name,
        "model_cfg": model_cfg,
        "mode": meta.get("mode", "hands_only"),
    }
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta_out, f, ensure_ascii=False, indent=2)

    print(f"✅ Saved weights-only: {w_path}")
    print(f"✅ Saved norm       : {norm_path}")
    print(f"✅ Saved meta       : {meta_path}")

def load_weights_only(w_path: str, model: nn.Module):
    state = torch.load(w_path, map_location=DEVICE)  # state_dict only -> 안전
    model.load_state_dict(state)
    model.eval()

# =========================
# 7. 모델별 학습
# =========================
def train_one(model_name: str, X: np.ndarray, y: np.ndarray, meta: dict,
              train_idx: np.ndarray, val_idx: np.ndarray, test_idx: np.ndarray):

    input_dim = int(meta["input_dim"])
    num_classes = int(meta["num_classes"])

    # train 기준 mean/std
    train_mean, train_std = compute_global_stats_from_array(X, train_idx)

    train_dataset = CachedSequenceDataset(X, y, train_idx, train_mean, train_std, augment=True)
    val_dataset   = CachedSequenceDataset(X, y, val_idx,   train_mean, train_std, augment=False)
    test_dataset  = CachedSequenceDataset(X, y, test_idx,  train_mean, train_std, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    if model_name == "gru":
        model_cfg = {"hidden_dim": 256, "num_layers": 2, "bidirectional": True, "dropout": 0.2}
        model = GRUClassifier(input_dim=input_dim, num_classes=num_classes, **model_cfg).to(DEVICE)
    elif model_name == "tcn":
        model_cfg = {"hidden_channels": 256}
        model = TCNClassifier(input_dim=input_dim, num_classes=num_classes, **model_cfg).to(DEVICE)
    else:
        raise ValueError("model_name must be 'gru' or 'tcn'")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)

    best_val_acc = -1.0
    best_w_path, _, _ = model_artifact_paths(model_name)

    for epoch in range(1, EPOCHS + 1):
        train_loss, train_acc = run_epoch(model, train_loader, optimizer, criterion)
        val_loss, val_acc     = run_epoch(model, val_loader, optimizer=None, criterion=criterion)

        print(f"[{model_name.upper()}][Epoch {epoch:02d}] "
              f"train_loss={train_loss:.4f}, train_acc={train_acc:.3f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.3f}")

        scheduler.step(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_artifacts(model_name, model, train_mean, train_std, meta, model_cfg)

    # best weights 로드 후 test
    load_weights_only(best_w_path, model)
    test_loss, test_acc = run_epoch(model, test_loader, optimizer=None, criterion=criterion)

    print(f"\n[{model_name.upper()}][Final] Best val_acc={best_val_acc:.3f}")
    print(f"[{model_name.upper()}][Final] Test loss={test_loss:.4f}, Test acc={test_acc:.3f}")
    print(f"✅ Best weights at: {best_w_path}\n")

# =========================
# 8. 메인
# =========================
def main():
    X, y, meta, train_idx, val_idx, test_idx = build_or_load_cache()
    train_one("gru", X, y, meta, train_idx, val_idx, test_idx)
    train_one("tcn", X, y, meta, train_idx, val_idx, test_idx)

if __name__ == "__main__":
    main()


Using device: cpu
⏳ Cache not found. Running preprocessing once and saving cache...
Found person dirs: ['/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/1', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/10', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/2', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/3', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/4', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/5', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/6', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/7', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/8', '/content/drive/MyDrive/cv-medislr/data/preprocessed/keypoints/9']
[Person 1] #base sequences: 107
[Person 10] #base sequences: 110
[Person 2] #base sequences: 110
[Person 3] #base sequences: 110
[Person 4] #base sequences: 110
[Person 5] #base sequences: 110
[Person 6] #base sequences: 110


### demo용 데이터 만들기

In [None]:
import os
import json
import numpy as np

# =========================
# 경로 설정
# =========================
CACHE_DIR = "/content/drive/MyDrive/cv-medislr/data/preprocessed/tensors/1D"
SAVE_DIR  = "/content/drive/MyDrive/cv-medislr/data/samples/1D"
os.makedirs(SAVE_DIR, exist_ok=True)

SEED = 42
K = 20

data_npz  = os.path.join(CACHE_DIR, "data_HANDS_T16_seed42.npz")
split_npz = os.path.join(CACHE_DIR, "split_HANDS_T16_seed42.npz")
meta_json = os.path.join(CACHE_DIR, "meta_HANDS_T16_seed42.json")

assert os.path.exists(data_npz),  f"❌ {data_npz} 없음"
assert os.path.exists(split_npz), f"❌ {split_npz} 없음"
assert os.path.exists(meta_json), f"❌ {meta_json} 없음"

print("✅ Using cache files:")
print(" data :", data_npz)
print(" split:", split_npz)
print(" meta :", meta_json)

# =========================
# 1. 로드
# =========================
pack = np.load(data_npz)
X = pack["X"].astype(np.float32)   # (N,T,D)
y = pack["y"].astype(np.int64)     # (N,)

split = np.load(split_npz)
test_idx = split["test_idx"].astype(np.int64)

with open(meta_json, "r", encoding="utf-8") as f:
    meta = json.load(f)

label2idx = meta["label2idx"]
idx2label = {int(v): k for k, v in label2idx.items()}

print(f"Total test samples: {len(test_idx)}")

# =========================
# 2. test에서 20개 고정 샘플링
# =========================
rng = np.random.default_rng(SEED)
k = min(K, len(test_idx))

picked_idx = rng.choice(test_idx, size=k, replace=False).astype(np.int64)

X_demo = X[picked_idx]   # (k,T,D)
y_demo = y[picked_idx]   # (k,)

# =========================
# 3. 저장
# =========================
out_npz = os.path.join(SAVE_DIR, f"demo_test{k}_HANDS_T16_seed{SEED}.npz")
np.savez_compressed(
    out_npz,
    X_demo=X_demo,
    y_demo=y_demo,
    orig_idx=picked_idx
)

out_json = os.path.join(SAVE_DIR, f"demo_test{k}_HANDS_T16_seed{SEED}.json")
out_meta = {
    "source_cache_dir": CACHE_DIR,
    "k": int(k),
    "seed": int(SEED),
    "target_len": int(meta["target_len"]),
    "input_dim": int(meta["input_dim"]),
    "num_classes": int(meta["num_classes"]),
    "label2idx": label2idx,
    "picked_orig_idx": picked_idx.tolist(),
    "picked_labels_text": [idx2label[int(c)] for c in y_demo.tolist()],
}

with open(out_json, "w", encoding="utf-8") as f:
    json.dump(out_meta, f, ensure_ascii=False, indent=2)

print("✅ Demo samples saved:")
print(" ", out_npz)
print(" ", out_json)
print(" X_demo:", X_demo.shape, "y_demo:", y_demo.shape)


✅ Using cache files:
 data : /content/drive/MyDrive/cv-medislr/data/preprocessed/tensors/1D/data_HANDS_T16_seed42.npz
 split: /content/drive/MyDrive/cv-medislr/data/preprocessed/tensors/1D/split_HANDS_T16_seed42.npz
 meta : /content/drive/MyDrive/cv-medislr/data/preprocessed/tensors/1D/meta_HANDS_T16_seed42.json
Total test samples: 166
✅ Demo samples saved:
  /content/drive/MyDrive/cv-medislr/data/samples/1D/demo_test20_HANDS_T16_seed42.npz
  /content/drive/MyDrive/cv-medislr/data/samples/1D/demo_test20_HANDS_T16_seed42.json
 X_demo: (20, 16, 336) y_demo: (20,)
