In [1]:
"""
Write a code using pytorch to replicate a grouped 2D convolution layer based on the original 2D convolution. 

The common way of using grouped 2D convolution layer in Pytorch is to use 
torch.nn.Conv2d(groups=n), where n is the number of groups.

However, it is possible to use a stack of n torch.nn.Conv2d(groups=1) to replicate the same
result. The wights must be copied and be split between the convs in the stack.

You can use:
    - use default values for anything unspecified  
    - all available functions in NumPy and Pytorch
    - the custom layer must be able to take all parameters of the original nn.Conv2d 
"""

import numpy as np
import torch
import torch.nn as nn


torch.manual_seed(8)    # DO NOT MODIFY!
np.random.seed(8)   # DO NOT MODIFY!

# random input (batch, channels, height, width)
x = torch.randn(2, 64, 100, 100)

# original 2d convolution
grouped_layer = nn.Conv2d(64, 128, 3, stride=1, padding=1, groups=16, bias=True)

# weights and bias
w_torch = grouped_layer.weight
b_torch = grouped_layer.bias

y = grouped_layer(x)


In [3]:

# Define the custom grouped 2D convolution layer
class CustomGroupedConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups, stride, padding, dilation=1, bias=True):
        super(CustomGroupedConv2D, self).__init__()
        assert in_channels % groups == 0 and out_channels % groups == 0, "in_channels and out_channels must be divisible by groups"
        self.groups = groups
        self.convs = nn.ModuleList()
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups, kernel_size, stride, padding, dilation, bias))
        
    def forward(self, x):
        x_splits = torch.chunk(x, self.groups, dim=1)
        out_splits = []
        for conv, x_split in zip(self.convs, x_splits):
            out_splits.append(conv(x_split))
        out = torch.cat(out_splits, dim=1)
        return out

custom_grouped_layer = CustomGroupedConv2D(64, 128, 3, stride=1, padding=1, groups=16, bias=True)

# Set the weights of the layers to be equal
grouped_layer.weight.data = custom_grouped_layer.convs[0].weight.data.clone()
grouped_layer.bias.data = custom_grouped_layer.convs[0].bias.data.clone()

# Compute the outputs of the layers
grouped_output = grouped_layer(x)
custom_grouped_output = custom_grouped_layer(x)

# Verify that the outputs are equal
assert torch.allclose(grouped_output, custom_grouped_output)


RuntimeError: ignored