Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Make Backend.concat() take tensor[], instead of 2 tensors #1260

Merged
merged 10 commits into from
Aug 31, 2018
Merged

Conversation

dsmilkov
Copy link
Contributor

@dsmilkov dsmilkov commented Aug 30, 2018

Due to our WebGL shader limitation, we only allowed 2 tensors to be passed to concat at a time. This results in N kernels calls when concating N tensors and slows down our Node.js backend.

This PR moves the 2 tensor limitation as an implementation detail of WebGL and changes backend.concat() to take an array of tensors.

Also:

  • Speeds up the CPU implementation by using flat indexing.
  • Simplifies the gradient of concat by calling tf.split(). This should result in perf benefits in Node.js since split() is a single kernel call.
  • Removes an unused utility in concat_util.
  • Moves split() from a high-level op to a backend kernel, so we can call directly in Node.
  • Since split() has multiple outputs, we need to backpropagate through those. This required the tape to be generalized to multi-output nodes.

PERF
FEATURE


This change is Reviewable

Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 7 of 7 files at r1, 1 of 1 files at r2.
Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nsthorat, @tafsiri, and @nkreeger)


src/kernels/backend_webgl.ts, line 539 at r1 (raw file):

  }

  concat(tensors: Tensor[], axis: number): Tensor {

optional, but not too hard: you can use this information gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); to determine how many samplers you can use and pass multiple texture inputs to a single concat program as different variables.

If not in this PR, let's file an issue to track this since it will speed this up for WebGL as well.


src/ops/concat_util.ts, line 20 at r1 (raw file):

import * as util from '../util';

export function assertParams(shapes: number[][], axis: number) {

can you rename this function to something like "assertParamsConsistent"

Copy link
Contributor Author

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ultra fast review!

Reviewed 5 of 7 files at r1.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @dsmilkov, @tafsiri, and @nkreeger)


src/kernels/backend_webgl.ts, line 539 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

optional, but not too hard: you can use this information gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); to determine how many samplers you can use and pass multiple texture inputs to a single concat program as different variables.

If not in this PR, let's file an issue to track this since it will speed this up for WebGL as well.

I was about to file an issue. Here it is: tensorflow/tfjs#657
I don't want to do it in this PR since I wanted to unblock Nick to do more perf measurements on node while I'm here, and someone else can take this since it's straightforward.


src/ops/concat_util.ts, line 20 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

can you rename this function to something like "assertParamsConsistent"

Done.

Copy link
Contributor

@nkreeger nkreeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:lgtm:

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @dsmilkov, @tafsiri, and @nkreeger)


src/ops/concat.ts, line 175 at r3 (raw file):

    const derTensors = split(dy, sizeSplits, axis);
 

Quick question - next step is adding a hook for concatDer() to optimize as needed - this gradient function can be rolled up for webgl/webcpu easily right?

Copy link
Contributor Author

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL! Moved split to backend as well, and generalized the tape logic to handle multiple outputs.

Reviewed 3 of 3 files at r3, 4 of 12 files at r4.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @nkreeger, and @tafsiri)


src/ops/concat.ts, line 175 at r3 (raw file):

Previously, nkreeger (Nick Kreeger) wrote…
    const derTensors = split(dy, sizeSplits, axis);
 

Quick question - next step is adding a hook for concatDer() to optimize as needed - this gradient function can be rolled up for webgl/webcpu easily right?

Talked offline. Moved split to backend since that's the gradient of concat().

@dsmilkov dsmilkov changed the title Make Backend.concat() take an array of tensors, instead of 2 Make Backend.concat() take tensor[], instead of 2 tensors Aug 30, 2018
Copy link
Contributor

@nsthorat nsthorat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @nkreeger, @dsmilkov, and @tafsiri)


src/tape.ts, line 140 at r5 (raw file):

    node.outputs.forEach(o => {
      const gradTensor = tensorAccumulatedGradientMap[o.id];
      if (gradTensor) {

!= null


src/tape_test.ts, line 37 at r5 (raw file):

        name: 'node0',
        inputs: {x},
        outputs: [intermediate1],

can you add a couple tests where there are multiple outputs, here and for backpropagate gradients (testing zeros and non zeros)

Copy link
Contributor Author

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 12 files at r4, 4 of 4 files at r5.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @nkreeger and @tafsiri)


src/tape.ts, line 140 at r5 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

!= null

Done.


src/tape_test.ts, line 37 at r5 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

can you add a couple tests where there are multiple outputs, here and for backpropagate gradients (testing zeros and non zeros)

Done. I was lazy, thanks for being careful!

@dsmilkov dsmilkov merged commit ce2c543 into master Aug 31, 2018
@dsmilkov dsmilkov deleted the concat branch August 31, 2018 02:59
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
3 participants