# TinyCenterSpeed: Training with DataLoader tuning + W&B logging
이 노트북은 최적화된 `CenterSpeedDataset`을 사용해 **학습/검증 루프**를 구성하고,  
**Weights & Biases (wandb)**로 `train_loss`와 `val_loss`를 시각화합니다.

> ⚠️ 주의: 인터넷 환경/로그인 이슈가 있으면 자동으로 오프라인 모드로 전환됩니다.


In [1]:
import os, sys, math, random, datetime, json, gc, time
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

# 프로젝트 경로(필요 시 수정)
current_dir = os.path.dirname(os.path.abspath(''))
two_up_dir = os.path.dirname(os.path.dirname(current_dir))
if two_up_dir not in sys.path:
    sys.path.append(two_up_dir)

# 모델/데이터셋/로스 임포트 (경로는 사용 환경에 맞게 구성되어 있다고 가정)
from TinyCenterSpeed.src.models.CenterSpeed import CenterSpeedDense
from TinyCenterSpeed.dataset.CenterSpeed_dataset import CenterSpeedDataset, RandomRotation, RandomFlip
from TinyCenterSpeed.src.models.losses import *  # (필요 시 내부 함수 사용)


In [2]:
# --- W&B 설정 (온라인/오프라인 자동 처리) ---
use_wandb = True  # 시각화 사용 여부
project_name = "TinyCenterSpeed_redbull"
run_name = "train_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

try:
    import wandb
    if use_wandb:
        # 환경에 따라 자동 로그인/오프라인 모드 전환
        if os.environ.get("WANDB_MODE","").lower() == "offline":
            wandb.init(project=project_name, name=run_name, mode="offline")
        else:
            try:
                wandb.init(project=project_name, name=run_name)  # 로그인 되어 있으면 정상 시작
            except Exception:
                # 로그인 안 되어 있거나 인터넷 X -> 오프라인
                wandb.init(project=project_name, name=run_name, mode="offline")
except Exception as e:
    print(f"[wandb] 사용 불가: {type(e).__name__}: {e}")
    use_wandb = False
    wandb = None


[34m[1mwandb[0m: Currently logged in as: [33mwhdaudpark[0m ([33mwhdaudpark-dongguk-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# --- 재현성/디바이스 ---
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# 입력 크기 고정이면 cudnn benchmark 권장
torch.backends.cudnn.benchmark = True


Device: cuda


In [4]:
# --- 하이퍼파라미터/경로 ---
dataset_path = "/home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT"  # <-- 필요 시 수정
image_size   = 128               # CenterSpeed_dataset 기본값과 일치
pixelsize    = 0.1               # 0.1 m/pixel
sigma_px     = 2.0               # 40cm @ 0.1m/px -> 반지름 2px -> σ≈2.0

epochs       = 100               # 필요 시 변경
batch_size   = 32
learning_rate= 5e-4
train_ratio  = 0.8

# DataLoader 최적화
NUM_WORKERS  = min(os.cpu_count() or 4, 8)
PIN_MEMORY   = torch.cuda.is_available()
PERSIST      = NUM_WORKERS > 0

print({
    "dataset_path": dataset_path,
    "image_size": image_size,
    "pixelsize": pixelsize,
    "sigma_px": sigma_px,
    "epochs": epochs,
    "batch_size": batch_size,
    "lr": learning_rate,
    "workers": NUM_WORKERS
})


{'dataset_path': '/home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT', 'image_size': 128, 'pixelsize': 0.1, 'sigma_px': 2.0, 'epochs': 100, 'batch_size': 32, 'lr': 0.0005, 'workers': 8}


In [5]:
# --- 변환/데이터셋 구성 ---
from torchvision import transforms as T

transform = T.Compose([
    RandomRotation(45, image_size=image_size),
    RandomFlip(0.5),
])

dataset = CenterSpeedDataset(dataset_path=dataset_path, transform=None, dense=True)
# 픽셀/이미지/시그마 설정(필요 시)
dataset.change_image_size(image_size)
dataset.change_pixel_size(pixelsize)
dataset.sx = sigma_px
dataset.sy = sigma_px

print("총 샘플:", len(dataset))

# Train/Val split
gen = torch.Generator().manual_seed(42)
train_size = int(len(dataset) * train_ratio)
val_size   = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size], generator=gen)
print("train:", len(train_set), "val:", len(val_set))


Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs1_0812_no_intensity.csv
Entries     :  11217
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs1_0812_no_intensity_more.csv
Entries     :  4190
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs2_0812_no_intensity.csv
Entries     :  11970
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs2_0812_no_intensity_more.csv
Entries     :  3951
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs3_0812_no_intensity.csv
Entries     :  11140
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/TrainGT_redbull_obs3_0812_no_intensity_more.csv
Entries     :  3931
Reading file:  /home/harry/sim_ws/src/f1tenth_gym_ros/Train_GT/redbull_testobs1.csv
Entries     :  7330
Total rows :  53729
File index :  [(0, 11217), (11217, 15407), (15407, 27377), (27377, 31328), (31328, 42468), (42468, 463

In [6]:
# --- DataLoader ---
def worker_init_fn(_):
    try:
        import torch, os
        torch.set_num_threads(1)
        os.environ.setdefault("OMP_NUM_THREADS", "1")
        os.environ.setdefault("MKL_NUM_THREADS", "1")
    except Exception:
        pass

train_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSIST,
    prefetch_factor=4,
    drop_last=True,
    worker_init_fn=worker_init_fn,
)

val_loader = DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSIST,
    prefetch_factor=2,
    drop_last=False,
    worker_init_fn=worker_init_fn,
)

len(train_loader), len(val_loader)


(1343, 336)

In [7]:
# --- 모델/옵티마이저/스케줄러/손실 ---
model = CenterSpeedDense(input_channels=4, image_size=image_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 일정 기준으로 LR 감소 (검증 손실 기반)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) #, verbose=True)

# 네가 쓰던 형태와 호환되는 손실 (출력 [B,4,H,W], gts [B,H,W], dense [B,H,W,3])
def dense_loss(output, gt_heatmap, gt_dense_data, is_free, alpha=0.99, decay=1.0):
    # output: [B,4,H,W] -> [B,H,W,4]
    preds = output.permute(0,2,3,1)
    # gt reshape
    w = gt_heatmap.unsqueeze(-1)  # [B,H,W,1]
    # 가중 MSE
    loss_occ   = (alpha     * (1 + w) * (preds[...,0:1] - gt_heatmap.unsqueeze(-1))**2).sum()
    loss_dense = ((1-alpha) * (1 + w) * (preds[...,1:]  - gt_dense_data)**2).sum()
    batch_size = output.shape[0]
    return (loss_occ + loss_dense) / batch_size

print("Model/optimizer/loss 준비 완료")


Model/optimizer/loss 준비 완료


In [None]:
# --- 학습/검증 루프 ---
from contextlib import nullcontext

best_val = float('inf')
train_hist, val_hist = [], []

scaler_ctx = nullcontext  # (필요 시 autocast 추가 가능)

for epoch in range(1, epochs+1):
    model.train()
    running = 0.0
    for batch in train_loader:
        inputs, gts, data_vec, dense_feats, is_free = batch
        inputs      = inputs.to(device, non_blocking=True)
        gts         = gts.to(device, non_blocking=True)
        dense_feats = dense_feats.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        out = model(inputs)
        loss = dense_loss(out, gts, dense_feats, is_free, alpha=0.99)
        loss.backward()
        optimizer.step()
        running += loss.item()

    train_loss = running / max(1, len(train_loader))
    train_hist.append(train_loss)

    # Validation
    model.eval()
    v_running = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs, gts, data_vec, dense_feats, is_free = batch
            inputs      = inputs.to(device, non_blocking=True)
            gts         = gts.to(device, non_blocking=True)
            dense_feats = dense_feats.to(device, non_blocking=True)

            out = model(inputs)
            v_loss = dense_loss(out, gts, dense_feats, is_free, alpha=0.99)
            v_running += v_loss.item()

    val_loss = v_running / max(1, len(val_loader))
    val_hist.append(val_loss)

    # 스케줄러 업데이트
    scheduler.step(val_loss if len(val_loader) > 0 else train_loss)

    # 로그
    msg = f"[{epoch:03d}/{epochs}] train_loss={train_loss:.6f}"
    if len(val_loader) > 0:
        msg += f" | val_loss={val_loss:.6f}"
    print(msg)

    if 'wandb' in globals() and wandb and use_wandb:
        log_data = {"epoch": epoch, "train_loss": train_loss}
        if len(val_loader) > 0:
            log_data["val_loss"] = val_loss
        wandb.log(log_data)

    # 베스트 저장 (val이 있으면 val 기준, 없으면 train 기준)
    score = val_loss if len(val_loader) > 0 else train_loss
    save_dir = "/home/harry/ros2_ws/src/TinyCenterSpeed/src/pt"
    if score < best_val:
        best_val = score
        torch.save(model.state_dict(), os.path.join(save_dir, f"centerspeed_best_epoch_1{epoch}.pt"))
        print(f"  ↳ Best model saved at {os.path.join(save_dir, f'centerspeed_best_epoch_1{epoch}.pt')} (score={best_val:.6f})")


# 학습 완료 후 마지막 모델 저장
save_dir = "/home/harry/ros2_ws/src/TinyCenterSpeed/src/pt"
torch.save(model.state_dict(), os.path.join(save_dir, f"centerspeed_last_epoch_1{epoch}.pt"))
print(f"Training finished. Model saved at {os.path.join(save_dir, f'centerspeed_last_epoch_1{epoch}.pt')}")



[001/100] train_loss=8626.357082 | val_loss=2876.241244
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_1.pt (score=2876.241244)
[002/100] train_loss=1208.952339 | val_loss=175.602081
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_2.pt (score=175.602081)
[003/100] train_loss=41.543420 | val_loss=2.462254
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_3.pt (score=2.462254)
[004/100] train_loss=2.091351 | val_loss=2.076904
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_4.pt (score=2.076904)
[005/100] train_loss=1.903253 | val_loss=2.043565
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_5.pt (score=2.043565)
[006/100] train_loss=1.782974 | val_loss=1.766138
  ↳ Best model saved at /home/harry/ros2_ws/src/TinyCenterSpeed/src/pt/centerspeed_best_epoch_6.pt (sco

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x747334a37340>> (for post_run_cell), with arguments args (<ExecutionResult object at 747334a377c0, execution_count=8 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 747334a376a0, raw_cell="# --- 학습/검증 루프 ---
from contextlib import nullcont.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/harry/ros2_ws/src/TinyCenterSpeed/src/train/1Train_fast.ipynb#X11sZmlsZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
# --- 로컬 손실 곡선 ---
plt.figure(figsize=(6,4))
plt.plot(train_hist, label='train')
if len(val_hist) > 0:
    plt.plot(val_hist, label='val')
plt.xlabel('epoch'); plt.ylabel('loss'); plt.legend(); plt.title('Loss Curves')
plt.show()


In [None]:
# --- W&B 종료 ---
try:
    if 'wandb' in globals() and wandb and use_wandb:
        wandb.finish()
except Exception as e:
    print(f"[wandb] finish skipped: {type(e).__name__}: {e}")
