In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### 1D

In [18]:
import os
import json
import numpy as np
import torch
import torch.nn as nn

# =========================
# 0) Í≤ΩÎ°ú ÏÑ§Ï†ï
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

DEMO_NPZ  = "/content/drive/MyDrive/cv-medislr/data/samples/1D/demo_test20_HANDS_T16_seed42.npz"

ARTIFACT_DIR = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/1D"
GRU_W    = os.path.join(ARTIFACT_DIR, "gru_best.pt")
GRU_NORM = os.path.join(ARTIFACT_DIR, "gru_norm.npz")
GRU_META = os.path.join(ARTIFACT_DIR, "gru_meta.json")

TCN_W    = os.path.join(ARTIFACT_DIR, "tcn_best.pt")
TCN_NORM = os.path.join(ARTIFACT_DIR, "tcn_norm.npz")
TCN_META = os.path.join(ARTIFACT_DIR, "tcn_meta.json")

assert os.path.exists(DEMO_NPZ), f"‚ùå demo npz not found: {DEMO_NPZ}"

for p in [GRU_W, GRU_NORM, GRU_META, TCN_W, TCN_NORM, TCN_META]:
    assert os.path.exists(p), f"‚ùå artifact not found: {p}"

# =========================
# 1) Î™®Îç∏ Ï†ïÏùò (ÌïôÏäµ ÏΩîÎìúÏôÄ ÎèôÏùº)
# =========================
class GRUClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256, num_layers=2, bidirectional=True, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        out_dim = hidden_dim * (2 if bidirectional else 1)
        self.attn_fc = nn.Linear(out_dim, 1)
        self.head = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(0.2),
            nn.Linear(out_dim, num_classes),
        )

    def forward(self, x):  # (B,T,D)
        out, _ = self.gru(x)                          # (B,T,512)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        w = torch.softmax(self.attn_fc(out), dim=1)   # (B,T,1)
        w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0)
        feat = (w * out).sum(dim=1)                   # (B,512)
        return self.head(feat)                        # (B,C)

class TemporalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.2):
        super().__init__()
        padding = ((kernel_size - 1) * dilation) // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None

    def forward(self, x):  # (B,C,T)
        out = self.dropout(self.relu(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            x = self.downsample(x)
        return self.relu(out + x)

class AttnPool1d(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.attn = nn.Linear(in_channels, 1)

    def forward(self, x):  # (B,C,T)
        x_perm = x.transpose(1, 2)                    # (B,T,C)
        scores = self.attn(x_perm).squeeze(-1)        # (B,T)
        weights = torch.softmax(scores, dim=-1)       # (B,T)
        pooled = torch.bmm(weights.unsqueeze(1), x_perm)  # (B,1,C)
        return pooled.squeeze(1)                      # (B,C)

class TCNClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_channels=256):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_channels)
        self.tcn = nn.Sequential(
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=1),
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=2),
            TemporalConvBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=4),
        )
        self.pool = AttnPool1d(hidden_channels)
        self.fc = nn.Linear(hidden_channels, num_classes)

    def forward(self, x):  # (B,T,D)
        x = self.input_proj(x)       # (B,T,C)
        x = x.transpose(1, 2)        # (B,C,T)
        x = self.tcn(x)              # (B,C,T)
        x = self.pool(x)             # (B,C)
        return self.fc(x)            # (B,num_classes)

# =========================
# 2) Ïú†Ìã∏: norm/meta/weights Î°úÎìú + ÏòàÏ∏°
# =========================
def load_meta(meta_path: str):
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    label2idx = meta["label2idx"]
    idx2label = {int(v): k for k, v in label2idx.items()}
    return meta, idx2label

def load_norm(norm_path: str):
    pack = np.load(norm_path)
    mean = pack["mean"].astype(np.float32)
    std  = pack["std"].astype(np.float32)
    return mean, std

def load_weights(model: nn.Module, w_path: str):
    state = torch.load(w_path, map_location=DEVICE)   # state_dict only
    model.load_state_dict(state)
    model.to(DEVICE).eval()

@torch.no_grad()
def run_demo(model_name: str, model: nn.Module, X_demo: np.ndarray, y_demo: np.ndarray, idx2label: dict, mean: np.ndarray, std: np.ndarray):
    # normalize
    Xn = (X_demo - mean[None, None, :]) / std[None, None, :]
    Xn = np.nan_to_num(Xn, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

    x = torch.from_numpy(Xn).to(DEVICE)  # (B,T,D)
    logits = model(x)
    pred = torch.argmax(logits, dim=1).cpu().numpy().astype(np.int64)

    correct = int((pred == y_demo).sum())
    total = int(len(y_demo))
    acc = correct / total if total > 0 else 0.0

    print("\n" + "=" * 70)
    print(f"[MODEL: {model_name}]  correct={correct}/{total}  acc={acc:.3f}")
    print("=" * 70)

    for i in range(total):
        gt_idx = int(y_demo[i])
        pr_idx = int(pred[i])
        gt = idx2label.get(gt_idx, str(gt_idx))
        pr = idx2label.get(pr_idx, str(pr_idx))
        mark = "‚úÖ" if gt_idx == pr_idx else "‚ùå"
        print(f"{i:02d} {mark}  GT: {gt:<25} | PRED: {pr}")

# =========================
# 3) demo Îç∞Ïù¥ÌÑ∞ Î°úÎìú
# =========================
demo = np.load(DEMO_NPZ)
X_demo = demo["X_demo"].astype(np.float32)  # (20,16,D)
y_demo = demo["y_demo"].astype(np.int64)    # (20,)

# =========================
# 4) GRU Îç∞Î™®
# =========================
gru_meta, gru_idx2label = load_meta(GRU_META)
gru_mean, gru_std = load_norm(GRU_NORM)

gru_input_dim = int(gru_meta["input_dim"])
gru_num_classes = int(gru_meta["num_classes"])
gru_cfg = gru_meta.get("model_cfg", {"hidden_dim":256, "num_layers":2, "bidirectional":True, "dropout":0.2})

assert X_demo.shape[-1] == gru_input_dim, f"‚ùå GRU input_dim mismatch: demo D={X_demo.shape[-1]} vs meta input_dim={gru_input_dim}"

gru_model = GRUClassifier(input_dim=gru_input_dim, num_classes=gru_num_classes, **gru_cfg)
load_weights(gru_model, GRU_W)
run_demo("GRU", gru_model, X_demo, y_demo, gru_idx2label, gru_mean, gru_std)

# =========================
# 5) TCN Îç∞Î™®
# =========================
tcn_meta, tcn_idx2label = load_meta(TCN_META)
tcn_mean, tcn_std = load_norm(TCN_NORM)

tcn_input_dim = int(tcn_meta["input_dim"])
tcn_num_classes = int(tcn_meta["num_classes"])
tcn_cfg = tcn_meta.get("model_cfg", {"hidden_channels":256})

assert X_demo.shape[-1] == tcn_input_dim, f"‚ùå TCN input_dim mismatch: demo D={X_demo.shape[-1]} vs meta input_dim={tcn_input_dim}"

tcn_model = TCNClassifier(input_dim=tcn_input_dim, num_classes=tcn_num_classes, **tcn_cfg)
load_weights(tcn_model, TCN_W)
run_demo("TCN", tcn_model, X_demo, y_demo, tcn_idx2label, tcn_mean, tcn_std)


Using device: cpu

[MODEL: GRU]  correct=16/20  acc=0.800
00 ‚úÖ  GT: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°            | PRED: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°
01 ‚ùå  GT: WORD0039_·Ñá·Öß·Ü´·Ñá·Öµ            | PRED: WORD0041_·Ñá·Ö©·ÑÄ·Ö•·Ü´·Ñâ·Ö©
02 ‚úÖ  GT: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°            | PRED: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°
03 ‚ùå  GT: WORD0033_·ÑÉ·Ö°·Üº·ÑÇ·Ö≠·Ñá·Öß·Üº         | PRED: WORD0042_·Ñá·ÖÆ·ÜØ·ÑÜ·Öß·Ü´·Ñå·Ö≥·Üº
04 ‚úÖ  GT: WORD1115_·ÑÄ·Ö•·Ü´·ÑÄ·Ö°·Üº           | PRED: WORD1115_·ÑÄ·Ö•·Ü´·ÑÄ·Ö°·Üº
05 ‚ùå  GT: WORD0689_·Ñê·Ö©·Üº·Ñå·Ö≥·Üº           | PRED: WORD0029_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°
06 ‚úÖ  GT: WORD0885_·Ñé·Öµ·ÑÖ·Ö≠·Ñå·Ö¶           | PRED: WORD0885_·Ñé·Öµ·ÑÖ·Ö≠·Ñå·Ö¶
07 ‚úÖ  GT: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°            | PRED: WORD0046_·Ñâ·Ö•·ÜØ·Ñâ·Ö°
08 ‚úÖ  GT: WORD1129_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°            | PRED: WORD1129_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°
09 ‚úÖ  GT: WORD0029_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°            | PRED: WORD0029_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°
10 ‚úÖ  GT: WORD0065_·Ñé·Öµ·ÑÖ·Ö≠·Ñá·Ö•·Ü∏          | PRED: WORD0065_·Ñé·Öµ·ÑÖ·Ö≠·Ñá·Ö•·Ü∏
11 ‚ù

###2D Sequence

In [16]:
import os, glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

# =========================================================
# 0) Í≤ΩÎ°ú ÏÑ§Ï†ï (ÏÇ¨Ïö©ÏûêÍ∞Ä Ï§Ä Í≤ΩÎ°ú Í∑∏ÎåÄÎ°ú)
# =========================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

GRU_W = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/2D Sequence/GRU_best.pt"
TCN_W = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/2D Sequence/TCN_best.pt"

GRU_DATA_DIR = "/content/drive/MyDrive/cv-medislr/data/samples/2D Sequence/GRU"
TCN_DATA_DIR = "/content/drive/MyDrive/cv-medislr/data/samples/2D Sequence/TCN"

assert os.path.exists(GRU_W), f"‚ùå GRU weight not found: {GRU_W}"
assert os.path.exists(TCN_W), f"‚ùå TCN weight not found: {TCN_W}"
assert os.path.isdir(GRU_DATA_DIR), f"‚ùå GRU data dir not found: {GRU_DATA_DIR}"
assert os.path.isdir(TCN_DATA_DIR), f"‚ùå TCN data dir not found: {TCN_DATA_DIR}"

# =========================================================
# 1) Î™®Îç∏ Ï†ïÏùò (ÌïôÏäµ ÏΩîÎìúÏôÄ ÎèôÏùºÌïú Íµ¨Ï°∞)
#   - Input: (B, T, 3, H, W)
#   - Frame encoder: MobileNetV3 Small -> 256D
#   - GRU / TCN + Attention -> Head
# =========================================================
class FrameEncoderMobileNetV3(nn.Module):
    def __init__(self, out_dim=256, pretrained=True):
        super().__init__()
        weights = models.MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
        backbone = models.mobilenet_v3_small(weights=weights)
        self.features = backbone.features
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        in_feat = backbone.classifier[0].in_features
        self.proj = nn.Linear(in_feat, out_dim)

    def forward(self, x):  # x: (B,3,H,W)
        f = self.features(x)
        f = self.gap(f).flatten(1)
        return self.proj(f)  # (B,out_dim)

class SeqCNN_MobileNet_GRU_Attn(nn.Module):
    def __init__(self, num_classes, frame_out_dim=256, hidden_dim=192, num_layers=2, bidirectional=True, dropout=0.2, pretrained_backbone=True):
        super().__init__()
        self.encoder = FrameEncoderMobileNetV3(out_dim=frame_out_dim, pretrained=pretrained_backbone)
        self.gru = nn.GRU(
            input_size=frame_out_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        out_dim = hidden_dim * (2 if bidirectional else 1)
        self.attn_fc = nn.Linear(out_dim, 1)
        self.head = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(0.1),
            nn.Linear(out_dim, num_classes),
        )

    def forward(self, x):  # x: (B,T,3,H,W)
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        feat = self.encoder(x).view(B, T, -1)  # (B,T,256)

        out, _ = self.gru(feat)                # (B,T,H*)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

        w = torch.softmax(self.attn_fc(out), dim=1)  # (B,T,1)
        w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0)

        feat_seq = (w * out).sum(dim=1)        # (B,H*)
        return self.head(feat_seq)             # (B,num_classes)

class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.2):
        super().__init__()
        padding = (kernel_size - 1) * dilation // 2  # Í∏∏Ïù¥ Ïú†ÏßÄ
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.dropout = nn.Dropout(dropout)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        self.final_relu = nn.ReLU(inplace=True)

    def forward(self, x):  # (B,C,T)
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.dropout(self.bn2(self.conv2(out)))
        res = x if self.downsample is None else self.downsample(x)
        return self.final_relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, input_channels, hidden_channels=256, num_layers=3, kernel_size=3, dropout=0.2):
        super().__init__()
        layers = []
        in_ch = input_channels
        for i in range(num_layers):
            layers.append(TemporalBlock(in_ch, hidden_channels, kernel_size=kernel_size, dilation=2**i, dropout=dropout))
            in_ch = hidden_channels
        self.network = nn.Sequential(*layers)
        self.out_channels = hidden_channels

    def forward(self, x):  # (B,C,T)
        return self.network(x)

class SeqCNN_MobileNet_TCN_Attn(nn.Module):
    def __init__(self, num_classes, frame_out_dim=256, tcn_hidden=256, tcn_layers=3, dropout=0.2, pretrained_backbone=True):
        super().__init__()
        self.encoder = FrameEncoderMobileNetV3(out_dim=frame_out_dim, pretrained=pretrained_backbone)
        self.tcn = TemporalConvNet(input_channels=frame_out_dim, hidden_channels=tcn_hidden, num_layers=tcn_layers, kernel_size=3, dropout=dropout)
        out_dim = self.tcn.out_channels
        self.attn_fc = nn.Linear(out_dim, 1)
        self.head = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Dropout(0.1),
            nn.Linear(out_dim, num_classes),
        )

    def forward(self, x):  # x: (B,T,3,H,W)
        B, T, C, H, W = x.shape
        x = x.view(B*T, C, H, W)
        feat = self.encoder(x).view(B, T, -1)    # (B,T,256)

        out = self.tcn(feat.permute(0,2,1))      # (B,H,T)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)

        out_seq = out.permute(0,2,1)             # (B,T,H)
        w = torch.softmax(self.attn_fc(out_seq), dim=1)  # (B,T,1)
        w = torch.nan_to_num(w, nan=0.0, posinf=0.0, neginf=0.0)

        feat_seq = (w * out_seq).sum(dim=1)      # (B,H)
        return self.head(feat_seq)               # (B,num_classes)

# =========================================================
# 2) Ïú†Ìã∏: state_dictÏóêÏÑú num_classes ÏûêÎèô Ï∂îÎ°†
# =========================================================
def infer_num_classes_from_state(state: dict) -> int:
    # Î≥¥ÌÜµ headÏùò ÎßàÏßÄÎßâ LinearÍ∞Ä head.2.weightÎ°ú Ïû°Ìûò (Sequential: LN, Dropout, Linear)
    for k in ["head.2.weight", "head.2.bias"]:
        if k in state:
            return int(state["head.2.weight"].shape[0])
    # fallback: weight Ï§ëÏóêÏÑú shape[0]Í∞Ä "ÌÅ¥ÎûòÏä§ Ïàò"Ïùº Í∞ÄÎä•ÏÑ±Ïù¥ ÌÅ∞ Ìï≠Î™© Ï∞æÍ∏∞
    cand = []
    for k, v in state.items():
        if isinstance(v, torch.Tensor) and v.ndim == 2 and v.shape[0] < 10000:
            cand.append((k, int(v.shape[0]), int(v.shape[1])))
    # Í∑∏Ï§ë Í∞ÄÏû• "ÌÅ¥ÎûòÏä§Ï≤òÎüº Î≥¥Ïù¥Îäî" ÏûëÏùÄ out_dimÏùÑ Ïö∞ÏÑ†
    cand.sort(key=lambda x: x[1])
    if not cand:
        raise ValueError("‚ùå Could not infer num_classes from state_dict.")
    return cand[0][1]

def make_idx2label(num_classes: int):
    # ÎùºÎ≤® Ïù¥Î¶Ñ Îß§Ìïë ÌååÏùºÏù¥ ÏóÜÏùÑ ÎïåÎ•º ÎåÄÎπÑÌïú Í∏∞Î≥∏ Îß§Ìïë
    return {i: f"class_{i:04d}" for i in range(num_classes)}

# =========================================================
# 3) Dataset: pt ÌååÏùº(dict)ÏóêÏÑú x,y ÏùΩÍ∏∞
#   - ptÎäî Ïù¥ÎØ∏ NormalizeÍπåÏßÄ ÎÅùÎÇú ÌÖêÏÑúÎùºÍ≥† Í∞ÄÏ†ï (Ï†ÑÏ≤òÎ¶¨ ÏΩîÎìú Í∏∞Ï§Ä)
# =========================================================
class PTSeqDataset(Dataset):
    def __init__(self, pt_dir: str):
        self.pt_paths = sorted(glob.glob(os.path.join(pt_dir, "*.pt")))
        assert len(self.pt_paths) > 0, f"‚ùå No pt files found in: {pt_dir}"

    def __len__(self):
        return len(self.pt_paths)

    def __getitem__(self, idx):
        pack = torch.load(self.pt_paths[idx], map_location="cpu")  # {"x":(T,3,H,W), "y":int, ...}
        x = pack["x"].float()      # (T,3,H,W)
        y = int(pack["y"])
        base_id = pack.get("base_id", os.path.basename(self.pt_paths[idx]))
        return x, y, base_id

def collate_fn(batch):
    xs, ys, ids = zip(*batch)
    x = torch.stack(xs, dim=0)            # (B,T,3,H,W)
    y = torch.tensor(ys, dtype=torch.long)
    return x, y, ids

# =========================================================
# 4) Îç∞Î™® Ïã§Ìñâ (Ï†ïÌôïÎèÑ + ÏÉòÌîåÎ≥Ñ Ï∂úÎ†•)
# =========================================================
@torch.no_grad()
def run_demo(model_name: str, model: nn.Module, loader: DataLoader, idx2label: dict, max_print: int = 30):
    model.eval()
    total, correct = 0, 0

    all_rows = []
    for xb, yb, ids in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)

        logits = model(xb)
        pred = logits.argmax(dim=1)

        total += int(yb.numel())
        correct += int((pred == yb).sum().item())

        pred_cpu = pred.cpu().tolist()
        y_cpu = yb.cpu().tolist()

        for i in range(len(y_cpu)):
            all_rows.append((ids[i], y_cpu[i], pred_cpu[i]))

    acc = correct / total if total > 0 else 0.0
    print("\n" + "=" * 80)
    print(f"[MODEL: {model_name}]  correct={correct}/{total}  acc={acc:.4f}")
    print("=" * 80)

    n_show = min(len(all_rows), max_print)
    for i in range(n_show):
        base_id, gt_i, pr_i = all_rows[i]
        gt = idx2label.get(int(gt_i), str(gt_i))
        pr = idx2label.get(int(pr_i), str(pr_i))
        mark = "‚úÖ" if int(gt_i) == int(pr_i) else "‚ùå"
        print(f"{i:02d} {mark}  ID: {base_id} | GT: {gt:<12} | PRED: {pr}")

    if len(all_rows) > n_show:
        print(f"... (printed {n_show}/{len(all_rows)})")

# =========================================================
# 5) GRU Îç∞Î™®
# =========================================================
gru_state = torch.load(GRU_W, map_location="cpu")  # ÌïôÏäµ ÏΩîÎìúÏóêÏÑú state_dictÎßå Ï†ÄÏû•ÌñàÏúºÎØÄÎ°ú dictÍ∞Ä Î∞îÎ°ú state_dict
gru_num_classes = infer_num_classes_from_state(gru_state)
gru_idx2label = make_idx2label(gru_num_classes)

gru_model = SeqCNN_MobileNet_GRU_Attn(
    num_classes=gru_num_classes,
    frame_out_dim=256,
    hidden_dim=192,
    num_layers=2,
    bidirectional=True,
    dropout=0.2,
    pretrained_backbone=True,   # Îç∞Î™®ÏóêÏÑúÎäî backbone weightÍ∞Ä Ìè¨Ìï®Îèº ÏûàÏùÑ Í∞ÄÎä•ÏÑ±Ïù¥ ÎÜíÏùå (state_dictÎ°ú ÎçÆÏûÑ)
).to(DEVICE)

gru_model.load_state_dict(gru_state, strict=True)
gru_model.eval()

gru_ds = PTSeqDataset(GRU_DATA_DIR)
gru_dl = DataLoader(gru_ds, batch_size=8, shuffle=False, num_workers=0, pin_memory=(DEVICE.type=="cuda"), collate_fn=collate_fn)
run_demo("2D-Sequence GRU", gru_model, gru_dl, gru_idx2label, max_print=30)

# =========================================================
# 6) TCN Îç∞Î™®
# =========================================================
tcn_state = torch.load(TCN_W, map_location="cpu")
tcn_num_classes = infer_num_classes_from_state(tcn_state)
tcn_idx2label = make_idx2label(tcn_num_classes)

tcn_model = SeqCNN_MobileNet_TCN_Attn(
    num_classes=tcn_num_classes,
    frame_out_dim=256,
    tcn_hidden=256,
    tcn_layers=3,
    dropout=0.2,
    pretrained_backbone=True,
).to(DEVICE)

tcn_model.load_state_dict(tcn_state, strict=True)
tcn_model.eval()

tcn_ds = PTSeqDataset(TCN_DATA_DIR)
tcn_dl = DataLoader(tcn_ds, batch_size=8, shuffle=False, num_workers=0, pin_memory=(DEVICE.type=="cuda"), collate_fn=collate_fn)
run_demo("2D-Sequence TCN", tcn_model, tcn_dl, tcn_idx2label, max_print=30)


Using device: cpu
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 9.83M/9.83M [00:00<00:00, 82.9MB/s]



[MODEL: 2D-Sequence GRU]  correct=19/20  acc=0.9500
00 ‚úÖ  ID: WORD0042_·Ñá·ÖÆ·ÜØ·ÑÜ·Öß·Ü´·Ñå·Ö≥·Üº_NIA_SL_WORD0042_REAL04_R | GT: class_0007   | PRED: class_0007
01 ‚úÖ  ID: WORD0689_·Ñê·Ö©·Üº·Ñå·Ö≥·Üº_NIA_SL_WORD0689_REAL06_F | GT: class_0016   | PRED: class_0016
02 ‚úÖ  ID: WORD1115_·ÑÄ·Ö•·Ü´·ÑÄ·Ö°·Üº_NIA_SL_WORD1115_REAL15_F | GT: class_0018   | PRED: class_0018
03 ‚úÖ  ID: WORD1496_·Ñá·Öß·Üº·Ñã·ÖØ·Ü´_NIA_SL_WORD1496_REAL08_U | GT: class_0021   | PRED: class_0021
04 ‚úÖ  ID: WORD0065_·Ñé·Öµ·ÑÖ·Ö≠·Ñá·Ö•·Ü∏_NIA_SL_WORD0065_REAL02_R | GT: class_0011   | PRED: class_0011
05 ‚úÖ  ID: WORD0041_·Ñá·Ö©·ÑÄ·Ö•·Ü´·Ñâ·Ö©_NIA_SL_WORD0041_REAL05_U | GT: class_0006   | PRED: class_0006
06 ‚úÖ  ID: WORD0036_·ÑÜ·Öß·Ü´·Ñã·Öß·Ü®_NIA_SL_WORD0036_REAL01_D | GT: class_0002   | PRED: class_0002
07 ‚úÖ  ID: WORD0042_·Ñá·ÖÆ·ÜØ·ÑÜ·Öß·Ü´·Ñå·Ö≥·Üº_NIA_SL_WORD0042_REAL02_L | GT: class_0007   | PRED: class_0007
08 ‚ùå  ID: WORD0033_·ÑÉ·Ö°·Üº·ÑÇ·Ö≠·Ñá·Öß·Üº_NIA_SL_WORD0033_REAL04_L | GT: class_0001   | PRED: c

###2D Only

In [19]:
import os, glob
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

# =========================
# 0) Í≤ΩÎ°ú ÏÑ§Ï†ï
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

DATA_DIR = "/content/drive/MyDrive/cv-medislr/data/samples/2d_only/skeleton_tsn"
CKPT_PT  = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/2d_only/mobilenet_tsn_hands_best.pt"

SAMPLE_META = os.path.join(DATA_DIR, "tsn_sample_meta.csv")  # ÏûàÏúºÎ©¥ Ïù¥Í±∏ Ïö∞ÏÑ† ÏÇ¨Ïö©

assert os.path.isdir(DATA_DIR), f"‚ùå DATA_DIR not found: {DATA_DIR}"
assert os.path.exists(CKPT_PT),  f"‚ùå CKPT not found: {CKPT_PT}"

# =========================
# 1) Î™®Îç∏ Ï†ïÏùò (ÌïôÏäµ ÏΩîÎìúÏôÄ ÎèôÏùº)
# =========================
class MobileNetTSN(nn.Module):
    def __init__(self, num_classes=22, pretrained=True):
        super().__init__()
        weights = MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None
        backbone = mobilenet_v2(weights=weights)

        self.features = backbone.features
        self.last_channel = backbone.last_channel  # 1280
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(self.last_channel, num_classes)

        # ImageNet Ï†ïÍ∑úÌôîÏö© mean/std (0~1 ÏûÖÎ†• Í∏∞Ï§Ä)
        self.register_buffer(
            "img_mean",
            torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1)
        )
        self.register_buffer(
            "img_std",
            torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1)
        )

    def forward(self, x):
        """
        x: (B, T, C, H, W)  Ïó¨Í∏∞ÏÑú Î≥¥ÌÜµ C=1 (grayscale)
        """
        B, T, C, H, W = x.shape

        # 1Ï±ÑÎÑê ‚Üí 3Ï±ÑÎÑê replicate
        if C == 1:
            x = x.repeat(1, 1, 3, 1, 1)   # (B,T,3,H,W)

        # ImageNet Ï†ïÍ∑úÌôî
        x = (x - self.img_mean) / self.img_std

        # (B*T, 3, H, W)
        x = x.view(B * T, 3, H, W)

        feat = self.features(x)          # (B*T, 1280, h, w)
        feat = self.pool(feat)           # (B*T, 1280, 1, 1)
        feat = feat.view(B, T, self.last_channel)  # (B, T, 1280)

        # TSN: ÏãúÍ∞Ñ ÌèâÍ∑†
        feat = feat.mean(dim=1)          # (B, 1280)

        feat = self.dropout(feat)
        logits = self.fc(feat)           # (B, num_classes)
        return logits

# =========================
# 2) Dataset: pt ÌÖêÏÑú Î°úÎìú
#    - sample_meta.csv ÏûàÏúºÎ©¥ Í∑∏Í±∏ Ïì∞Í≥†,
#    - ÏóÜÏúºÎ©¥ Ìè¥ÎçîÏùò pt ÌååÏùºÎì§ÏùÑ ÏßÅÏ†ë ÏùΩÏñ¥ÏÑú yÎäî -1 Ï≤òÎ¶¨
# =========================
class SkeletonTSNDemoDataset(Dataset):
    def __init__(self, data_dir: str, meta_csv: str | None = None):
        self.items = []

        if meta_csv is not None and os.path.exists(meta_csv):
            df = pd.read_csv(meta_csv)
            # label_idx Ïª¨ÎüºÏù¥ ÏûàÏùÑ ÏàòÎèÑ ÏûàÍ≥†, ÏóÜÏùÑ ÏàòÎèÑ ÏûàÏñ¥ÏÑú ÏïàÏ†Ñ Ï≤òÎ¶¨
            label_col = "label_idx" if "label_idx" in df.columns else ("label" if "label" in df.columns else None)

            for _, row in df.iterrows():
                p = row["tensor_path"] if "tensor_path" in df.columns else row["tensor_path".strip()]
                y = int(row[label_col]) if label_col is not None else -1
                base_id = row["base_id"] if "base_id" in df.columns else os.path.basename(p)
                self.items.append((p, y, base_id))
        else:
            pts = sorted(glob.glob(os.path.join(data_dir, "*.pt")))
            assert len(pts) > 0, f"‚ùå no .pt files in {data_dir}"
            for p in pts:
                self.items.append((p, -1, os.path.basename(p)))

        assert len(self.items) > 0, "‚ùå empty dataset"

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, y, base_id = self.items[idx]
        x = torch.load(path, map_location="cpu")   # Ïó¨Í∏∞ÏÑúÎäî seq_tensorÎßå Ï†ÄÏû•ÌñàÏóàÏùå: (T,C,H,W)
        if isinstance(x, dict) and "x" in x:
            # ÌòπÏãú dictÎ°ú Ï†ÄÏû•Îêú Í≤ΩÏö∞ÍπåÏßÄ ÎåÄÎπÑ
            x = x["x"]
            if y == -1 and "y" in x:
                y = int(x["y"])
        x = x.float()  # (T,C,H,W)
        return x, int(y), base_id

def collate_fn(batch):
    xs, ys, ids = zip(*batch)
    xb = torch.stack(xs, dim=0)  # (B,T,C,H,W)
    yb = torch.tensor(ys, dtype=torch.long)
    return xb, yb, ids

# =========================
# 3) Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ Î°úÎìú
# =========================
ckpt = torch.load(CKPT_PT, map_location="cpu")
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
num_classes = int(ckpt.get("num_classes", 22)) if isinstance(ckpt, dict) else 22

model = MobileNetTSN(num_classes=num_classes, pretrained=False).to(DEVICE)
model.load_state_dict(state_dict, strict=True)
model.eval()

idx2label = {i: f"class_{i:04d}" for i in range(num_classes)}  # ÎùºÎ≤®Î™Ö ÌååÏùº ÏóÜÏúºÎãà Í∏∞Î≥∏

# =========================
# 4) Îç∞Î™® Ïã§Ìñâ
# =========================
ds = SkeletonTSNDemoDataset(DATA_DIR, SAMPLE_META if os.path.exists(SAMPLE_META) else None)
dl = DataLoader(ds, batch_size=4, shuffle=False, num_workers=0,
                pin_memory=(DEVICE.type=="cuda"), collate_fn=collate_fn)

@torch.no_grad()
def run_demo():
    total, correct = 0, 0
    rows = []

    for xb, yb, ids in dl:
        xb = xb.to(DEVICE, non_blocking=True)   # (B,T,C,H,W)
        yb = yb.to(DEVICE, non_blocking=True)

        logits = model(xb)
        pred = logits.argmax(dim=1)

        # yÍ∞Ä -1Ïù¥Î©¥(Ï†ïÎãµ ÏóÜÏùå) accÎäî Í≥ÑÏÇ∞ Î∂àÍ∞Ä ‚Üí Ï∂úÎ†•Îßå
        if (yb >= 0).all():
            total += int(yb.numel())
            correct += int((pred == yb).sum().item())

        pred_cpu = pred.cpu().tolist()
        y_cpu = yb.cpu().tolist()

        for i in range(len(ids)):
            rows.append((ids[i], y_cpu[i], pred_cpu[i]))

    print("\n" + "="*80)
    if total > 0:
        print(f"[MODEL: MobileNetTSN] correct={correct}/{total}  acc={correct/total:.4f}")
    else:
        print("[MODEL: MobileNetTSN] (no GT labels found) showing predictions only")
    print("="*80)

    for i, (base_id, gt, pr) in enumerate(rows):
        gt_s = idx2label.get(gt, str(gt)) if gt >= 0 else "N/A"
        pr_s = idx2label.get(pr, str(pr))
        mark = "‚úÖ" if (gt >= 0 and gt == pr) else ("‚ùå" if gt >= 0 else "‚Ä¢")
        print(f"{i:02d} {mark}  ID: {base_id} | GT: {gt_s:<12} | PRED: {pr_s}")

run_demo()

Using device: cpu

[MODEL: MobileNetTSN] correct=15/20  acc=0.7500
00 ‚úÖ  ID: seq_00044.pt | GT: class_0008   | PRED: class_0008
01 ‚úÖ  ID: seq_00568.pt | GT: class_0004   | PRED: class_0004
02 ‚úÖ  ID: seq_00056.pt | GT: class_0011   | PRED: class_0011
03 ‚ùå  ID: seq_00636.pt | GT: class_0017   | PRED: class_0010
04 ‚ùå  ID: seq_00486.pt | GT: class_0009   | PRED: class_0001
05 ‚úÖ  ID: seq_00096.pt | GT: class_0019   | PRED: class_0019
06 ‚úÖ  ID: seq_00761.pt | GT: class_0020   | PRED: class_0020
07 ‚úÖ  ID: seq_00051.pt | GT: class_0010   | PRED: class_0010
08 ‚úÖ  ID: seq_00107.pt | GT: class_0000   | PRED: class_0000
09 ‚ùå  ID: seq_00666.pt | GT: class_0001   | PRED: class_0004
10 ‚úÖ  ID: seq_00631.pt | GT: class_0016   | PRED: class_0016
11 ‚ùå  ID: seq_00270.pt | GT: class_0010   | PRED: class_0001
12 ‚úÖ  ID: seq_00545.pt | GT: class_0021   | PRED: class_0021
13 ‚úÖ  ID: seq_00849.pt | GT: class_0016   | PRED: class_0016
14 ‚úÖ  ID: seq_01014.pt | GT: class_0005   | PRED:

In [30]:
import os, glob
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

# =========================
# 0) Í≤ΩÎ°ú
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

DEMO_DIR = "/content/drive/MyDrive/cv-medislr/data/samples/2d_only/skeleton_tiling/demo_test20_TILING_seed42"
MODEL_PT = "/content/drive/MyDrive/cv-medislr/data/preprocessed/model_weights/2d_only/2d_tiling_resnet18_best.pt"

NUM_CLASSES = 22   # üî• ÌïôÏäµ ÎãπÏãú ÌÅ¥ÎûòÏä§ Ïàò (Ï†àÎåÄ demoÏóêÏÑú Ï∂îÏ†ï X)

# =========================
# 1) demo pt Î™©Î°ù
# =========================
pt_files = sorted(glob.glob(os.path.join(DEMO_DIR, "*.pt")))
assert len(pt_files) > 0, f"‚ùå demo pt not found in: {DEMO_DIR}"

print("sample file:", pt_files[0])

# =========================
# 2) Î™®Îç∏ Ï†ïÏùò (ÌïôÏäµÍ≥º ÎèôÏùº)
# =========================
class TilingResNet18(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        base.fc = nn.Linear(base.fc.in_features, num_classes)
        self.backbone = base

    def forward(self, x):
        return self.backbone(x)

model = TilingResNet18(num_classes=NUM_CLASSES).to(DEVICE)

# =========================
# 3) checkpoint Î°úÎìú
# =========================
state_dict = torch.load(MODEL_PT, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.eval()

# =========================
# 4) DEMO Ïã§Ìñâ
# =========================
print("\n" + "="*78)
print("[DEMO] TilingResNet18 on skeleton_tiling demo pts")
print("="*78)

correct = 0
with torch.inference_mode():
    for i, p in enumerate(pt_files):
        d = torch.load(p, map_location="cpu")

        x = d["x"]                 # (3,224,224)
        y = int(d["y"])
        meta = d.get("meta", {})

        x = x.unsqueeze(0).float().to(DEVICE)   # (1,3,224,224)

        logits = model(x)
        pred = int(torch.argmax(logits, dim=1))

        ok = (pred == y)
        correct += int(ok)

        sid = meta.get("seq_id", os.path.basename(p))
        print(f"{i:02d} {'‚úÖ' if ok else '‚ùå'}  GT={y:02d} | PRED={pred:02d} | {sid}")

acc = correct / len(pt_files)
print("\n" + "-"*70)
print(f"RESULT: correct={correct}/{len(pt_files)}  acc={acc:.4f}")
print("-"*70)


Using device: cpu
sample file: /content/drive/MyDrive/cv-medislr/data/samples/2d_only/skeleton_tiling/demo_test20_TILING_seed42/sample00.pt

[DEMO] TilingResNet18 on skeleton_tiling demo pts
00 ‚ùå  GT=00 | PRED=12 | 7/WORD0029_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°/F
01 ‚ùå  GT=00 | PRED=12 | 7/WORD0029_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°/L
02 ‚úÖ  GT=19 | PRED=19 | 10/WORD1129_·ÑÄ·Ö•·Ü∑·Ñâ·Ö°/R
03 ‚úÖ  GT=05 | PRED=05 | 9/WORD0040_·Ñá·Öß·Üº·ÑÜ·Öß·Üº/U
04 ‚úÖ  GT=04 | PRED=04 | 8/WORD0039_·Ñá·Öß·Ü´·Ñá·Öµ/D
05 ‚úÖ  GT=05 | PRED=05 | 7/WORD0040_·Ñá·Öß·Üº·ÑÜ·Öß·Üº/L
06 ‚úÖ  GT=05 | PRED=05 | 10/WORD0040_·Ñá·Öß·Üº·ÑÜ·Öß·Üº/L
07 ‚úÖ  GT=17 | PRED=17 | 1/WORD0885_·Ñé·Öµ·ÑÖ·Ö≠·Ñå·Ö¶/R
08 ‚úÖ  GT=05 | PRED=05 | 6/WORD0040_·Ñá·Öß·Üº·ÑÜ·Öß·Üº/F
09 ‚úÖ  GT=05 | PRED=05 | 5/WORD0040_·Ñá·Öß·Üº·ÑÜ·Öß·Üº/F
10 ‚úÖ  GT=01 | PRED=01 | 9/WORD0033_·ÑÉ·Ö°·Üº·ÑÇ·Ö≠·Ñá·Öß·Üº/R
11 ‚úÖ  GT=03 | PRED=03 | 9/WORD0037_·ÑÄ·Ö°·Ü∑·ÑÄ·Öµ/F
12 ‚úÖ  GT=20 | PRED=20 | 4/WORD1158_·Ñë·Öµ·ÑÄ·Ö©·Ü´·Ñí·Ö°·ÑÉ·Ö°/R
13 ‚úÖ  GT=17 | PRED=17 | 7/WORD0885_·Ñé·Öµ·ÑÖ·Ö≠·Ñå·