In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ultralytics.nn.modules.block import C3k2, SPPF, C2PSA
from ultralytics.nn.modules.conv import Conv
from ultralytics.nn.modules.head import Detect

x = torch.randn(1, 1, 640, 640)

In [None]:
class CustomYOLO(nn.Module):
    def __init__(self, nc):
        super(CustomYOLO, self).__init__()
        
        self.backbone = nn.Sequential(
            Conv(1, 32, 3, 1, 1),
            Conv(32, 64, 3, 2, 1),
            C3k2(64, 64, 1),
            Conv(64, 128, 3, 2, 1),
            C3k2(128, 128, 2),
            Conv(128, 256, 3, 2, 1),
            C3k2(256, 256, 8),
            Conv(256, 512, 3, 2, 1),
            C3k2(512, 512, 8),
            Conv(512, 1024, 3, 2, 1),
            C3k2(1024, 1024, 4),
        )
        
        self.neck = nn.Sequential(
            SPPF(1024, 1024, [5, 9, 13]),
            C2PSA(1024, 1024, 4),
        )
        
        self.head = Detect(1024, nc, 3)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x
        
        