Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mirrorPad kernel #3951

Merged
merged 18 commits into from
Oct 9, 2020
Merged

Added mirrorPad kernel #3951

merged 18 commits into from
Oct 9, 2020

Conversation

AbdallaGomaa
Copy link
Contributor

@AbdallaGomaa AbdallaGomaa commented Sep 20, 2020

FEATURE

Implemented 'REFLECT' and 'SYMMETRIC' modes of tf.pad in mirrorPad (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/mirror-pad)


This change is Reviewable

Abdalla Ahmed added 3 commits September 20, 2020 16:57
…cpu, tfjs-backend-webgl and tfjs-backend-webgpu
# Conflicts:
#	tfjs-backend-cpu/src/register_all_kernels.ts
@googlebot
Copy link

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@AbdallaGomaa
Copy link
Contributor Author

AbdallaGomaa commented Sep 20, 2020 via email

@googlebot
Copy link

We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for all the commit author(s) or Co-authors. If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google.
In order to pass this check, please resolve this problem and then comment @googlebot I fixed it.. If the bot doesn't comment, it means it doesn't think anything has changed.

ℹ️ Googlers: Go here for more info.

@AbdallaGomaa
Copy link
Contributor Author

AbdallaGomaa commented Sep 20, 2020 via email

@googlebot
Copy link

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution, it looks good to me for most part just a couple minor comments.

Reviewable status: 0 of 1 approvals obtained (waiting on @AbdallaGomaa and @tafsiri)


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 1 at r3 (raw file):

import {GPGPUProgram} from './gpgpu_math';

please make sure license header is added to all new files.


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 31 at r3 (raw file):

      void main() {
        ${type} coords = getOutputCoords();
        ${type} lt = ${type}(${lessThan});

wondering if the behavior of bvec multiplication can be driver dependent. Would be if more safe to use condition logic?


tfjs-core/src/ops/mirror_pad.ts, line 46 at r3 (raw file):

 * then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
 * `x.shape[D]`
 * @param mode Optional string from `'reflect' | 'symmetric'`,

It would be good to give an example for each value:
In reflect mode the padded regions do not include the borders, while in symmetric mode the padded regions do include the borders. For example, if input is [1, 2, 3] and paddings is [0, 2], then the output is [1, 2, 3, 2, 1] in reflect mode, and it is [1, 2, 3, 3, 2] in symmetric mode.


tfjs-core/src/ops/mirror_pad.ts, line 70 at r3 (raw file):

        () => `Invalid number of paddings. Must be length of 2 each.`);
    util.assert(
        paddings[i][0] < $x.shape[i] + offset &&

it should be <= $x.shape[i] - offset. Also it should be greater than 0

Copy link
Collaborator

@annxingyuan annxingyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an amazing contribution! Thank you!

Reviewable status: 0 of 1 approvals obtained (waiting on @AbdallaGomaa, @annxingyuan, @pyu10055, and @tafsiri)


tfjs-backend-webgl/src/backend_webgl.ts, line 1016 at r3 (raw file):

  }

  mirrorPad<T extends Tensor>(

Could you add this kernel in the modularized style, as you did with the CPU backend?


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 15 at r3 (raw file):

        (p, i) => p[0] /* beforePad */ + xShape[i] + p[1] /* afterPad */);
    const rank = xShape.length;
    const type = getCoordsDataType(rank);

our convention is to name this dtype rather than type


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 31 at r3 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

wondering if the behavior of bvec multiplication can be driver dependent. Would be if more safe to use condition logic?

+1


tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts, line 21 at r3 (raw file):

import {getChannels} from './packing_util';
import {getCoordsDataType} from './shader_compiler';

I realize we haven't been doing this elsewhere :) but it would be super helpful if you could include the generated shader code for a sample input here in a comment.


tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts, line 63 at r3 (raw file):

    let mainLoop = '';
    for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {

Could you separate this into an if / else block depending on the rank? I think it will be a bit more readable that way.


tfjs-backend-webgpu/src/backend_webgpu.ts, line 521 at r3 (raw file):

  }

  mirrorPad<T extends Tensor>(

Could you add this kernel in the modular style instead?


tfjs-backend-webgpu/src/kernels/mirror_pad_webgpu.ts, line 33 at r3 (raw file):

  variableNames = ['x'];
  workPerThread = 8;
  workGroupSize: [number, number, number] = [16, 1, 1];

Curious whether you arrived at these numbers through benchmarking?


tfjs-core/src/backends/backend.ts, line 584 at r3 (raw file):

  }

  mirrorPad<T extends Tensor>(

You might not need to add this if you're adding a fully modularized kernel.


tfjs-core/src/ops/mirror_pad.ts, line 83 at r3 (raw file):

  const attrs: MirrorPadAttrs = {paddings, mode};
  const inputs: MirrorPadInputs = {x: $x};
  return ENGINE.runKernelFunc(

If you ensure that all the kernels are written in the modular style, you should just be able to call runKernel here instead - see our dilation2d op as an example.

Copy link
Contributor

@tafsiri tafsiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really thorough PR, great work. Left a few comments.

Reviewed 24 of 26 files at r1, 4 of 4 files at r4.
Reviewable status: 0 of 1 approvals obtained (waiting on @AbdallaGomaa, @annxingyuan, @pyu10055, and @tafsiri)


tfjs-core/src/backends/backend.ts, line 584 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

You might not need to add this if you're adding a fully modularized kernel.

+1 to this. we should not add new kernels to the backend interface.


tfjs-core/src/ops/mirror_pad.ts, line 83 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

If you ensure that all the kernels are written in the modular style, you should just be able to call runKernel here instead - see our dilation2d op as an example.

+1


tfjs-core/src/ops/mirror_pad.ts, line 53 at r4 (raw file):

    x: T|TensorLike, paddings: Array<[number, number]>,
    mode?: 'reflect'|'symmetric'): T {
  mode = mode || 'reflect';

Why make mode optional/default to reflect? I think we should make it required to align with https://www.tensorflow.org/api_docs/python/tf/raw_ops/MirrorPad?hl=en


tfjs-core/src/ops/mirror_pad_test.ts, line 38 at r4 (raw file):

    // 2, 1, 2, 3, 2
    // 5, 4, 5, 6, 5
    // 2, 1, 2, 3, 2

these diagrams are great. Could you add assertions for the expected shape (e.g. [4,5] in this case) in all these tests


tfjs-node/src/nodejs_kernel_backend.ts, line 1347 at r4 (raw file):

  }

  mirrorPad<T extends Tensor>(

You should also be able to implement this in the 'modular' style. See Dilation2D or SquaredDifference in tfjs-node/src/kernels

Abdalla Ahmed and others added 2 commits October 1, 2020 17:57
Simplified mirror pad gpu shader
Added test of output shape for mirror pad
Removed mode being optional for mirror pad to match TF implementation
Copy link
Contributor Author

@AbdallaGomaa AbdallaGomaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)


tfjs-backend-webgl/src/backend_webgl.ts, line 1016 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Could you add this kernel in the modularized style, as you did with the CPU backend?

Done.


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 1 at r3 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

please make sure license header is added to all new files.

Done.


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 15 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

our convention is to name this dtype rather than type

Done.


tfjs-backend-webgl/src/mirror_pad_gpu.ts, line 31 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

+1

Done.


tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts, line 21 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

I realize we haven't been doing this elsewhere :) but it would be super helpful if you could include the generated shader code for a sample input here in a comment.

Done. :)


tfjs-backend-webgl/src/mirror_pad_packed_gpu.ts, line 63 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Could you separate this into an if / else block depending on the rank? I think it will be a bit more readable that way.

Done.


tfjs-backend-webgpu/src/backend_webgpu.ts, line 521 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Could you add this kernel in the modular style instead?

Done.


tfjs-backend-webgpu/src/kernels/mirror_pad_webgpu.ts, line 33 at r3 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Curious whether you arrived at these numbers through benchmarking?

I just copied this from pad_webgpu.ts


tfjs-core/src/backends/backend.ts, line 584 at r3 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

+1 to this. we should not add new kernels to the backend interface.

Done.


tfjs-core/src/ops/mirror_pad.ts, line 46 at r3 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

It would be good to give an example for each value:
In reflect mode the padded regions do not include the borders, while in symmetric mode the padded regions do include the borders. For example, if input is [1, 2, 3] and paddings is [0, 2], then the output is [1, 2, 3, 2, 1] in reflect mode, and it is [1, 2, 3, 3, 2] in symmetric mode.

Done.


tfjs-core/src/ops/mirror_pad.ts, line 70 at r3 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

it should be <= $x.shape[i] - offset. Also it should be greater than 0

It actually can be 0 if you don't want to pad along that dimension or on that side but you're right it can't be negative. Updated it to be greater than or equal to 0


tfjs-core/src/ops/mirror_pad.ts, line 83 at r3 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

+1

Done.


tfjs-core/src/ops/mirror_pad.ts, line 53 at r4 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

Why make mode optional/default to reflect? I think we should make it required to align with https://www.tensorflow.org/api_docs/python/tf/raw_ops/MirrorPad?hl=en

Changed it to required!


tfjs-core/src/ops/mirror_pad_test.ts, line 38 at r4 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

these diagrams are great. Could you add assertions for the expected shape (e.g. [4,5] in this case) in all these tests

Done.


tfjs-node/src/nodejs_kernel_backend.ts, line 1347 at r4 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

You should also be able to implement this in the 'modular' style. See Dilation2D or SquaredDifference in tfjs-node/src/kernels

I've updated this to be in the modular style but I seem to have an issue with the tests. I'm getting Error: Backend 'tensorflow' has an internal memory leak (1 data ids) after running 'MirrorPad' for pretty much every test and I am not entirely sure what causes this. Would be great if you can help me with this!

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @AbdallaGomaa, @annxingyuan, @pyu10055, and @tafsiri)


tfjs-core/src/ops/mirror_pad.ts, line 70 at r3 (raw file):

offset = mode === 'reflect' ? 0 : 1;

offset = mode === 'reflect' ? 1 : 0;
I think the offset here need to be reversed.


tfjs-core/src/ops/mirror_pad.ts, line 77 at r6 (raw file):

            paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - offset,
        () => `Padding in dimension ${i} cannot be greater than or equal ` +
            `to ${$x.shape[i] + offset} for input of shape ${$x.shape}`);

please update the error message to reflect the >=0 check, also correct the ${$x.shape[i] - offset}

Copy link
Contributor Author

@AbdallaGomaa AbdallaGomaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)


tfjs-core/src/ops/mirror_pad.ts, line 70 at r3 (raw file):

Previously, pyu10055 (Ping Yu) wrote…
offset = mode === 'reflect' ? 0 : 1;

offset = mode === 'reflect' ? 1 : 0;
I think the offset here need to be reversed.

Good catch! Done.


tfjs-core/src/ops/mirror_pad.ts, line 77 at r6 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

please update the error message to reflect the >=0 check, also correct the ${$x.shape[i] - offset}

Done. Also updated tests to check these edge cases.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 17 of 26 files at r1, 1 of 4 files at r4, 12 of 16 files at r5, 2 of 2 files at r6, 2 of 2 files at r7.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)

Copy link
Contributor

@tafsiri tafsiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a suggestion for fixing the memory leak.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @AbdallaGomaa, @annxingyuan, @pyu10055, and @tafsiri)


tfjs-core/src/ops/mirror_pad.ts, line 57 at r8 (raw file):

    x: T|TensorLike, paddings: Array<[number, number]>,
    mode: 'reflect'|'symmetric'): T {
  mode = mode || 'reflect';

remove the default assignement, add an assertion that it is reflect or symettric.


tfjs-node/src/kernels/MirrorPad.ts, line 31 at r8 (raw file):

    const nodeBackend = backend as NodeJSKernelBackend;

    const paddingsTensor = tensor2d(paddings, [paddings.length, 2], 'int32');

I think this is the cause of the memory leak. Modular kernels are generally expected to clean up their own memory (and we are moving to system where kernels don't create 'tensors' they just allocate memory as needed). But for now I think this tensor just needs to be disposed after calling executeSingleOutput. You can save the result of the backend execution to a variable so that you can dispose this before returning the result.

Copy link
Contributor Author

@AbdallaGomaa AbdallaGomaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)


tfjs-core/src/ops/mirror_pad.ts, line 57 at r8 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

remove the default assignement, add an assertion that it is reflect or symettric.

Done.


tfjs-node/src/kernels/MirrorPad.ts, line 31 at r8 (raw file):

Previously, tafsiri (Yannick Assogba) wrote…

I think this is the cause of the memory leak. Modular kernels are generally expected to clean up their own memory (and we are moving to system where kernels don't create 'tensors' they just allocate memory as needed). But for now I think this tensor just needs to be disposed after calling executeSingleOutput. You can save the result of the backend execution to a variable so that you can dispose this before returning the result.

That fixed the issue. Thanks! Done.

Copy link
Contributor

@tafsiri tafsiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 2 files at r11.
Reviewable status: :shipit: complete! 3 of 1 approvals obtained (waiting on @annxingyuan, @pyu10055, and @tafsiri)

@pyu10055 pyu10055 merged commit 7fb7446 into tensorflow:master Oct 9, 2020
@gkcalat gkcalat mentioned this pull request Oct 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants