-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tfjs-layers] Add a Transformer layer (Multihead Attention) #4955
Conversation
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it! |
padSize: number; | ||
} | ||
|
||
class TransformerLayer extends Layer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick comment: TransformerLayer doesn't seem to be in the standard TensorFlow (Python) API. We only add layer types that are already in TensorFlow (Python) to the TF.js layers API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @caisq, thanks for taking time to read my PR !
Would it be better if I respect the layer named MultiHeadAttention (which does the same thing) ?
It would only be minor changes.
However, I'm still not sure how to access the layers I need to make it build in tfjs
.
Thanks for the feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, implementing MultiHeadAttention will be good. Please note that we generally strive for high fidelity in replicating the behavior of Python. The implementation of MultiHeadAttention in Python uses the einsum op, which has been recently added to tfjs-core. But note that because the gradient of the einsum op hasn't been implemented in tfjs-core yet, the first implementation of MultiHeadAttention in tfjs most likely won't be able to support training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @caisq, thanks a lot for the hindsights.
Could we start with Dense
only and then create another PR when einsum will be stable for training ? Or is the python fidelity more important ?
In any case I will start to change the code to match the MultiHeadAttention
behaviour.
Note that I'm doing this on my free time so I will comment here when modifications are made.
Have a great day.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean by "start with Dense
". Dense
is already implemented in TF.js.
I still think there is value in adding MultiHeadAttention
to TF.js at this time, because there is no specialized code for training in the code for that layer. So when einsum gradient is supported in tfjs-core in the future, training for that layer should just work. Before the gradient is supported, the layer can still support inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I wasn't clear, I meant using classic fully connected layers to create the MultiheadAttention so we can train it and then refactor when einsum gradient will be available.
Then I understand that you still prefer to do everything with einsum from now on even if we cannot train it.
So I will change the code to use einsum, do you have any ETA on when the gradient will be available ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in response. But my take is to use the einsum implementation. This will involve less total amount of work. I don't have an ETA for the gradient of einsum yet. But einsum grad is not terribly hard to write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced the K.dot(a,b, bias)
with add(eisum('eq', a,b))
as required.
However I don't know if there is an equation to do the dot and the sum at once.
To be honest I find harder to read einsum than plain methods like .add
or .dot
why is this better ?
I also modified things to be more close to the python code (still not the same and I would like your thoughts about it).
The only pain point is the layerNormalisation
, I don't know how to get it and apply it in the tfjs-layers/core.ts
as it's another layer...
Then just back and forth to ensure everything is correct and match your guidelines/API requirements.
Note : I don't know which formater you're using and I'm sorry that the linting is so different.
Thanks again for your time on this PR.
Have a great day.
Hello there, I'm willing to continue to work on it but I don't know what to change to make it "google standards compatible". Thanks, |
Hey @ierezell! Thank you for pushing this issue, I'm sure many people are waiting for this to be added to the API. |
* [webgpu] Fix matmul_small_output program index issue * remove util.assert from matmul small size program * Change all a.shape and b.shape to a3dShape and b3dShape * Correct the the shapes checking for useVec4
This image bases off of debian:10 instead of debian:11 because gcc-10 in debian:11 segfaults on BatchMatMul and FusedMatMul cc tests in the wasm backend. Debian:10 uses gcc-8 which does not have this problem. This PR updates the dockerfile, but it does not update the cloudbuild files to use the release image because the image must be published first.
Make all cloudbuild CI steps use the release docker, as described in #5640.
This PR changes the name with WebGPU_ prefix to distinguish with webgl one. And use 1000 as the default threshold to enable USE model. Otherwise, it will complain readSync error.
PERF The total time of ScatterNd becomes 22.68 ms from 126.76 ms in USE-batchsize 30 model in benchmarks.
Make the default ts_library macro behave like pre-4.0 versions of rules_nodejs by automatically setting package_name to module_name if its not set. Temporarily mark webgl tests as not testonly to work around bazel-contrib/rules_nodejs#2984 until the next release. Add package_name to pkg_npm targets.
Downloading the TFLite Web API WASM module is now handled by Bazel.
…#5662) * Handle int32 * fix * address comments
PERF With this change, the total time of scatterND in USE(batchSize=30) is further reduced to less than 1ms, while the revious one is larger than 20ms.
FEATURE INTERNAL * add more bazel build files * remove the integration test directory * enumerate tests * require tests files * fix tests * fixed spyOn test errors * fix more tests * keep function name for esbuild * fix more tests * allow random flag for jasmine test * updated deps * merged * fixed bazel lint * fixed layers snippet tests * fixed tests * fix test failure * fix bazel tests * addressed comments * addressed comments * fix test * fix bazel failure
Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
Run the correct browsers in nightly and pass the correct parameters to them to match the original test-ci.sh file. Pin Firefox to 90. Fix square test comparison precision on Safari webgl1.
FEATURE INTERNAL * fix nightly layers tests * fixed webgl1 failed tests * fix failed webgl1 tests * restrict gpu tests that failed for webgl1 to webgl2 only * fix lint
) Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.10.0 to 6.12.6. - [Release notes](https://github.com/ajv-validator/ajv/releases) - [Commits](ajv-validator/ajv@v6.10.0...v6.12.6) --- updated-dependencies: - dependency-name: ajv dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [handlebars](https://github.com/wycats/handlebars.js) from 4.1.2 to 4.7.7. - [Release notes](https://github.com/wycats/handlebars.js/releases) - [Changelog](https://github.com/handlebars-lang/handlebars.js/blob/master/release-notes.md) - [Commits](handlebars-lang/handlebars.js@v4.1.2...v4.7.7) --- updated-dependencies: - dependency-name: handlebars dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…#6141) Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.10.2 to 6.12.6. - [Release notes](https://github.com/ajv-validator/ajv/releases) - [Commits](ajv-validator/ajv@v6.10.2...v6.12.6) --- updated-dependencies: - dependency-name: ajv dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [karma](https://github.com/karma-runner/karma) from 6.3.2 to 6.3.14. - [Release notes](https://github.com/karma-runner/karma/releases) - [Changelog](https://github.com/karma-runner/karma/blob/master/CHANGELOG.md) - [Commits](karma-runner/karma@v6.3.2...v6.3.14) --- updated-dependencies: - dependency-name: karma dependency-type: direct:development ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.12.1 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.12.1...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
Bumps [ua-parser-js](https://github.com/faisalman/ua-parser-js) from 0.7.20 to 0.7.31. - [Release notes](https://github.com/faisalman/ua-parser-js/releases) - [Commits](faisalman/ua-parser-js@0.7.20...0.7.31) --- updated-dependencies: - dependency-name: ua-parser-js dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [karma](https://github.com/karma-runner/karma) from 6.3.1 to 6.3.14. - [Release notes](https://github.com/karma-runner/karma/releases) - [Changelog](https://github.com/karma-runner/karma/blob/master/CHANGELOG.md) - [Commits](karma-runner/karma@v6.3.1...v6.3.14) --- updated-dependencies: - dependency-name: karma dependency-type: direct:development ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.3.0 to 6.12.3. - [Release notes](https://github.com/ajv-validator/ajv/releases) - [Commits](ajv-validator/ajv@v6.3.0...v6.12.3) --- updated-dependencies: - dependency-name: ajv dependency-type: direct:development ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…6148) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…6149) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Upgrade xnnpack and emsdk to support WASM development on Apple silicon.
…6107) * Update isnan implementation in WebGL backend to follow IEEE 754-1985 Previous implementation of isNaN in WebGL relies on the platform behaviour. And it might generate incorrect results(See #5800). This PR implements isnan based on the rules in IEEE 754-1985 to restrict the rules. Issue:#5800 * Only apply bit version isnan for WebGL2 * Fix 80-line exceeed issue
Wgsl removes inNaN from spec (gpuweb/gpuweb#2311) This CL implement isnan based on the rules in IEEE 754-1985
Hi @jaspermolgvits, I did this work as I needed Transformers in JS a moment ago...so my fork was totally usable. I just rebased with the latest master and I will re-check this code and provide an example here (or in my forked readme). Thanks to look at this PR after so long. Please tell me if anything isn't in accord with your guidelines. Have a great week. |
I'm sorry I had to recreate another PR as this one was too old and the history/changes were too big. It was easier for me just to fork again and create a new PR. The new PR is here: #6212 Have a great day. |
As the title says, this aim to add a transformer layer (based on multihead attention) in the built-in
tfjs-layers
.Note : This is only one layer, the full architecture with positional embeddings, Encoder Decoder would be built on this.
tf.add(a,b) => a.add(b)
as other layers because add is not imported)tfjs-layers/src/core.ts
likematMul
,sqrt
,softmax
,cast
,scalar
. I don't know the best way to integrate and use them.tf.layers.LayerNormalization
andtf.layers.Dropout
which I guess could be imported from the core (as the matMul).Please tell me how to use the core functionalities to be fully compatible with the other layers. I tried it myself as a custom tfjs layer and it seems to be working (so I wanted to make a PR to have it built-in). However it seems a bit slow.
Have a great day.
PS : Sorry for the linting modifications, just pay attention to the new layer at the bottom of the file.
I also commented the shape for each variable, but it will be removed, it just help for reviewing and debuging.
This change is