Skip to content

implement double backwards for MaxPool3d #5328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 8, 2018
Merged

Conversation

li-roy
Copy link
Contributor

@li-roy li-roy commented Feb 21, 2018

fixes #4497.

test plan: added double backwards check for maxpool3d tests

@zou3519
Copy link
Contributor

zou3519 commented Mar 2, 2018

Could you run some benchmarks to see if the memory usage increase from changing MaxUnpool3d to not pack indices is significant? I think someone (@apaszke?) mentioned Resnet as a network that uses MaxUnpool3d; it would be nice to see if the memory usage increases significantly or not under resnet.

edit: Ignore this comment, the memory usage is the same

@apaszke
Copy link
Contributor

apaszke commented Mar 3, 2018

Oh, resnets use the 2d version, and I thought we’re modifying that. Will this make 3d consistent with the other pooling layers?

@li-roy
Copy link
Contributor Author

li-roy commented Mar 5, 2018

Yeah this makes it more consistent. Also, I don't see any reason for this to use more memory than the old implementation. The way indices are being stored is changed, but we're still using the same number of bits in the tensor.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math looks good to me. I had some minor comments.

THIndex_t idx = k*oT*oW*oH + oH*oW*(start_t+maxz) + oW*(start_h+maxy) + (start_w+maxx);
if (start_t+maxz<0 || start_h+maxy<0 || start_w+maxx<0 || start_t+maxz>=oT
|| start_h+maxy>=oH || start_w+maxx>=oW)
maxp = ind_p_k[ti * iH * iW + i * iW + j] - TH_INDEX_BASE; /* retrieve position of max */

This comment was marked as off-topic.

THIndex_t *ind_p_k = ind_p + k * iT * iH * iW;

int ti, i, j;
THIndex_t maxp;
for (ti = 0; ti < iT; ti++)

This comment was marked as off-topic.

THIndex_t *ind_p_k = ind_p + k * iT * iH * iW;

int ti, i, j;
THIndex_t maxp;
for (ti = 0; ti < iT; ti++)

This comment was marked as off-topic.


if (maxp != -1) {
/* update gradient */
gradInput_p_k[maxp] += gradOutput_p_k[index];

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More nits about the naming that existed previously in the code (before your changes). Feel free to ignore

@@ -127,6 +127,7 @@ static void THNN_(VolumetricDilatedMaxPooling_updateOutput_frame)(
{
/* loop over output */
int64_t i, j, ti;
real *ip = input_p + k * itime * iwidth * iheight;
for (ti = 0; ti < otime; ti++)

This comment was marked as off-topic.

{
for (x = 0; x < kernel_w; x++)
for (x = start_w; x < end_w; x += dilationW)

This comment was marked as off-topic.

@@ -381,16 +372,13 @@ static void THNN_(VolumetricDilatedMaxPooling_updateGradInput_frame)(
for (j = 0; j < owidth; j++)

This comment was marked as off-topic.


int t, i, j, index;
THIndex_t maxp;
for (t = 0; t < iT; t++)
{
for (i = 0; i < iH; i++)
{
for (j = 0; j < iW; j++)

This comment was marked as off-topic.


int t, i, j, index;
THIndex_t maxp;
for (t = 0; t < iT; t++)
{
for (i = 0; i < iH; i++)
{
for (j = 0; j < iW; j++)

This comment was marked as off-topic.

@zou3519
Copy link
Contributor

zou3519 commented Mar 7, 2018

This will also fix #1197 I believe, thanks @li-roy :D

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@soumith soumith merged commit 363de58 into pytorch:master Mar 8, 2018
@li-roy li-roy deleted the maxpool3d branch March 20, 2018 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MaxPool3d cannot be differentiated twice
4 participants