Skip to content

Commit

Permalink
[JS/WebGPU] Add GatherBlockQuantized op support (microsoft#21734)
Browse files Browse the repository at this point in the history
### Description
Add GatherBlockQuantized operator to JSEP.



### Motivation and Context
Gemma model requires this.
  • Loading branch information
satyajandhyala authored Aug 26, 2024
1 parent ad38212 commit af18824
Show file tree
Hide file tree
Showing 14 changed files with 567 additions and 14 deletions.
4 changes: 3 additions & 1 deletion js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ export class Tensor implements TensorInterface {
type !== 'int64' &&
type !== 'uint32' &&
type !== 'uint8' &&
type !== 'bool'
type !== 'bool' &&
type !== 'uint4' &&
type !== 'int4'
) {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
Expand Down
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherBlockQuantized | com.microsoft(1+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| Gelu | ai.onnx(20+); com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum';
import { expand } from './ops/expand';
import { fastGelu } from './ops/fast-gelu';
import { gather, parseGatherAttributes } from './ops/gather';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
Expand Down Expand Up @@ -96,6 +97,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
Expand Down
5 changes: 4 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s
throw new Error('bool must be vec4');
}
return ['u32', 'vec4<bool>'];

case DataType.int4:
return 'i32';
case DataType.uint4:
return 'u32';
default:
throw new Error(`Unknown data type: ${type}`);
}
Expand Down
196 changes: 196 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';

import {
createTensorShapeVariables,
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglValueType,
UniformsArrayType,
} from './common';

export interface GatherBlockQuantizedAttributes extends AttributeWithCacheKey {
gatherAxis: number;
quantizeAxis: number;
blockSize: number;
}

export const validateInputs = (inputs: readonly TensorView[], attributes: GatherBlockQuantizedAttributes): void => {
if (inputs.length < 3 || inputs.length > 4) {
throw new Error('GatherBlockQuantized requires 3 or 4 inputs.');
}
const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputs[0].dims.length);
const blockSize = attributes.blockSize;
const data = inputs[0];
const scales = inputs[2];
const zeroPoint = inputs.length === 4 ? inputs[3] : undefined;
if (
scales.dims.length !== data.dims.length ||
!data.dims
.map((d, i) => (i === quantizeAxis ? Math.ceil(d / blockSize) === scales.dims[i] : d === scales.dims[i]))
.reduce((a, b) => a && b, true)
) {
throw new Error(
'Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.',
);
}
// TODO Uncomment the following check once the test case creation code is fixed to create data correctly aligned.
// const indices = inputs[1];
// const validIndex = (index: number) => index >= 0 && index < data.dims[attributes.gatherAxis];
// if (indices.dataType === DataType.int32 && indices.getInt32Array().some((v) => !validIndex(v)) ||
// indices.dataType === DataType.int64 && indices.getBigInt64Array().some((v) => !validIndex(Number(v)))) {
// throw new Error('Indices must be within the bounds of the gatherAxis.');
// }
if (zeroPoint) {
if (zeroPoint.dataType !== data.dataType) {
throw new Error('Zero point must have the same data type as the input tensor.');
}
if (
zeroPoint.dims.length !== scales.dims.length ||
!zeroPoint.dims.map((d, i) => d === scales.dims[i]).reduce((a, b) => a && b, true)
) {
throw new Error(
'Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.',
);
}
}
};

const createGatherBlockQuantizedProgramInfo = (
inputs: readonly TensorView[],
attributes: GatherBlockQuantizedAttributes,
): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
const inputRank = inputShape.length;
const gatherAxis = ShapeUtil.normalizeAxis(attributes.gatherAxis, inputRank);
const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputRank);
const outputShape = inputShape.slice(0);
outputShape.splice(gatherAxis, 1, ...indicesShape);
const outputSize = ShapeUtil.size(outputShape);
const outputType = inputs[2].dataType;
const inputType = inputs[0].dataType;
const isSigned = inputType === DataType.int4; // input data type is either int4 or uint4.
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: quantizeAxis },
{ type: DataType.uint32, data: gatherAxis },
{ type: DataType.uint32, data: attributes.blockSize },
...createTensorShapeVariables(...inputs.map((input, _) => input.dims), outputShape),
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const zeroPoint =
inputs.length > 3 ? inputVariable('zeroPoint', inputs[3].dataType, inputs[3].dims.length) : undefined;
const output = outputVariable('output', outputType, outputShape.length);
const inputVariables = [data, indices, scales];
if (zeroPoint) {
inputVariables.push(zeroPoint);
}
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'quantize_axis', type: 'u32' },
{ name: 'gather_axis', type: 'u32' },
{ name: 'block_size', type: 'u32' },
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
let output_indices = ${output.offsetToIndices('global_idx')};
var indices_indices = ${indices.type.indices}(0);
${(() => {
if (indicesShape.length > 1) {
return `
for (var i: u32 = 0; i < ${indicesShape.length}; i++) {
let index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')};
${indices.indicesSet('indices_indices', 'i', 'index')};
}`;
} else {
return `indices_indices = ${output.indicesGet('output_indices', 'uniforms.gather_axis')};`;
}
})()};
var data_indices = ${data.type.indices}(0);
for (var i: u32 = 0; i < uniforms.gather_axis; i++) {
let index = ${output.indicesGet('output_indices', 'i')};
${data.indicesSet('data_indices', 'i', 'index')};
}
var index_from_indices = ${indices.getByIndices('indices_indices')};
if (index_from_indices < 0) {
index_from_indices += ${inputShape[gatherAxis]};
}
${data.indicesSet('data_indices', 'uniforms.gather_axis', 'u32(index_from_indices)')};
for (var i = uniforms.gather_axis + 1; i < ${outputShape.length}; i++) {
let index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)};
${data.indicesSet('data_indices', 'i', 'index')};
}
let data_offset = ${data.indicesToOffset('data_indices')};
let data_index = data_offset % 8;
// Convert 4-bit packed data to 8-bit packed data.
let packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')};
let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;
let quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data));
let quantized_data = quantized_data_vec[data_index / 2];
var scale_indices = data_indices;
let quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size;
${scales.indicesSet('scale_indices', 'uniforms.quantize_axis', 'quantize_axis_index')};
var scale = ${scales.getByIndices('scale_indices')};
${(() => {
if (!zeroPoint) {
return 'var zero_point = 0';
} else {
return `
let zero_point_indices = scale_indices;
let zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')};
let zero_point_index = zero_point_offset % 8;
let packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')};
let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;
let zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_points));
let zero_point = zero_point_vec[zero_point_index / 2];`;
}
})()};
let dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale;
${output.setByOffset('global_idx', 'dequantized_data')};
}`;
};
return {
name: 'GatherBlockQuantized',
shaderCache: {
hint: `${attributes.cacheKey};${inputs
.filter((_, i) => i !== 1)
.map((input) => input.dims.join('_'))
.join(';')}`,
inputDependencies: Array.from({ length: inputs.length }, (_v, _i) => 'rank'),
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: outputType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};

export const gatherBlockQuantized = (context: ComputeContext, attributes: GatherBlockQuantizedAttributes): void => {
const inputs = context.inputs;
validateInputs(inputs, attributes);
context.compute(createGatherBlockQuantizedProgramInfo(context.inputs, attributes));
};

export const parseGatherBlockQuantizedAttributes = (
attributes: Record<string, unknown>,
): GatherBlockQuantizedAttributes =>
createAttributeWithCacheKey({
blockSize: attributes.blockSize as number,
gatherAxis: attributes.gatherAxis as number,
quantizeAxis: attributes.quantizeAxis as number,
});
18 changes: 9 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
let pastSequenceLength = 0;
let maxSequenceLength = 0;
const headSize = Math.floor(hiddenSize / attributes.numHeads);
if (pastKey && pastValue) {
if (pastKey && pastValue && ShapeUtil.size(pastKey.dims) && ShapeUtil.size(pastValue.dims)) {
if (pastKey.dims.length !== 4) {
throw new Error('Input "past_key" is expected to have 4 dimensions');
}
Expand All @@ -107,12 +107,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
}
pastSequenceLength = pastKey.dims[2];
maxSequenceLength = pastKey.dims[2];
} else if (pastKey || pastValue) {
} else if ((pastKey && ShapeUtil.size(pastKey.dims)) || (pastValue && ShapeUtil.size(pastValue.dims))) {
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
}

let qkvFormat: AttentionQkvFormat;
if (key) {
if (key && ShapeUtil.size(key.dims) > 0) {
if (query.dims.length !== 3) {
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
}
Expand Down Expand Up @@ -159,7 +159,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
qkvFormat = AttentionQkvFormat.qkvBSN3H;
}

if (bias) {
if (bias && ShapeUtil.size(bias.dims) > 0) {
if (bias.dims.length !== 1) {
throw new Error('Input "bias" is expected to have 1 dimension');
}
Expand All @@ -174,7 +174,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
const totalSequenceLength = pastSequenceLength + kvSequenceLength;

let maskType: AttentionMaskType = AttentionMaskType.none;
if (keyPaddingMask) {
if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) {
maskType = AttentionMaskType.maskUnknown;
const maskDims = keyPaddingMask.dims;
if (maskDims.length === 1) {
Expand All @@ -194,7 +194,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr

let passPastInKv = false;
let vHiddenSize = hiddenSize;
if (value) {
if (value && ShapeUtil.size(value.dims) > 0) {
if (value.dims.length !== 3 && value.dims.length !== 4) {
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
}
Expand All @@ -220,11 +220,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr

const broadcastResPosBias = false;

if (keyPaddingMask) {
if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) {
throw new Error('Key padding mask is not supported');
}

if (attentionBias) {
if (attentionBias && ShapeUtil.size(attentionBias.dims) > 0) {
if (attentionBias.dims.length !== 4) {
throw new Error('Input "attention_bias" is expected to have 4 dimensions');
}
Expand Down Expand Up @@ -334,7 +334,7 @@ export const maybeTransposeToBNSHAndAddBias = (
// const newDims = [];

let reshapedInput = input;
if (!bias) {
if (!(bias && ShapeUtil.size(bias.dims) > 0)) {
if (input.dims.length === 3) {
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
Expand Down
4 changes: 3 additions & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB
type === 'int64' ||
type === 'uint32' ||
type === 'uint8' ||
type === 'bool';
type === 'bool' ||
type === 'uint4' ||
type === 'int4';

/**
* Map string data location to integer value
Expand Down
1 change: 1 addition & 0 deletions js/web/script/generate-webgpu-operator-md.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const MATCHERS = [
/class ONNX_OPERATOR_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersionStart>\d+),\s*(?<opsetVersionEnd>\d+),\s*(?<type>\w+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<type>\w+),\s*(?<op>\w+)\)/g,
/class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME\(\s*(?<ep>\w+),\s*(?<opsetDomain>\w+),\s*(?<opsetVersion>\d+),\s*(?<type1>\w+),\s*(?<type2>\w+),\s*(?<op>\w+)\)/g,
];
/* eslint-enable max-len */

Expand Down
Loading

0 comments on commit af18824

Please sign in to comment.