From 0651409a7265828ba6907f1c42a68f56b4ed7606 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Tue, 13 Feb 2018 16:42:14 +0100 Subject: [PATCH] Use foo.new instead of torch.FloatTensor for GPU. This replaces the calls to `torch.FloatTensor` by a call to `.new` on the input tensor, such that GPU types are respected. --- advanced_source/numpy_extensions_tutorial.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/advanced_source/numpy_extensions_tutorial.py b/advanced_source/numpy_extensions_tutorial.py index 93016d60219..e14dec66a19 100644 --- a/advanced_source/numpy_extensions_tutorial.py +++ b/advanced_source/numpy_extensions_tutorial.py @@ -38,12 +38,12 @@ class BadFFTFunction(Function): def forward(self, input): numpy_input = input.numpy() result = abs(rfft2(numpy_input)) - return torch.FloatTensor(result) + return input.new(result) def backward(self, grad_output): numpy_go = grad_output.numpy() result = irfft2(numpy_go) - return torch.FloatTensor(result) + return grad_output.new(result) # since this layer does not have any parameters, we can # simply declare this as a function, rather than as an nn.Module class @@ -90,7 +90,7 @@ class ScipyConv2dFunction(Function): def forward(ctx, input, filter): result = correlate2d(input.numpy(), filter.numpy(), mode='valid') ctx.save_for_backward(input, filter) - return torch.FloatTensor(result) + return input.new(result) @staticmethod def backward(ctx, grad_output): @@ -99,8 +99,8 @@ def backward(ctx, grad_output): grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full') grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid') - return Variable(torch.FloatTensor(grad_input)), \ - Variable(torch.FloatTensor(grad_filter)) + return Variable(grad_output.new(grad_input)), \ + Variable(grad_output.new(grad_filter)) class ScipyConv2d(Module):