In [92]:
import torch
from torch.autograd import Function

# Parameter-less example

In [93]:
from numpy.fft import rfft2, irfft2
class BadFFTFunction(Function):
    
    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return torch.FloatTensor(result)
    
    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return torch.FloatTensor(result)

def incorrect_fft(input):
    return FFTFunction()(input)

In [94]:
input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)


  3.0878   7.1403   7.5860   1.7596   3.0176
  6.3160  15.2517  11.1081   0.9172   6.8577
  8.6503   2.2013   6.3555  11.1981   1.9266
  3.9919   6.8862   8.8132   5.7938   4.2413
 12.2501  10.7839   6.7181  12.1096   1.1942
  3.9919   9.3072   2.6704   3.3263   4.2413
  8.6503   6.8158  12.4148   2.6462   1.9266
  6.3160  15.2663   9.8261   5.8583   6.8577
[torch.FloatTensor of size 8x5]


 0.0569 -0.3193  0.0401  0.1293  0.0318  0.1293  0.0401 -0.3193
 0.0570  0.0161 -0.0421 -0.1272  0.0414  0.0121 -0.0592 -0.0874
-0.1144 -0.0146  0.0604 -0.0023  0.0222  0.0622  0.0825 -0.1057
-0.0451  0.1061  0.0329 -0.0274  0.0302 -0.0347  0.0227 -0.1079
 0.1287  0.1796 -0.0766 -0.0698  0.0929 -0.0698 -0.0766  0.1796
-0.0451 -0.1079  0.0227 -0.0347  0.0302 -0.0274  0.0329  0.1061
-0.1144 -0.1057  0.0825  0.0622  0.0222 -0.0023  0.0604 -0.0146
 0.0570 -0.0874 -0.0592  0.0121  0.0414 -0.1272 -0.0421  0.0161
[torch.FloatTensor of size 8x8]



# Parametrized example

In [95]:
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module

class ScipyConv2dFunction(Function):
    
    def forward(self, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        self.save_for_backward(input, filter)
        return torch.FloatTensor(result)
    
    def backward(self, grad_output):
        input, filter = self.saved_tensors
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(grad_output.numpy(), input.numpy(), mode='valid')
        return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)


class ScipyConv2d(Module):
    
    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__(
            filter=torch.randn(kh, kw)
        )
    
    def forward(self, input):
        return ScipyConv2dFunction()(input, self.filter)

In [96]:
module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)


[Variable containing:
-1.5070  1.2195  0.3059
-0.9716 -1.6591  0.0582
 0.3959  1.4859  0.5762
[torch.FloatTensor of size 3x3]
]
Variable containing:
  0.8031  -2.6673  -3.7764   0.3957  -3.7494  -1.7617  -1.0052  -5.8402
  1.3038   6.2255   3.8769   2.4016  -1.7805  -3.1314   4.7049  11.2956
 -3.4491   0.1618  -2.5647   2.3304  -0.2030   0.9072  -3.5095  -1.4599
  1.7574   0.6292   0.5140  -0.9045  -0.7373  -1.2061  -2.2977   3.6035
  0.4435  -1.0651  -0.5496   0.6387   1.7522   4.5231  -0.5720  -3.3034
 -0.8580  -0.4809   2.4041   7.1462  -6.4747  -5.3665   2.0541   4.8248
 -3.3959   0.2333  -0.2029  -2.6130   2.9378   2.5276  -0.8665  -2.6157
  4.6814  -5.2214   5.0351   0.9138  -5.0147  -3.1597   1.9054  -1.2458
[torch.FloatTensor of size 8x8]


 0.1741 -1.9989 -0.2740  3.8120  0.3502  0.6712  3.0274  1.7058  0.4150 -0.3298
-1.8919 -2.6355 -3.2564  3.6947  2.5255 -6.7857  0.2239 -1.5672 -0.2663 -1.1211
 2.8815  2.5121 -4.7712  3.5822 -4.3752  0.7339 -0.7228 -1.7776 -2.0243  0.5019
-