-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Support WEBGPU_PRINT_SHADER #7523
Conversation
@@ -56,6 +58,7 @@ const TUNABLE_FLAG_NAME_MAP = { | |||
if (tf.engine().backendNames().includes('webgpu')) { | |||
TUNABLE_FLAG_NAME_MAP['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE'] = | |||
'deferred submit batch size'; | |||
TUNABLE_FLAG_NAME_MAP['WEBGPU_PRINT_SHADER'] = 'Print shader(s)'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just “Print shader” to be consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
pipeline = webgpu_program.compileProgram( | ||
this.device, program, inputsData, output); | ||
this.device, program, inputsData, output, printShader); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add another parameter "printShader" here couldn't save the checks. We may just put the check in compileProgram().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -65,6 +65,10 @@ export const compileProgram = | |||
layout: 'auto' | |||
}); | |||
|
|||
if (printShader) { | |||
console.log(program.constructor.name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can put some decoration around the name so that it can be easily found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
/** | ||
* Whether print shader. | ||
*/ | ||
ENV.registerFlag('WEBGPU_PRINT_SHADER', () => false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to wait for the string flag to be landed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this change depends on #7509
cf99132
to
a61a2c7
Compare
|
||
const printShaderArray = printShaderString.split(','); | ||
if (printShaderString === 'all' || | ||
printShaderArray.some(item => program.shaderKey.includes(item))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to convert program.shaderKey
to lower case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks
throw new Error( | ||
`WEBGPU_PRINT_SHADER doesn't support value ` + | ||
`'${printShaderString}'. ` + | ||
`It should be 'all' or program name`); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can the user know the supported program name
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this to shader key.
63e1ace
to
4eb15f5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay!
The changes in WebGPU backend looks good. But I don't like the changes related to untunableFlagConfig.
@@ -54,7 +54,7 @@ export interface WebGPUProgram { | |||
|
|||
export const compileProgram = | |||
(device: GPUDevice, program: WebGPUProgram, inputsData: InputInfo[], | |||
output: TensorInfo): GPUComputePipeline => { | |||
output: TensorInfo, key: string): GPUComputePipeline => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this to shaderKey
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -818,7 +818,7 @@ export class WebGPUBackend extends KernelBackend { | |||
pipeline = this.pipelineCache[key]; | |||
} else { | |||
pipeline = webgpu_program.compileProgram( | |||
this.device, program, inputsData, output); | |||
this.device, program, inputsData, output, key); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename this as shaderKey
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -113,6 +114,9 @@ <h2>TensorFlow.js Model Benchmark</h2> | |||
if (params.has('task')) { | |||
task = params.get('task'); | |||
} | |||
if (params.has('WEBGPU_PRINT_SHADER')) { | |||
untunableFlagsConfig['WEBGPU_PRINT_SHADER'] = params.get('WEBGPU_PRINT_SHADER'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just tf.env().setFlags() here and avoid all the changes related to untunableFlagsConfig?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this line https://github.com/tensorflow/tfjs/pull/7523/files#diff-622e3e249bd7bb5327a80ac8034614cdec1761a725466edce4dd57439d734fb8R756 will reset all env flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gyagp , PTAL
@@ -113,6 +114,9 @@ <h2>TensorFlow.js Model Benchmark</h2> | |||
if (params.has('task')) { | |||
task = params.get('task'); | |||
} | |||
if (params.has('WEBGPU_PRINT_SHADER')) { | |||
untunableFlagsConfig['WEBGPU_PRINT_SHADER'] = params.get('WEBGPU_PRINT_SHADER'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this line https://github.com/tensorflow/tfjs/pull/7523/files#diff-622e3e249bd7bb5327a80ac8034614cdec1761a725466edce4dd57439d734fb8R756 will reset all env flags.
@@ -818,7 +818,7 @@ export class WebGPUBackend extends KernelBackend { | |||
pipeline = this.pipelineCache[key]; | |||
} else { | |||
pipeline = webgpu_program.compileProgram( | |||
this.device, program, inputsData, output); | |||
this.device, program, inputsData, output, key); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -54,7 +54,7 @@ export interface WebGPUProgram { | |||
|
|||
export const compileProgram = | |||
(device: GPUDevice, program: WebGPUProgram, inputsData: InputInfo[], | |||
output: TensorInfo): GPUComputePipeline => { | |||
output: TensorInfo, key: string): GPUComputePipeline => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@axinging Can you remove the e2e related changes? The webgpu part LGTM. Then we can merge this PR. Offline synced with Yang, for e2e changes, we would like to support any env flag as the url flag. It will be very convenient for us if want to test any env flag which is not in the current tunable flags list. In this way, we don't have to add new flag one by one to e2e , such as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for splitting this!
To see the printed shader, turn on 'Verbose' in console.
Below is a test page(example url: index.html?WEBGPU_PRINT_SHADER=all, index.html?WEBGPU_PRINT_SHADER=binary, index.html?WEBGPU_PRINT_SHADER=binary,depth):
Bug: #7516
This change is