Skip to content

Commit

Permalink
Add THNN conversion of {oftShrink, Sqrt, Square, Tanh, Threshold}
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 1, 2016
1 parent ccd562a commit 83a17b2
Show file tree
Hide file tree
Showing 12 changed files with 273 additions and 247 deletions.
13 changes: 11 additions & 2 deletions SoftShrink.lua
Expand Up @@ -6,11 +6,20 @@ function SoftShrink:__init(lam)
end

function SoftShrink:updateOutput(input)
input.nn.SoftShrink_updateOutput(self, input)
input.THNN.SoftShrink_updateOutput(
input:cdata(),
self.output:cdata(),
self.lambda
)
return self.output
end

function SoftShrink:updateGradInput(input, gradOutput)
input.nn.SoftShrink_updateGradInput(self, input, gradOutput)
input.THNN.SoftShrink_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.lambda
)
return self.gradInput
end
15 changes: 13 additions & 2 deletions Sqrt.lua
Expand Up @@ -7,9 +7,20 @@ end

function Sqrt:updateOutput(input)
self.eps = self.eps or 0
return input.nn.Sqrt_updateOutput(self,input)
input.THNN.Sqrt_updateOutput(
input:cdata(),
self.output:cdata(),
self.eps
)
return self.output
end

function Sqrt:updateGradInput(input, gradOutput)
return input.nn.Sqrt_updateGradInput(self,input,gradOutput)
input.THNN.Sqrt_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.output:cdata()
)
return self.gradInput
end
15 changes: 12 additions & 3 deletions Square.lua
@@ -1,13 +1,22 @@
local Square, parent = torch.class('nn.Square','nn.Module')
local Square, parent = torch.class('nn.Square', 'nn.Module')

function Square:__init(args)
parent.__init(self)
end

function Square:updateOutput(input)
return input.nn.Square_updateOutput(self, input)
input.THNN.Square_updateOutput(
input:cdata(),
self.output:cdata()
)
return self.output
end

function Square:updateGradInput(input, gradOutput)
return input.nn.Square_updateGradInput(self, input, gradOutput)
input.THNN.Square_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata()
)
return self.gradInput
end
92 changes: 66 additions & 26 deletions THNN.lua
Expand Up @@ -153,7 +153,6 @@ TH_API void THNN_(LookupTable_accGradParameters)(
THTensor *sorted,
THTensor *indices);
TH_API void THNN_(MarginCriterion_updateOutput)(
THNNState *state,
THTensor *input,
Expand Down Expand Up @@ -199,7 +198,7 @@ TH_API void THNN_(MultiMarginCriterion_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *target,
THTensor* output,
THTensor *output,
bool sizeAverage,
int p);
TH_API void THNN_(MultiMarginCriterion_updateGradInput)(
Expand All @@ -224,10 +223,10 @@ TH_API void THNN_(PReLU_updateGradInput)(
THTensor *weight,
THIndex_t nOutputPlane);
TH_API void THNN_(PReLU_accGradParameters)(
THNNState* state,
THTensor* input,
THTensor* gradOutput,
THTensor* gradInput,
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *weight,
THTensor *gradWeight,
THTensor *gradWeightBuf,
Expand Down Expand Up @@ -306,13 +305,73 @@ TH_API void THNN_(SoftPlus_updateGradInput)(
real beta,
real threshold);
TH_API void THNN_(SoftShrink_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
real lambda);
TH_API void THNN_(SoftShrink_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
real lambda);
TH_API void THNN_(Sqrt_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
real eps);
TH_API void THNN_(Sqrt_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *output);
TH_API void THNN_(Square_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output);
TH_API void THNN_(Square_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput);
TH_API void THNN_(Tanh_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output);
TH_API void THNN_(Tanh_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
THTensor *output);
TH_API void THNN_(Threshold_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
real threshold,
real val,
bool inplace);
TH_API void THNN_(Threshold_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
real threshold,
bool inplace);
TH_API void THNN_(SpatialConvolutionMM_updateOutput)(
THNNState *state,
THTensor *input,
THTensor *output,
THTensor *weight,
THTensor *bias,
THTensor* finput,
THTensor *finput,
THTensor *fgradInput,
int kW, int kH,
int dW, int dH,
Expand Down Expand Up @@ -394,25 +453,6 @@ TH_API void THNN_(SpatialMaxPooling_updateGradInput)(
int dW, int dH,
int padW, int padH,
bool ceil_mode);
TH_API void THNN_(unfolded_acc)(
THTensor *finput,
THTensor *input,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int nInputPlane,
int inputWidth, int inputHeight,
int outputWidth, int outputHeight);
TH_API void THNN_(unfolded_copy)(
THTensor *finput,
THTensor *input,
int kW, int kH,
int dW, int dH,
int padW, int padH,
int nInputPlane,
int inputWidth, int inputHeight,
int outputWidth, int outputHeight);
]]

-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h
Expand Down
14 changes: 12 additions & 2 deletions Tanh.lua
@@ -1,9 +1,19 @@
local Tanh = torch.class('nn.Tanh', 'nn.Module')

function Tanh:updateOutput(input)
return input.nn.Tanh_updateOutput(self, input)
input.THNN.Tanh_updateOutput(
input:cdata(),
self.output:cdata()
)
return self.output
end

function Tanh:updateGradInput(input, gradOutput)
return input.nn.Tanh_updateGradInput(self, input, gradOutput)
input.THNN.Tanh_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.output:cdata()
)
return self.gradInput
end
16 changes: 14 additions & 2 deletions Threshold.lua
Expand Up @@ -17,13 +17,25 @@ end

function Threshold:updateOutput(input)
self:validateParameters()
input.nn.Threshold_updateOutput(self, input)
input.THNN.Threshold_updateOutput(
input:cdata(),
self.output:cdata(),
self.threshold,
self.val,
self.inplace
)
return self.output
end

function Threshold:updateGradInput(input, gradOutput)
self:validateParameters()
input.nn.Threshold_updateGradInput(self, input, gradOutput)
input.THNN.Threshold_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.threshold,
self.inplace
)
return self.gradInput
end

Expand Down
52 changes: 16 additions & 36 deletions lib/THNN/generic/SoftShrink.c
Expand Up @@ -2,49 +2,29 @@
#define TH_GENERIC_FILE "generic/SoftShrink.c"
#else

static int nn_(SoftShrink_updateOutput)(lua_State *L)
void THNN_(SoftShrink_updateOutput)(THNNState *state, THTensor *input, THTensor *output, real lambda)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
real lambda = luaT_getfieldchecknumber(L, 1, "lambda");
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);

THTensor_(resizeAs)(output, input);

TH_TENSOR_APPLY2(real, output, real, input, \
if ((*input_data) > lambda) *output_data = *input_data - lambda; \
else if ((*input_data) < -lambda) *output_data = *input_data + lambda; \
else *output_data = 0;);
return 1;
TH_TENSOR_APPLY2(real, output, real, input,
if ((*input_data) > lambda)
*output_data = *input_data - lambda;
else if ((*input_data) < -lambda)
*output_data = *input_data + lambda;
else
*output_data = 0;
);
}

static int nn_(SoftShrink_updateGradInput)(lua_State *L)
void THNN_(SoftShrink_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, real lambda)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
real lambda = luaT_getfieldchecknumber(L, 1, "lambda");
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);

THTensor_(resizeAs)(gradInput, input);
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
if ((*input_data) > lambda || (*input_data) < -lambda) \
*gradInput_data = (*gradOutput_data); \
else \
*gradInput_data = 0; \
);
return 1;
}

static const struct luaL_Reg nn_(SoftShrink__) [] = {
{"SoftShrink_updateOutput", nn_(SoftShrink_updateOutput)},
{"SoftShrink_updateGradInput", nn_(SoftShrink_updateGradInput)},
{NULL, NULL}
};

static void nn_(SoftShrink_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(SoftShrink__), "nn");
lua_pop(L,1);
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
if ((*input_data) > lambda || (*input_data) < -lambda)
*gradInput_data = (*gradOutput_data);
else
*gradInput_data = 0;
);
}

#endif
47 changes: 12 additions & 35 deletions lib/THNN/generic/Sqrt.c
Expand Up @@ -2,63 +2,40 @@
#define TH_GENERIC_FILE "generic/Sqrt.c"
#else

static int nn_(Sqrt_updateOutput)(lua_State *L)
void THNN_(Sqrt_updateOutput)(THNNState *state, THTensor *input, THTensor *output, real eps)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
real bias = luaT_getfieldchecknumber(L,1,"eps");
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);

THTensor_(resizeAs)(output, input);
THTensor_(sqrt)(output, input);
return 1;
}

static int nn_(Sqrt_updateGradInput)(lua_State *L)
void THNN_(Sqrt_updateGradInput)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *output)
{
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);

THTensor_(resizeAs)(gradInput, input);

if (output->nDimension == 1 ||
!THTensor_(isContiguous)(output) ||
!THTensor_(isContiguous)(gradOutput) ||
!THTensor_(isContiguous)(gradInput))
{
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output, \
*gradInput_data = ((*output_data == 0.0) ? 0.0 : \
(0.5 * (*gradOutput_data / *output_data))););
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
*gradInput_data = (*output_data == 0.0) ? 0.0 : (0.5 * (*gradOutput_data / *output_data));
);
}
else
{
real* gradOutput_data = THTensor_(data)(gradOutput);
real* gradInput_data = THTensor_(data)(gradInput);
real* output_data = THTensor_(data)(output);
real *gradOutput_data = THTensor_(data)(gradOutput);
real *gradInput_data = THTensor_(data)(gradInput);
real *output_data = THTensor_(data)(output);
long i;
#pragma omp parallel for private(i)
for(i = 0; i < THTensor_(nElement)(output); i++)
if (output_data[i] == 0.0) {
{
if (output_data[i] == 0.0)
gradInput_data[i] = 0.0;
} else {
else
gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]);
}
}
}
return 1;
}

static const struct luaL_Reg nn_(Sqrt__) [] = {
{"Sqrt_updateOutput", nn_(Sqrt_updateOutput)},
{"Sqrt_updateGradInput", nn_(Sqrt_updateGradInput)},
{NULL, NULL}
};

static void nn_(Sqrt_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(Sqrt__), "nn");
lua_pop(L,1);
}

#endif

0 comments on commit 83a17b2

Please sign in to comment.