In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

# Define Depth-Wise Separable Convolution

In [17]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, padding=1):
        super().__init__()

        self.depthwise = nn.Conv2d(
            in_channels, in_channels,
            kernel_size=kernel_size, stride=stride, padding=padding,
            groups=in_channels, bias=False
        )
        self.dw_bn = nn.BatchNorm2d(in_channels)

        self.pointwise = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, bias=False
        )
        self.pw_bn = nn.BatchNorm2d(out_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.dw_bn(self.depthwise(x)))
        x = self.relu(self.pw_bn(self.pointwise(x)))
        return x


# MobileNet

In [30]:
class MobileNet(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            DepthwiseSeparableConv2d(32, 64, kernel_size=3, stride=1, padding=1),
            DepthwiseSeparableConv2d(64, 128, kernel_size=3, stride=2, padding=1),
            DepthwiseSeparableConv2d(128, 128, kernel_size=3, stride=1, padding=1),
            DepthwiseSeparableConv2d(128, 256, kernel_size=3, stride=2, padding=1),
            DepthwiseSeparableConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            DepthwiseSeparableConv2d(256, 512, kernel_size=3, stride=2, padding=1),

            *[DepthwiseSeparableConv2d(512, 512, kernel_size=3, stride=1, padding=1) for _ in range(5)],

            DepthwiseSeparableConv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            DepthwiseSeparableConv2d(1024, 1024, kernel_size=3, stride=1, padding=1),         
        )

        # Classification head
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x,1)
        x = self.classifier(x)

        return x

In [33]:
def load_image_as_tensor(path,
                         img_size=224,
                         mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225),
                         add_batch_dim=True,
                         device="cpu"):
    # 1. Load image
    img = Image.open(path).convert("RGB")

    # 2. Define transforms
    tfm = transforms.Compose([
        transforms.Resize(img_size + 32),  # resize slightly larger first
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),             # [0,1] range, shape (C,H,W)
        transforms.Normalize(mean, std),
    ])

    # 3. Apply transforms
    tensor = tfm(img)

    # 4. Optional batch dimension
    if add_batch_dim:
        tensor = tensor.unsqueeze(0)       # shape (1,C,H,W)

    return tensor.to(device, non_blocking=True)


# Example
if __name__ == "__main__":
    x = load_image_as_tensor("../../dent2.jpeg", img_size=256, device="cuda")
    print(x.shape)  # e.g. torch.Size([1, 3, 256, 256])

torch.Size([1, 3, 256, 256])


In [35]:
model = MobileNet()
model = model.to("cuda")
out = model(x)

print(out.shape)

torch.Size([1, 1000])
