In [17]:
"""
YOLOv9-like full model implementation (PyTorch) with 3-scale outputs (P3, P4, P5)
"""

import torch
import torch.nn as nn


def autopad(k, p=None):
    if p is None:
        p = k // 2
    return p


class Conv(nn.Module):
    def __init__(self, in_ch, out_ch, k=1, s=1, g=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, autopad(k), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.SiLU() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class Bottleneck(nn.Module):
    def __init__(self, in_ch, out_ch, shortcut=True, g=1, e=0.5):
        super().__init__()
        hidden_ch = int(out_ch * e)
        self.conv1 = Conv(in_ch, hidden_ch, 1, 1)
        self.conv2 = Conv(hidden_ch, out_ch, 3, 1, g=g)
        self.use_add = shortcut and in_ch == out_ch

    def forward(self, x):
        y = self.conv2(self.conv1(x))
        return x + y if self.use_add else y


class CSPBlock(nn.Module):
    def __init__(self, in_ch, out_ch, n=1, e=0.5):
        super().__init__()
        hidden = int(out_ch * e)
        self.conv1 = Conv(in_ch, hidden, 1, 1)
        self.conv2 = Conv(in_ch, hidden, 1, 1)
        self.blocks = nn.Sequential(*[Bottleneck(hidden, hidden, shortcut=True, e=1.0) for _ in range(n)])
        self.conv3 = Conv(2 * hidden, out_ch, 1, 1)

    def forward(self, x):
        y1 = self.blocks(self.conv1(x))
        y2 = self.conv2(x)
        return self.conv3(torch.cat((y1, y2), dim=1))


class SPPF(nn.Module):
    def __init__(self, in_ch, out_ch, k=5):
        super().__init__()
        hidden = in_ch // 2
        self.conv1 = Conv(in_ch, hidden, 1, 1)
        self.conv2 = Conv(hidden * 4, out_ch, 1, 1)
        self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.conv1(x)
        y1 = self.pool(x)
        y2 = self.pool(y1)
        return self.conv2(torch.cat((x, y1, y2, self.pool(y2)), dim=1))


class Backbone(nn.Module):
    def __init__(self, depth_mult=1.0, width_mult=1.0):
        super().__init__()
        base_channels = int(64 * width_mult)
        base_depth = max(round(3 * depth_mult), 1)

        self.conv1 = Conv(3, base_channels, 3, 2)
        self.csp1 = CSPBlock(base_channels, base_channels * 2, n=base_depth)

        self.conv2 = Conv(base_channels * 2, base_channels * 4, 3, 2)
        self.csp2 = CSPBlock(base_channels * 4, base_channels * 4, n=base_depth * 2)

        self.conv3 = Conv(base_channels * 4, base_channels * 8, 3, 2)
        self.csp3 = CSPBlock(base_channels * 8, base_channels * 8, n=base_depth * 3)

        self.conv4 = Conv(base_channels * 8, base_channels * 16, 3, 2)
        self.csp4 = CSPBlock(base_channels * 16, base_channels * 16, n=base_depth * 1)

        self.conv5 = Conv(base_channels * 16, base_channels * 32, 3, 2)
        self.sppf = SPPF(base_channels * 32, base_channels * 32)

    def forward(self, x):
        x = self.csp1(self.conv1(x))
        x = self.csp2(self.conv2(x))
        c3 = self.csp3(self.conv3(x))
        c4 = self.csp4(self.conv4(c3))
        c5 = self.sppf(self.conv5(c4))
        return c3, c4, c5


class PANNeck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.reduce5 = Conv(in_channels[2], out_channels, 1, 1)
        self.reduce4 = Conv(in_channels[1], out_channels, 1, 1)
        self.reduce3 = Conv(in_channels[0], out_channels, 1, 1)

        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.fuse4 = Conv(out_channels * 2, out_channels, 3, 1)
        self.fuse3 = Conv(out_channels * 2, out_channels, 3, 1)

        self.down4 = Conv(out_channels, out_channels, 3, 2)
        self.fuse4_out = Conv(out_channels * 2, out_channels, 3, 1)

        self.down5 = Conv(out_channels, out_channels, 3, 2)
        self.fuse5_out = Conv(out_channels * 2, out_channels, 3, 1)

    def forward(self, c3, c4, c5):
        p5 = self.reduce5(c5)
        p4 = self.reduce4(c4)
        p3 = self.reduce3(c3)

        p4 = self.fuse4(torch.cat((p4, self.up(p5)), dim=1))
        p3 = self.fuse3(torch.cat((p3, self.up(p4)), dim=1))

        p4 = self.fuse4_out(torch.cat((self.down4(p3), p4), dim=1))
        p5 = self.fuse5_out(torch.cat((self.down5(p4), p5), dim=1))

        return p3, p4, p5


class DetectHead(nn.Module):
    def __init__(self, in_channels, num_classes, width=256):
        super().__init__()
        self.num_classes = num_classes
        self.cls_convs = nn.ModuleList([nn.Sequential(Conv(c, width, 3, 1), Conv(width, width, 3, 1)) for c in in_channels])
        self.reg_convs = nn.ModuleList([nn.Sequential(Conv(c, width, 3, 1), Conv(width, width, 3, 1)) for c in in_channels])

        self.cls_preds = nn.ModuleList([nn.Conv2d(width, num_classes, 1) for _ in in_channels])
        self.obj_preds = nn.ModuleList([nn.Conv2d(width, 1, 1) for _ in in_channels])
        self.reg_preds = nn.ModuleList([nn.Conv2d(width, 4, 1) for _ in in_channels])

    def forward(self, features):
        outputs = []
        for i, x in enumerate(features):
            cls_feat = self.cls_convs[i](x)
            reg_feat = self.reg_convs[i](x)
            cls_out = self.cls_preds[i](cls_feat)
            obj_out = self.obj_preds[i](reg_feat)
            reg_out = self.reg_preds[i](reg_feat)
            outputs.append(torch.cat((reg_out, obj_out, cls_out), dim=1))
        return outputs


class YOLOv9(nn.Module):
    def __init__(self, num_classes=80, width_mult=1.0, depth_mult=1.0):
        super().__init__()
        self.backbone = Backbone(depth_mult=depth_mult, width_mult=width_mult)
        c3_ch = int(64 * width_mult * 8)
        c4_ch = int(64 * width_mult * 16)
        c5_ch = int(64 * width_mult * 32)
        self.neck = PANNeck([c3_ch, c4_ch, c5_ch], out_channels=int(256 * width_mult))
        out_ch = int(256 * width_mult)
        self.detect = DetectHead([out_ch, out_ch, out_ch], num_classes, width=out_ch)

    def forward(self, x):
        c3, c4, c5 = self.backbone(x)
        p3, p4, p5 = self.neck(c3, c4, c5)
        outputs = self.detect([p3, p4, p5])
        return outputs





if __name__ == '__main__':
    model = YOLOv9(num_classes=80)
    dummy = torch.randn(1, 3, 640, 640)
    outs = model(dummy)
    for o in outs:
        print(o.shape)
    print('Params:', sum(p.numel() for p in model.parameters()))

In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

# YOLOv9 model definition (with P3, P4, P5)
# ... [Model code from before remains here] ...

class YOLODataset(Dataset):
    def __init__(self, images_dir, mode='train', img_size=640):
        self.mode = mode
        if self.mode == 'train':
            self.img_files = sorted(glob.glob(os.path.join(images_dir, 'images', 'train', '*.*')))
            self.label_dir = os.path.join(images_dir, 'labels', 'train')
        elif self.mode == 'val':
            self.img_files = sorted(glob.glob(os.path.join(images_dir, 'images', 'val', '*.*')))
            self.label_dir = os.path.join(images_dir, 'labels', 'val')
        else:
            raise ValueError("mode must be 'train' or 'val'")
        self.img_size = img_size

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_path = self.img_files[idx]
        label_path = os.path.join(self.label_dir, os.path.splitext(os.path.basename(img_path))[0] + '.txt')

        img = Image.open(img_path).convert('RGB')
        img = img.resize((self.img_size, self.img_size))
        img = np.array(img).astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)

        boxes = []
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                for line in f.readlines():
                    cls, cx, cy, w, h = map(float, line.strip().split())
                    boxes.append([cls, cx, cy, w, h])
        boxes = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0, 5), dtype=torch.float32)

        return img, boxes

def train_model(model, dataloader, epochs=10, lr=1e-3, device='cuda'):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = sum(criterion(out, torch.zeros_like(out)) for out in outputs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader):.4f}")

if __name__ == "__main__":
    train_dataset = YOLODataset(r"D:\pcb_defect\ProdVision_django\PCB_DATASET_yolo_version", mode='train')
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_dataset  = YOLODataset(r"D:\pcb_defect\ProdVision_django\PCB_DATASET_yolo_version", mode='val')
    val_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)


    model = YOLOv9(num_classes=6)


YOLOv9(
  (backbone): Backbone(
    (conv1): Conv(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (csp1): CSPBlock(
      (conv1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (conv2): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (blocks): Sequential(
        (0): Bottleneck(
          (conv1): Conv(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a