In [25]:
import torch
import torch.nn as nn
from torch.nn import Conv2d

In [82]:
class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth = 1, bias = False, auto_padding = False, padding = 0, stride = 1):
        super(DepthwiseSeparableConv2D, self).__init__()
        if auto_padding:
            if type(kernel_size) == tuple:
                padding = (kernel_size[0] // 2, kernel_size[1] // 2)
            else:
                padding = kernel_size // 2
            
        depthwise = nn.Conv2d(in_channels, in_channels, 
                              kernel_size = kernel_size,  padding = padding, stride = stride, 
                              groups = in_channels, bias = bias)
        pointwise = nn.Conv2d(in_channels, depth * out_channels, kernel_size = 1, bias = bias)
        
        self.depthwise_separable_convolution = nn.Sequential(depthwise,
                                                             pointwise)
        
    def forward(self, X):
        return self.depthwise_separable_convolution(X)
    
    
class SeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size: int, bias = False, auto_padding = False, padding = 0, stride = 1):
        super(SeparableConv2D, self).__init__()
        if auto_padding:
            padding = kernel_size // 2

        horizontal_convolution = nn.Conv2d(in_channels, in_channels, kernel_size = (kernel_size, 1), padding = (padding, 0), stride = stride,
                                        groups = in_channels, bias = bias)
        vertical_convolution = nn.Conv2d(in_channels, out_channels, 
                                      kernel_size = (1, kernel_size), padding = (0, padding), stride = stride,
                                      groups = 1, bias = bias)
        
        self.separable_convolution = nn.Sequential(horizontal_convolution,
                                                   vertical_convolution)
        
    def forward(self, X):
        return self.separable_convolution(X)

In [87]:
conv = torch.nn.Conv2d(in_channels=10, out_channels=32, kernel_size=3)
params = sum(p.numel() for p in conv.parameters() if p.requires_grad)

x = torch.rand(5, 10, 50, 50)
out = conv(x)

depthwise_separable_conv = DepthwiseSeparableConv2D(10, 32, depth = 1, kernel_size = 3)
out_depthwise = depthwise_separable_conv(x)

separable_convolution = SeparableConv2D(10, 10, 3, auto_padding = True)
out_separable = separable_convolution(x)

params = sum(p.numel() for p in conv.parameters() if p.requires_grad)
params_depthwise = sum(p.numel() for p in depthwise_separable_conv.parameters() if p.requires_grad)
params_separable = sum(p.numel() for p in separable_convolution.parameters() if p.requires_grad)

print(f"Output shape of standard convolution: {out.shape}")
print(f"Output shape of depthwise separated convolution: {out_depthwise.shape}")
print(f"Output shape of separable convolution: {out_separable.shape}")
print(f"The standard convolution uses {params} parameters.")
print(f"The depthwise separable convolution uses {params_depthwise} parameters.")
print(f"The separable convolution uses {params_separable} parameters.")

Output shape of standard convolution: torch.Size([5, 32, 48, 48])
Output shape of depthwise separated convolution: torch.Size([5, 32, 48, 48])
Output shape of separable convolution: torch.Size([5, 10, 50, 50])
The standard convolution uses 2912 parameters.
The depthwise separable convolution uses 410 parameters.
The separable convolution uses 330 parameters.
