-
Notifications
You must be signed in to change notification settings - Fork 2k
[core] Modularize max. #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Modularize max. #2955
Conversation
dsmilkov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but wait for Yannick
Reviewed 22 of 22 files at r1.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, @lina128, @nsthorat, and @tafsiri)
tfjs-core/src/backends/cpu/backend_cpu.ts, line 383 at r1 (raw file):
softmax<T extends Tensor>(logits: T, dim: number): T { const axes = util.parseAxisParam([dim], logits.shape); // TODO(annxingyuan): Call maxImpl rather than op as part of softmax kernel
just curious what's the blocker for softmax to get modularized (since it keeps yielding TODOs)
tfjs-core/src/gradients/Max_grad.ts, line 21 at r1 (raw file):
import {GradConfig, NamedAttrMap} from '../kernel_registry'; import * as axis_util from '../ops/axis_util'; import {gradForMinAndMax} from '../ops/reduction_ops';
can you copy-paste the gradForMinandMax function here. This way you don't rely on reduction ops anymore. it's fine to repeat the code for min since the function is not that large. If you do feel strongly to code-share, move gradForMinandMax to a seperate file (grad_for_min_max.ts)
nsthorat
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 20 of 22 files at r1, 1 of 1 files at r2.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @annxingyuan, @lina128, @nsthorat, and @tafsiri)
tfjs-core/src/backends/webgl/backend_webgl.ts, line 657 at r2 (raw file):
*/ shouldExecuteOnCPU( inputs: Tensor[]|TensorInfo[],
this can just be TensorInfo because Tensor is a TensorInfo
tfjs-core/src/backends/webgl/kernels/Max_impl.ts, line 24 at r2 (raw file):
import {reshape} from '../kernel_utils/reshape'; export const maxImpl =
export function (this should always be preferred for simple function definitions)
tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts, line 23 at r2 (raw file):
// Testing for presence of chained op in this file will allow us to more easily // customize when we want this test to run. Currently it will run be default // (And kerma will always load the chain augmentor files). But this gives us
s/kerma/karma
tafsiri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 21 of 22 files at r1, 1 of 1 files at r2.
Reviewable status:complete! 2 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, @lina128, and @tafsiri)
tfjs-core/src/kernel_names.ts, line 42 at r2 (raw file):
export const Max = 'Max'; export type MaxInputs = Pick<NamedTensorInfoMap, 'x'>; export interface MaxAttrs {
See comment below about adding keepDims here
tfjs-core/src/kernel_names.ts, line 43 at r2 (raw file):
export type MaxInputs = Pick<NamedTensorInfoMap, 'x'>; export interface MaxAttrs { axes: number[];
While axes is more correct should we change this to axis to match op api and convention in TF?
tfjs-core/src/backends/webgl/kernel_utils/reshape.ts, line 46 at r2 (raw file):
} export function reshape(
Just a question to confirm my mental model of the convention discussed yesterday, would this move to ReshapeImpl when reshape kernel is implemented? (If so no need to move it right now just wanted to check)
tfjs-core/src/gradients/Max_grad.ts, line 21 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
can you copy-paste the gradForMinandMax function here. This way you don't rely on reduction ops anymore. it's fine to repeat the code for min since the function is not that large. If you do feel strongly to code-share, move gradForMinandMax to a seperate file (grad_for_min_max.ts)
+1 ultimately we want to remove the reduction ops file.
tfjs-core/src/ops/max.ts, line 60 at r2 (raw file):
let $x = convertToTensor(x, 'x', 'max'); const origAxes = util.parseAxisParam(axis, $x.shape);
Let's move this into the forward func to prevent other parts of the code (e.g. kernel definition) relying on transformations done here. As is it looks like $x and axes might be mutated in the if block below in a way that a kernel might not do directly. This relates to the ability to directly call runKernel from converter (which we maybe need to discuss at a high level so i'll set up a sync on this).
tfjs-core/src/ops/max.ts, line 73 at r2 (raw file):
return y; }, {x: $x}, null /* gradient */, 'Max', {axes}); if (keepDims) {
I think keepDims be part of the attrs and this part be done by the kernel? Else calling runKernel and by-passing the op doesn't allow use of keepDims.
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 3 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, @lina128, and @tafsiri)
tfjs-backend-cpu/src/kernels/Max.ts, line 54 at r3 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
do we need this here? should this be at the op level? or might other backends support complex inputs.
WebGL supports complex inputs, but not CPU.
tfjs-backend-webgl/src/kernels/Max.ts, line 18 at r3 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
An optional suggestion: If you have a higher level file in webgl that exports the shared symbols from cpu, then this file can import just the symbols you want (maximpl in this case) by name from that file.
Done
Is this what you had in mind?
tfjs-backend-webgl/src/kernels/Transpose_impl.ts, line 18 at r3 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
optional suggestion: same idea as above to import these by name directly from a re export of cpu impls in webgl
Done
tfjs-core/src/ops/max.ts, line 60 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
Let's move this into the forward func to prevent other parts of the code (e.g. kernel definition) relying on transformations done here. As is it looks like $x and axes might be mutated in the if block below in a way that a kernel might not do directly. This relates to the ability to directly call runKernel from converter (which we maybe need to discuss at a high level so i'll set up a sync on this).
Are you saying that we should move the origAxes definition into the forward function?
tfjs-core/src/ops/max.ts, line 73 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
I think keepDims is an attribute https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/max
Done
tfjs-core/src/ops/max.ts, line 63 at r3 (raw file):
Previously, lina128 (Na Li) wrote…
nit: forward: ForwardFunc
Done
tfjs-core/src/ops/max.ts, line 77 at r3 (raw file):
Previously, lina128 (Na Li) wrote…
Do we need to define a type for inputs?
Done
tfjs-core/src/ops/max.ts, line 80 at r3 (raw file):
Previously, lina128 (Na Li) wrote…
nit: reshape(res, newShape)
Done
tfjs-core/src/ops/reduce_util.ts, line 32 at r3 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
are these used across multiple backends?
Looks like no - I moved them into webgl.
tfjs-core/src/public/chained_ops/max.ts, line 29 at r3 (raw file):
Previously, lina128 (Na Li) wrote…
this.throwIfDisposed();
Done
tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts, line 23 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
s/kerma/karma
Done
tfjs-core/src/backends/webgl/kernels/Max_impl.ts, line 24 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
+1
Done
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 3 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, @lina128, and @tafsiri)
tfjs-core/src/gradients/Max_grad.ts, line 21 at r1 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
Looks like this is done... (Assuming the import changed). Ann could you confirm?
Yes this is done.
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 3 of 1 approvals obtained (waiting on @dsmilkov, @lina128, and @tafsiri)
tfjs-core/src/kernel_names.ts, line 42 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
Any update/comment on keepDims?
I've added keepDims.
tfjs-core/src/kernel_names.ts, line 43 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
While axes is more correct should we change this to axis to match op api and convention in TF?
Changed to reductionIndices.
tfjs-core/src/backends/cpu/backend_cpu.ts, line 383 at r1 (raw file):
Previously, dsmilkov (Daniel Smilkov) wrote…
just curious what's the blocker for softmax to get modularized (since it keeps yielding TODOs)
We're hoping to modularize all the intermediate kernels first. There are still a few remaining (subtract, exp, sum).
tfjs-core/src/backends/webgl/backend_webgl.ts, line 657 at r2 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
this can just be TensorInfo because Tensor is a TensorInfo
Done
tafsiri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work Ann! I think I left one comment/suggestion. But otherwise good to go.
Reviewed 1 of 32 files at r3, 16 of 16 files at r4.
Reviewable status:complete! 4 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @lina128)
tfjs-backend-cpu/src/kernels/Max.ts, line 54 at r3 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
WebGL supports complex inputs, but not CPU.
Should we throw a more descriptive error message? So that if a user switches from webgl to cpu they know that it is really an intended difference between the backends?
tfjs-backend-webgl/src/kernels/Max.ts, line 18 at r3 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Done
Is this what you had in mind?
Yep!
tfjs-core/src/ops/max.ts, line 60 at r2 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Are you saying that we should move the
origAxesdefinition into the forward function?
Yep. And pass axis (as originally passed into the op :) to the kernel.
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 4 of 1 approvals obtained (waiting on @dsmilkov, @lina128, and @tafsiri)
tfjs-backend-cpu/src/kernels/Max.ts, line 54 at r3 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
Should we throw a more descriptive error message? So that if a user switches from webgl to cpu they know that it is really an intended difference between the backends?
Good idea - I added a note in the error message that this is specific to the CPU backend.
tfjs-core/src/ops/max.ts, line 60 at r2 (raw file):
Previously, tafsiri (Yannick Assogba) wrote…
Yep. And pass axis (as originally passed into the op :) to the kernel.
Done
Changes
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is