Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Support WEBGPU_PRINT_SHADER #7523

Merged
merged 2 commits into from
Apr 10, 2023
Merged

Conversation

axinging
Copy link
Contributor

@axinging axinging commented Mar 28, 2023

To see the printed shader, turn on 'Verbose' in console.
Below is a test page(example url: index.html?WEBGPU_PRINT_SHADER=all, index.html?WEBGPU_PRINT_SHADER=binary, index.html?WEBGPU_PRINT_SHADER=binary,depth):

async function testWebGPUPrintShader() {
  tf.env().set('WEBGPU_CPU_FORWARD', false);
  await tf.setBackend('webgpu');
  await tf.ready();
  const re = getURLState(location.search);
  tf.env().set('WEBGPU_PRINT_SHADER', re);
  console.log(tf.env().get('WEBGPU_PRINT_SHADER'));
  // depthwise, matches 'depth'.
  {
    const fSize = 2;
    const pad = 'valid';
    const stride = 1;
    const chMul = 1;
    const inDepth = 1;

    const x = tf.tensor4d(
        [
          0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641,
          0.0741907, 0.409265, 0.351377
        ],
        [1, 3, 3, inDepth]);
    const w = tf.tensor4d(
        [0.303873, 0.229223, 0.144333, 0.803373],
        [fSize, fSize, inDepth, chMul],
    );

    const result = tf.depthwiseConv2d(x, w, stride, pad);
  }

  // add(sub,mul), matches 'binary'(Full binary list: https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-webgpu/src/binary_op_util.ts).
  {
    const a = tf.tensor2d([1, 2], [1, 2]);
    const b = tf.tensor2d([1, 2], [1, 2]);
    const c = tf.add(a, b);
  }

  // maxPool, matches 'pool'.
  {
    const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 9, 8], [3, 3, 1]);

    const result = tf.maxPool(x, 2, 1, 0);
  }
}

function getURLState(url) {
  let params = new URLSearchParams(url);
  const keys = [...params.keys()];
  if (keys.length === 0) return '';
  let printShaderString = '';
  if (params.has('WEBGPU_PRINT_SHADER')) {
    printShaderString = params.get('WEBGPU_PRINT_SHADER');
  }
  return printShaderString;
}

Bug: #7516


This change is Reviewable

@axinging axinging changed the title Support WEBGPU_PRINT_SHADER [webgpu] Support WEBGPU_PRINT_SHADER Mar 28, 2023
@@ -56,6 +58,7 @@ const TUNABLE_FLAG_NAME_MAP = {
if (tf.engine().backendNames().includes('webgpu')) {
TUNABLE_FLAG_NAME_MAP['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE'] =
'deferred submit batch size';
TUNABLE_FLAG_NAME_MAP['WEBGPU_PRINT_SHADER'] = 'Print shader(s)';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just “Print shader” to be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pipeline = webgpu_program.compileProgram(
this.device, program, inputsData, output);
this.device, program, inputsData, output, printShader);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add another parameter "printShader" here couldn't save the checks. We may just put the check in compileProgram().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -65,6 +65,10 @@ export const compileProgram =
layout: 'auto'
});

if (printShader) {
console.log(program.constructor.name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can put some decoration around the name so that it can be easily found.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

/**
* Whether print shader.
*/
ENV.registerFlag('WEBGPU_PRINT_SHADER', () => false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to wait for the string flag to be landed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this change depends on #7509

@axinging axinging force-pushed the print_shader branch 2 times, most recently from cf99132 to a61a2c7 Compare March 30, 2023 01:34
@axinging axinging marked this pull request as ready for review March 30, 2023 01:44
@axinging
Copy link
Contributor Author

@gyagp @qjia7 @xhcao PTAL


const printShaderArray = printShaderString.split(',');
if (printShaderString === 'all' ||
printShaderArray.some(item => program.shaderKey.includes(item))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to convert program.shaderKey to lower case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

throw new Error(
`WEBGPU_PRINT_SHADER doesn't support value ` +
`'${printShaderString}'. ` +
`It should be 'all' or program name`);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the user know the supported program name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to shader key.

@axinging axinging force-pushed the print_shader branch 4 times, most recently from 63e1ace to 4eb15f5 Compare March 31, 2023 07:45
Copy link
Contributor

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay!
The changes in WebGPU backend looks good. But I don't like the changes related to untunableFlagConfig.

@@ -54,7 +54,7 @@ export interface WebGPUProgram {

export const compileProgram =
(device: GPUDevice, program: WebGPUProgram, inputsData: InputInfo[],
output: TensorInfo): GPUComputePipeline => {
output: TensorInfo, key: string): GPUComputePipeline => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to shaderKey

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -818,7 +818,7 @@ export class WebGPUBackend extends KernelBackend {
pipeline = this.pipelineCache[key];
} else {
pipeline = webgpu_program.compileProgram(
this.device, program, inputsData, output);
this.device, program, inputsData, output, key);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this as shaderKey

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -113,6 +114,9 @@ <h2>TensorFlow.js Model Benchmark</h2>
if (params.has('task')) {
task = params.get('task');
}
if (params.has('WEBGPU_PRINT_SHADER')) {
untunableFlagsConfig['WEBGPU_PRINT_SHADER'] = params.get('WEBGPU_PRINT_SHADER');
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just tf.env().setFlags() here and avoid all the changes related to untunableFlagsConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@axinging axinging left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gyagp , PTAL

@@ -113,6 +114,9 @@ <h2>TensorFlow.js Model Benchmark</h2>
if (params.has('task')) {
task = params.get('task');
}
if (params.has('WEBGPU_PRINT_SHADER')) {
untunableFlagsConfig['WEBGPU_PRINT_SHADER'] = params.get('WEBGPU_PRINT_SHADER');
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -818,7 +818,7 @@ export class WebGPUBackend extends KernelBackend {
pipeline = this.pipelineCache[key];
} else {
pipeline = webgpu_program.compileProgram(
this.device, program, inputsData, output);
this.device, program, inputsData, output, key);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -54,7 +54,7 @@ export interface WebGPUProgram {

export const compileProgram =
(device: GPUDevice, program: WebGPUProgram, inputsData: InputInfo[],
output: TensorInfo): GPUComputePipeline => {
output: TensorInfo, key: string): GPUComputePipeline => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@qjia7
Copy link
Contributor

qjia7 commented Apr 7, 2023

@axinging Can you remove the e2e related changes? The webgpu part LGTM. Then we can merge this PR.

Offline synced with Yang, for e2e changes, we would like to support any env flag as the url flag. It will be very convenient for us if want to test any env flag which is not in the current tunable flags list. In this way, we don't have to add new flag one by one to e2e , such as WEBGPU_PRINT_SHADER. For implementation, we can always treat unrecognized url flags as env flags and cache them like you did in this PR. You may need to carefully process the env flags if they are both in url flags and tunable flags list.

Copy link
Contributor

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for splitting this!

@gyagp gyagp merged commit 8db53b0 into tensorflow:master Apr 10, 2023
@axinging axinging deleted the print_shader branch May 9, 2023 05:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants