In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
train_pmt_dit.py

- PMTSignalsH5 데이터셋 로딩
- PMTDit + GaussianDiffusion 학습
- NaN/Inf 디버깅 도구 포함:
  * 입력(raw), q_sample(x_t), 모델 출력(eps_hat), loss, grad 순서로 점검
  * step 번호와 에폭 진행 퍼센트 출력
"""

import os
import torch
import torch.nn as nn
import h5py 
# === 프로젝트 모듈 import ===
from dataloader import make_dataloader
from models import PMTDit, GaussianDiffusion, DiffusionConfig
from utils import print_h5_structure

device = "cuda" if torch.cuda.is_available() else "cpu"
h5_path = "/home/work/GENESIS/GENESIS-data/22644_0921.h5"
batch_size = 8
num_epochs = 5
lr = 2e-4

In [2]:
print_h5_structure(h5_path)

HDF5 file: /home/work/GENESIS/GENESIS-data/22644_0921.h5
[Dataset] info - shape: (178056, 9), dtype: float32
[Dataset] input - shape: (178056, 2, 5160), dtype: float32
[Dataset] label - shape: (178056, 6), dtype: float32
[Dataset] xpmt - shape: (5160,), dtype: float32
[Dataset] ypmt - shape: (5160,), dtype: float32
[Dataset] zpmt - shape: (5160,), dtype: float32


In [17]:
import numpy
with h5py.File(h5_path, "r") as f:
    # geometry (optional)
    npe = f["input"][0, 0,:]
    time = f["input"][0,0,:]
    xpmt = f["xpmt"][:] if "xpmt" in f else None
    print(xpmt)
    ypmt = f["ypmt"][:] if "ypmt" in f else None
    print(ypmt)
    zpmt = f["zpmt"][:] if "zpmt" in f else None
    #print(zpmt)

[-256.14 -256.14 -256.14 ...  -10.97  -10.97  -10.97]
[-521.08 -521.08 -521.08 ...    6.72    6.72    6.72]


In [5]:
loader = make_dataloader(
    h5_path=h5_path,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    replace_time_inf_with=0.0,   # time의 inf를 0.0으로 치환 (원인 파악 후 조정 가능)
    channel_first=True,
)

In [6]:
def assert_finite(name: str, x: torch.Tensor):
    """NaN/Inf 있으면 통계 찍고 에러 발생"""
    if not torch.isfinite(x).all():
        tensor_stats(name, x)
        raise RuntimeError(f"Non-finite values detected in {name}")
    else:
        print(f"[{name}] ✅ all finite")

In [7]:
# ------------------------
# DataLoader
# ------------------------
L = 5160
model = PMTDit(
    seq_len=L,
    hidden=64,
    depth=8,
    heads=8,
    dropout=0.1,
    fusion="FiLM",   # or "SUM"
    label_dim=6,
    t_embed_dim=128,
).to(device)

diffusion = GaussianDiffusion(
    model, DiffusionConfig(timesteps=1000, objective="eps")
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

# 디버깅 옵션(필요할 때 주석 해제)
# torch.autograd.set_detect_anomaly(True)
# register_nan_hooks(model)

# ------------------------
# 학습 루프 (NaN 디버깅 포함)
# ------------------------
steps_per_epoch = len(loader)
print(steps_per_epoch)
global_step = 0

for epoch in range(num_epochs):
    for step_in_epoch, (x_sig, geom, label, idx) in enumerate(loader, start=1):
        # geom shape 맞추기
        if geom.ndim == 2:  # (3,L)
            geom = geom.unsqueeze(0).expand(x_sig.size(0), -1, -1)

        x_sig = x_sig.to(device)    # (B,2,L)
        geom  = geom.to(device)     # (B,3,L)
        print(geom[0][1])
        print(geom[1][1])
        label = label.to(device)    # (B,6)

        # 1) 입력 원천 점검
        try:
            print(f"step in epoch {epoch} : {step_in_epoch}")
            assert_finite("x_sig(raw)", x_sig)
            assert_finite("geom(raw)", geom)
            assert_finite("label(raw)", label)
        except RuntimeError:
            print(f"[BAD INPUT] epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch} idx={idx.tolist()}")
            raise

        # 2) q_sample / 모델 출력 점검 (forward만)
        with torch.no_grad():
            B = x_sig.size(0)
            print(f"batch size of {B}")
            print(f"diffusion.cfg.timesteps {diffusion.cfg.timesteps}")
            t = torch.randint(0, diffusion.cfg.timesteps, (B,), device=device, dtype=torch.long)
            print(f"time {t}")
            x_sig_t = diffusion.q_sample(x_sig, t)
            assert_finite("x_sig_t", x_sig_t)
            eps_hat = model(x_sig_t, geom, t, label)
            assert_finite("eps_hat(fwd-only)", eps_hat)
            print("__________________________________")

        # 3) 실제 loss 계산/역전파 (여기서 NaN이면 원인 더 출력)
        loss = diffusion.loss(x_sig, geom, label)
        print(loss)
        if not torch.isfinite(loss):
            print("\n[WARN] Non-finite loss detected!")
            print(f"  epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch}, idx={idx.tolist()}")
            tensor_stats("loss", loss)
            tensor_stats("x_sig(raw)", x_sig)
            tensor_stats("geom(raw)", geom)
            tensor_stats("label(raw)", label)
            # 필요시 문제 배치 저장
            # torch.save({"x_sig": x_sig.cpu(), "geom": geom.cpu(), "label": label.cpu()}, "bad_batch.pt")
            raise RuntimeError("Non-finite loss")

        # 4) 그래디언트 점검
        bad_grad = None
        for n, p in model.named_parameters():
            if p.grad is not None and not torch.isfinite(p.grad).all():
                bad_grad = n
                tensor_stats(f"grad:{n}", p.grad)
                break
        if bad_grad is not None:
            print(f"[BAD GRAD] epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch}, param={bad_grad}")
            raise RuntimeError("Non-finite gradient")

        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        global_step += 1
        # 진행률(퍼센트) 계산
        pct = 100.0 * step_in_epoch / steps_per_epoch
        if global_step % 50 == 0:
            print(f"[epoch {epoch+1}/{num_epochs}] "
                    f"step {step_in_epoch}/{steps_per_epoch} ({pct:5.1f}%) "
                    f"| loss = {loss.item():.6f}")
        break
    break

2783
tensor([-521.0800, -521.0800, -521.0800,  ...,    6.7200,    6.7200,
           6.7200], device='cuda:0')
tensor([-521.0800, -521.0800, -521.0800,  ...,    6.7200,    6.7200,
           6.7200], device='cuda:0')
step in epoch 0 : 1
[x_sig(raw)] ✅ all finite
[geom(raw)] ✅ all finite
[label(raw)] ✅ all finite
batch size of 64
diffusion.cfg.timesteps 1000
time tensor([721, 898, 117, 719, 359, 117, 857, 295, 135, 582, 399, 263, 505,  10,
        936, 875, 733, 292,  77,  25, 769,  91, 585, 844, 424, 404, 140, 942,
        200, 432, 847, 881, 698, 223, 741, 573, 486, 203, 584, 629, 658, 442,
        843, 381,  25,  74, 573, 317, 883, 610, 555,   5, 369, 533, 182, 346,
        813, 687, 331, 269, 720, 245, 224, 148], device='cuda:0')
[x_sig_t] ✅ all finite
[eps_hat(fwd-only)] ✅ all finite
__________________________________
tensor(1.4243, device='cuda:0', grad_fn=<MseLossBackward0>)


In [9]:
# ------------------------
# 학습 루프 (NaN 디버깅 포함)
# ------------------------
steps_per_epoch = len(loader)
print(steps_per_epoch)
global_step = 0

model.train()  # 학습 모드

for epoch in range(num_epochs):
    for step_in_epoch, (x_sig, geom, label, idx) in enumerate(loader, start=1):
        # geom shape 맞추기
        if geom.ndim == 2:  # (3,L)
            geom = geom.unsqueeze(0).expand(x_sig.size(0), -1, -1)

        x_sig = x_sig.to(device)    # (B,2,L)
        geom  = geom.to(device)     # (B,3,L)
        print(geom[0][1])
        print(geom[1][1])
        label = label.to(device)    # (B,6)

        # 1) 입력 원천 점검
        try:
            print(f"step in epoch {epoch} : {step_in_epoch}")
            assert_finite("x_sig(raw)", x_sig)
            assert_finite("geom(raw)", geom)
            assert_finite("label(raw)", label)
        except RuntimeError:
            print(f"[BAD INPUT] epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch} idx={idx.tolist()}")
            raise

        # 2) q_sample / 모델 출력 점검 (forward만)
        with torch.no_grad():
            B = x_sig.size(0)
            print(f"batch size of {B}")
            print(f"diffusion.cfg.timesteps {diffusion.cfg.timesteps}")
            t = torch.randint(0, diffusion.cfg.timesteps, (B,), device=device, dtype=torch.long)
            print(f"time {t}")
            x_sig_t = diffusion.q_sample(x_sig, t)
            assert_finite("x_sig_t", x_sig_t)
            eps_hat = model(x_sig_t, geom, t, label)
            assert_finite("eps_hat(fwd-only)", eps_hat)

        # 3) 실제 loss 계산/역전파 (여기서 NaN이면 원인 더 출력)
        loss = diffusion.loss(x_sig, geom, label)
        print(loss)
        if not torch.isfinite(loss):
            print("\n[WARN] Non-finite loss detected!")
            print(f"  epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch}, idx={idx.tolist()}")
            tensor_stats("loss", loss)
            tensor_stats("x_sig(raw)", x_sig)
            tensor_stats("geom(raw)", geom)
            tensor_stats("label(raw)", label)
            raise RuntimeError("Non-finite loss")

        # ====== 파라미터 업데이트 필수 단계 추가 ======
        optimizer.zero_grad(set_to_none=True)   # 누적된 grad 초기화
        # (옵션) 업데이트 전 파라미터 노름 합 체크
        # with torch.no_grad():
        #     w_before = sum(p.norm().item() for p in model.parameters() if p.requires_grad)

        loss.backward()                         # 그래디언트 계산

        # 4) 그래디언트 점검 (backward 이후)
        bad_grad = None
        for n, p in model.named_parameters():
            if p.grad is not None and not torch.isfinite(p.grad).all():
                bad_grad = n
                tensor_stats(f"grad:{n}", p.grad)
                break
        if bad_grad is not None:
            print(f"[BAD GRAD] epoch={epoch+1}, step={step_in_epoch}/{steps_per_epoch}, param={bad_grad}")
            raise RuntimeError("Non-finite gradient")

        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()                        # ← 실제 파라미터 업데이트

        # (옵션) 업데이트 후 파라미터 노름 합 비교
        # with torch.no_grad():
        #     w_after = sum(p.norm().item() for p in model.parameters() if p.requires_grad)
        # print(f"param norm sum: before={w_before:.6f} -> after={w_after:.6f}")

        global_step += 1
        # 진행률(퍼센트) 계산
        pct = 100.0 * step_in_epoch / steps_per_epoch
        if global_step % 50 == 0:
            print(f"[epoch {epoch+1}/{num_epochs}] "
                  f"step {step_in_epoch}/{steps_per_epoch} ({pct:5.1f}%) "
                  f"| loss = {loss.item():.6f}")

        break  # ← 디버깅용: 첫 배치만 돌고 종료. 실제 학습 시 제거!     # ← 디버깅용: 첫 에폭만. 실제 학습 시 제거!


2783
tensor([-521.0800, -521.0800, -521.0800,  ...,    6.7200,    6.7200,
           6.7200], device='cuda:0')
tensor([-521.0800, -521.0800, -521.0800,  ...,    6.7200,    6.7200,
           6.7200], device='cuda:0')
step in epoch 0 : 1
[x_sig(raw)] ✅ all finite
[geom(raw)] ✅ all finite
[label(raw)] ✅ all finite
batch size of 64
diffusion.cfg.timesteps 1000
time tensor([611, 167, 402, 375, 388, 804, 795, 667, 819, 798, 446, 517, 972, 643,
        306, 952, 148, 922, 887, 707, 479, 206, 259, 735, 861, 272, 357, 335,
        843, 211, 711, 943, 573, 867, 461, 819, 948, 610, 995,  77, 552, 536,
        959, 504, 812,  42, 761, 529, 482, 243,  39, 648,  79, 364, 535, 526,
        248, 205, 786, 888, 901, 853, 369, 940], device='cuda:0')
[x_sig_t] ✅ all finite
[eps_hat(fwd-only)] ✅ all finite
tensor(2.4717, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor([-521.0800, -521.0800, -521.0800,  ...,    6.7200,    6.7200,
           6.7200], device='cuda:0')
tensor([-521.0800, -521.0800, -521.