In [27]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
# Build_block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_skip, type_block="BasicBlock"):
        super().__init__()
        self.is_skip = is_skip
        self.type_block = type_block
        # First connections block stride = 2,
        self.stride = 1 + is_skip
        
        if type_block == "BasicBlock":
            self.block = self._build_basic_block(in_channels, out_channels)
        else:
            self.block = self._build_bottleneck_block(in_channels, out_channels)
        
        if self.is_skip:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, padding=0, bias=False)
            self.shortcut_bn = nn.BatchNorm2d(out_channels)
        else:
            self.shortcut = nn.Identity()
    #ResNet18,34
    def _build_basic_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
    # ResNet 50 >>
    def _build_bottleneck_block(self, in_channels, out_channels):
        inter_channels = out_channels // 4
        return nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, kernel_size=1, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, inter_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.block(x)
        shortcut = self.shortcut(x)
        # Skip block y = f(x) + x
        if self.is_skip:
            shortcut = self.shortcut_bn(shortcut)
        out += shortcut
        out = nn.ReLU(inplace=True)(out)  # Adding ReLU at end of block
        return out


In [33]:
class ResNet(nn.Module):
    def __init__(self, block_out, block_depth, type_block="BasicBlock"):
        super().__init__()
        self.in_channels = block_out[0]
        # First layers
        layers = [
            nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        ]
        # Blocks
        for out_channels, depth in zip(block_out[1:], block_depth):
            for i in range(depth):
                is_skip = (i == 0)
                block = ResidualBlock(self.in_channels, out_channels, is_skip, type_block)
                layers.append(block)
                self.in_channels = out_channels
        # Fully connected
        layers += [
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(block_out[-1], 1000)
        ]
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)


In [34]:
class ResNet18(ResNet):
    def __init__(self):
        super().__init__(
            block_out=[64, 64, 128, 256, 512],
            block_depth=[2, 2, 2, 2],
            type_block="BasicBlock"
        )

class ResNet34(ResNet):
    def __init__(self):
        super().__init__(
            block_out=[64, 64, 128, 256, 512],
            block_depth=[3, 4, 6, 3],
            type_block="BasicBlock"
        )

class ResNet50(ResNet):
    def __init__(self):
        super().__init__(
            block_out=[64, 256, 512, 1024, 2048],
            block_depth=[3, 4, 6, 3],
            type_block="BottleneckBlock"
        )


model = ResNet18()
print(model)

ResNet18(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (shortcut_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): ResidualBlock(
      (blo

In [35]:
def test_resnet18():
    # Khởi tạo mô hình ResNet18
    model = ResNet18()
    
    # Tạo dữ liệu giả làm với kích thước (batch_size=4, 3 kênh màu, chiều cao và chiều rộng 224x224)
    x = torch.randn(4, 3, 224, 224)
    
    # Thực hiện phép tính forward pass
    out = model(x)
    
    # In ra shape của đầu ra và cấu trúc mô hình
    print("Output shape:", out.shape)
    print("Model architecture:\n", model)

# Chạy hàm test
test_resnet18()

Output shape: torch.Size([4, 1000])
Model architecture:
 ResNet18(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): ResidualBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (shortcut_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run