In [6]:
import torch
import torch.nn.functional as F
from torch import nn

In [7]:
class Conv2D(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2D, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.conv = nn.Parameter(torch.ones(out_channels, in_channels, kernel_size, kernel_size))
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, input_batch):
        b, c, h, w = input_batch.size()
        k = self.kernel_size
        p = self.padding
        s = self.stride
        
        h_out = (h + 2*p - k)/s + 1
        w_out = (w + 2*p - k)/s + 1
        h_out, w_out = int(h_out), int(w_out)

        #Unfold
        x = torch.nn.functional.unfold(input_batch, (k, k), padding=p)
        x = x.transpose(1, 2)
        
        P = x.shape[1]
        
        # Reshape to (b*p, k)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
        
        #Matmul
        W = torch.reshape(self.conv , (self.conv.size(0), -1))
        y = x.matmul(W.t())
        
        # Reshape to (b, l, p)
        y = torch.reshape(y, (input_batch.shape[0], P, y.shape[1]))
        
        y = y.transpose(1, 2)
        
        out = torch.nn.functional.fold(y, (h_out, w_out), (1, 1))
        print(out.shape)
        
        r = self.conv1(input_batch)
        assert r.shape == out.shape
        
        return out


In [8]:
conv = Conv2D(in_channels = 3, out_channels = 10)
input_batch = torch.randn(16, 3, 32, 32)
output_batch = conv(input_batch)

torch.Size([16, 10, 32, 32])


In [9]:
class Conv2DFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input_batch, kernel, stride=1, padding=1):
        
        # store objects for the backward
        ctx.save_for_backward(input_batch)
        ctx.save_for_backward(kernel)
        
        h = input_batch.shape[2]
        w = input_batch.shape[3]
        c = input_batch.shape[1]
        k = kernel.shape[2]
        C = kernel.shape[0]

            
        #output_batch = F.conv2d(input_batch, kernel, stride=stride, padding=padding)
        h_out = (h + 2*padding - k)/stride + 1
        w_out = (w + 2*padding - k)/stride + 1
        h_out, w_out = int(h_out), int(w_out)
        

        #Unfold
        x = torch.nn.functional.unfold(input_batch, (k, k), padding=padding)
        x = x.transpose(1, 2)
        
        P = x.shape[1]
        
        # Reshape to (b*p, k)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2]))
        ctx.save_for_backward(x)
        
        #Matmul
        W = torch.reshape(kernel , (kernel.size(0), -1))
        #ctx.save_for_backward(W)
        y = x.matmul(W.t())
        
        
        # Reshape to (b, l, p)
        y = torch.reshape(y, (input_batch.shape[0], P, y.shape[1]))
        yshape = y.shape
        ctx.yshape = yshape
        y = y.transpose(1, 2)
    
        out = torch.nn.functional.fold(y, (h_out, w_out), (1, 1))
        print(out.shape)
        
        r = F.conv2d(input_batch, kernel, stride=stride, padding=padding)
        assert r.shape == out.shape
        
        return out
    
    
    @staticmethod
    def backward(ctx, grad_output):
        
        # retrieve stored objects
        input, kernel, U = ctx.saved_tensors
        
        
        # backward of Kernel
        yshape = ctx.yshape
        
        y_grad = torch.reshape(grad_output, yshape)
        y_grad = torch.reshape(yshape, (yhape[0] * yshape[1], yshape[2]))
        
        kernel_grad = U.t().matmul(y_grad)
        
        input_batch_grad = None
        
        return input_batch_grad, kernel_grad, None, None
        
        

In [10]:
input_batch = torch.randn(16, 3, 32, 32,  requires_grad=True, dtype=torch.double)
kernel = torch.randn(10, 3, 3, 3,  requires_grad=True, dtype=torch.double)
out = Conv2DFunc.apply(input_batch, kernel)

torch.Size([16, 10, 32, 32])
