# HybridTwoWay Model (Colab Ready)


## Imports
필요한 PyTorch 모듈과 타입 힌트를 불러옵니다.

In [2]:
!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib




In [3]:
import math
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


In [14]:
# --- 옵션 3: Roboflow 데이터셋 사용 --- #
# !pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')


loading Roboflow workspace...
loading Roboflow project...
Roboflow dataset downloaded to: /content/detection-base-6


## 0. Utility Functions

In [4]:
def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)


## 1. Anomaly-Aware CNN Stem

In [6]:
class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'), self.weight, groups=self.groups)


class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)

    @property
    def out_channels(self):
        return 4 * 48

    def forward(self, x):
        f_main = self.stem(x)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:], mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v


## 2. Vision Transformer Encoder

In [7]:
class PatchEmbed1x1(nn.Module):
    """Map CNN features to ViT embeddings while keeping spatial resolution."""
    def __init__(self, in_ch, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        x = self.bn(x)
        x = F.silu(x, inplace=True)
        return x


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiheadSelfAttention(dim, num_heads, drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio=4.0, drop=0.0)
            for _ in range(depth)
        ])

    def forward(self, tokens):
        for blk in self.blocks:
            tokens = blk(tokens)
        return tokens


## 3. Feedback Adapter

In [8]:
class FeedbackAdapter(nn.Module):
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens: torch.Tensor, Ht: int, Wt: int, f_stem: torch.Tensor):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta


## 4. PAN-Lite Neck

In [9]:
class PANLite(nn.Module):
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.down5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.up4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.up3 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f4 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse4 = conv_bn_act(mid + mid, mid, 3, 1, 1)
        self.down_f5 = conv_bn_act(mid, mid, 3, 2, 1)
        self.fuse5 = conv_bn_act(mid + mid, mid, 3, 1, 1)

    def forward(self, p3):
        p3 = self.lateral(p3)
        p4 = self.down4(p3)
        p5 = self.down5(p4)
        p4u = F.interpolate(p5, size=p4.shape[-2:], mode='nearest')
        p4 = self.up4(torch.cat([p4, p4u], dim=1))
        p3u = F.interpolate(p4, size=p3.shape[-2:], mode='nearest')
        p3 = self.up3(torch.cat([p3, p3u], dim=1))
        p4b = self.down_f4(p3)
        p4 = self.fuse4(torch.cat([p4, p4b], dim=1))
        p5b = self.down_f5(p4)
        p5 = self.fuse5(torch.cat([p5, p5b], dim=1))
        return p3, p4, p5


## 5. YOLO-style Detection Head

In [18]:
class YOLOHeadLite(nn.Module):
    def __init__(self, in_ch=256, num_classes=1, reg_max=0):
        super().__init__()
        c = in_ch
        self.stem3 = conv_bn_act(c, c, 3, 1, 1)
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)
        self.stem5 = conv_bn_act(c, c, 3, 1, 1)
        self.cls3 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj3 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box3 = nn.Conv2d(c, 4, 1, 1, 0)
        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box4 = nn.Conv2d(c, 4, 1, 1, 0)
        self.cls5 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj5 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box5 = nn.Conv2d(c, 4, 1, 1, 0)

    def forward_single(self, x, stem, cls, obj, box):
        f = stem(x)
        return cls(f), obj(f), box(f)

    def forward(self, p3, p4, p5):
        c3, o3, b3 = self.forward_single(p3, self.stem3, self.cls3, self.obj3, self.box3)
        c4, o4, b4 = self.forward_single(p4, self.stem4, self.cls4, self.obj4, self.box4)
        c5, o5, b5 = self.forward_single(p5, self.stem5, self.cls5, self.obj5, self.box5)
        return [(c3, o3, b3), (c4, o4, b4), (c5, o5, b5)]


## 6. HybridTwoWay Model

In [25]:
class HybridTwoWay(nn.Module):
    def __init__(
        self,
        in_ch=3,
        stem_base=32,
        embed_dim=256,
        vit_depth=4,
        vit_heads=4,
        num_classes=3,
        iters=1,
        detach_feedback=True,
    ):
        super().__init__()
        assert iters >= 1
        self.iters = iters
        self.detach_feedback = detach_feedback
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4
        self.patch = PatchEmbed1x1(c_stem, embed_dim)
        self.vit = ViTEncoder(embed_dim=embed_dim, depth=vit_depth, num_heads=vit_heads)
        self.feedback = FeedbackAdapter(embed_dim, c_stem, use_bn=True)
        self.neck = PANLite(in_ch=embed_dim, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

    def forward_once(self, x):
        f_stem, vis = self.stem(x)
        p = self.patch(f_stem)
        Ht, Wt = p.shape[-2:]
        tokens = p.flatten(2).transpose(1, 2)
        tokens = self.vit(tokens)
        toks_for_fb = tokens.detach() if self.detach_feedback else tokens
        f_fb = self.feedback(toks_for_fb, Ht, Wt, f_stem)
        p3_in = self.patch(f_fb)
        p3 = p3_in
        p3, p4, p5 = self.neck(p3)
        preds = self.head(p3, p4, p5)
        aux = {"P3": p3, "P4": p4, "P5": p5, "V": vis}
        return preds, aux, f_fb

    def forward(self, x):
        preds, aux, f_fb = self.forward_once(x)
        for _ in range(self.iters - 1):
            preds, aux, f_fb = self.forward_once(x)
        return preds, aux


In [26]:
from torch.utils.data import Dataset, DataLoader
import cv2

IMG_SIZE = 512

def yolo_collate_fn(batch):
    imgs = []
    targets = []

    for img, tgt in batch:
        imgs.append(img)
        targets.append(tgt)   # tensor of shape [num_boxes, 5]

    # stack images → OK
    imgs = torch.stack(imgs, 0)

    # targets는 stack 안 함 (num_boxes 다르기 때문)
    return imgs, targets

class YoloDataset(Dataset):
    def __init__(self, root):
        self.img_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.images = sorted(os.listdir(self.img_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]

        img_path = os.path.join(self.img_dir, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = torch.tensor(img).permute(2,0,1).float() / 255.0

        label_path = os.path.join(self.label_dir, name.replace(".jpg",".txt").replace(".png",".txt"))
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, x, y, w, h = map(float, line.split())
                    boxes.append([cls, x, y, w, h])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        return img, boxes



In [27]:
import os
DATA_PATH = dataset.location

train_dataset = YoloDataset(os.path.join(DATA_PATH, "train"))
val_dataset   = YoloDataset(os.path.join(DATA_PATH, "valid"))

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=yolo_collate_fn
)


In [29]:
import torch.nn.functional as F

def yolo_loss(preds, targets, img_size=512):
    object_loss = 0
    class_loss = 0
    box_loss = 0

    for scale_id, (cls_pred, obj_pred, box_pred) in enumerate(preds):
        B, N, _ = obj_pred.shape

        # ----- sigmoid for predictions -----
        obj_pred = obj_pred.sigmoid()
        cls_pred = cls_pred.sigmoid()

        # ----- 매우 단순화된 YOLO matching -----
        # 여기서는 "해당 이미지에 object가 있으면 obj=1" 그냥 기본 체크
        for b in range(B):
            num_boxes = len(targets[b])
            if num_boxes == 0:
                # object 없음 → obj=0에 가까워야 함
                object_loss += (obj_pred[b] ** 2).mean()
                continue

            # object 있음 → obj=1 근처여야 함
            object_loss += ((1 - obj_pred[b]) ** 2).mean()

            # class loss (1 클래스일 때)
            tcls = targets[b][:, 0].long()
            class_loss += F.cross_entropy(cls_pred[b], tcls, reduction="mean")

            # box loss (L1)
            # normalize YOLO xywh → pixel 단위 좌표 가능
            tbox = targets[b][:, 1:].clone()  # xywh (0~1)
            tbox[:, 0] *= img_size
            tbox[:, 1] *= img_size
            tbox[:, 2] *= img_size
            tbox[:, 3] *= img_size

            pred_box = box_pred[b][:num_boxes]
            pred_box = pred_box * img_size  # normalize → pixel

            box_loss += F.l1_loss(pred_box, tbox, reduction="mean")

    total = object_loss + class_loss + box_loss
    return total


In [28]:
model = HybridTwoWay(num_classes=3).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


EPOCHS = 3

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, targets in train_loader:
        imgs = imgs.cuda()

        preds, aux = model(imgs)

        loss = yolo_loss(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | loss {total_loss:.4f}")
torch.save(model.state_dict(), "hybrid_two_way_best.pt")



Epoch 1 | loss 0.0000
Epoch 2 | loss 0.0000
Epoch 3 | loss 0.0000


In [None]:
model = HybridTwoWay(
    in_ch=3,
    stem_base=48,
    embed_dim=512,
    vit_depth=8,
    vit_heads=8,
    num_classes=3,
    iters=1,
    detach_feedback=True
).cuda()

model.load_state_dict(torch.load("hybrid_two_way_best.pt"))
model.eval()


## 7. Quick Sanity Check
Colab에서 바로 실행해 모델 입출력 형태를 확인할 수 있습니다.

In [None]:
model = HybridTwoWay(
    in_ch=3,
    stem_base=48,
    embed_dim=512,
    vit_depth=8,
    vit_heads=8,
    num_classes=1,
    iters=1,
    detach_feedback=True,
)

x = torch.randn(2, 3, 640, 640)
preds, aux = model(x)
for i, (c, o, b) in enumerate(preds, start=3):
    print(f"[TwoWay] P{i} cls:{list(c.shape)} obj:{list(o.shape)} box:{list(b.shape)}")
