forked from hughperkins/clnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ported SpatialUpSamplingNearest from cunn
- Loading branch information
Showing
8 changed files
with
388 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
require 'nn' | ||
|
||
nn.SpatialUpSamplingNearest.baseUpdateOutput = nn.SpatialUpSamplingNearest.updateOutput | ||
nn.SpatialUpSamplingNearest.baseUpdateGradInput = nn.SpatialUpSamplingNearest.updateGradInput | ||
|
||
function nn.SpatialUpSamplingNearest:updateOutput(input) | ||
if torch.type(input) ~= 'torch.ClTensor' then | ||
return self:baseUpdateOutput(input, target) | ||
end | ||
if input:dim() ~= 4 and input:dim() ~= 3 then | ||
error('SpatialUpSamplingNearest only support 3D or 4D tensors') | ||
end | ||
-- Copy the input size | ||
local xdim = input:dim() | ||
local ydim = input:dim() - 1 | ||
for i = 1, input:dim() do | ||
self.inputSize[i] = input:size(i) | ||
self.outputSize[i] = input:size(i) | ||
end | ||
self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor | ||
self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor | ||
-- Resize the output if needed | ||
if input:dim() == 3 then | ||
self.output:resize(self.outputSize[1], self.outputSize[2], | ||
self.outputSize[3]) | ||
else | ||
self.output:resize(self.outputSize) | ||
end | ||
input.THNN.SpatialUpSamplingNearest_updateOutput(input:cdata(), self.output:cdata(), self.scale_factor) | ||
|
||
return self.output | ||
end | ||
|
||
function nn.SpatialUpSamplingNearest:updateGradInput(input, gradOutput) | ||
if torch.type(input) ~= 'torch.ClTensor' then | ||
return self:baseUpdateGradInput(input, gradOutput) | ||
end | ||
self.gradInput:resizeAs(input) | ||
input.THNN.SpatialUpSamplingNearest_updateGradInput(input:cdata(), gradOutput:cdata(), self.gradInput:cdata(), self.scale_factor) | ||
return self.gradInput | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,4 @@ include 'CMulTable.lua' | |
|
||
include 'test.lua' | ||
|
||
include 'SpatialUpSamplingNearest.lua' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// from SpatialUpSamplingNearest.cu: | ||
|
||
/*__device__*/ int translate_idx(int ii, int d1, int d2, int d3, int scale_factor) | ||
{ | ||
int x, y, z, w; | ||
w = ii % d3; | ||
ii = ii/d3; | ||
z = ii % d2; | ||
ii = ii/d2; | ||
y = ii % d1; | ||
ii = ii/d1; | ||
x = ii; | ||
w = w/scale_factor; | ||
z = z/scale_factor; | ||
d2 /= scale_factor; | ||
d3 /= scale_factor; | ||
return (((x*d1+y)*d2)+z)*d3+w; | ||
|
||
} | ||
/*__device__*/ int translate_idx_inv(int ii, int d1, int d2, int d3, int scale_factor, int off_x, int off_y) | ||
{ | ||
int x, y, z, w; | ||
w = ii % d3; | ||
ii = ii/d3; | ||
z = ii % d2; | ||
ii = ii/d2; | ||
y = ii % d1; | ||
ii = ii/d1; | ||
x = ii; | ||
w = w*scale_factor+off_x; | ||
z = z*scale_factor+off_y; | ||
d2 *= scale_factor; | ||
d3 *= scale_factor; | ||
return (((x*d1+y)*d2)+z)*d3+w; | ||
|
||
} | ||
|
||
kernel void upscale(global float *input_data, int input_offset, global float *output_data, int output_offset, long no_elements, | ||
int scale_factor, int d1, int d2, int d3) | ||
{ | ||
global float *input = input_data + input_offset; | ||
global float *output = output_data + output_offset; | ||
// output offset: | ||
long ii = get_local_id(0) + get_local_size(0) * get_group_id(0); | ||
ii += get_local_id(1) + get_local_size(1) * (get_local_size(0) * get_num_groups(0)) * get_group_id(1); | ||
if (ii >= no_elements) return; | ||
int ipidx = translate_idx(ii, d1, d2, d3, scale_factor); | ||
output[ii]=input[ipidx]; | ||
} | ||
|
||
/* | ||
* Description: | ||
*/ | ||
kernel void downscale(global float *gradInput_data_data, int gradInput_data_offset, global float *gradOutput_data_data, int gradOutput_data_offset, long no_elements, | ||
int scale_factor, int d1, int d2, int d3) | ||
{ | ||
global float *gradInput_data = gradInput_data_data + gradInput_data_offset; | ||
global float *gradOutput_data = gradOutput_data_data + gradOutput_data_offset; | ||
// output offset: | ||
long ii = get_local_id(0) + get_local_size(0) * get_group_id(0); | ||
ii += get_local_id(1) + get_local_size(1) * (get_local_size(0) * get_num_groups(0)) * get_group_id(1); | ||
if (ii >= no_elements) return; | ||
for (int i=0; i < scale_factor; i++){ | ||
for(int j=0; j < scale_factor; j++){ | ||
int ipidx = translate_idx_inv(ii, d1, d2, d3, scale_factor, i, j); | ||
gradInput_data[ii] += gradOutput_data[ipidx]; | ||
} | ||
} | ||
} |
Oops, something went wrong.