In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DepthwiseSeparableConv, self).__init__()
        # Depthwise Convolution
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        # Pointwise Convolution
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        # Batch Normalization layers for both convolutions
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = self.bn1(F.relu(self.depthwise(x)))
        x = self.bn2(F.relu(self.pointwise(x)))
        return x

In [None]:
class MobileNetV1(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV1, self).__init__()
        
        # Initial standard convolution
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        
        # Sequence of depthwise separable convolutions
        self.features = nn.Sequential(
            DepthwiseSeparableConv(32, 64, stride=1),
            DepthwiseSeparableConv(64, 128, stride=2),
            DepthwiseSeparableConv(128, 128, stride=1),
            DepthwiseSeparableConv(128, 256, stride=2),
            DepthwiseSeparableConv(256, 256, stride=1),
            DepthwiseSeparableConv(256, 512, stride=2),
            DepthwiseSeparableConv(512, 512, stride=1),
            DepthwiseSeparableConv(512, 512, stride=1),
            DepthwiseSeparableConv(512, 512, stride=1),
            DepthwiseSeparableConv(512, 1024, stride=2),
            DepthwiseSeparableConv(1024, 1024, stride=1)
        )
        
        # Global Average Pooling
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        # Fully connected layer
        self.fc = nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = self.bn1(F.relu(self.conv1(x)))  # Apply initial conv and BN
        x = self.features(x)  # Apply the sequence of depthwise separable convolutions
        x = self.avgpool(x)  # Global average pooling
        x = torch.flatten(x, 1)  # Flatten the output for the fully connected layer
        x = self.fc(x)  # Fully connected output
        return x