In [316]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import torch.nn.functional as F
import math
import torch.optim as optim

from torchvision.transforms import CenterCrop

In [362]:
class Conv2dDF(nn.Conv2d):
    def __init__(self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, alpha=0.5, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(Conv2dDF, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode)
        
        # unpair kernel_size from init super method
#         self.kernel_size = self.kernel_size[0]
        self.kernel_size = Parameter(torch.Tensor(self.kernel_size))
#         self.register_parameter(name='kernel_size', param=self.kernel_size)
        
        self.k_plus = (math.floor((self.kernel_size[0] + 1) / 2) * 2 + 1, math.floor((self.kernel_size[1] + 1) / 2) * 2 + 1)
        self.k_minus = (math.ceil((self.kernel_size[0] + 1) / 2) * 2 - 1, math.ceil((self.kernel_size[1] + 1) / 2) * 2 - 1)
        
        self.weight_plus = torch.Tensor(
                out_channels, in_channels // groups, *self.k_plus)
        
        self.weight_minus = CenterCrop(self.k_minus)(self.weight_plus)
        
        self.delta_w = self.weight_plus - F.pad(self.weight_minus, (1,1,1,1,0,0), mode='constant', value=0)
        self.weight = Parameter(alpha * self.delta_w + F.pad(self.weight_minus, (1,1,1,1,0,0), mode='constant', value=0))
        
        self.register_backward_hook(self.backward_func)
        
    def backward_func(self, module, grad_input, grad_output):
        print(grad_input)
#         print(grad_output)
        
    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
#     @staticmethod
    def forward(self, input):
        return self._conv_forward(input, self.weight)
    
# #     @staticmethod
#     def backward(self, grad_output):
#         """
#         In the backward pass we receive a Tensor containing the gradient of the loss
#         with respect to the output, and we need to compute the gradient of the loss
#         with respect to the input.
#         """
# #         print(ctx, grad_output)
#         print(grad_output)
#         return super(Conv2dDF, self).backward(grad_output)
       

In [363]:


conv = Conv2dDF(3, 64, 3, padding=2, stride=1)
# conv = nn.Conv2d(3, 64, 3, padding=1, stride=1)
optimizer = optim.Adam(conv.parameters(), 0.001)

x = conv(torch.rand(12, 3, 48, 48))
lasst_w = conv.weight

(torch.ones((12, 64, 48, 48)) - x).mean().backward()
optimizer.step()

cur_w = conv.weight

(None, tensor([[[[-0.0072, -0.0074, -0.0075, -0.0074, -0.0072],
          [-0.0074, -0.0075, -0.0077, -0.0075, -0.0074],
          [-0.0075, -0.0077, -0.0078, -0.0077, -0.0075],
          [-0.0074, -0.0075, -0.0077, -0.0075, -0.0074],
          [-0.0072, -0.0074, -0.0075, -0.0074, -0.0072]],

         [[-0.0072, -0.0074, -0.0075, -0.0074, -0.0072],
          [-0.0074, -0.0075, -0.0077, -0.0075, -0.0074],
          [-0.0075, -0.0077, -0.0078, -0.0077, -0.0075],
          [-0.0074, -0.0075, -0.0077, -0.0075, -0.0074],
          [-0.0072, -0.0073, -0.0075, -0.0074, -0.0072]],

         [[-0.0072, -0.0073, -0.0075, -0.0073, -0.0072],
          [-0.0073, -0.0075, -0.0076, -0.0075, -0.0073],
          [-0.0075, -0.0076, -0.0078, -0.0076, -0.0075],
          [-0.0073, -0.0075, -0.0076, -0.0075, -0.0073],
          [-0.0072, -0.0073, -0.0075, -0.0073, -0.0072]]],


        [[[-0.0072, -0.0074, -0.0075, -0.0074, -0.0072],
          [-0.0074, -0.0075, -0.0077, -0.0075, -0.0074],
          [-0.00

In [365]:
from torch.autograd import Function

class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

In [366]:
linear = LinearFunction.apply

In [374]:
t = linear(torch.rand(12, 23), torch.rand(12, 23))
(torch.ones(12,12) - t).mean().backward(torch.rand(12,12))

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([12, 12]) and output[0] has a shape of torch.Size([]).