Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for cutorch API changes #296

Merged
merged 1 commit into from
Jun 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/THCUNN/Abs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void THNN_CudaAbs_updateOutput(THCState *state, THCudaTensor *input, THCudaTenso
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, absupdateOutput_functor());
THC_pointwiseApply2(state, output, input, absupdateOutput_functor());
}

struct absupdateGradInput_functor
Expand All @@ -28,5 +28,5 @@ void THNN_CudaAbs_updateGradInput(THCState *state, THCudaTensor *input, THCudaTe
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, absupdateGradInput_functor());
THC_pointwiseApply3(state, gradInput, input, gradOutput, absupdateGradInput_functor());
}
8 changes: 4 additions & 4 deletions lib/THCUNN/ELU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ void THNN_CudaELU_updateOutput(THCState *state, THCudaTensor *input, THCudaTenso

if (inplace)
{
THCudaTensor_pointwiseApply1(state, input, ELUupdateOutputIP_functor(alpha));
THC_pointwiseApply1(state, input, ELUupdateOutputIP_functor(alpha));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, ELUupdateOutput_functor(alpha));
THC_pointwiseApply2(state, output, input, ELUupdateOutput_functor(alpha));
}
}

Expand Down Expand Up @@ -82,12 +82,12 @@ void THNN_CudaELU_updateGradInput(THCState *state, THCudaTensor *input, THCudaTe

if (inplace)
{
THCudaTensor_pointwiseApply2(state, gradOutput, output, ELUupdateGradInputIP_functor(alpha));
THC_pointwiseApply2(state, gradOutput, output, ELUupdateGradInputIP_functor(alpha));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, ELUupdateGradInput_functor(alpha));
THC_pointwiseApply3(state, gradInput, output, gradOutput, ELUupdateGradInput_functor(alpha));
}
}
4 changes: 2 additions & 2 deletions lib/THCUNN/HardTanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void THNN_CudaHardTanh_updateOutput(THCState *state, THCudaTensor *input, THCuda
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input,
THC_pointwiseApply2(state, output, input,
hardtanhupdateOutput_functor(min_val, max_val));
}

Expand Down Expand Up @@ -54,6 +54,6 @@ void THNN_CudaHardTanh_updateGradInput(THCState *state, THCudaTensor *input, THC
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);

THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
THC_pointwiseApply3(state, gradInput, input, gradOutput,
hardtanhupdateGradInput_functor(min_val, max_val));
}
8 changes: 4 additions & 4 deletions lib/THCUNN/LeakyReLU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ void THNN_CudaLeakyReLU_updateOutput(THCState *state, THCudaTensor *input, THCud

if (inplace)
{
THCudaTensor_pointwiseApply1(state, input, LeakyReLUUpdateOutputIP(negval));
THC_pointwiseApply1(state, input, LeakyReLUUpdateOutputIP(negval));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, LeakyReLUUpdateOutput(negval));
THC_pointwiseApply2(state, output, input, LeakyReLUUpdateOutput(negval));
}

THCudaCheck(cudaGetLastError());
Expand Down Expand Up @@ -90,13 +90,13 @@ void THNN_CudaLeakyReLU_updateGradInput(THCState *state, THCudaTensor *input, TH

if (inplace)
{
THCudaTensor_pointwiseApply2(state, gradOutput, input, LeakyReLUUpdateGradInputIP(negval));
THC_pointwiseApply2(state, gradOutput, input, LeakyReLUUpdateGradInputIP(negval));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, LeakyReLUUpdateGradInput(negval));
THC_pointwiseApply3(state, gradInput, input, gradOutput, LeakyReLUUpdateGradInput(negval));
}

THCudaCheck(cudaGetLastError());
Expand Down
4 changes: 2 additions & 2 deletions lib/THCUNN/LogSigmoid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void THNN_CudaLogSigmoid_updateOutput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor());
THC_pointwiseApply2(state, output, input, logSigmoid_updateOutput_functor());
}

struct logSigmoid_updateGradInput_functor
Expand All @@ -31,5 +31,5 @@ void THNN_CudaLogSigmoid_updateGradInput(THCState *state, THCudaTensor *input, T
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor());
THC_pointwiseApply3(state, gradInput, input, gradOutput, logSigmoid_updateGradInput_functor());
}
10 changes: 5 additions & 5 deletions lib/THCUNN/PReLU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void THNN_CudaPReLU_updateOutput(

if (nOutputPlane == 0)
{
THCudaTensor_pointwiseApply2(state, output, input, PReLUUpdateOutput(w));
THC_pointwiseApply2(state, output, input, PReLUUpdateOutput(w));
}
else
{
Expand Down Expand Up @@ -109,7 +109,7 @@ void THNN_CudaPReLU_updateGradInput(
float *w = THCudaTensor_data(state, weight);
if (nOutputPlane == 0)
{
THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, PReLUUpdateGradInput(w));
THC_pointwiseApply3(state, gradInput, gradOutput, input, PReLUUpdateGradInput(w));
}
else
{
Expand Down Expand Up @@ -189,7 +189,7 @@ void THNN_CudaPReLU_accGradParameters(

if (nOutputPlane == 0)
{
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParametersShared());
THC_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParametersShared());

// introduces a sync point
float sum = THCudaTensor_sumall(state, gradInput);
Expand All @@ -205,11 +205,11 @@ void THNN_CudaPReLU_accGradParameters(

if (ndim == 1)
{
THCudaTensor_pointwiseApply3(state, gradWeight, input, gradOutput, PReLUAccGradParameters1to1(scale));
THC_pointwiseApply3(state, gradWeight, input, gradOutput, PReLUAccGradParameters1to1(scale));
}
else
{
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParameters(scale));
THC_pointwiseApply3(state, gradInput, input, gradOutput, PReLUAccGradParameters(scale));
THCudaTensor *sumbuf = gradWeightBuf2;
THCudaTensor_resizeAs(state, gradWeightBuf, gradWeight);

Expand Down
8 changes: 4 additions & 4 deletions lib/THCUNN/RReLU.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ void THNN_CudaRReLU_updateOutput(THCState *state, THCudaTensor *input, THCudaTen
const double negSlope = (lower + upper) / 2;
if (inplace)
{
THCudaTensor_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
THC_pointwiseApply1(state, input, RReLUUpdateOutputEvalIP_functor(negSlope));
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
THC_pointwiseApply2(state, output, input, RReLUUpdateOutputEval_functor(negSlope));
}
}
}
Expand Down Expand Up @@ -169,13 +169,13 @@ void THNN_CudaRReLU_updateGradInput(THCState *state, THCudaTensor *input, THCuda
const double negSlope = (lower + upper) / 2;
if (inplace)
{
THCudaTensor_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
THC_pointwiseApply2(state, gradOutput, input, RReLUupdateGradInputEvalIP_functor(negSlope));
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
THC_pointwiseApply3(state, gradInput, gradOutput, input, RReLUupdateGradInputEval_functor(negSlope));
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/THCUNN/Sigmoid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void THNN_CudaSigmoid_updateOutput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
THC_pointwiseApply2(state, output, input, sigmoidupdateOutput_functor());
}

struct sigmoidupdateGradInput_functor
Expand All @@ -28,5 +28,5 @@ void THNN_CudaSigmoid_updateGradInput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
THC_pointwiseApply3(state, gradInput, output, gradOutput, sigmoidupdateGradInput_functor());
}
4 changes: 2 additions & 2 deletions lib/THCUNN/SoftPlus.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void THNN_CudaSoftPlus_updateOutput(THCState *state, THCudaTensor *input, THCuda
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
THC_pointwiseApply2(state, output, input, softPlusupdateOutput_functor(threshold, beta));
}

struct softPlusupdateGradInput_functor
Expand All @@ -48,5 +48,5 @@ void THNN_CudaSoftPlus_updateGradInput(THCState *state, THCudaTensor *input, THC
{
THCUNN_assertSameGPU(state, 4, input, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
THC_pointwiseApply3(state, gradInput, output, gradOutput, softPlusupdateGradInput_functor(threshold, beta));
}
4 changes: 2 additions & 2 deletions lib/THCUNN/SoftShrink.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void THNN_CudaSoftShrink_updateOutput(THCState *state, THCudaTensor *input, THCu
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
THC_pointwiseApply2(state, output, input, SoftShrinkUpdateOutput(lambda));
THCudaCheck(cudaGetLastError());
}

Expand All @@ -49,6 +49,6 @@ void THNN_CudaSoftShrink_updateGradInput(THCState *state, THCudaTensor *input, T
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
THC_pointwiseApply3(state, gradInput, input, gradOutput, SoftShrinkUpdateGradInput(lambda));
THCudaCheck(cudaGetLastError());
}
6 changes: 3 additions & 3 deletions lib/THCUNN/SpatialReflectionPadding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void THNN_CudaSpatialReflectionPadding_updateOutput(THCState *state,
int padL, int padR,
int padT, int padB
) {
THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");

int planeDim = 0;
Expand Down Expand Up @@ -139,9 +139,9 @@ void THNN_CudaSpatialReflectionPadding_updateGradInput(THCState *state,
int padL, int padR,
int padT, int padB) {

THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
THArgCheck(THC_canUse32BitIndexMath(state, gradOutput), 3,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, gradOutput), 3,
"output gradient tensor must fit into 32-bit index math");

int planeDim = 0;
Expand Down
6 changes: 3 additions & 3 deletions lib/THCUNN/SpatialReplicationPadding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void THNN_CudaSpatialReplicationPadding_updateOutput(THCState *state,
int padL, int padR,
int padT, int padB
) {
THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");

int planeDim = 0;
Expand Down Expand Up @@ -121,9 +121,9 @@ void THNN_CudaSpatialReplicationPadding_updateGradInput(THCState *state,
int padL, int padR,
int padT, int padB) {

THArgCheck(THC_canUse32BitIndexMath(state, input), 2,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, input), 2,
"input tensor must fit into 32-bit index math");
THArgCheck(THC_canUse32BitIndexMath(state, gradOutput), 3,
THArgCheck(TensorUtils<THCudaTensor>::canUse32BitIndexMath(state, gradOutput), 3,
"output gradient tensor must fit into 32-bit index math");

int planeDim = 0;
Expand Down
4 changes: 2 additions & 2 deletions lib/THCUNN/Sqrt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void THNN_CudaSqrt_updateOutput(THCState *state, THCudaTensor *input, THCudaTens
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(eps));
THC_pointwiseApply2(state, output, input, sqrtupdateOutput_functor(eps));
}

struct sqrtupdateGradInput_functor
Expand All @@ -36,5 +36,5 @@ void THNN_CudaSqrt_updateGradInput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
THC_pointwiseApply3(state, gradInput, output, gradOutput, sqrtupdateGradInput_functor());
}
4 changes: 2 additions & 2 deletions lib/THCUNN/Square.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void THNN_CudaSquare_updateOutput(THCState *state, THCudaTensor *input, THCudaTe
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, squareupdateOutput_functor());
THC_pointwiseApply2(state, output, input, squareupdateOutput_functor());
}

struct squareupdateGradInput_functor
Expand All @@ -28,5 +28,5 @@ void THNN_CudaSquare_updateGradInput(THCState *state, THCudaTensor *input, THCud
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
THC_pointwiseApply3(state, gradInput, input, gradOutput, squareupdateGradInput_functor());
}
4 changes: 2 additions & 2 deletions lib/THCUNN/Tanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void THNN_CudaTanh_updateOutput(THCState *state, THCudaTensor *input, THCudaTens
{
THCUNN_assertSameGPU(state, 2, input, output);
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
THC_pointwiseApply2(state, output, input, tanhupdateOutput_functor());
}

struct tanhupdateGradInput_functor
Expand All @@ -28,5 +28,5 @@ void THNN_CudaTanh_updateGradInput(THCState *state, THCudaTensor *input, THCudaT
{
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCudaTensor_resizeAs(state, gradInput, output);
THCudaTensor_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
THC_pointwiseApply3(state, gradInput, output, gradOutput, tanhupdateGradInput_functor());
}
8 changes: 4 additions & 4 deletions lib/THCUNN/Threshold.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ void THNN_CudaThreshold_updateOutput(THCState *state, THCudaTensor *input, THCud

if (inplace)
{
THCudaTensor_pointwiseApply1(state, input,
THC_pointwiseApply1(state, input,
ThresholdUpdateOutputIP(threshold, val)
);
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THCudaTensor_pointwiseApply2(state, output, input,
THC_pointwiseApply2(state, output, input,
ThresholdUpdateOutput(threshold, val)
);
}
Expand Down Expand Up @@ -95,15 +95,15 @@ void THNN_CudaThreshold_updateGradInput(THCState *state, THCudaTensor *input, TH

if (inplace)
{
THCudaTensor_pointwiseApply2(state, gradOutput, input,
THC_pointwiseApply2(state, gradOutput, input,
ThresholdUpdateGradInputIP(threshold)
);
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THCudaTensor_pointwiseApply3(state, gradInput, input, gradOutput,
THC_pointwiseApply3(state, gradInput, input, gradOutput,
ThresholdUpdateGradInput(threshold)
);
}
Expand Down