Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for transformers #375

Open
dontcallmedom opened this issue Apr 5, 2023 · 27 comments
Open

Support for transformers #375

dontcallmedom opened this issue Apr 5, 2023 · 27 comments

Comments

@dontcallmedom
Copy link
Contributor

While our draft charter says that the group:

priority on building blocks required by well-known model architectures such as recurrent neural network (RNN), long short-term memory (LSTM) and transformers

and while the first two are directly mentioned in WebNN, the latter aren't.

@anssiko anssiko added the v2 label Apr 5, 2023
@anssiko
Copy link
Member

anssiko commented Apr 5, 2023

My expectation is the WG will look at transformers and related requirements and gaps as part of its v2 feature work. We considered the initial CR "v1", so we're good to move here now.

@dani-lbnl
Copy link

It would be valuable to briefly state 2 main kinds of applications of transformers, namely for predictive and generative AI, and each one if not both will be included in the charter.

@anssiko
Copy link
Member

anssiko commented Apr 27, 2023

Support for transformers was discussed on today's call:
https://www.w3.org/2023/04/27-webmachinelearning-minutes.html#t11

The WG felt positive about the prospects of supporting transformers in WebNN and in accordance to the contribution guidelines decided to start explore applicable use cases in this issue first, then moving to investigation of samples models, cross-framework support and cross-platform implementability.

@anssiko
Copy link
Member

anssiko commented May 9, 2023

@fdwr has been working on a Chromium WebNN prototype fdwr/chromium-src-webnn-dml#1 to inform what additional operators are needed in WebNN to support a well-known generative AI model, Stable Diffusion. I expect this prototyping effort to help inform this discussion on use cases. I believe this prototype is WIP, so @fdwr feel free to drop a comment here when appropriate to share your findings.

@dani-lbnl do you have specific predictive or generative AI models in mind that are in use in your area of research? We could look into them more closely in a similar fashion.

@anssiko
Copy link
Member

anssiko commented May 19, 2023

We've discussed this topic on a few of our bi-weekly calls and so far proposed investigation paths include Stable Diffusion (informed by @fdwr's Chromium experiment), SegmentAnything (thanks @huningxin!), Transformers.js/HuggingFace's transformers. Want to propose something else? Drop a comment here.

We should use this issue to discuss the most promising use cases enabled by transformers that are a good fit to be run in the browser then decomp to see what new ops would be needed in WebNN. Based on this update the use cases in the spec as appropriate.

I'm proposing we try to identify a few key use cases first to keep the spec and implementation close to each other.

I'll try to keep this topic on our bi-weekly agenda so folks can also bring their input on the call.

@anssiko anssiko changed the title Mention transformer in use cases Support for transformers May 23, 2023
@anssiko
Copy link
Member

anssiko commented May 23, 2023

@xenova may be able to provide insights from Transformers.js :-)

@xenova
Copy link

xenova commented May 24, 2023

@xenova may be able to provide insights from Transformers.js :-)

Thanks for the ping! I'll list a few things that I have learnt/experienced while developing Transformers.js.


I'm proposing we try to identify a few key use cases first to keep the spec and implementation close to each other.

Current abilities

Transformers.js currently supports 17 different tasks in different modalities, including:

  • 📝 Natural Language Processing: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
  • 🖼️ Computer Vision: image classification, object detection, and segmentation.
  • 🗣️ Audio: automatic speech recognition.
  • 🐙 Multimodal: zero-shot image classification.

For the full list, see https://huggingface.co/docs/transformers.js/index#tasks

We're mainly focused on adding tasks which have text-based inputs at the moment, primarily due to processing limitations. Some of the other modalities work quite well (e.g., Whisper for speech-to-text; demo), while others (e.g., image segmentation) take much longer and are not suitable to CPU-based inference. Once WebGPU support is added (see here for progress), we'll continue adding the more "demanding" tasks, like text-to-image (e.g., stable diffusion).

Limitations

First, a discussion on the limits. The current maximum model sizes I have tested and have got working reliably are between 800M and 1B parameters. The main contributing factors are:

  1. Memory requirements for exporting with ONNX/optimum. I've run into OOM issues when exporting larger models (see here for the open bug report). For example, attempting to export llama-7B currently has some issues (like this)
  2. 2GB protobuf limit for ONNX (see here for more info), and we don't yet support the "external data" format.
  3. 4GB WASM limit (32-bit address space) means there is a hard limit for running models (for loading and for inference). There are some workarounds (especially when using WebGPU; see web-llm for examples).
  4. Willingness of users to download large models. Although models can be cached (i.e., download once, use forever), the lack of a unified caching system means there's no sharing of weights across sites. This leads to duplicate weights being stored in the same browser. Also, no one wants a website to just download model weights in the background while they simply browse. Websites should clearly indicate that they will download a model (including its size), and should only start after the user provides consent.

Focusing on NLP tasks, we've been able to get relatively large models running in the browser with onnxruntime-web (using their WASM backend). This includes:

However, in practice (especially due to current tech limitations), loading such large (general) models are better run outside of a browser. Once WebGPU becomes more standardized and approaches native performance (currently impacted by redundant bounds checking), this might change. But as of right now, I think it's best to focus on specific use cases.

Practical use cases

I've mentioned to a some people that the original reason I developed Transformers.js was out of a need to run a ML-powered chrome extension to block spam YouTube comments. I tested BERT, DistilBERT, and T5, and they all worked pretty well! There are some problems with multithreading in chrome extensions at the moment (like this), but once that is fixed, I think you'll see many more ML-powered chrome extensions which run locally in browsers.

Anyway, here are some actually useful ideas:

  1. Text classification - for similar use-cases to spam detection.
  2. Embeddings. We support sentence-transformers for computing embeddings client side. These vectors (typically 512 or 768 dimensions) can then be sent to some backend vector database for quick lookups (meaning the server doesn't have to perform embedding calculations). With a ton of hype currently around vector databases and semantic search (especially across modalities; e.g., text <-> image), this may prove fruitful in future.
  3. Automatic Speech Recognition: As shown above, getting whisper running in the browser (with pretty decent speeds, even on CPU) is a pretty neat example of running models which do not only use textual inputs. Due to the plethora of Web APIs that browsers provide to web developers in a sandboxed-environment, I anticipate greater adoption of multimodal models in the future.
  4. Text-to-text/text-generation. This may include summarization, translation, code completion. These are language models which are finetuned for a very specific use case (as opposed to general LLMs).

I look forward to seeing the progression of the WebNN standard, and I hope that one day we can add it as a backend of Transformers.js!

@anssiko
Copy link
Member

anssiko commented Jun 9, 2023

This topic was discussed on our 8 June 2023 call where @xenova gave a well-received presentation on Transformers.js (thanks!).

The WG had an active discussion around the presented Transformers.js-enabled use cases. Transformers.js demonstrates a number of transformer-centric generative models for various real-world tasks are now feasible in the browser:

  • text-to-text (translation, summarization)
  • constrained text-generation (code completion)
  • speech-to-text (automatic speech recognition)
  • image-to-text (captioning)
  • embeddings (semantic search, clustering, data analysis)
  • and more

These tasks will now inform the WG's WebNN v2 feature priorities similarly to how the majority of the WG's existing WebNN v1 use cases were informed by predictive ("old school AI") models when we initiated this effort. Notably, also many of the existing WebNN v1 use cases such as those NLP-related are now improved with transformers.

This issues remains open for further feedback, comments and contributions from other projects in this space. I expect the ongoing Stable Diffusion Chromium experiment to soon provide additional insights into text-to-image use case feasibility in the browser context.

Thank you for your continued contributions everyone! 🚀

@anssiko
Copy link
Member

anssiko commented Jun 29, 2023

We continued transformer-centric discussion on our 29 June 2023 call where @fdwr gave another well-received and informative presentation on Transformer models via WebNN in ORT & Chromium (thanks again!). We agreed to use this and the earlier Transformers.js presentation as input to inform our v2 op effort. We welcome further contributions from anyone interested in this space.

We discussed our intent to start with a tracking issue for v2 ops (we can reuse this issue or spin a new) and have op-specific detailed discussion in specific issues. We will use our contributing guidelines as the guide but on the high level we want to provide a list of proposed new ops and data types to support transformer-based generative AI use cases for key models. This allows us to seek broader review outside this group for the proposed expanded op set.

@fdwr
Copy link
Collaborator

fdwr commented Aug 11, 2023

so fdwr feel free to drop a comment here when appropriate to share your findings.

From this presentation and prototype IDL, these operators are needed for:

Elementwise comparison

equal

Compares two inputs of the same element data type and returns an 8-bit tensor with 0=false or 1=true. It follows standard IEEE rules for NaNs. Denormal/subnormal comparison behavior is unspecified, dependent on the device CPU/GPU/NPU.

partial interface MLGraphBuilder {
  MLOperand equal(MLOperand a, MLOperand b);
}

Pseudocode:

output[i] = (a[i] == b[i]) for each element.

Element data types:

  • a: all supported data types
  • b: all supported data types (same as a)
  • output: uint8

greater

partial interface MLGraphBuilder {
  MLOperand greater(MLOperand a, MLOperand b);
}

Pseudocode:

output[i] = a[i] > b[i] for each element.

Element data types:

  • a: all supported data types
  • b: all supported data types (same as a)
  • output: uint8

lesser

partial interface MLGraphBuilder {
  MLOperand lesser(MLOperand a, MLOperand b);
}

Pseudocode:

output[i] = a[i] < b[i] for each element.

Element data types:

  • a: all supported data types
  • b: all supported data types (same as a)
  • output: uint8

Alternate names?

  • more and less? (sounds more natural, such as "a is less than b" and "a is more than b", whereas other pairings like "a is weaker than b" and "a is more than b", or "a is less than b" and "a is greater than b" sound mismatched)

Elementwise logical functions/selection

logicalNot

Inverts every element of an 8-bit tensor, not to be confused with a bitwiseNot which inverts each bit. So input of 0 yields 1, and 1-255 yields 0.

Pseudocode:

output[i] = !!input[i] for each element.

Element data types:

  • input : uint8
  • output: uint8

(future: logicalAnd, logicalOr, logicalXor, bitwiseAnd, bitwiseOr, bitwiseXor...)

elementwiseIf / ternary select

A per-element immediate if (or ternary operator) that selects from one tensor or the other depending on the condition. The input to this is often the output from an earlier comparison operator.

partial interface MLGraphBuilder {
  MLOperand elementwiseIf(MLOperand condition, MLOperand trueValues, MLOperand falseValues);
}

Pseudocode:

output[i] = iif(condition[i], trueValues[i], falseValues[i]) for each element.

Decomposition: add(mul(trueValues, cast(condition, typeof(trueValues)), mul(falseValues, cast(logicalNot(condition), typeof(falseValues)))

Element data types:

  • condition: uint8 (0=false, !0=true)
  • trueValues: all supported data types
  • falseValues: all supported data types (same as a)
  • output: all supported data types (same as a)

Notes

Input tensors are broadcasted to the final output shape. So given condition.shape=[10,1], trueValues.shape=[10,1], and falseValues.shape=[1,10], then output.shape=[10,10]. References: NumPy, ONNX, TF, PyTorch

Alternate names: stablehlo.select, tosa.select, HLSL select, torch.where (a rather unintuitive name), ONNX.Where.

greaterOrEqual/lesserOrEqual

For set completeness, these two are worth considering too (not in the models, but missing them would be an awkward gap). Note greaterOrEqual can usually be implemented in terms of logicalNot(lesser(x)) (and similiarly lesserOrEqual as a decomposition of logicalNot(greater(x)), but they are not equivalent when NaN's are involved. Saying x >= NaN (false) is different from !(x < NaN) (true).

More elementwise unary operations

identity

Returns the input as-is. Although it's a nop copy, having this completes the set (every ML framework has one), provides a direct mapping for frameworks, and is a useful placeholder in more complex graphs that you can insert without the caller needing to stitching up topology (e.g. swapping out an activation function with a nop, or working around issues with split where multiple inputs have the same name). We've already encountered cases where having this would have been useful when mapping from ONNX Runtime to WebNN too.

partial interface MLGraphBuilder {
  MLOperand identity(MLOperand input);
}

Pseudocode:

output[i] = input[i] for each element.

Element data types:

  • input: all supported data types
  • output: all supported data types

Alternate names?

  • copy?

sqrt

Elementwise square root. See #438. Tis equivalent to pow(input, 0.5), but since sqrt is such a common case, for which many backends have optimized versions, having a dedicated operator is worthwhile. It also avoids callers needing to allocate tiny temporary scalar tensors for the 0.5.

partial interface MLGraphBuilder {
  MLOperand sqrt(MLOperand input);
}

Pseudocode:

output[i] = sqrt(input[i]) for each element.

Element data types:

  • input: float*
  • output: same as input

erf

The Gauss error function occurs frequently in probability statistics.

partial interface MLGraphBuilder {
  MLOperand erf(MLOperand input);
}

Pseudocode:

Polynomial expansion approximation

var a1 =  0.254829592
var a2 = -0.284496736
var a3 =  1.421413741
var a4 = -1.453152027
var a5 =  1.061405429
var p  =  0.3275911
var x  =  abs(input)
var t  =  1.0 / (1.0 + p * x)
var y  =  1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x * x)
output[i] = y * sign(input)

Element data types:

  • input: float*
  • output: same as input

reciprocal

This inverse is often used in conjunction with multiplication because it's faster than division. GPU's typically implement a dedicated "rcp" instruction which is actually how division is often achieved (rather than a dedicated division instruction). So supporting this operator directly allows more efficient mapping to hardware, and it avoids the extra scalar tensor allocation and broadcasting.

partial interface MLGraphBuilder {
  MLOperand reciprocal(MLOperand input);
}

Pseudocode:

output[i] = 1 / input[i]

Element data types:

  • input: float*
  • output: same as input

Reshaping operations

Reshaping operations do not modify the values, just reinterpret the elements with a new shape. Following are a class of operators that should either all be added to the set, or they should all be just resolved into the explicit shape by the caller and implemented via reshape. The current situation is a bit goofy where we have squeeze in the spec but not its matching counterpart unsqueeze (issue #296). Originally I leaned toward adding all 4 of these, and I implemented them in my fork, but now I'd rather the caller resolve them because distinct operators offer no hardware advantages, the caller implementation to resolve the shape is trivial, internally the other operators are immediately resolved into a reshape anyway, and it reduces test complexity for WebNN implementations.

squeeze

I recommend deleting this, having callers implement this themselves instead via reshape by resolving the shape with a few lines of caller code. That reduces higher level policy (such as the empty axis case which varies per framework) and WebNN testing complexity.

Removes dimensions of size 1 from the previous shape and reshapes. This is present in WebNN as of 2023-08-10 and misses its counterpart unsqueeze.

partial interface MLGraphBuilder {
  MLOperand squeeze(MLOperand input, optional MLSqueezeOptions options = {});
}

Pseudocode:

newShape = input.shape
if axes.empty then axes = [0,1,2,...]
foreach axis in reverse(sort(axes))
    if newShape[axis] == 1
        newShape.delete(axis)
    endif
endforeach
return reshape(input, newShape)

Element data types:

  • input: all supported data types
  • output: same as input

unsqueeze

I recommend callers implement this themselves instead via reshape by resolving the shape with a few lines of caller code, which reduces higher level policy and WebNN testing complexity.

Reinserts 1's into the new shape and reshapes.

partial interface MLGraphBuilder {
  MLOperand unsqueeze(MLOperand input, MLSqueezeOptions options);
}

Pseudocode:

newShape = input.shape
foreach axis in sort(axes)
    newShape.insert(axis, 1)
endforeach
return reshape(input, newShape)

Element data types:

  • input: all supported data types
  • output: same as input

flattenTo2d

I recommend callers implement this themselves instead via reshape by resolving the shape with a few lines of caller code, which reduces higher level policy and WebNN testing complexity.

Reshapes an ND shape to 2D, using the axis as a split point to flatten each subrange (e.g. given input.shape=[2,3,4,5] and axis=2, then output.shape=[6,20]; given input.shape=[2,3,4,5] and axis=0, then output.shape=[1,120]).

partial interface MLGraphBuilder {
  MLOperand flattenTo2d(MLOperand input, unsigned long axis);
}

Pseudocode:

inputRank = input.shape.size()
elementsBefore = reduceProd(input.shape[0..axis], 1)
elementsAfter  = reduceProd(input.shape[axis..inputRank], 1)
newShape = [elementsBefore, elementsAfter]
return reshape(input, newShape)

Element data types:

  • input: all supported data types
  • output: same as input

Data rearrangement operations

These project/rearrange data, extending the existing ones of the class (concat/slice/split/transpose/pad/...).

expand

Expands (or broadcasts) a small input to a large size, where any dimension of size 1 becomes the output shape. Consistent with broadcasting rules, the input dimensions must be 1 or must match the corresponding output dimensions.

partial interface MLGraphBuilder {
  MLOperand expand(MLOperand input, sequence<unsigned long> newShape);
}

Pseudocode:

foreach outputCoordinate in output
    // getInputCoordinate is a function that maps any output coordinate back to the corresponding
    // input coordinate, by setting any input coordinate to 0 if the input dimension has a size of 1,
    // meaning it was broadcasted.
    inputCoordinate = getInputCoordinate(outputCoordinate, input.shape, output.shape)
    output[outputCoordinate] = input[inputCoordinate]
endforeach

Element data types:

  • input: all supported data types
  • output: same as input

gather

Gathers various parts of the input using the indices.

dictionary MLGatherOptions {
  // TODO: Should axis just be a function parameter?
  // Sometimes it is (concat), but other times it is not (e.g. MLSplitOptions).
  // So what is the WebNN rule here?
  unsigned long axis = 0;
};

partial interface MLGraphBuilder {
  MLOperand gather(MLOperand input, MLOperand indices, optional MLGatherOptions options = {});
}

Pseudocode:

Complicated to fit entirely here when considering multiple dimensions, but for the simpler 1D case:

output[i] = input[indices[i]]

References:

Element data types:

  • input: all supported data types
  • indices: int64
  • output: same as input

(per Jiewei's comment below) Implementations must not read out of bounds, either clamping that read to the end of the tensor (e.g. value = input[min(elementOffset, lastValidOffset)]) or returning an out-of-bounds value like 0. When processing on devices like NPU's/GPU's, the data has already been submitted, and they do not "throw exceptions" deep inside during processing.

Normalization operations

When you look holistically at the zoo of normalization operations (layerNormalization, instanceNormalization, batchNormalization, grouped channel normalization...) you realize a few things:

  • They all involve mean&variance, with the difference being whether they are optionally passed as precomputed tensors or computed on demand from the input.
  • They are kinda confusingly named ("batch" normalization actually normalizes the spatial dimensions too, not just across the batches).
  • Frameworks differ in their treatment of the parameters even for similarly named operators (some taking a generic list of axes, sometimes a single axis, or other times the axis meaning being interpreted as the upper bound for a range of dimensions) meaning WebNN should target the more generic form.
  • They are functionally all the same thing, except for which axes they normalize. Their interchangeability is further demonstrated by the realization that one of these can implement the other one after a reshape/transpose.

This leads me to propose that the existing variants batchNormalization and instanceNormalization be deleted, and that layerNormalization (needed for SAM decode and SD Unet post-Olive) not be introduced. Instead we just have MVN with an axes parameter, and each of these are implemented in terms of MVN. This also simplifies other aspects. Different frameworks/libraries have differing rules about which shapes are compatible and how to align them, but being explicit at the lower WebNN level eliminates higher level policies around remapping 1D tensors and even eliminates the need for MLInputOperandLayout in instanceNormalization, a concept that need not apply at this low level after realizing there are no special semantics to dimensions here - they're all just numbered axes.

If backends have optimized implementations for specific axes, mapping the axes field to them is trivial. Otherwise, there remains a decomposition.

image

meanVarianceNormalization

dictionary MLMeanVarianceNormalizationOptions {
  MLOperand mean;
  MLOperand variance;
  MLOperand scale;
  MLOperand bias;
  float epsilon = 1e-5;
  sequence<[EnforceRange] unsigned long> axes;
};

partial interface MLGraphBuilder {
  // implements instanceNormalization/layerNormalization/batchNormalization/grouped channel normalization
  MLOperand meanVarianceNormalization(MLOperand input, optional MLMeanVarianceNormalizationOptions options = {});
}

Pseudocode:

func meanVarianceNormalization(input, mean, variance, scale, bias, reducedAxes)
    if no mean
        mean = reduceMean(input, reducedAxes, keepDimensions = true)
    endif
    inputRecentered = input - mean
    if no variance
        variance = reduceMean(inputRecentered * inputRecentered, reducedAxes, keepDimensions = true)
    endif
    return inputRecentered / sqrt(variance + epsilon) * scale + bias
endfunc

MVN is not limited to 4D (it can take any dimension count, just like reduce), but for example's sake, given a 4D tensor, the following mappings apply:

// "layer" normalization:
// input.shape = [N,C,H,W]
// scale.shape = input.shape[firstAxis..input.rank]
//          so = [N,C,H,W] if firstAxis = 0
//             =   [C,H,W] if firstAxis = 1
//             =     [H,W] if firstAxis = 2
//             =       [W] if firstAxis = 3
// bias.shape  = input.shape[firstAxis..input.rank]
func onnxLayerNormalization(input, scale, bias, firstAxis = 1)
    // Reduce along dimensions from the first axis onward.
    // So for the default axis = 1, that means channel, width, and height (independently per batch).
    // For 4D and axis = 1, axes = [1,2,3].
    // For 5D and axis = 1, axes = [1,2,3,4].
    // For 3D and axis = 2, axes = [2].
    reducedAxes = getSequence(firstAxis, input.rank)

    // Reshape scale and bias to same rank as input, directly broadcastable without realignment.
    // e.g. Given input shape [2,3,4,5] and axis = 2 yields [1,1,4,5].
    shapeCompatibleToInput = keepShapeGivenAxes(input.shape, reducedAxes)
    scale = reshape(scale, shapeCompatibleToInput)
    bias = reshape(bias, shapeCompatibleToInput)

    return meanVarianceNormalization(input, none, none, scale, bias, reducedAxes)
endfunc
// "instance" normalization:
// input.shape = [N,C,H,W]
// scale.shape = [C]
// bias.shape  = [C]
func instanceNormalization(input, scale, bias)
    // Reduce along width and height (independently per batch and channel).
    // For 4D, axes = [2,3]
    // For 5D, axes = [2,3,4]
    reducedAxes = getSequence(2, input.rank)

    // Reshape scale and bias compatibly to input rather than 1D tensors of size C.
    // e.g. [3] yields [1, 3, 1, 1]
    shapeCompatibleToInput = keepShapeGivenAxes(input.shape, [1])
    scale = reshape(scale, shapeCompatibleToInput)
    bias = reshape(bias, shapeCompatibleToInput)

    meanVarianceNormalization(input, none, none, scale, bias, axes)
endfunc
// "batch" normalization:
// input.shape = [N,C,H,W]
// scale.shape = input.shape[featureAxis..featureAxis+1]
//          so = [N] if feature axis = 0
//             = [C] if feature axis = 1
//             = [H] if feature axis = 2
//             = [W] if feature axis = 3
// bias.shape  = input.shape[featureAxis..featureAxis+1]
func batchNormalization(input, mean, variance, scale, bias, featureAxis = 1)
    // Given feature axis = 1, reduce along batch, width, and height (independently per channel).
    // For 4D and axis = 1, axes = [0,2,3].
    // For 5D and axis = 2, axes = [0,1,3,4].
    reducedAxes = getComplementSequence([featureAxis], inputRank)

    // Reshape scale, bias, and others compatibly to input rather than 1D tensors of size input.shape[axis].
    // e.g. [3] and axis = 2 yields [1, 1, 3, 1]
    shapeCompatibleToInput = keepShapeGivenAxes(input.shape, [featureAxis])
    scale = reshape(scale, shapeCompatibleToInput)
    bias = reshape(bias, shapeCompatibleToInput)

    return meanVarianceNormalization(input, mean, variance, scale, bias, reducedAxes)
endfunc
// input.shape = [N,C,H,W]
// scale.shape = [channelGroupCount]
// bias.shape  = [channelGroupCount]
func groupedChannelNormalization(input, scale, bias, channelGroupCount)
    // Reduce along subchannel, width, and height (independently per batch and per channel group).
    reducedAxes = [2,3,4] // subchannel, height, width

    // Reshape the input so we can virtually access subchannels within C, reshaping from 4D up to 5D.
    oldShape = input.shape
    newShape = [oldShape.N, channelGroupCount, oldShape.C / channelGroupCount, oldShape.height, oldShape.width]
    reshapedInput = reshape(input, newShape)
  
    // Reshape scale, bias, and others compatibly to the new input rather than 1D tensors of size channelGroupCount.
    // e.g. Input original shape = [2,12,5,6]
    //      Input new shape      = [2,3,4,5,6]
    //      channelGroupCount    = 3
    //      ->
    //      compatible shape     = [1,3,1,1,1]
    shapeCompatibleToInput = keepShapeGivenAxes(newShape, [1])
    scale = reshape(scale, shapeCompatibleToInput)
    bias = reshape(bias, shapeCompatibleToInput)

    return meanVarianceNormalization(reshapedInput, none, none, scale, bias, reducedAxes)
endfunc

Normalization mini-helpers:

// Returns a sequence of contiguous values from inclusive start to exclusive end.
// For start = 2, end = 4, you get [2,3].
// For start = 1, end = 5, you get [1,2,3,4].
func getSequence(inclusiveStart, exclusiveEnd)
    newValues = []
    for i = 0 up to exclusiveEnd
        newValues.append(i)
    endfor
    return newValues
endfunc

// Returns a complementary sequence to the given values (the opposite of what was passed).
// For rank 4 and values = [1,3], you get [0,2].
// For rank 5 and values = [2], you get [0,1,3,4].
func getComplementSequence(values, exclusiveEnd)
    complementValues = []
    for i in exclusiveEnd
        if not i in values
            complementValues.append(i)
        end
    endfor
    return complementValues
endfunc

// Keeps the all shape values that are in axes, masking any others to 1's.
// This is useful for creating a broadcast compatible shape.
// e.g. shape = [2,3,4,5], axes = [1], yields [1,3,1,1]
// e.g. shape = [3,4,5], axes = [0,2], yields [3,1,5]
func keepShapeGivenAxes(shape, axes)
    newShape = shape
    for i in shape.rank
        if not i in axes
            newShape[i] = 1 // Not in axes, so force to 1. Otherwise preserve.
        endif
    endfor
endfunc

Element data types:

  • input: float*
  • output: same as input

Notes:

  • If the optional mean and variance are precomputed, then they must both be supplied (it's an error to only supply one or the other).
  • All input tensors must be broadcast compatible with each other. There are no higher level policies about reshaping 1D arrays to specific semantic dimensions.
  • Tensors are not limited to 4D, and there is no semantic meaning assigned to channels (4D usage above is just for simpler illustration).

Index seeking operations

argMin/argMax

Returns the index along the axis (not the absolute element index) of the first greatest or smallest element. Normally the scan proceeds forward, increasing along the axis until finding a better match. Some implementations in the case of ties will return the first match (TensorFlow), whereas others will return the last tie (PyTorch). The selectLastIndex field reverses the search direction.

dictionary MLArgMinMaxOptions {
  unsigned long axis = 0;
  boolean keepDimensions = false;
  boolean selectLastIndex = false; // selects the *last* index rather than the first find along axis
  //  NAMING: Maybe an enum for scanDirection or tieBreakerFirst/Last would be clearer?
};
partial interface MLGraphBuilder {
  MLOperand argMax(MLOperand input, optional MLArgMinMaxOptions options = {});
}
partial interface MLGraphBuilder {
  MLOperand argMin(MLOperand input, optional MLArgMinMaxOptions options = {});
}

Element data types:

  • input: all supported data types
  • output: int64

Miscellaneous

cast

Casts each element to the target element data type. All combinations are possible, and implementations should provide as accurate as possible a conversion (as if a static cast was executed directly from the source type to the destination type), but the rules are not specified exactly how rounding occurs and whether denorms are zeroed, as GPU/NPU hardware differs (there are however expected tolerances).

partial interface MLGraphBuilder {
  MLOperand cast(MLOperand input, MLOperandDataType operandDataType);
}

Element data types:

  • input: any
  • output: any

Notes

Since supporting every (source data type) x (target data type) incurs a cross product of implementations, backends are welcome to mitigate the number of combinations in any way such that the same result would be achieved as if there was a direct conversion, including intermediate conversions. e.g. if the source was uint8, it could be upconverted to uint32, and then converted from uint32 to float32.

fillSequence

Fills a sequence into the output tensor, starting from a lower bound and incrementing by the delta. The tensor is filled sequentially as if it was a single long 1D tensor. So a [3,2] tensor with start=4 and delta=2 would have values [[4,6],[8,10],[12,14]]. Notice this operator has no input tensor, purely an output, which implies as a consequence (since you can legally have a graph with only this one operator in it) that graphs can legally have 0 bound inputs.

dictionary MLFillSequenceOptions {
  float start = 0;
  float delta = 1;
};

partial interface MLGraphBuilder {
  MLOperand fillSequence(MLOperandDataType operandDataType, sequence<unsigned long> outputShape, optional MLFillSequenceOptions options = {});
}

Element data types:

  • input: any
  • output: any

triangularMatrix

Keeps either the upper half or lower half of a diagonal triangular matrix, filling the other half with zeros.

// If upper, then the upper diagonal is preserved, and the bottom half is filled with zeros.
enum MLTriangularPart { "upper", "lower" };

dictionary MLTriangularMatrixOptions {
  MLTriangularPart triangularPart;
  long diagonalDelta;
  // diagonalDelta is a horizontal shift. So positive delta means that for an the upper triangular matrix,
  // the value mask is shifted rightward.
};

partial interface MLGraphBuilder {
  MLOperand triangularMatrix(MLOperand input, optional MLTriangularMatrixOptions options = {});
}
for every coordinate
    if mode enum keep upper half ◥
        preserveInput = ((coordinate.x - coordinate.y) >= shift)
    else mode enum keep lower half ◣
        preserveInput = ((coordinate.x - coordinate.y) < shift)
    endif
    output[coordinate] = preserveInput ? input[coordinate] : 0
endfor

Related:

Element data types:

  • input: float*
  • output: same as input

shape() method

Reports the current shape of an MLOperand from the MLGraphBuilder. See details here for justification: #379.

interface MLOperand {
   sequence<unsigned long> shape(MLOperand input);
}

New data types

The lack of int64/uint64 caused pain since all the indices (gather, argMin, argMax...) in the models were 64-bit, and even if we forced those relevant operators to int32/uint32 by casting at runtime, the cast operator itself had no way to express such a conversion without the enum. So, int64/uint64 type support was added to the fork.

enum MLOperandDataType { // formerly MLOperandType: https://github.com/webmachinelearning/webnn/issues/454
...
  "int64",
  "uint64",
...
};

Backends which do not support int64/uint64 natively can emulate them. It's not expected that all operators would support them, just the basic arithmetic ones (add/sub/mul/div/...) and data manipulation ones (slice/concat/transpose/...). It's also not expected that just because the indices are 64-bit that individual tensors will be greater than 4GB, as there may be other limitations in the underlying hardware/API's. As a work-around for missing arithmetic support, backends could read only the lower 32-bits of elements and use a doubled-strides trick, but it's not recommended. The ORT DirectML execution provider originally did this (which worked okay for a few years, but it became increasingly problematic as models grew over time). So it's wiser to emulate it via two reads of lower and upper 32-bits.

Data type support in libraries

  • Apple MPS (Metal Performance Shaders) - https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphtensor/3564663-datatype
    • float16
    • int16
    • int8
    • normalizedBit
    • signedBit
    • uInt16
    • uInt32
    • uInt8
    • unorm1
    • unorm8
    • alternateEncodingBit
    • bFloat16
    • bool
    • complexBit
    • complexFloat16
    • complexFloat32
    • int32
    • int64
    • uInt64
  • Apple BNNS (Basic Neural Network Subroutines) - bnns_constants.h
    • BNNSDataTypeFloat16 - 16-bit half precision floating point
    • BNNSDataTypeFloat32 - 32-bit single precision floating point
    • BNNSDataTypeBFloat16 - 16-bit brain floating Point
    • BNNSDataTypeIntBit - Common bit to signed integer types, this constant is not a valid type
    • BNNSDataTypeInt8 - 8-bit signed integer
    • BNNSDataTypeInt16 - 16-bit signed integer
    • BNNSDataTypeInt32 - 32-bit signed integer
    • BNNSDataTypeInt64 - 64-bit signed integer
    • BNNSDataTypeUIntBit - Common bit to unsigned integer types, this constant is not a valid type
    • BNNSDataTypeUInt8 - 8-bit unsigned integer
    • BNNSDataTypeUInt16 - 16-bit unsigned integer
    • BNNSDataTypeUInt32 - 32-bit unsigned integer
    • BNNSDataTypeUInt64 - 64-bit unsigned integer
    • BNNSDataTypeIndexed2 - 2-bit unsigned indices into a floating point conversion table (4 values)
    • BNNSDataTypeIndexed4 - 4-bit unsigned indices into a floating point conversion table (16 values)
    • BNNSDataTypeIndexed8 - 8-bit unsigned indices into a floating point conversion table (256 values)
    • BNNSDataTypeBoolean
  • Apple CoreML - https://chromium-review.googlesource.com/c/chromium/src/+/5075312/12/services/webnn/coreml/ModelFormat/FeatureTypes.proto
    • FLOAT32
    • DOUBLE
    • INT32
    • FLOAT16
  • NumPy - https://numpy.org/doc/stable/user/basics.types.html
    • numpy.bool_ - bool, Boolean (True or False) stored as a byte
    • numpy.byte - signed char
    • numpy.ubyte - unsigned char
    • numpy.short - short
    • numpy.ushort - unsigned short
    • numpy.intc - int
    • numpy.uintc - unsigned int
    • numpy.int_ - long
    • numpy.uint - unsigned long
    • numpy.longlong - long long
    • numpy.ulonglong - unsigned long long
    • numpy.half / numpy.float16 - Half precision float: sign bit, 5 bits exponent, 10 bits mantissa
    • numpy.single - float, Platform-defined single precision float: typically sign bit, 8 bits exponent, 23 bits mantissa
    • numpy.double - double, Platform-defined double precision float: typically sign bit, 11 bits exponent, 52 bits mantissa.
    • numpy.longdouble - long double, Platform-defined extended-precision float
    • numpy.csingle - float complex, Complex number, represented by two single-precision floats (real and imaginary components)
    • numpy.cdouble - double complex, Complex number, represented by two double-precision floats (real and imaginary components).
    • numpy.clongdouble - long double complex, Complex number, represented by two extended-precision floats (real and imaginary components).
  • PyTorch - https://pytorch.org/docs/stable/tensor_attributes.html
    • torch.float32 or torch.float - 32-bit floating point
    • torch.float64 or torch.double - 64-bit floating point
    • torch.complex64 or torch.cfloat - 64-bit complex
    • torch.complex128 or torch.cdouble - 128-bit complex
    • torch.float16 or torch.half - 16-bit floating point 1
    • torch.bfloat16 - 16-bit floating point 2
    • torch.uint8 - 8-bit integer (unsigned)
    • torch.int8 - 8-bit integer (signed)
    • torch.int16 or torch.short - 16-bit integer (signed)
    • torch.int32 or torch.int - 32-bit integer (signed)
    • torch.int64 or torch.long - 64-bit integer (signed)
    • torch.bool - Boolean
  • TensorFlow - https://www.tensorflow.org/api_docs/python/tf/dtypes
    • bfloat16 - brain floating point
    • bool - Boolean
    • complex128 - 128-bit complex
    • complex64 - 64-bit complex
    • double - 64-bit (double precision) floating-point
    • float16 - 16-bit (half precision) floating-point
    • float32 - 32-bit (single precision) floating-point
    • float64 - 64-bit (double precision) floating-point
    • half - 16-bit (half precision) floating-point
    • int16 - Signed 16-bit integer
    • int32 - Signed 32-bit integer
    • int64 - Signed 64-bit integer
    • int8 - Signed 8-bit integer
    • qint16 - Signed quantized 16-bit integer
    • qint32 - signed quantized 32-bit integer
    • qint8 - Signed quantized 8-bit integer
    • quint16 - Unsigned quantized 16-bit integer
    • quint8 - Unsigned quantized 8-bit integer
    • resource - Handle to a mutable, dynamically allocated resource
    • string - Variable-length string, represented as byte array
    • uint16 - Unsigned 16-bit (word) integer
    • uint32 - Unsigned 32-bit (dword) integer
    • uint64 - Unsigned 64-bit (qword) integer
    • uint8 - Unsigned 8-bit (byte) integer
    • variant - Data of arbitrary type (known at runtime)
  • StableHLO (high level optimizer) - https://github.com/openxla/stablehlo/blob/main/docs/spec.md#types
    • Boolean
    • ui4
    • ui8
    • ui16
    • ui32
    • ui64
    • i4
    • i8
    • i16
    • i32
    • i64
    • f8E4M3FN
    • f8E5M2
    • f8E4M3FNUZ
    • f8E5M2FNUZ
    • f8E4M3B11FNUZ
    • bf16
    • f16
    • f32
    • f64
  • TOSA - https://www.mlplatform.org/tosa/tosa_spec.html
    • bool_t, - to -, Boolean value. Size implementation defined. The TOSA reference model implements this as int8_t with 0 for false and 1 for true. All non-zero values are accepted on input as true.
    • i4_t, - to -, Signless 4-bit integer type. Will be interpreted as int4_t by all operators
    • int4_t, -7 to +7, Signed 4-bit two’s-complement value. Excludes -8 to maintain a symmetric about zero range for weights.
    • i8_t, - to -, Signless 8-bit integer value. Will be interpreted as int8_t unless otherwise specified by an operator.
    • int8_t, -128 to +127, Signed 8-bit two’s-complement value.
    • uint8_t, 0 to 255, Unsigned 8-bit integer value.
    • i16_t, - to -, Signless 16-bit integer type. Will be interpreted as int16_t unless otherwise specified by an operator.
    • int16_t, -32768 to +32767, Signed 16-bit two’s-complement value.
    • uint16_t, 0 to 65535, Unsigned 16-bit value.
    • i32_t, - to -, Signless 32-bit integer value. Will be interpreted as int32_t by all operators.
    • int32_t, -(1<<31) to (1<<31)-1, Signed 32-bit two’s-complement value.
    • i48_t, - to -, Signless 32-bit integer value. Will be interpreted as int48_t by all operators.
    • int48_t, -(1<<47) to (1<<47)-1, Signed 48-bit two’s-complement value.
    • fp16_t, -infinity to +infinity, 16-bit half-precision floating-point defined by Other publications[1]. Normal values must be supported. Denormal values must either be supported or flushed to zero. Positive and negative infinity must be supported. At least one NaN encoding must be supported. Signed zero must be supported.
    • bf16_t, -infinity to +infinity, 16-bit brain floating-point defined as bits [31:16] of the fp32_t format. Normal values must be supported. Denormal values must either be supported or flushed to zero. Positive and negative infinity must be supported. At least one NaN encoding must be supported. Signed zero must be supported.
    • fp32_t, -infinity to +infinity, 32-bit single-precision floating-point defined by Other publications[1]. Normal values must be supported. Denormal values must either be supported or flushed to zero. Positive and negative infinity must be supported. At least one NaN encoding must be supported. Signed zero must be supported.
    • fp64_t, -infinity to + infinity, 64-bit double-precision floating-point defined by Other publications[1]. Normal values must be supported. Denormal values must either be supported or flushed to zero. Positive and negative infinity must be supported. At least one NaN encoding must be supported. Signed zero must be supported.
  • XNNPack - https://github.com/google/XNNPACK/blob/03d2a24b53f18103a2bb0f62e2c7123c5cea8890/include/xnnpack.h#L211-L232
    • xnn_datatype_fp32 - IEEE754 single-precision floating-point.
    • xnn_datatype_fp16 - IEEE754 half-precision floating-point.
    • xnn_datatype_qint8 - Quantized 8-bit signed integer with shared per-Value quantization parameters.
    • xnn_datatype_quint8 - Quantized 8-bit unsigned integer with shared per-Value quantization parameters.
    • xnn_datatype_qint32 - Quantized 32-bit signed integer with shared per-Value quantization parameters.
    • xnn_datatype_qcint8 - Quantized 8-bit signed integer with shared per-channel quantization parameters.
    • xnn_datatype_qcint32 - Quantized 32-bit signed integer with shared per-channel quantization parameters.
    • xnn_datatype_qcint4 - Quantized 4-bit signed integer with shared per-channel quantization parameters.
    • xnn_datatype_qdint8 - Dynamically quantized 8-bit signed integer with per-batch quantization parameters.
  • ANN - https://developer.android.com/ndk/reference/group/neural-networks#operandcode
    • ANEURALNETWORKS_BOOL - An 8 bit boolean scalar value. Values of this operand type are either true or false. A zero value represents false; any other value represents true.
    • ANEURALNETWORKS_FLOAT16 - An IEEE 754 16 bit floating point scalar value.
    • ANEURALNETWORKS_FLOAT32 - A 32 bit floating point scalar value.
    • ANEURALNETWORKS_INT32 - A signed 32 bit integer scalar value.
    • ANEURALNETWORKS_TENSOR_BOOL8- A tensor of 8 bit boolean values. Values of this operand type are either true or false. A zero value represents false; any other value represents true.
    • ANEURALNETWORKS_TENSOR_FLOAT16 - A tensor of IEEE 754 16 bit floating point values.
    • ANEURALNETWORKS_TENSOR_FLOAT32 - A tensor of 32 bit floating point values.
    • ANEURALNETWORKS_TENSOR_INT32 - A tensor of 32 bit integer values.
    • ANEURALNETWORKS_TENSOR_QUANT16_ASYMM - A tensor of 16 bit unsigned integers that represent real numbers. Attached to this tensor are two numbers that can be used to convert the 16 bit integer to the real value and vice versa. These two numbers are: scale: a 32 bit floating point value greater than zero. zeroPoint: a 32 bit integer, in range [0, 65535]. The formula is: real_value = (integer_value - zeroPoint) * scale.
    • ANEURALNETWORKS_TENSOR_QUANT16_SYMM - A tensor of 16 bit signed integers that represent real numbers. Attached to this tensor is a number representing real value scale that is used to convert the 16 bit number to a real value in the following way: realValue = integerValue * scale. scale is a 32 bit floating point with value greater than zero.
    • ANEURALNETWORKS_TENSOR_QUANT8_ASYMM - A tensor of 8 bit unsigned integers that represent real numbers. Attached to this tensor are two numbers that can be used to convert the 8 bit integer to the real value and vice versa. These two numbers are: scale: a 32 bit floating point value greater than zero. zeroPoint: a 32 bit integer, in range [0, 255]. The formula is: real_value = (integer_value - zeroPoint) * scale.
    • ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED - A tensor of 8 bit signed integers that represent real numbers. Attached to this tensor are two numbers that can be used to convert the 8 bit integer to the real value and vice versa. These two numbers are: scale: a 32 bit floating point value greater than zero. zeroPoint: a 32 bit integer, in range [-128, 127]. The formula is: real_value = (integer_value - zeroPoint) * scale.
    • ANEURALNETWORKS_TENSOR_QUANT8_SYMM - A tensor of 8 bit signed integers that represent real numbers. Attached to this tensor is a number representing real value scale that is used to convert the 8 bit number to a real value in the following way: realValue = integerValue * scale. scale is a 32 bit floating point with value greater than zero.
    • ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL - A tensor of 8 bit signed integers that represent real numbers. This tensor is associated with additional fields that can be used to convert the 8 bit signed integer to the real value and vice versa. These fields are: channelDim: a 32 bit unsigned integer indicating channel dimension. scales: an array of positive 32 bit floating point values. The size of the scales array must be equal to dimensions[channelDim]. ANeuralNetworksModel_setOperandSymmPerChannelQuantParams must be used to set the parameters for an Operand of this type. The channel dimension of this tensor must not be unknown (dimensions[channelDim] != 0). The formula is: realValue[..., C, ...] = integerValue[..., C, ...] * scales[C] where C is an index in the Channel dimension.
    • ANEURALNETWORKS_UINT32 - An unsigned 32 bit integer scalar value.

Shape issues

We had several issues implementing these models due to WebNN casually treating 0D scalars as if they are 1D arrays (see details: #390). It's important that the semantic distinction be preserved between them.

TOSA/StableHLO Mappings

Most of these proposed operators are fairly primitive and have mappings to primitives in TOSA or StableHLO.

WebNN operator TOSA Stable HLO
argMax tosa.argmax ---
argMin tosa.argmin ---
cast tosa.cast stablehlo.convert
elementwiseIf tosa.select stablehlo.select
equal tosa.equal stablehlo.compare EQ
erf tosa.erf ---
expand tosa.tile(a, tosa.div(b.shape, a.shape)) stablehlo.broadcast_in_dim?
fillSequence --- stablehlo.iota (lacks start and step)
gather tosa.gather stablehlo.gather? (much more complicated)
greater tosa.greater stablehlo.compare GT
greaterOrEqual tosa.greater_equal stablehlo.compare GE
identity tosa.identity ---
lesser --- stablehlo.compare LT
lesserOrEqual --- stablehlo.compare LE
logicalNot tosa.logical_not stablehlo.not with bool
meanVarianceNormalization --- (nearest is stablehlo.batch_norm_inference)
reciprocal tosa.reciprocal ---
reshape tosa.reshape stablehlo.reshape
sqrt tosa.reciprocal(tosa.rsqrt) or tosa.pow stablehlo.sqrt
triangularMatrix --- ---
  • triangularMatrix has no direct/indirect mapping to either TOSA or StableHLO, but there is no known decomposition from smaller primitives.
  • meanVarianceNormalization has no exact mapping, but there exists an operator of equivalent complexity and similarity in stablehlo.batch_norm_inference.

Questions for readers:

  • Do you see any operators that make sense to add to this list for set completeness (even if not actually used in the models)?
  • Are there any naming incongruities you see intra-spec? As a global web standard, I expect naming to be more holistically and rigorously thought out than some existing libraries where operators were added more adhoc over time, but it's also important for framework developers to have a clear mapping from their library to WebNN, and so including alternately known names directly in the specification for trivial Ctrl+F searchability is wise.

Chai will create a review for it - we can comment on the details there...

@anssiko anssiko added the opset label Aug 11, 2023
@anssiko
Copy link
Member

anssiko commented Aug 24, 2023

The WG would like to add also text-to-text models.

@xenova, given your relevant Transformers.js experience, please feel free to propose text-to-text model(s) you think are the most appropriate targets.

(This topic was discussed at WebML WG Teleconference – 24 August 2023.)

@xenova
Copy link

xenova commented Aug 25, 2023

@xenova, given your relevant Transformers.js experience, please feel free to propose text-to-text model(s) you think are the most appropriate targets.

I'd love to! I'll break them up into text-to-text generation (encoder-decoder) models and text-generation (decoder-only) models:

Text-to-text generation (encoder-decoder)

Of the trending text-to-text models on the Hugging Face Hub, the majority are t5-based or some variation thereof (e.g., flan-t5-base, t5-v1_1, mt5-base). Even newer models like musicgen-small use a t5 text-encoder. Other non-t5 architectures include m2m100 (e.g., m2m100_418M) and bart (e.g., rebel-large).

Click to see image of most-downloaded text-to-text models in the past 30 days

image

see latest list: https://huggingface.co/models?pipeline_tag=text2text-generation&sort=downloads

In fact, I've actually already added support for each of these architectures (except musicgen) to Transformers.js, for tasks including translation, summarization, or even instruction finetuned text2text models.

Text-generation (decoder-only)

On the other hand, taking a look at the trending text-generation models, we unsurprisingly see a ton of llama models pop up: base models (e.g., llama-2-7b) as well as finetuned versions (e.g., Platypus2-70B-instruct) for conversational use-cases. However, in a web context, I haven't seen anything larger than 7 billion parameters run in the browser. To see what's currently "state-of-the-art", check out the Open LLM Leaderboard; you can also sort by model sizes (<1B, ~3B, and ~7B are most relevant to our discussions).

As a side note, for those unaware, other projects (particularly mlc.ai) have demonstrated that it is possible to run 7-billion parameter models in the browser with WebGPU, with 4-bit quantization.

Click to see image of most-downloaded text-generation models in the past 30 days

image

see latest list: https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads

On the smaller (and more reasonable/manageable) side of things, models like gpt2, bloom, gpt-neo are strong contenders for use in web browsers, and are very useful when finetuned on specific use-cases. My favorite use-case right now is code-completion. In these cases, the best models I've seen are: StarCoder, CodeGen, and DeciCoder. I created a little playground/demo to showcase these models being run in-browser with Transformers.js, which you can play around with here (or see the demo video).

@anssiko anssiko pinned this issue Aug 28, 2023
@anssiko
Copy link
Member

anssiko commented Sep 7, 2023

Per WebML WG Teleconference – 7 September 2023 discussions, the following architectures were proposed as additional targets to provide a baseline for a range of browser-friendly use cases in:
• Text-to-text generation (encoder-decoder) -> t5 and m2m100
• Text-generation (decoder-only) -> llama

Thanks @xenova for sharing all the insights and data from Hugging Face Hub that helps the group make informed decisions!

With the model targets in place and agreed, we will follow a similar path as with the first-wave models: do an op breakdown to better understand what is common across these architectures so we can make informed decisions on priorities. We'll conduct this evaluation in public and welcome contributions.

@Honry
Copy link
Contributor

Honry commented Sep 25, 2023

All these transformer models contain dynamic shape, and a good news that ONNX Runtime Web recently enabled a really useful feature: sessionOptions.freeDimensionOverrides, that supports dynamic shape models by setting free dimensions.

Besides, ONNX Runtime provides various graph optimizations to improve the performance. Such as constant folding, shape inference, node eliminations, node fusions and so on. This is enabled by default during initializing an inference session. I.e. The ONNX Runtime Web applies all enabled graph optimizations before performing model inference.

That means, the WebNN EP will actually run the optimized graphs rather than the original dynamic shape models. After graph optimization, I found number of nodes are eliminated or fused. Follow's the comparison table, you will have a completed view of ops change during graph optimization in all these transformer models.

transformer1

@Honry
Copy link
Contributor

Honry commented Oct 7, 2023

Please check the original source data of above spreadsheet from https://docs.google.com/spreadsheets/d/1ELfHuv2UqP2LoXWLgqsC0L8T_qqfBx48KxzFighl8d8/edit#gid=86906292.

Next step, I will integrate the data from @fdwr's comment to the table for TOSA and StableHLO mapping.

@wacky6
Copy link

wacky6 commented Oct 12, 2023

Regarding gather op:

This looks identical to index based accessor, and could lead to out-of-bounds read. I don't think Ahead-of-Time checks are possible because indices are dynamic.

I think the spec should explicitly define such behavior.

Backends seem to have different opinions on how this should be handled:

  • Return 0 (TF GPU)
  • Runtime error (TF CPU, torch, ONNX?)

I brought this up based on a recently published WebGPU security technical report: https://chromium.googlesource.com/chromium/src/+/main/docs/security/research/graphics/webgpu_technical_report.md

I'd recommend a read because it provides insight into what kind of security challenges we could face by exposing access to low level hardware (e.g. something can DMA).

My takeaways:

  • Insecure "third-party" (from a browser vendor's perspective) drivers/OS/firmware should be part of Web API's threat model, because these parts are complex.
  • Ideally an API should be "safe" to lower layers (e.g. OS and drivers), even if it's used inappropriately.
  • If the above is infeasible (i.e. had to include dynamic indices), runtime checks should be used before passing commands to lower layers (WebGPU does this in WGSL shader compiler). This will incur overhead.

@sushraja-msft
Copy link

sushraja-msft commented Oct 31, 2023

Just checking, regarding Sqrt - supported data type sections reads

Element data types:
input: all supported data types
output: all supported data types

Is that intentional or shouldn't this be just floats.

@fdwr
Copy link
Collaborator

fdwr commented Nov 1, 2023

regarding Sqrt ... Is that intentional or shouldn't this be just floats.

@sushraja-msft: Thanks - fixed typo.

aarongable pushed a commit to chromium/chromium that referenced this issue Nov 3, 2023
reciprocal

This CL implements directml nodes for unary operators logical not,
identity, erf, and reciprocal. Spec for these operators is available here
webmachinelearning/webnn#375 (comment)

Unit test is added for the operators, a follow up change will test corner
cases such as data type support, sqrt(-1) etc. that requires additional
test helper methods.

Bug: 1273291
Change-Id: I3ffbacdff7b7e0a0604c53c1869519bc3b761026
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4990981
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Sushanth Rajasankar <Sushraja@microsoft.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Alex Gough <ajgo@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1219715}
@anssiko
Copy link
Member

anssiko commented Nov 9, 2023

This issue has been and will be discussed actively on our calls to keep us focused on this priority task. Please continue to use this issue for feedback, suggestions, questions.

The WG currently believes the following 7 models across 4 broad modalities represent good targets for this push:

  • Text-to-image: stable-diffusion-v1-5
  • Image segmentation: segment-anything
  • Speech-to-text: whisper-tiny.en
  • Text generation: t5-small, m2m100_418M, gpt2, llama-2-7b

The WG also believes these models are implementable on major platforms and address diverse browser-friendly use cases with user demand. The WG participants continue implementation efforts to help inform this work. We may adjust this selection of models if new information emerges from these experiments.

Please see the op breakdown table for a detailed mapping of these models to proposed new ops. Included is also ONNX, TOSA, StableHLO op set mapping for comparison. Many thanks to @Honry for maintaining this table.

@tqchen
Copy link

tqchen commented Nov 9, 2023

Great to see this thread, as part of of WebLLM project https://github.com/mlc-ai/web-llm We are also building related compilation flows for WebGPU, with ability to run llama up to 70b(with latest M2 max) https://webllm.mlc.ai/

There are great synergies to webnn related projects that possibly enables future hybrid executions of models(e.g. webgpu for customized op and some through webnn)

aarongable pushed a commit to chromium/chromium that referenced this issue Nov 9, 2023
kSqrt

This change implements blink side changes to enable building ml graphs
with Erf, Identity, LogicalNot, Reciprocal operators.

The spec for these operators is available here
webmachinelearning/webnn#375 (comment)

Bug: 1273291
Change-Id: Idb6d6d82428f4773c782850908cf42ae8502943a
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5015083
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Commit-Queue: Sushanth Rajasankar <Sushraja@microsoft.com>
Cr-Commit-Position: refs/heads/main@{#1222600}
@anssiko
Copy link
Member

anssiko commented Nov 10, 2023

@tqchen thanks for your feedback! This WG has discussed WebLLM and is hugely inspired by the project. We'd be happy to have a high-bandwidth discussion with you to hear your learnings and suggestions around hybrid execution on one of our future bi-weekly calls when it fits your schedule. We meet Thu 7 am Pacific.

The feedback from Transformers.js, ONNX Runtime Web, TF.js and other frameworks have informed WebNN API development and direction. We'd like to get your first-hand insights considered too, including hearing about any browser-specific workarounds and optimizations you've had to make to get large-language models running in today's browser builds. We can help get any such issues looked at with the help of browser engineers who participate this WG.

Edit: Spun off into its own issue: #480

aarongable pushed a commit to chromium/chromium that referenced this issue Nov 14, 2023
This operator expands input shapes to new shapes, and the input shapes
must be broadcastable according to numpy-broadcasting-rule [1].

The spec for the operator is available here [2].

[1]
https://www.w3.org/TR/webnn/#biblio-numpy-broadcasting-rule
[2]
webmachinelearning/webnn#375 (comment)

Bug: 1273291
Change-Id: Ic8b16c293e175bc0865883fb2b4cc93473ddf039
Cq-Include-Trybots: luci.chromium.try:win11-blink-rel
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5022177
Commit-Queue: Junwei Fu <junwei.fu@intel.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1224053}
@philloooo
Copy link
Contributor

philloooo commented Nov 16, 2023

hi! I did an experiment to try to convert from the original pytorch whisper model to webnn by hand, here are some gaps I find needed for the whisper model, aside from the new ops you are proposing:

  • WebNN only supports 2d convolution. To achieve conv1d with conv2d we need to add a dimension to the input, and reshape it back later. Would be more efficient to support conv1d or generic conv ?
  • GELU activation function is not supported, I see in Transformer Models Analysis for adding the erf so we can compose GELU, should we support GELU as a high level op too? e.g. Apple coreml, onnx all support GELU. I know there is still ongoing conversation about how much high level ops we support, so I am curious to learn what's the decision process.
  • linear takes scaler value of alpha and beta, and do alpha * x + beta, torch.linear takes weights and bias tensors, which makes a lot more more sense, what do you think?
  • softmax only takes 2d input, should be generalized and takes in a dimension argument? e.g.: softmax used in whisper

@anssiko @wchao1115 @fdwr I am still new to this space, let me know if I misunderstood anything :)

@anssiko
Copy link
Member

anssiko commented Nov 17, 2023

@philloooo thanks for this experiment, @xenova may have insights on this given his recent work on whisper-tiny and also distil-whisper.

aarongable pushed a commit to chromium/chromium that referenced this issue Nov 22, 2023
This CL implements directml nodes for unary cast operator. The spec
for this operator can be found here
webmachinelearning/webnn#375 (comment)

Unit test is added for all data types the operator supports.

Bug: 1273291
Change-Id: I1a3753a1eb41dddf1fdbcef036124de1151b7edf
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5050358
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Sushanth Rajasankar <Sushraja@microsoft.com>
Reviewed-by: Alex Gough <ajgo@chromium.org>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Cr-Commit-Position: refs/heads/main@{#1228121}
@fdwr
Copy link
Collaborator

fdwr commented Nov 23, 2023

hi! I did an experiment to try to convert from the original pytorch whisper model to webnn by hand, here are some gaps I find needed for the whisper model, aside from the new ops you are proposing:

@philloooo Hi Yajing. Cool, you're trying out WebNN. Thanks for your observations.

  • "WebNN only supports 2d convolution" - That's okay, as conv1d is just a reshape away, and a reshape doesn't actually copy the tensor. It just lightweightly reinterprets the tensor description.
  • "The models GELU activation function is not supported." - True. It decomposes into {erf, mul, add}, and since I only saw a few occurrences of it (4x in whisper tiny), I didn't propose it my comment above, but we welcome new useful ops - the steps for adding a new is here: https://github.com/webmachinelearning/webnn/blob/main/CONTRIBUTING.md#proposing-and-adding-a-new-operation
  • "linear takes scaler value of alpha and beta" - So an elementwise fused multiply-add then?
  • "softmax only takes 2d input" - indeed 😉 Softmax axis absent #466 .

@philloooo
Copy link
Contributor

philloooo commented Nov 30, 2023

thanks! @fdwr
Another question is - for dynamically shaped inputs for the autoregressive decoders of transformer models, do we have a way to do that through webnn?

@fdwr
Copy link
Collaborator

fdwr commented Nov 30, 2023

thanks! @fdwr Another question is - for dynamically shaped inputs for the autoregressive decoders of transformer models, do we have a way to do that through webnn?

@philloooo: WebNN focuses on statically compiled graphs (at least as of this comment, meaning for a different size, you need to rebuild the graph), but one common technique in such cases is to round up. For example with Stable Diffusion, the text prompt is a variable size, but the token id input is padded with empty tokens and rounded up to 77 tokens, and so the same graph can work with multiple different prompts.

@anssiko
Copy link
Member

anssiko commented Dec 11, 2023

It is my honour to announce the WG has just reached another major milestone by adding support for operations needed for well-known transformers! 👏 🚀

This work happened in PR #478 and is now delivered as a new W3C Candidate Recommendation Draft published on 11 December 2023 at https://www.w3.org/TR/webnn/

For a summary of changes, see:

Thank you everyone, in particular the editors @wchao1115 @huningxin & co who diligently worked on this PR addressing in total 195 review comments, @fdwr @xenova for key contributions that helped shape and formulate the initial scope in #375 (comment) and #375 (comment), @Honry for the transformer models analysis, @BruceDai @mei1127 for WPT and webnn-baseline contributions, @wacky6 for continued careful review and comments also via Chromium CLs, @inexorabletash @zolkis for contributions that helped keep this PR aligned with the latest spec authoring conventions, @miaobin @shiyi9801 @RafaelCintron for all the implementation-informed insights, and all the other contributors whose GH handles escaped me right now -- your contributions are equally appreciated!

Please join us to celebrate this major milestone on our 14 December 2023 teleconference! 🥳 🍿

(We will keep this meta issue open for discussion on future enhancements.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

10 participants