diff --git a/tfjs-backend-webgpu/src/kernels/Multinomial.ts b/tfjs-backend-webgpu/src/kernels/Multinomial.ts new file mode 100644 index 00000000000..3f5b0d4deaa --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Multinomial.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, Multinomial, MultinomialAttrs, MultinomialInputs, TensorInfo} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {MultinomialProgram} from '../multinomial_webgpu'; + +import {softmax} from './Softmax'; + +export function multinomial(args: { + inputs: MultinomialInputs, + backend: WebGPUBackend, + attrs: MultinomialAttrs +}): TensorInfo { + const {inputs, backend, attrs} = args; + const {logits} = inputs; + const {numSamples, seed, normalized} = attrs; + + const probs = normalized ? + logits : + softmax( + {inputs: {logits}, backend, attrs: {dim: logits.shape.length - 1}}); + const batchSize = probs.shape[0]; + const numOutcomes = probs.shape[1]; + const program = new MultinomialProgram(batchSize, numSamples); + const uniformData = + [{type: 'float32', data: [seed]}, {type: 'int32', data: [numOutcomes]}]; + const res = backend.runWebGPUProgram(program, [probs], 'int32', uniformData); + if (!normalized) { + backend.disposeData(probs.dataId); + } + return res; +} + +export const multinomialConfig: KernelConfig = { + kernelName: Multinomial, + backendName: 'webgpu', + kernelFunc: multinomial as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/multinomial_webgpu.ts b/tfjs-backend-webgpu/src/multinomial_webgpu.ts new file mode 100644 index 00000000000..2f50abdc35e --- /dev/null +++ b/tfjs-backend-webgpu/src/multinomial_webgpu.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class MultinomialProgram implements WebGPUProgram { + variableNames: string[] = ['probs']; + outputShape: number[] = []; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + uniforms = 'seed : f32, numOutcomes: i32,'; + workgroupSize: [number, number, number] = [64, 1, 1]; + size = true; + + constructor(batchSize: number, numSamples: number) { + this.outputShape = [batchSize, numSamples]; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + + this.shaderKey = 'multinomial'; + } + + getUserCode(): string { + const userCode = ` + //Based on the work of Dave Hoskins + //https://www.shadertoy.com/view/4djSRW + fn random (seed : f32, resultUV : vec2) -> f32 { + let HASHSCALE1 = 443.8975; + let p = resultUV * seed; + var p3 = fract(vec3(p.xyx) * HASHSCALE1); + p3 = p3 + dot(p3, p3.yzx + 19.19); + return fract((p3.x + p3.y) * p3.z); + } + + ${main('index')} { + if (index < uniforms.size) { + let coords = getOutputCoords(); + let batch = coords[0]; + + let resUV = vec2(f32(coords[1]) / f32(uniforms.outShape[1]), + f32(coords[0]) / f32(uniforms.outShape[0])); + let r = random(uniforms.seed, resUV); + var cdf = 0.0; + for (var i = 0; i < uniforms.numOutcomes - 1; i = i + 1) { + cdf = cdf + getProbs(batch, i); + + if (r < cdf) { + setOutputAtIndexI32(index, i); + return; + } + } + + // If no other event happened, last event happened. + setOutputAtIndexI32(index, uniforms.numOutcomes - 1); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index eeaa0f5e0e0..a08fb1ac8fc 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -98,6 +98,7 @@ import {minConfig} from './kernels/Min'; import {minimumConfig} from './kernels/Minimum'; import {mirrorPadConfig} from './kernels/MirrorPad'; import {modConfig} from './kernels/Mod'; +import {multinomialConfig} from './kernels/Multinomial'; import {multiplyConfig} from './kernels/Multiply'; import {negConfig} from './kernels/Neg'; import {nonMaxSuppressionV3Config} from './kernels/NonMaxSuppressionV3'; @@ -238,6 +239,7 @@ const kernelConfigs: KernelConfig[] = [ minimumConfig, mirrorPadConfig, modConfig, + multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 4374ad489fc..55f48db9a8d 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -257,7 +257,6 @@ const TEST_FILTERS: TestFilter[] = [ 'maxPool3dBackprop ', 'maxPoolBackprop ', 'maxPoolWithArgmax ', - 'multinomial ', 'raggedGather ', 'raggedRange ', 'raggedTensorToTensor ',