In [1]:
import torch
from torch.autograd import Function
from torch.autograd import Variable

In [2]:
# Parameter-less example

from numpy.fft import rfft2, irfft2


class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an nn.Module class


def incorrect_fft(input):
    return BadFFTFunction()(input)

In [3]:
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)


  6.8588   9.5525   7.9841   2.8806  12.0924
  1.7423   2.5442   5.2907  14.8600   4.1375
 11.7613   3.1875   2.2128   7.2097   6.7645
  7.3674  10.4123   3.8119   5.6222   7.8309
  7.4967   0.8123   5.8453   8.1193  16.5027
  7.3674  12.5861   6.2285   5.2712   7.8309
 11.7613   1.2707   6.8687   5.4283   6.7645
  1.7423  11.2106   4.1893   7.1905   4.1375
[torch.FloatTensor of size 8x5]

Variable containing:
 0.1086  0.1009 -0.0471 -0.0078  0.1361 -0.0078 -0.0471  0.1009
-0.0674  0.0455  0.0523 -0.0335  0.1619 -0.1309 -0.0858  0.2233
 0.2692 -0.0024 -0.0193  0.2196 -0.2154 -0.3030 -0.0429 -0.1246
 0.0246  0.0781  0.1144 -0.1674 -0.0332  0.2231 -0.0808  0.0858
-0.2110 -0.0919 -0.1999  0.1392  0.3285  0.1392 -0.1999 -0.0919
 0.0246  0.0858 -0.0808  0.2231 -0.0332 -0.1674  0.1144  0.0781
 0.2692 -0.1246 -0.0429 -0.3030 -0.2154  0.2196 -0.0193 -0.0024
-0.0674  0.2233 -0.0858 -0.1309  0.1619 -0.0335  0.0523  0.0455
[torch.FloatTensor of size 8x8]



In [4]:
# Parametrized example

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):

    def forward(self, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        self.save_for_backward(input, filter)
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        input, filter = self.saved_tensors
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
        return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction()(input, self.filter)

In [5]:
module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

[Parameter containing:
 0.2725 -0.2817  0.1157
-0.1808  0.3620 -0.0478
-0.9202  0.1223 -0.2714
[torch.FloatTensor of size 3x3]
]
Variable containing:
 0.1253  0.6888 -1.1041  0.2030  0.1879 -0.1882  0.0959  0.2377
-0.1341 -0.5954 -0.9443 -1.2079 -0.0058 -0.2759  1.3677  0.0510
 0.6825 -0.6854  0.0778  0.3036  0.2354 -1.0283  1.1369 -0.7911
-0.4420  1.1664 -0.8897  1.9818  2.2564  0.4604 -1.4295  0.5643
-0.5910 -0.4363  0.6603 -0.9424  1.0148  0.8080 -0.5390  1.4721
-0.1394 -1.2209 -0.3980  0.0518 -0.2115 -0.6316  0.4584  1.0215
 1.5436  1.4514  0.0699  2.6585  0.5521  0.7643  1.3757  1.4045
-0.1154 -2.0398  0.2820  0.1142 -2.1098 -0.1444 -1.2546  2.4765
[torch.FloatTensor of size 8x8]

Variable containing:
 0.1534  0.3503 -0.5697 -1.6099 -0.8139 -0.0340 -0.6373 -0.6907  0.4703 -0.6276
-0.2122 -0.4150  0.4637  1.0650  0.7771  0.5896  1.4773  0.1563  1.3614  0.3508
-0.1531  0.0284  1.1692  0.1714 -0.6412  1.5286  1.9874 -0.0075 -1.0116  0.4074
 0.2928 -0.4348 -0.6240  2.6402 -0.5040 -2.0