-
Notifications
You must be signed in to change notification settings - Fork 955
Make Backend.concat() take tensor[], instead of 2 tensors #1260
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 7 of 7 files at r1, 1 of 1 files at r2.
Reviewable status: 0 of 1 approvals obtained (waiting on @dsmilkov, @nsthorat, @tafsiri, and @nkreeger)
src/kernels/backend_webgl.ts, line 539 at r1 (raw file):
} concat(tensors: Tensor[], axis: number): Tensor {
optional, but not too hard: you can use this information gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
to determine how many samplers you can use and pass multiple texture inputs to a single concat program as different variables.
If not in this PR, let's file an issue to track this since it will speed this up for WebGL as well.
src/ops/concat_util.ts, line 20 at r1 (raw file):
import * as util from '../util'; export function assertParams(shapes: number[][], axis: number) {
can you rename this function to something like "assertParamsConsistent"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ultra fast review!
Reviewed 5 of 7 files at r1.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @dsmilkov, @tafsiri, and @nkreeger)
src/kernels/backend_webgl.ts, line 539 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
optional, but not too hard: you can use this information
gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
to determine how many samplers you can use and pass multiple texture inputs to a single concat program as different variables.If not in this PR, let's file an issue to track this since it will speed this up for WebGL as well.
I was about to file an issue. Here it is: tensorflow/tfjs#657
I don't want to do it in this PR since I wanted to unblock Nick to do more perf measurements on node while I'm here, and someone else can take this since it's straightforward.
src/ops/concat_util.ts, line 20 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
can you rename this function to something like "assertParamsConsistent"
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @dsmilkov, @tafsiri, and @nkreeger)
src/ops/concat.ts, line 175 at r3 (raw file):
const derTensors = split(dy, sizeSplits, axis);
Quick question - next step is adding a hook for concatDer() to optimize as needed - this gradient function can be rolled up for webgl/webcpu easily right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PTAL! Moved split to backend as well, and generalized the tape logic to handle multiple outputs.
Reviewed 3 of 3 files at r3, 4 of 12 files at r4.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @nkreeger, and @tafsiri)
src/ops/concat.ts, line 175 at r3 (raw file):
Previously, nkreeger (Nick Kreeger) wrote…
const derTensors = split(dy, sizeSplits, axis);
Quick question - next step is adding a hook for concatDer() to optimize as needed - this gradient function can be rolled up for webgl/webcpu easily right?
Talked offline. Moved split to backend since that's the gradient of concat().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nsthorat, @nkreeger, @dsmilkov, and @tafsiri)
src/tape.ts, line 140 at r5 (raw file):
node.outputs.forEach(o => { const gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor) {
!= null
src/tape_test.ts, line 37 at r5 (raw file):
name: 'node0', inputs: {x}, outputs: [intermediate1],
can you add a couple tests where there are multiple outputs, here and for backpropagate gradients (testing zeros and non zeros)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 1 of 12 files at r4, 4 of 4 files at r5.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @nkreeger and @tafsiri)
src/tape.ts, line 140 at r5 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
!= null
Done.
src/tape_test.ts, line 37 at r5 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
can you add a couple tests where there are multiple outputs, here and for backpropagate gradients (testing zeros and non zeros)
Done. I was lazy, thanks for being careful!
Due to our WebGL shader limitation, we only allowed 2 tensors to be passed to concat at a time. This results in N kernels calls when concating N tensors and slows down our Node.js backend.
This PR moves the 2 tensor limitation as an implementation detail of WebGL and changes
backend.concat()
to take an array of tensors.Also:
tf.split()
. This should result in perf benefits in Node.js sincesplit()
is a single kernel call.PERF
FEATURE
This change is