In [23]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, "../..")
import torch
import torch.nn as nn
from pathlib import Path

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
from src.data import data_tools
image_folder = Path("../../data/raw/datasets/flower_photos")
datagen = data_tools.data_generator(
    batch_size = 32, 
    path = image_folder, 
    image_size = (150, 150),
    channels = 3,
    shuffle = True
)

In [25]:
X, y = next(datagen)
X.shape, y.shape

((32, 150, 150, 3), (32,))

In [34]:
input = torch.from_numpy(X).float().permute(0, 3, 1, 2)
input.shape

torch.Size([32, 3, 150, 150])

In [36]:
input[0].dtype

torch.float32

In [131]:
conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
relu = nn.ReLU()
maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 

x = conv1(input)
x = relu(x)
x = maxpool(x)
x.shape

torch.Size([32, 64, 38, 38])

In [129]:
in_channels = 64
out_channels = 128
stride = 2
conv3x3 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False
    )

conv1x1 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=1,
        stride=stride,
        padding=0,
        bias=False
    )

x2 = conv3x3(x)
identity = conv1x1(x)
x2.shape, identity.shape

(torch.Size([32, 128, 19, 19]), torch.Size([32, 128, 19, 19]))

In [140]:
def conv3x3(in_channels: int, out_channels: int, stride: int=1) -> nn.Conv2d:
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False
    )

def conv1x1(in_channels: int, out_channels: int, stride: int=2) -> nn.Conv2d:
    return nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=1,
        stride=stride,
        padding=0,
        bias=False
    )


class BasicBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        if in_channels != out_channels:
            self.downsample = nn.Sequential(
                conv1x1(in_channels, out_channels),
                nn.BatchNorm2d(out_channels)
            )
            stride = 2
        else:
            self.downsample = None
            stride = 1

        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)

        return out 

In [141]:
x.shape

torch.Size([32, 64, 38, 38])

In [142]:
basicblock = BasicBlock(64, 64)
x1 = basicblock(x)
x1.shape

torch.Size([32, 64, 38, 38])

In [143]:
basicblock = BasicBlock(64, 128)
x2 = basicblock(x1)
x2.shape

torch.Size([32, 128, 19, 19])

In [169]:
layers = [64, 64, 64, 128, 128, 128, 256, 256, 512]

for i in range(len(layers) - 1):
    print(layers[i], layers[i+1])

64 64
64 64
64 128
128 128
128 128
128 256
256 256
256 512


In [182]:
class ResNet(nn.Module):
    def __init__(self, num_classes, layers) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 
        block_num = len(layers) - 1

        self.blocks = nn.ModuleList([BasicBlock(layers[i], layers[i+1]) for i in range(block_num)])

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.flat = nn.Flatten()
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for block in self.blocks:
            x = block(x)

        x = self.avgpool(x)
        x = self.flat(x)
        out = self.fc(x)

        return out 

In [183]:
resnet = ResNet(num_classes=5, layers=layers)

In [184]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(resnet)

6753733

In [185]:
out = resnet(input)
out.shape

torch.Size([32, 5])