-
Notifications
You must be signed in to change notification settings - Fork 2k
[webgpu] Add batchNormal kernel #2652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @NALLEIN ! Left a few comments.
| ] | ||
| } | ||
| }, | ||
| {include: 'deprecated batchNormalization', excludes: []} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the non deprecated batchnorm tests not work?
| input.size < sizeThreshold); | ||
| } | ||
|
|
||
| batchNormalization( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are moving towards a fully modular setup - would you mind extracting batchNormalization into its own kernel file, similar to the Square op? WebGL and CPU backends have not yet done this, but they will.
| * ============================================================================= | ||
| */ | ||
|
|
||
| import {KernelConfig, FusedBatchNorm, FusedBatchNormInputs,FusedBatchNormAttrs, Tensor} from '@tensorflow/tfjs-core'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing FusedBatchNormInputs and FusedBatchNormAttrs? These shoube be defined in the tfjs-cores. Any why fused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was defined in https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/kernel_names.ts#L36 and it has the same inputs as batchNorm so I reuse it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are trying to match TensorFlow for kernel names as much as possible - the batchNormalization kernel is called FusedBatchNorm in TensorFlow.
| import {WebGPUBackend} from '../backend_webgpu'; | ||
| import {BatchNormProgram} from './batchnorm_webgpu'; | ||
|
|
||
| export const batchNromConfig : KernelConfig = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename this file to be FusedBatchNorm.ts? See our WASM kernel for reference. This matches TensorFlow's kernel naming scheme.
| batchNormInputs.push(scale as Tensor); | ||
| } | ||
| const program = new BatchNormProgram( | ||
| x.shape,mean.shape,variance.shape,offsetShape,scaleShape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you install our code formatter clang? Otherwise whitespace diffs will appear in later commits :)
Details for installation here: https://github.com/tensorflow/tfjs/blob/master/DEVELOPMENT.md
| const {x} = inputs as RsqrtInputs; | ||
| const webGPUBackend = backend as WebGPUBackend; | ||
| const program = new UnaryOpProgram(x.shape, RSQRT); | ||
| return webGPUBackend.compileAndRun(program, [x as Tensor]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cast necessary?
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few more comments!
I fixed some problems, please review. |
| this.dispatchLayout = {x: [1, 2], y: [0], z: []}; | ||
| setOutput = 'setOutput(coords[0], coords[1], coords[2], value);'; | ||
| break; | ||
| case 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like case 4 and default are the same - can you use fallthrough syntax to avoid duplication?
| default: | ||
| this.dispatchLayout = {x: [1, 2], y: [0], z: [3]}; | ||
| setOutput = | ||
| 'setOutput(coords[0], coords[1], coords[2], coords[3],value);'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use if/else rather than switch? Also we know that the output and x are 4D - can we avoid this logic altogether? We often reshape tensors before passing them to kernels to make it possible to simplify shader code like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fix the syntax. I think this logic needs to be preserved. The x and output may be 2D, 3D, or 4D because ops is modified in #2996 ; batchnorm_2d, batchnorm_3d, and batchnorm_4d will call this kernel.
|
@NALLEIN A few more tiny comments! |
|
|
||
| import {BatchNormProgram} from './batchnorm_webgpu'; | ||
|
|
||
| export const fusedBatchNromConfig: KernelConfig = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: fusedBatchNormConfig
| // List all kernel configs here | ||
| const kernelConfigs: KernelConfig[] = | ||
| [squareConfig, squaredDifferenceConfig]; | ||
| [squareConfig, squaredDifferenceConfig, fusedBatchNromConfig]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: fusedBatchNormConfig
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay on this... just two more typos and we'll be all set!
|
Oh also if you wouldn't mind merging in master - we might have a few overlapping changes! |
Sorry,I've fixed the name issue. |
Fix #2358
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is