In [1]:
import torch
import torch.nn.functional as F
from torch import nn

In [19]:
class Conv2D(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=(3,3), stride=1, padding=1):
        super(Conv2D, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, input_batch):
        b, c, h, w = input_batch.size()
        x = self.conv(input_batch)
        return x

In [20]:
conv = Conv2D(in_channels = 3, out_channels = 16)
input_batch = torch.randn(16, 3, 32, 32)
output_batch = conv(input_batch)

In [23]:
output_batch.shape

torch.Size([16, 16, 32, 32])

In [81]:
class Conv2DFunc(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input_batch, kernel, stride=1, padding=1):
        
        # store objects for the backward
        ctx.save_for_backward(input_batch)
        ctx.save_for_backward(kernel)
        
        output_batch = F.conv2d(input_batch, kernel, stride=stride, padding=padding)
        
        return output_batch
    
    
    @staticmethod
    def backward(ctx, grad_output):
        
        # retrieve stored objects
        input, kernel = ctx.saved_tensors
        
        input_batch_grad = kernel_grad = None
        
        #input_batch_grad = grad_output.mm(kernel)
        #kernel_grad = grad_output.t().mm(input)
        
        input_batch_grad = F.conv_transpose2d(grad_out, kernel)
        kernel_grad =  F.conv2d(input.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)
        
        return input_batch_grad, kernel_grad, None, None
        
        

In [87]:
input_batch = torch.randn(16, 3, 32, 32,  requires_grad=True, dtype=torch.double)
kernel = torch.randn(16, 3, 3, 3,  requires_grad=True, dtype=torch.double)
out = Conv2DFunc.apply(input_batch, kernel)

In [None]:
torch.autograd.gradcheck(Conv2DFunc.apply, (input_batch, kernel))