In [3]:
import torch
import torch.nn as nn

class StripedDWConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(StripedDWConv, self).__init__()
        self.conv_kx1 = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1),
                                  stride=1, padding=(kernel_size // 2, 0),groups=in_channels)
        self.conv_1xk = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                                  stride=1, padding=(0, kernel_size // 2),groups=in_channels)

    def forward(self, x):
        out = self.conv_1xk(x)
        out = self.conv_kx1(out)
        # out = out_kx1 + out_1xk
        return out

# Example usage
input_tensor = torch.randn(1, 64, 64, 64)  # batch_size=1, in_channels=32, height=64, width=64
striped_conv_layer = StripedDWConv(in_channels=64, out_channels=64, kernel_size=7)
output_tensor = striped_conv_layer(input_tensor)

print(output_tensor.shape)  # Should be [1, 64, 64, 64]


torch.Size([1, 64, 64, 64])
