# HybridTwoWay Model (Colab Ready)


## Imports
필요한 PyTorch 모듈과 타입 힌트를 불러옵니다.

In [1]:
!pip install roboflow torch torchvision torchaudio opencv-python numpy tqdm pillow matplotlib


Collecting roboflow
  Downloading roboflow-1.2.11-py3-none-any.whl.metadata (9.7 kB)
Collecting opencv-python
  Downloading opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (19 kB)
Collecting idna==3.7 (from roboflow)
  Downloading idna-3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting opencv-python-headless==4.10.0.84 (from roboflow)
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting pi-heif<2 (from roboflow)
  Downloading pi_heif-1.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.5 kB)
Collecting pillow-avif-plugin<2 (from roboflow)
  Downloading pillow_avif_plugin-1.5.2-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Collecting python-dotenv (from roboflow)
  Downloading python_dotenv-1.2.1-py3-none-any.whl.metadata (25 kB)
Collecting requests-toolbelt (from roboflow)
  Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl.me

In [2]:
import math
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


In [5]:
# --- 옵션 3: Roboflow 데이터셋 사용 --- #
# !pip install roboflow

from roboflow import Roboflow
rf = Roboflow(api_key="HG9M6YJZpcCUgAQaKO9v")
project = rf.workspace("arakon").project("detection-base-hqaeg")
version = project.version(6)
dataset = version.download("yolov8")

print(f'Roboflow dataset downloaded to: {dataset.location}')


loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in detection-base-6 to yolov8:: 100%|██████████| 111006/111006 [00:01<00:00, 79402.89it/s] 





Extracting Dataset Version Zip to detection-base-6 in yolov8:: 100%|██████████| 3306/3306 [00:00<00:00, 8306.31it/s]


Roboflow dataset downloaded to: /content/detection-base-6


## 0. Utility Functions

In [3]:
def conv_bn_act(in_ch, out_ch, k=3, s=1, p=1, act=True):
    m = [nn.Conv2d(in_ch, out_ch, k, s, p, bias=False),
         nn.BatchNorm2d(out_ch)]
    if act:
        m.append(nn.SiLU(inplace=True))
    return nn.Sequential(*m)


## 1. Anomaly-Aware CNN Stem

In [4]:
class FixedGaussianBlur(nn.Module):
    def __init__(self, channels, k=5, sigma=1.0):
        super().__init__()
        grid = torch.arange(k).float() - (k - 1) / 2
        gauss = torch.exp(-(grid ** 2) / (2 * sigma ** 2))
        kernel1d = gauss / gauss.sum()
        kernel2d = torch.outer(kernel1d, kernel1d)
        weight = kernel2d[None, None, :, :].repeat(channels, 1, 1, 1)
        self.register_buffer('weight', weight)
        self.groups = channels
        self.k = k

    def forward(self, x):
        pad = (self.k // 2,) * 4
        return F.conv2d(F.pad(x, pad, mode='reflect'), self.weight, groups=self.groups)


class AnomalyAwareStem(nn.Module):
    def __init__(self, in_ch=3, base_ch=48):
        super().__init__()
        C1, C2, C3 = base_ch, base_ch * 2, base_ch * 4
        self.stem = nn.Sequential(
            conv_bn_act(in_ch, C1, 3, 2, 1),
            conv_bn_act(C1, C2, 3, 2, 1),
            conv_bn_act(C2, C3, 3, 2, 1),
        )
        self.blur = FixedGaussianBlur(in_ch, k=5, sigma=1.0)
        self.anom = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, 1, 1, groups=in_ch, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, C3 // 4, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C3 // 4),
            nn.SiLU(inplace=True),
        )
        self.fuse = nn.Conv2d(C3 + C3 // 4, C3, 1, 1, 0, bias=False)
        self.fuse_bn = nn.BatchNorm2d(C3)
        self.vis_head = nn.Conv2d(C3, 1, 1, 1, 0)

    @property
    def out_channels(self):
        return 4 * 48

    def forward(self, x):
        f_main = self.stem(x)
        blurred = self.blur(x)
        high = x - blurred
        high_ds = F.interpolate(high, size=f_main.shape[-2:], mode='bilinear', align_corners=False)
        f_anom = self.anom(high_ds)
        f = torch.cat([f_main, f_anom], dim=1)
        f = self.fuse_bn(self.fuse(f))
        f = F.silu(f, inplace=True)
        v = torch.sigmoid(self.vis_head(f_main))
        return f, v


## 2. Vision Transformer Encoder

In [6]:
class PatchEmbed1x1(nn.Module):
    """Map CNN features to ViT embeddings while keeping spatial resolution."""
    def __init__(self, in_ch, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        x = self.bn(x)
        x = F.silu(x, inplace=True)
        return x


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class MultiheadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiheadSelfAttention(dim, num_heads, drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio=4.0, drop=0.0)
            for _ in range(depth)
        ])

    def forward(self, tokens):
        for blk in self.blocks:
            tokens = blk(tokens)
        return tokens


## 3. Feedback Adapter

In [7]:
class FeedbackAdapter(nn.Module):
    def __init__(self, d_token: int, c_stem: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(d_token, c_stem * 2, 1, 1, 0, bias=not use_bn)]
        if use_bn:
            layers.append(nn.BatchNorm2d(c_stem * 2))
        layers.append(nn.SiLU(inplace=True))
        self.adapter = nn.Sequential(*layers)

    def forward(self, tokens: torch.Tensor, Ht: int, Wt: int, f_stem: torch.Tensor):
        B, N, D = tokens.shape
        t2d = tokens.transpose(1, 2).reshape(B, D, Ht, Wt)
        ab = self.adapter(t2d)
        Cs = f_stem.shape[1]
        gamma, beta = torch.split(ab, Cs, dim=1)
        return f_stem * (1 + torch.tanh(gamma)) + beta


## 4. PAN-Lite Neck

In [17]:
class PANLite(nn.Module):
    def __init__(self, in_ch=512, mid=256):
        super().__init__()
        # p3 → mid
        self.lateral = conv_bn_act(in_ch, mid, 1, 1, 0)

        # mid(80×80) → mid(40×40)
        self.down4 = conv_bn_act(mid, mid, 3, 2, 1)

        # optional: 강화 conv
        self.refine4 = conv_bn_act(mid, mid, 3, 1, 1)

    def forward(self, p3):
        """
        input:  p3  = [B, in_ch, H, W] (ex: 80x80)
        output: p4  = [B, mid, H/2, W/2] (ex: 40x40)
        """
        # lateral transform
        p3 = self.lateral(p3)

        # downsample once → P4
        p4 = self.down4(p3)

        # optional refinement
        p4 = self.refine4(p4)

        return p4


## 5. YOLO-style Detection Head

In [18]:
class YOLOHeadLite_P4Only(nn.Module):
    def __init__(self, in_ch=256, num_classes=1):
        super().__init__()
        c = in_ch

        # P4 전용 stem
        self.stem4 = conv_bn_act(c, c, 3, 1, 1)

        # P4 detection layers
        self.cls4 = nn.Conv2d(c, num_classes, 1, 1, 0)
        self.obj4 = nn.Conv2d(c, 1, 1, 1, 0)
        self.box4 = nn.Conv2d(c, 4, 1, 1, 0)

    def forward_single(self, x):
        f = self.stem4(x)
        return self.cls4(f), self.obj4(f), self.box4(f)

    def forward(self, p4):
        c4, o4, b4 = self.forward_single(p4)
        return [(c4, o4, b4)]     # 리스트 형태 유지 (학습 코드 호환)


## 6. HybridTwoWay Model

In [10]:
class HybridTwoWay(nn.Module):
    def __init__(
        self,
        in_ch=3,
        stem_base=32,
        embed_dim=256,
        vit_depth=4,
        vit_heads=4,
        num_classes=3,
        iters=1,
        detach_feedback=True,
    ):
        super().__init__()
        assert iters >= 1
        self.iters = iters
        self.detach_feedback = detach_feedback
        self.stem = AnomalyAwareStem(in_ch=in_ch, base_ch=stem_base)
        c_stem = stem_base * 4
        self.patch = PatchEmbed1x1(c_stem, embed_dim)
        self.vit = ViTEncoder(embed_dim=embed_dim, depth=vit_depth, num_heads=vit_heads)
        self.feedback = FeedbackAdapter(embed_dim, c_stem, use_bn=True)
        self.neck = PANLite(in_ch=embed_dim, mid=256)
        self.head = YOLOHeadLite(in_ch=256, num_classes=num_classes)

    def forward_once(self, x):
        f_stem, vis = self.stem(x)
        p = self.patch(f_stem)
        Ht, Wt = p.shape[-2:]
        tokens = p.flatten(2).transpose(1, 2)
        tokens = self.vit(tokens)
        toks_for_fb = tokens.detach() if self.detach_feedback else tokens
        f_fb = self.feedback(toks_for_fb, Ht, Wt, f_stem)
        p3_in = self.patch(f_fb)
        p3 = p3_in
        p4 = self.neck(p3)
        all_preds = self.head([p4])
        preds = [all_preds[0]]
        aux = {"P4": p4, "V": vis}
        return preds, aux, f_fb

    def forward(self, x):
        preds, aux, f_fb = self.forward_once(x)
        for _ in range(self.iters - 1):
            preds, aux, f_fb = self.forward_once(x)
        return preds, aux


In [11]:
from torch.utils.data import Dataset, DataLoader
import cv2

IMG_SIZE = 512

def yolo_collate_fn(batch):
    imgs = []
    targets = []

    for img, tgt in batch:
        imgs.append(img)
        targets.append(tgt)   # tensor of shape [num_boxes, 5]

    # stack images → OK
    imgs = torch.stack(imgs, 0)

    # targets는 stack 안 함 (num_boxes 다르기 때문)
    return imgs, targets

class YoloDataset(Dataset):
    def __init__(self, root):
        self.img_dir = os.path.join(root, "images")
        self.label_dir = os.path.join(root, "labels")
        self.images = sorted(os.listdir(self.img_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]

        img_path = os.path.join(self.img_dir, name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = torch.tensor(img).permute(2,0,1).float() / 255.0

        label_path = os.path.join(self.label_dir, name.replace(".jpg",".txt").replace(".png",".txt"))
        boxes = []
        if os.path.exists(label_path):
            with open(label_path, "r") as f:
                for line in f.readlines():
                    cls, x, y, w, h = map(float, line.split())
                    boxes.append([cls, x, y, w, h])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        return img, boxes



In [12]:
import os
DATA_PATH = dataset.location

train_dataset = YoloDataset(os.path.join(DATA_PATH, "train"))
val_dataset   = YoloDataset(os.path.join(DATA_PATH, "valid"))

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=yolo_collate_fn
)


In [13]:
import torch
import torch.nn.functional as F

def yolo_loss(preds, targets, img_size=512):
    object_loss = 0
    class_loss = 0
    box_loss = 0

    for scale_id, (cls_pred, obj_pred, box_pred) in enumerate(preds):
        B, num_cls, H, W = cls_pred.shape

        # ----- flatten -----
        cls_pred = cls_pred.permute(0,2,3,1).reshape(B, -1, num_cls)     # [B, H*W, C]
        obj_pred = obj_pred.permute(0,2,3,1).reshape(B, -1, 1)           # [B, H*W, 1]
        box_pred = box_pred.permute(0,2,3,1).reshape(B, -1, 4)           # [B, H*W, 4]

        obj_pred = obj_pred.sigmoid()
        cls_pred = cls_pred.sigmoid()

        # ----- iterate batch -----
        for b in range(B):
            gt = targets[b]

            if len(gt) == 0:
                # object 없음 → obj는 0이어야 함
                object_loss += (obj_pred[b] ** 2).mean()
                continue

            # object 있음 → obj는 1이어야 함
            object_loss += ((1 - obj_pred[b]) ** 2).mean()

            # ----- class loss -----
            tcls = gt[:, 0].long()        # class index
            class_loss += F.cross_entropy(
                cls_pred[b],
                tcls.expand(cls_pred[b].shape[0]),
                reduction='mean'
            )

            # ----- box loss (L1) -----
            # GT scaled to pixels
            tbox = gt[:, 1:].clone()
            tbox = tbox * img_size        # [xywh normalized → pixel]

            pred_box = box_pred[b][:len(gt)] * img_size

            box_loss += F.l1_loss(pred_box, tbox, reduction='mean')

    total = object_loss + class_loss + box_loss
    return total


In [15]:
model = HybridTwoWay(num_classes=3).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def simple_yolo_loss(preds, targets):
    return preds[0][0].mean() * 0  # 일단 파이프라인 연결용

EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, targets in train_loader:
        imgs = imgs.cuda()

        preds, aux = model(imgs)

        loss = simple_yolo_loss(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | loss {total_loss:.4f}")
torch.save(model.state_dict(), "hybrid_two_way_best.pt")



AssertionError: Torch not compiled with CUDA enabled

In [None]:
model = HybridTwoWay(
    in_ch=3,
    stem_base=48,
    embed_dim=512,
    vit_depth=8,
    vit_heads=8,
    num_classes=3,
    iters=1,
    detach_feedback=True
).cuda()

model.load_state_dict(torch.load("hybrid_two_way_best.pt"))
model.eval()


## 7. Quick Sanity Check
Colab에서 바로 실행해 모델 입출력 형태를 확인할 수 있습니다.

In [16]:
model = HybridTwoWay(
    in_ch=3,
    stem_base=48,
    embed_dim=512,
    vit_depth=8,
    vit_heads=8,
    num_classes=1,
    iters=1,
    detach_feedback=True,
)

x = torch.randn(2, 3, 640, 640)
preds, aux = model(x)
for i, (c, o, b) in enumerate(preds, start=3):
    print(f"[TwoWay] P{i} cls:{list(c.shape)} obj:{list(o.shape)} box:{list(b.shape)}")


[TwoWay] P3 cls:[2, 1, 80, 80] obj:[2, 1, 80, 80] box:[2, 4, 80, 80]
[TwoWay] P4 cls:[2, 1, 40, 40] obj:[2, 1, 40, 40] box:[2, 4, 40, 40]
[TwoWay] P5 cls:[2, 1, 20, 20] obj:[2, 1, 20, 20] box:[2, 4, 20, 20]
