In [8]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ============================================================
# 1. IoU (2D bounding box) – pred ↔ GT 매칭에 사용
# ============================================================
def iou_2d(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH
    boxAArea = (boxA[2]-boxA[0])*(boxA[3]-boxA[1])
    boxBArea = (boxB[2]-boxB[0])*(boxB[3]-boxB[1])
    return interArea / (boxAArea + boxBArea - interArea + 1e-6)

# ============================================================
# 2. KITTI loader (prediction & label)
# ============================================================
def load_kitti_file(path, with_score=True):
    dets = []
    if not os.path.exists(path): return dets
    with open(path,'r') as f:
        for line in f:
            parts=line.strip().split()
            if len(parts)<15: continue
            if parts[0]!="Car": continue
            x1,y1,x2,y2=map(float,parts[4:8])
            h,w,l=map(float,parts[8:11])
            x,y,z,ry=map(float,parts[11:15])
            score=float(parts[15]) if with_score and len(parts)>15 else 1.0
            dets.append({
                "box2d":[x1,y1,x2,y2],
                "box3d":[x,y,z,w,h,l,ry],
                "score":score
            })
    return dets

# ============================================================
# 3. Sample builder (pred ↔ gt 매칭 후 residual 생성)
# ============================================================
def build_samples(output_dir, label_dir, iou_thresh=0.7):
    samples=[]
    out_files=sorted(glob.glob(os.path.join(output_dir,"*.txt")))
    for out_path in out_files:
        image_id=os.path.splitext(os.path.basename(out_path))[0]
        label_path=os.path.join(label_dir,image_id+".txt")

        outputs=load_kitti_file(out_path,with_score=True)
        labels=load_kitti_file(label_path,with_score=False)

        if len(labels)==0:   # 🚨 GT에 Car가 없는 경우 스킵
            continue

        for out in outputs:
            # labels이 비어있지 않을 때만 max() 수행
            if len(labels)>0:
                gt=max(labels,key=lambda g:iou_2d(out["box2d"],g["box2d"]))
                if iou_2d(out["box2d"],gt["box2d"])>=iou_thresh:
                    samples.append({
                        "init_3d": torch.tensor(out["box3d"],dtype=torch.float32),
                        "gt_3d": torch.tensor(gt["box3d"],dtype=torch.float32),
                    })
    return samples

# ============================================================
# 4. Dataset
# ============================================================
class GraphDataset(Dataset):
    def __init__(self,samples): self.samples=samples
    def __len__(self): return len(self.samples)
    def __getitem__(self,idx): return self.samples[idx]

# ============================================================
# 5. Graph Utilities
# ============================================================
def build_distance_adj_matrix(pos, threshold=3.0):
    N=pos.size(0)
    dists=torch.cdist(pos,pos)
    adj=(dists<threshold).float()
    adj.fill_diagonal_(0)
    deg=adj.sum(1,keepdim=True)
    adj=adj/(deg+1e-6)
    return adj

class GCNLayer(nn.Module):
    def __init__(self,in_dim,out_dim):
        super().__init__()
        self.linear=nn.Linear(in_dim,out_dim)
    def forward(self,x,adj):
        return F.relu(self.linear(adj@x))

# ============================================================
# 6. Graph Refinement Model (입력은 3D box 7차원)
# ============================================================
class RelationalRefinement(nn.Module):
    def __init__(self,input_dim=7,hidden_dim=64,out_dim=7):
        super().__init__()
        self.fc1=nn.Linear(input_dim,hidden_dim)
        self.fc2=nn.Linear(hidden_dim,hidden_dim)
        self.gcn1=GCNLayer(hidden_dim,hidden_dim)
        self.gcn2=GCNLayer(hidden_dim,hidden_dim)
        self.regressor=nn.Linear(hidden_dim,out_dim)
    def forward(self,features,centers):
        adj=build_distance_adj_matrix(centers,threshold=3.0)
        x=F.relu(self.fc1(features))
        x=F.relu(self.fc2(x))
        x=self.gcn1(x,adj)
        x=self.gcn2(x,adj)
        return self.regressor(x)

# ============================================================
# 7. Training
# ============================================================
def train_model(samples,epochs=20,batch_size=16,lr=1e-3,save_path="../outputs/graph_refine.pth"):
    dataset=GraphDataset(samples)
    loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
    model=RelationalRefinement(input_dim=7)
    optimz=torch.optim.Adam(model.parameters(),lr=lr)
    best=float("inf")

    for ep in range(epochs):
        model.train(); total=0
        for batch in loader:
            init3d=batch["init_3d"]
            gt=batch["gt_3d"]
            features=init3d
            centers=init3d[:,:3]
            delta_pred=model(features,centers)
            delta_gt=gt-init3d
            loss=F.smooth_l1_loss(delta_pred,delta_gt)
            optimz.zero_grad(); loss.backward(); optimz.step()
            total+=loss.item()
        avg=total/len(loader)
        print(f"[Epoch {ep+1}/{epochs}] Loss={avg:.4f}")
        if avg<best:
            best=avg; torch.save(model.state_dict(),save_path)
            print(f"  ↳ Best 갱신: {best:.4f}")
    return model

# ============================================================
# 8. Inference (refinement 후 저장)
# ============================================================
def refine_and_save(model,output_dir,save_dir):
    os.makedirs(save_dir,exist_ok=True)
    out_files=sorted(glob.glob(os.path.join(output_dir,"*.txt")))
    for out_path in out_files:
        image_id=os.path.splitext(os.path.basename(out_path))[0]
        outputs=load_kitti_file(out_path,with_score=True)
        refined=[]
        for out in outputs:
            init3d=torch.tensor(out["box3d"],dtype=torch.float32).unsqueeze(0)
            centers=init3d[:,:3]
            delta=model(init3d,centers).detach().numpy()[0]
            refined_box=init3d.numpy()[0]+delta
            refined.append({"cls":"Car","box2d":out["box2d"],"box3d":refined_box,"score":out["score"]})
        save_path=os.path.join(save_dir,image_id+".txt")
        with open(save_path,"w") as f:
            for det in refined:
                x1,y1,x2,y2=det["box2d"]
                x,y,z,w,h,l,ry=det["box3d"]
                score=det["score"]
                line=f"Car 0.00 0 -1.67 {x1:.2f} {y1:.2f} {x2:.2f} {y2:.2f} {h:.2f} {w:.2f} {l:.2f} {x:.2f} {y:.2f} {z:.2f} {ry:.2f} {score:.3f}\n"
                f.write(line)
    print(f"✅ Refinement 완료: {save_dir}")


In [9]:
if __name__ == "__main__":
    # ----------------------------
    # 1. 경로 설정
    # ----------------------------
    train_output_dir = "../dataset/merge_output_train"   # detector 예측 결과 (KITTI txt, train set)
    train_label_dir  = "../dataset/label_2_train"        # GT 라벨 (KITTI label_2, train set)
    val_output_dir   = "../dataset/merge_output_val"     # detector 예측 결과 (KITTI txt, val set)
    val_label_dir    = "../dataset/label_2_val"          # GT 라벨 (KITTI label_2, val set)

    # ----------------------------
    # 2. 학습 샘플 생성
    # ----------------------------
    print("📂 Train 샘플 생성 중...")
    train_samples = build_samples(train_output_dir, train_label_dir, iou_thresh=0.7)
    print(f"✅ Train 샘플 개수: {len(train_samples)}")

    # ----------------------------
    # 3. 모델 학습
    # ----------------------------
    print("🚀 Graph Refinement 학습 시작")
    model = train_model(
        train_samples,
        epochs=10,                  # 학습 epoch 수
        batch_size=16,
        lr=1e-3,
        save_path="../outputs/graph_refine_best.pth"
    )

    # ----------------------------
    # 4. 학습된 best 모델 불러오기
    # ----------------------------
    print("📂 학습된 모델 로드")
    model = RelationalRefinement(input_dim=7)
    model.load_state_dict(torch.load("../outputs/graph_refine_best.pth"))
    model.eval()

    # ----------------------------
    # 5. Validation refinement 실행
    # ----------------------------
    save_dir = "../dataset/graph_refined_output_val"
    refine_and_save(
        model,
        output_dir=val_output_dir,
        save_dir=save_dir
    )


📂 Train 샘플 생성 중...
✅ Train 샘플 개수: 11355
🚀 Graph Refinement 학습 시작
[Epoch 1/10] Loss=0.0247
  ↳ Best 갱신: 0.0247
[Epoch 2/10] Loss=0.0246
  ↳ Best 갱신: 0.0246
[Epoch 3/10] Loss=0.0246
  ↳ Best 갱신: 0.0246
[Epoch 4/10] Loss=0.0246
[Epoch 5/10] Loss=0.0246
  ↳ Best 갱신: 0.0246
[Epoch 6/10] Loss=0.0246
  ↳ Best 갱신: 0.0246
[Epoch 7/10] Loss=0.0246
  ↳ Best 갱신: 0.0246
[Epoch 8/10] Loss=0.0246
[Epoch 9/10] Loss=0.0246
[Epoch 10/10] Loss=0.0245
  ↳ Best 갱신: 0.0245
📂 학습된 모델 로드
✅ Refinement 완료: ../dataset/graph_refined_output_val
