@@ -38,12 +38,12 @@ class BadFFTFunction(Function):
3838 def forward (self , input ):
3939 numpy_input = input .numpy ()
4040 result = abs (rfft2 (numpy_input ))
41- return torch . FloatTensor (result )
41+ return input . new (result )
4242
4343 def backward (self , grad_output ):
4444 numpy_go = grad_output .numpy ()
4545 result = irfft2 (numpy_go )
46- return torch . FloatTensor (result )
46+ return grad_output . new (result )
4747
4848# since this layer does not have any parameters, we can
4949# simply declare this as a function, rather than as an nn.Module class
@@ -90,7 +90,7 @@ class ScipyConv2dFunction(Function):
9090 def forward (ctx , input , filter ):
9191 result = correlate2d (input .numpy (), filter .numpy (), mode = 'valid' )
9292 ctx .save_for_backward (input , filter )
93- return torch . FloatTensor (result )
93+ return input . new (result )
9494
9595 @staticmethod
9696 def backward (ctx , grad_output ):
@@ -99,8 +99,8 @@ def backward(ctx, grad_output):
9999 grad_input = convolve2d (grad_output .numpy (), filter .t ().numpy (), mode = 'full' )
100100 grad_filter = convolve2d (input .numpy (), grad_output .numpy (), mode = 'valid' )
101101
102- return Variable (torch . FloatTensor (grad_input )), \
103- Variable (torch . FloatTensor (grad_filter ))
102+ return Variable (grad_output . new (grad_input )), \
103+ Variable (grad_output . new (grad_filter ))
104104
105105
106106class ScipyConv2d (Module ):
0 commit comments