In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torchvision.ops import DeformConv2d
import os

# -----------------------------
# Model Definition (with DeformConv2d fix and pooled transformer neck)
# -----------------------------
class CNNBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        # Offset prediction layer for DeformConv2d (kernel size 3 -> 2*3*3=18 channels)
        self.offset2 = nn.Conv2d(64, 18, 3, padding=1)
        self.conv2 = DeformConv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        offset = self.offset2(x)
        x = self.relu(self.conv2(x, offset))
        x = self.relu(self.conv3(x))
        return x

class TransformerNeck(nn.Module):
    def __init__(self, dim, num_heads=4, out_size=20):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((out_size, out_size))
        self.attn = nn.MultiheadAttention(dim, num_heads)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.pool(x)  # Reduce spatial size to (out_size, out_size)
        h, w = x.shape[2], x.shape[3]
        x_flat = x.view(b, c, -1).permute(2, 0, 1)  # (hw, b, c)
        x_attn, _ = self.attn(x_flat, x_flat, x_flat)
        x_attn = x_attn.permute(1, 2, 0).view(b, c, h, w)
        return x_attn

class YOLOHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, (num_classes + 5) * 3, 1)  # 3 anchors

    def forward(self, x):
        return self.conv(x)

class HybridYOLOv11(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = CNNBackbone()
        self.neck = TransformerNeck(dim=256, out_size=20)  # out_size can be adjusted
        self.head = YOLOHead(in_channels=256, num_classes=num_classes)

    def forward(self, x):
        features = self.backbone(x)
        fused = self.neck(features)
        preds = self.head(fused)
        return preds

# -----------------------------
# Top-level collate function
# -----------------------------
def coco_collate_fn(batch):
    return tuple(zip(*batch))

# -----------------------------
# Dataset Preparation
# -----------------------------
root = r"F:\Object Detection  Dataset\coco2017"
train_img_dir = os.path.join(root, "train2017")
val_img_dir = os.path.join(root, "val2017")
ann_dir = os.path.join(root, "annotations")

train_ann_file = os.path.join(ann_dir, "instances_train2017.json")
val_ann_file = os.path.join(ann_dir, "instances_val2017.json")

transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

train_dataset = CocoDetection(root=train_img_dir, annFile=train_ann_file, transform=transform)
val_dataset = CocoDetection(root=val_img_dir, annFile=val_ann_file, transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
    collate_fn=coco_collate_fn
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=coco_collate_fn
)

# -----------------------------
# Training Loop (Simplified)
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridYOLOv11(num_classes=80).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()  # Placeholder! Real YOLO uses custom loss

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, targets in train_loader:
        images = torch.stack(images).to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, torch.zeros_like(outputs))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}")

print("Training complete.")


loading annotations into memory...
Done (t=9.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.33s)
creating index...
index created!
