# HybridTwoWay Model (Colab Ready)


## Imports
ÌïÑÏöîÌïú PyTorch Î™®ÎìàÍ≥º ÌÉÄÏûÖ ÌûåÌä∏Î•º Î∂àÎü¨ÏòµÎãàÎã§.

In [None]:
!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib albumentations


In [2]:
# ============================================
# Cell 1: ÏÑ§Ïπò Î∞è Í∏∞Î≥∏ Import, Roboflow Îã§Ïö¥Î°úÎìú
# ============================================

import math
import os
from typing import List, Tuple
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from roboflow import Roboflow
from tqdm import tqdm

rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')


loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in detection-base-6 to yolov8:: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 111006/111006 [00:02<00:00, 49870.64it/s]





Extracting Dataset Version Zip to detection-base-6 in yolov8:: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3306/3306 [00:00<00:00, 5865.30it/s]

Roboflow dataset downloaded to: /content/detection-base-6





## 0. Utility Functions
## 1. Anomaly-Aware CNN Stem

In [3]:
# ============================================
# Cell 2: Í∏∞Î≥∏ Conv Î∏îÎ°ù + Stem
# ============================================

def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)

class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'),
                        self.weight, groups=self.groups)
class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)
        self.base_ch = base_ch # Ï†ÄÏû•

    @property
    def out_channels(self):
        return 4 * self.base_ch

    def forward(self, x):
        f_main = self.stem(x)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:], mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v


## 2. Vision Transformer Encoder

In [4]:
# ============================================
# Cell 3: ViT Encoder + Feedback Adapter
# ============================================

class PatchEmbed1x1(nn.Module):
    def __init__(self, in_ch, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        x = self.bn(x)
        x = F.silu(x, inplace=True)
        return x

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class MultiheadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop_p = attn_drop
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (self.qkv(x)
               .reshape(B, N, 3, self.num_heads, self.head_dim)
               .permute(2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Flash Attention
        out = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop_p if self.training else 0.0,
            scale=self.scale
        )
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiheadSelfAttention(dim, num_heads, drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio=4.0, drop=0.0)
            for _ in range(depth)
        ])

    def forward(self, tokens):
        for blk in self.blocks:
            tokens = blk(tokens)
        return tokens

class FeedbackAdapter(nn.Module):
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens, Ht, Wt, f_stem):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta


## 4. PAN-Lite Neck

In [5]:
class PANLite(nn.Module):
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.down5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.up4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.up3 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse5 = conv_bn_act(mid + mid, mid, 3, 1, 1)

    def forward(self, p3):
        p3 = self.lateral(p3)
        p4 = self.down4(p3)
        p5 = self.down5(p4)
        p4u = F.interpolate(p5, size=p4.shape[-2:], mode='nearest')
        p4 = self.up4(torch.cat([p4, p4u], dim=1))
        p3u = F.interpolate(p4, size=p3.shape[-2:], mode='nearest')
        p3 = self.up3(torch.cat([p3, p3u], dim=1))
        p4b = self.down_f4(p3)
        p4 = self.fuse4(torch.cat([p4, p4b], dim=1))
        p5b = self.down_f5(p4)
        p5 = self.fuse5(torch.cat([p5, p5b], dim=1))
        return p3, p4, p5

## 5. YOLO-style Detection Head

In [6]:
class YOLOHeadLite(nn.Module):
    def __init__(self, in_ch=256, num_classes=1):
        super().__init__()
        c = in_ch
        self.stem3 = conv_bn_act(c, c, 3, 1, 1)
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)
        self.stem5 = conv_bn_act(c, c, 3, 1, 1)
        self.cls3 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj3 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box3 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj3.bias, -4.59)
        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box4 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj4.bias, -4.59)
        self.cls5 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj5 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box5 = nn.Conv2d(c, 4, 1, 1, 0)
        nn.init.constant_(self.obj5.bias, -4.59)

    def forward_single(self, x, stem, cls, obj, box):
        f = stem(x)
        return cls(f), obj(f), box(f)

    def forward(self, p3, p4, p5):
        c3, o3, b3 = self.forward_single(p3, self.stem3, self.cls3, self.obj3, self.box3)
        c4, o4, b4 = self.forward_single(p4, self.stem4, self.cls4, self.obj4, self.box4)
        c5, o5, b5 = self.forward_single(p5, self.stem5, self.cls5, self.obj5, self.box5)
        return [(c3, o3, b3), (c4, o4, b4), (c5, o5, b5)]

## 6. HybridTwoWay Model

In [7]:
class HybridTwoWay(nn.Module):
    def __init__(self, in_ch=3, stem_base=32, embed_dim=256, vit_depth=4, vit_heads=4, num_classes=3, iters=1, detach_feedback=True, img_size=640):
        super().__init__()
        self.iters = iters
        self.detach_feedback = detach_feedback
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4
        self.patch = PatchEmbed1x1(c_stem, embed_dim)
        self.num_patches = (img_size // 8) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        self.vit = ViTEncoder(embed_dim=embed_dim, depth=vit_depth, num_heads=vit_heads)
        self.feedback = FeedbackAdapter(d_token=embed_dim, c_stem=c_stem, use_bn=True)
        self.neck = PANLite(in_ch=embed_dim, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

    def forward(self, x):
        f_stem, vis = self.stem(x)
        p0 = self.patch(f_stem)
        Ht, Wt = p0.shape[-2:]
        tokens = p0.flatten(2).transpose(1, 2)

        # Positional Embedding (Dynamic Resizing)
        if tokens.shape[1] != self.pos_embed.shape[1]:
            pos_embed = F.interpolate(
                self.pos_embed.reshape(1, int(self.num_patches**0.5), int(self.num_patches**0.5), -1).permute(0, 3, 1, 2),
                size=(Ht, Wt), mode='bicubic', align_corners=False
            ).flatten(2).transpose(1, 2)
            tokens = tokens + pos_embed
        else:
            tokens = tokens + self.pos_embed

        f_fb = f_stem
        preds, aux = None, None

        for i in range(self.iters):
            tokens = self.vit(tokens)
            toks_for_fb = tokens.detach() if self.detach_feedback else tokens
            f_fb = self.feedback(toks_for_fb, Ht, Wt, f_fb)
            p3_in = self.patch(f_fb)
            p3, p4, p5 = self.neck(p3_in)
            preds = self.head(p3, p4, p5)
            aux = {"P3": p3, "P4": p4, "P5": p5, "V": vis}
            if i != self.iters - 1:
                tokens = p3_in.flatten(2).transpose(1, 2)
                # Re-add Pos Embed for next iter
                tokens = tokens + (pos_embed if 'pos_embed' in locals() else self.pos_embed)
        return preds, aux

In [8]:
# ============================================
# Cell 6: Dataset / Dataloader
# ============================================

IMG_SIZE = 640

def yolo_collate_fn(batch):
    imgs = []
    targets = []
    for img, tgt in batch:
        imgs.append(img)
        targets.append(tgt)
    imgs = torch.stack(imgs, 0)
    return imgs, targets

class YoloDataset(Dataset):
    def __init__(self, root):
        self.img_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.images = sorted(os.listdir(self.img_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]
        img_path = os.path.join(self.img_dir, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Ïù¥ÎØ∏ 640x640Ïù¥ÎØÄÎ°ú resizeÎäî Í∞ÄÎ≥çÍ≤å ÌÜµÍ≥º
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = torch.tensor(img).permute(2,0,1).float() / 255.0

        label_path = os.path.join(self.label_dir, name.replace(".jpg",".txt").replace(".png",".txt"))
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, x, y, w, h = map(float, line.split())
                    boxes.append([cls, x, y, w, h])
        boxes = torch.tensor(boxes, dtype=torch.float32)
        return img, boxes

In [9]:
# ============================================
# Cell 7: YOLO-style Loss (Focal + GIoU)
# ============================================

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        prob = torch.sigmoid(logits)
        ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        p_t = prob * targets + (1 - prob) * (1 - targets)
        focal_term = (1 - p_t) ** self.gamma
        loss = ce * focal_term
        if self.alpha >= 0:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            loss = alpha_t * loss
        if self.reduction == "mean": return loss.mean()
        elif self.reduction == "sum": return loss.sum()
        else: return loss

def xywh_to_xyxy(box_xywh):
    x_c, y_c, w, h = box_xywh.unbind(-1)
    x1, y1 = x_c - w / 2, y_c - h / 2
    x2, y2 = x_c + w / 2, y_c + h / 2
    return torch.stack([x1, y1, x2, y2], dim=-1)


def giou_loss(pred_box_xyxy, tgt_box_xyxy):
    x1 = torch.max(pred_box_xyxy[:, 0], tgt_box_xyxy[:, 0])
    y1 = torch.max(pred_box_xyxy[:, 1], tgt_box_xyxy[:, 1])
    x2 = torch.min(pred_box_xyxy[:, 2], tgt_box_xyxy[:, 2])
    y2 = torch.min(pred_box_xyxy[:, 3], tgt_box_xyxy[:, 3])
    inter = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)
    area_p = (pred_box_xyxy[:, 2]-pred_box_xyxy[:, 0]).clamp(min=0)*(pred_box_xyxy[:, 3]-pred_box_xyxy[:, 1]).clamp(min=0)
    area_t = (tgt_box_xyxy[:, 2]-tgt_box_xyxy[:, 0]).clamp(min=0)*(tgt_box_xyxy[:, 3]-tgt_box_xyxy[:, 1]).clamp(min=0)
    union = area_p + area_t - inter + 1e-6
    iou = inter / union
    c_x1 = torch.min(pred_box_xyxy[:, 0], tgt_box_xyxy[:, 0])
    c_y1 = torch.min(pred_box_xyxy[:, 1], tgt_box_xyxy[:, 1])
    c_x2 = torch.max(pred_box_xyxy[:, 2], tgt_box_xyxy[:, 2])
    c_y2 = torch.max(pred_box_xyxy[:, 3], tgt_box_xyxy[:, 3])
    c_area = (c_x2 - c_x1).clamp(min=0) * (c_y2 - c_y1).clamp(min=0) + 1e-6
    giou = iou - (c_area - union) / c_area
    return (1.0 - giou).mean()

# focal loss Ïù∏Ïä§ÌÑ¥Ïä§ (object + class Îëò Îã§ ÏÇ¨Ïö©)
_focal_loss = FocalLoss(alpha=0.25, gamma=2.0, reduction="mean")

def yolo_loss(preds, targets, img_size=512, lambda_obj=1.0, lambda_cls=1.0, lambda_box=5.0):
    total_obj_loss, total_cls_loss, total_box_loss = 0.0, 0.0, 0.0

    # Batch Size ÌôïÏù∏
    B_size = preds[0][0].shape[0]

    for scale_id, (cls_pred, obj_pred, box_pred) in enumerate(preds):
        B, C, H, W = cls_pred.shape
        device = cls_pred.device
        cls_p = cls_pred.permute(0, 2, 3, 1).reshape(B, H * W, C)
        obj_p = obj_pred.permute(0, 2, 3, 1).reshape(B, H * W, 1)
        box_p = box_pred.permute(0, 2, 3, 1).reshape(B, H * W, 4)
        stride = img_size // H

        for b in range(B):
            gt = targets[b]
            if gt.numel() == 0:
                obj_tgt = torch.zeros((H * W, 1), device=device)
                total_obj_loss += _focal_loss(obj_p[b], obj_tgt)
                continue

            gcls = gt[:, 0].long()
            gxy_norm = gt[:, 1:3]
            gwh_norm = gt[:, 3:5]
            gxy_pix = gxy_norm * img_size
            gx = (gxy_pix[:, 0] / stride).long().clamp(0, W - 1)
            gy = (gxy_pix[:, 1] / stride).long().clamp(0, H - 1)
            gi = gy * W + gx

            obj_tgt = torch.zeros((H * W, 1), device=device)
            obj_tgt[gi] = 1.0
            total_obj_loss += _focal_loss(obj_p[b], obj_tgt)

            cls_tgt = torch.zeros((H * W, C), device=device)
            cls_tgt[gi, gcls] = 1.0
            total_cls_loss += _focal_loss(cls_p[b], cls_tgt)

            pred_raw = box_p[b][gi]
            pred_box_norm_xywh = pred_raw.sigmoid()
            tgt_box_norm_xywh = torch.cat([gxy_norm, gwh_norm], dim=1)
            pred_xyxy = xywh_to_xyxy(pred_box_norm_xywh)
            tgt_xyxy = xywh_to_xyxy(tgt_box_norm_xywh)
            total_box_loss += giou_loss(pred_xyxy, tgt_xyxy)

    # Î∞∞Ïπò ÏÇ¨Ïù¥Ï¶àÎ°ú ÎÇòÎàÑÏñ¥ Ï†ïÍ∑úÌôî (ÌïôÏäµ ÏïàÏ†ïÏÑ± ÌôïÎ≥¥)
    total = (lambda_obj * total_obj_loss + lambda_cls * total_cls_loss + lambda_box * total_box_loss) / B_size
    return total


In [11]:
# ============================================
# Cell 8: Decode Predictions + mAP Evaluation
# ============================================

def box_iou(box1, box2):
    N = box1.size(0)
    M = box2.size(0)
    if N == 0 or M == 0: return torch.zeros(N, M)
    tl = torch.max(box1[:, None, :2], box2[:, :2])
    br = torch.min(box1[:, None, 2:], box2[:, 2:])
    wh = (br - tl).clamp(min=0)
    inter = wh[..., 0] * wh[..., 1]
    area1 = (box1[:, 2]-box1[:, 0]) * (box1[:, 3]-box1[:, 1])
    area2 = (box2[:, 2]-box2[:, 0]) * (box2[:, 3]-box2[:, 1])
    return inter / (area1[:, None] + area2 - inter + 1e-6)

def nms(boxes, scores, iou_thres=0.5):
    if boxes.numel() == 0: return torch.zeros(0, dtype=torch.long, device=boxes.device)
    idxs = scores.argsort(descending=True)
    keep = []
    while idxs.numel() > 0:
        i = idxs[0]
        keep.append(i.item())
        if idxs.numel() == 1: break
        ious = box_iou(boxes[i].unsqueeze(0), boxes[idxs[1:]]).squeeze(0)
        idxs = idxs[1:][ious < iou_thres]
    return torch.tensor(keep, dtype=torch.long, device=boxes.device)

def decode_predictions(preds, img_size=512, conf_thres=0.25, nms_iou_thres=0.5):
    all_outputs = []
    B = preds[0][0].shape[0]
    for b in range(B):
        dets_all = []
        for (cls_pred, obj_pred, box_pred) in preds:
            B_s, C, H, W = cls_pred.shape
            cls_logits = cls_pred[b].permute(1,2,0).reshape(H*W, C)
            obj_logits = obj_pred[b].permute(1,2,0).reshape(H*W, 1)
            box_logits = box_pred[b].permute(1,2,0).reshape(H*W, 4)
            obj_scores = obj_logits.sigmoid().squeeze(-1)
            cls_scores = cls_logits.sigmoid()
            box_norm = box_logits.sigmoid()
            cls_max_scores, cls_ids = cls_scores.max(dim=-1)
            scores = obj_scores * cls_max_scores
            mask = scores > conf_thres
            if mask.sum() == 0: continue
            scores_ = scores[mask]
            cls_ids_ = cls_ids[mask]
            boxes = box_norm[mask]
            x_c, y_c, w, h = boxes[:, 0]*img_size, boxes[:, 1]*img_size, boxes[:, 2]*img_size, boxes[:, 3]*img_size
            x1, y1 = (x_c - w/2).clamp(0, img_size), (y_c - h/2).clamp(0, img_size)
            x2, y2 = (x_c + w/2).clamp(0, img_size), (y_c + h/2).clamp(0, img_size)
            boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
            keep = nms(boxes_xyxy, scores_, iou_thres=nms_iou_thres)
            if keep.numel() == 0: continue
            dets = torch.cat([boxes_xyxy[keep], scores_[keep].unsqueeze(1), cls_ids_[keep].float().unsqueeze(1)], dim=1)
            dets_all.append(dets)
        all_outputs.append(torch.cat(dets_all, dim=0) if len(dets_all) > 0 else [])
    return all_outputs

def compute_ap(recall, precision):
    mrec = torch.cat([torch.tensor([0.0]), recall, torch.tensor([1.0])])
    mpre = torch.cat([torch.tensor([0.0]), precision, torch.tensor([0.0])])
    for i in range(mpre.size(0)-1, 0, -1):
        mpre[i-1] = torch.max(mpre[i-1], mpre[i])
    idx = (mrec[1:] != mrec[:-1]).nonzero().squeeze()
    return ((mrec[idx+1] - mrec[idx]) * mpre[idx+1]).sum().item()

def evaluate_map(model, dataloader, num_classes=3, img_size=512, iou_thr=0.5, conf_thres=0.25):
    model.eval()
    device = next(model.parameters()).device
    all_dets = {c: [] for c in range(num_classes)}
    all_gts  = {c: [] for c in range(num_classes)}
    global_img_id = 0

    with torch.no_grad():
        for batch_i, (imgs, targets) in enumerate(dataloader):
            imgs = imgs.to(device)
            targets = [t.to(device) for t in targets]
            preds, _ = model(imgs)
            dets_list = decode_predictions(preds, img_size=img_size, conf_thres=conf_thres)

            for b in range(len(imgs)):
                dets = dets_list[b]
                gt = targets[b]
                current_img_id = global_img_id
                global_img_id += 1

                if len(gt) > 0:
                    gcls = gt[:, 0].long()
                    gxy, gwh = gt[:, 1:3] * img_size, gt[:, 3:5] * img_size
                    gx1, gy1 = gxy[:, 0] - gwh[:, 0]/2, gxy[:, 1] - gwh[:, 1]/2
                    gx2, gy2 = gxy[:, 0] + gwh[:, 0]/2, gxy[:, 1] + gwh[:, 1]/2
                    gboxes = torch.stack([gx1, gy1, gx2, gy2], dim=1)
                    for c in range(num_classes):
                        mask = (gcls == c)
                        if mask.sum() > 0: all_gts[c].append((current_img_id, gboxes[mask].cpu()))

                if dets is not None and len(dets) > 0:
                    boxes, scores, cls_ids = dets[:, :4], dets[:, 4], dets[:, 5].long()
                    for c in range(num_classes):
                        mask = (cls_ids == c)
                        if mask.sum() > 0: all_dets[c].append((current_img_id, scores[mask].cpu(), boxes[mask].cpu()))

    aps = []
    for c in range(num_classes):
        gts_c = all_gts[c]
        if len(gts_c) == 0: continue
        n_gt = sum(boxes.size(0) for _, boxes in gts_c)
        gt_dict = {}
        for img_id, boxes in gts_c:
            gt_dict.setdefault(img_id, [])
            gt_dict[img_id].append({"boxes": boxes, "matched": torch.zeros(boxes.size(0), dtype=torch.bool)})

        dets_c = all_dets[c]
        if len(dets_c) == 0:
            aps.append(0.0)
            continue

        scores_all, boxes_all, img_ids_all = [], [], []
        for img_id, scores, boxes in dets_c:
            for i in range(boxes.size(0)):
                scores_all.append(scores[i].item())
                boxes_all.append(boxes[i])
                img_ids_all.append(img_id)

        scores_all = torch.tensor(scores_all)
        boxes_all = torch.stack(boxes_all, dim=0)
        order = scores_all.argsort(descending=True)
        scores_all, boxes_all = scores_all[order], boxes_all[order]
        img_ids_all = [img_ids_all[i] for i in order]

        tps, fps = torch.zeros(len(scores_all)), torch.zeros(len(scores_all))
        for i in range(len(scores_all)):
            img_id = img_ids_all[i]
            pred_box = boxes_all[i].unsqueeze(0)
            if img_id not in gt_dict:
                fps[i] = 1; continue
            gt_entry = gt_dict[img_id][0]
            ious = box_iou(pred_box, gt_entry["boxes"]).squeeze(0)
            if ious.numel() == 0: fps[i] = 1; continue
            max_iou, max_idx = ious.max(0)
            if max_iou >= iou_thr and not gt_entry["matched"][max_idx]:
                tps[i] = 1; gt_entry["matched"][max_idx] = True
            else: fps[i] = 1

        tp_cum, fp_cum = torch.cumsum(tps, dim=0), torch.cumsum(fps, dim=0)
        recall = tp_cum / (n_gt + 1e-6)
        precision = tp_cum / (tp_cum + fp_cum + 1e-6)
        aps.append(compute_ap(recall, precision))

    mAP = sum(aps) / len(aps) if len(aps) > 0 else 0.0
    return mAP, aps

In [12]:
DATA_PATH = dataset.location
train_dataset = YoloDataset(os.path.join(DATA_PATH, "train"))
val_dataset   = YoloDataset(os.path.join(DATA_PATH, "valid"))
test_dataset  = YoloDataset(os.path.join(DATA_PATH, "test"))
# [ÏµúÏ†ÅÌôî] num_workers=2, pin_memory=True Ï∂îÍ∞Ä
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=yolo_collate_fn, num_workers=2, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")

Using Device: cuda


In [19]:
# ============================================
# ÌïôÏäµ Ï§ÄÎπÑ
# ============================================
cfg = dict(
    in_ch=3,
    stem_base=32,
    embed_dim=256,
    vit_depth=4,
    vit_heads=4,
    num_classes=3,
    iters=1,
    detach_feedback=True,
    img_size=IMG_SIZE
)

model = HybridTwoWay(**cfg).to(device)

# [ÏÑ†ÌÉù] torch.compile (PyTorch 2.0+, Colab T4/L4)
try:
    model = torch.compile(model)
    print("‚úÖ Model compiled with torch.compile")
except:
    print("‚ö†Ô∏è torch.compile failed/skipped")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler()
EPOCHS = 10

best_map = -1.0
best_epoch = -1

‚úÖ Model compiled with torch.compile


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.


In [20]:
# VRAM ÏÇ¨Ïö©Îüâ
imgs, targets = next(iter(train_loader))
imgs = imgs.to(device)
targets = [t.to(device) for t in targets]

optimizer.zero_grad()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
    preds, aux = model(imgs)
    loss = yolo_loss(preds, targets, img_size=IMG_SIZE)

loss.backward()

print("Max allocated:", torch.cuda.max_memory_allocated() / 1024**3, "GB")
print("Max reserved :", torch.cuda.max_memory_reserved() / 1024**3, "GB")


Max allocated: 1.306610107421875 GB
Max reserved : 2.177734375 GB


In [23]:
# ============================================
# Cell 9: ÌïôÏäµ Î£®ÌîÑ (AMP + Flash Attention + PosEmbed)
# ============================================

print("üöÄ Start Training...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for imgs, targets in loop:
        imgs = imgs.to(device)
        targets = [t.to(device) for t in targets]
        optimizer.zero_grad()

        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            preds, aux = model(imgs)
            loss = yolo_loss(preds, targets, img_size=IMG_SIZE)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Train Average Loss: {avg_loss:.4f}")

    # Validation
    val_map, val_aps = evaluate_map(model, val_loader, num_classes=cfg["num_classes"], img_size=IMG_SIZE, conf_thres=0.01)
    print(f"Epoch {epoch+1} | Val mAP@0.5: {val_map:.4f}")

    if val_map > best_map:
        best_map = val_map
        best_epoch = epoch + 1
        # compile ÏÇ¨Ïö© Ïãú ÏõêÎ≥∏ state_dict Ï†ÄÏû•ÏùÑ ÏúÑÌï¥ _orig_mod ÌôïÏù∏
        state_dict = model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict()
        ckpt = {"state_dict": state_dict, "cfg": cfg, "epoch": best_epoch, "val_map": best_map}
        torch.save(ckpt, "hybrid_two_way_best.pt")
        print(f"‚úÖ Best model saved! (Val mAP: {best_map:.4f})")

print("üèÅ Training Finished!")

üöÄ Start Training...


Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:04<00:00,  2.67it/s, loss=11.1724]

Epoch 1 | Train Average Loss: 12.6366





Epoch 1 | Val mAP@0.5: 0.0147
‚úÖ Best model saved! (Val mAP: 0.0147)


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:04<00:00,  2.69it/s, loss=8.9998]

Epoch 2 | Train Average Loss: 10.4648





Epoch 2 | Val mAP@0.5: 0.0267
‚úÖ Best model saved! (Val mAP: 0.0267)


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.72it/s, loss=10.2649]

Epoch 3 | Train Average Loss: 9.9766





Epoch 3 | Val mAP@0.5: 0.0240


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.71it/s, loss=11.3181]

Epoch 4 | Train Average Loss: 9.6514





Epoch 4 | Val mAP@0.5: 0.0241


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:04<00:00,  2.70it/s, loss=10.1903]

Epoch 5 | Train Average Loss: 9.3809





Epoch 5 | Val mAP@0.5: 0.0216


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.72it/s, loss=10.5236]

Epoch 6 | Train Average Loss: 9.1112





Epoch 6 | Val mAP@0.5: 0.0277
‚úÖ Best model saved! (Val mAP: 0.0277)


Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.71it/s, loss=8.3596]

Epoch 7 | Train Average Loss: 8.8977





Epoch 7 | Val mAP@0.5: 0.0268


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.73it/s, loss=6.0074]

Epoch 8 | Train Average Loss: 8.7006





Epoch 8 | Val mAP@0.5: 0.0289
‚úÖ Best model saved! (Val mAP: 0.0289)


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.72it/s, loss=9.1786]

Epoch 9 | Train Average Loss: 8.5025





Epoch 9 | Val mAP@0.5: 0.0163


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 173/173 [01:03<00:00,  2.72it/s, loss=8.0174]

Epoch 10 | Train Average Loss: 8.2511





Epoch 10 | Val mAP@0.5: 0.0219
üèÅ Training Finished!


In [24]:
# ============================================
# Ï†ÄÏû•Îêú Best Î™®Îç∏ Î∂àÎü¨Ïò§Í∏∞ Î∞è ÌÖåÏä§Ìä∏ ÌèâÍ∞Ä
# ============================================

checkpoint = torch.load("hybrid_two_way_best.pt", map_location=device)
loaded_cfg = checkpoint["cfg"]
model = HybridTwoWay(**loaded_cfg).to(device)
model.load_state_dict(checkpoint["state_dict"])
model.eval()

test_map, class_aps = evaluate_map(model, test_loader, num_classes=loaded_cfg["num_classes"], img_size=IMG_SIZE, conf_thres=0.01)
print(f"\nüèÜ Final Test mAP@0.5: {test_map:.4f}")
for i, ap in enumerate(class_aps):
    print(f"   Class {i} AP@0.5: {ap:.4f}")


üèÜ Final Test mAP@0.5: 0.0402
   Class 0 AP@0.5: 0.0508
   Class 1 AP@0.5: 0.0693
   Class 2 AP@0.5: 0.0004


In [18]:
# ============================================
# Cell 10: Quick Sanity Check (ÏûÖÏ∂úÎ†• shape ÌôïÏù∏)
# ============================================

x = torch.randn(2, 3, 640, 640).to(device)
preds, aux = model(x)

for level, (c, o, b) in zip(["P3","P4","P5"], preds):
    print(f"[{level}] cls: {list(c.shape)}, obj: {list(o.shape)}, box: {list(b.shape)}")


[P3] cls: [2, 3, 80, 80], obj: [2, 1, 80, 80], box: [2, 4, 80, 80]
[P4] cls: [2, 3, 40, 40], obj: [2, 1, 40, 40], box: [2, 4, 40, 40]
[P5] cls: [2, 3, 20, 20], obj: [2, 1, 20, 20], box: [2, 4, 20, 20]
