Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tfjs-layers] Add a Transformer layer (Multihead Attention) #4955

Closed
wants to merge 583 commits into from
Closed

[tfjs-layers] Add a Transformer layer (Multihead Attention) #4955

wants to merge 583 commits into from

Conversation

ierezell
Copy link

@ierezell ierezell commented Apr 19, 2021

As the title says, this aim to add a transformer layer (based on multihead attention) in the built-in tfjs-layers.

Note : This is only one layer, the full architecture with positional embeddings, Encoder Decoder would be built on this.

  • Refactor code to use tensor built-in operation to be compatible (like tf.add(a,b) => a.add(b) as other layers because add is not imported)
  • Some operations are not avalaible in the tfjs-layers/src/core.ts like matMul, sqrt,softmax,cast, scalar. I don't know the best way to integrate and use them.
  • Need the tf.layers.LayerNormalization and tf.layers.Dropout which I guess could be imported from the core (as the matMul).

Please tell me how to use the core functionalities to be fully compatible with the other layers. I tried it myself as a custom tfjs layer and it seems to be working (so I wanted to make a PR to have it built-in). However it seems a bit slow.

Have a great day.

PS : Sorry for the linting modifications, just pay attention to the new layer at the bottom of the file.
I also commented the shape for each variable, but it will be removed, it just help for reviewing and debuging.


This change is Reviewable

@google-cla
Copy link

google-cla bot commented Apr 19, 2021

Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

📝 Please visit https://cla.developers.google.com/ to sign.

Once you've signed (or fixed any issues), please reply here with @googlebot I signed it! and we'll verify it.


What to do if you already signed the CLA

Individual signers
Corporate signers

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added the cla: no label Apr 19, 2021
@ierezell
Copy link
Author

@googlebot I signed it!

@google-cla google-cla bot added cla: yes and removed cla: no labels Apr 19, 2021
@ierezell ierezell changed the title Add a Transformer layer (Multihead Attention) [tfjs-layers] Add a Transformer layer (Multihead Attention) Apr 19, 2021
padSize: number;
}

class TransformerLayer extends Layer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick comment: TransformerLayer doesn't seem to be in the standard TensorFlow (Python) API. We only add layer types that are already in TensorFlow (Python) to the TF.js layers API.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @caisq, thanks for taking time to read my PR !

Would it be better if I respect the layer named MultiHeadAttention (which does the same thing) ?

It would only be minor changes.

However, I'm still not sure how to access the layers I need to make it build in tfjs.

Thanks for the feedback

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, implementing MultiHeadAttention will be good. Please note that we generally strive for high fidelity in replicating the behavior of Python. The implementation of MultiHeadAttention in Python uses the einsum op, which has been recently added to tfjs-core. But note that because the gradient of the einsum op hasn't been implemented in tfjs-core yet, the first implementation of MultiHeadAttention in tfjs most likely won't be able to support training.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @caisq, thanks a lot for the hindsights.

Could we start with Dense only and then create another PR when einsum will be stable for training ? Or is the python fidelity more important ?

In any case I will start to change the code to match the MultiHeadAttention behaviour.

Note that I'm doing this on my free time so I will comment here when modifications are made.

Have a great day.

Copy link
Contributor

@caisq caisq Apr 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by "start with Dense". Dense is already implemented in TF.js.

I still think there is value in adding MultiHeadAttention to TF.js at this time, because there is no specialized code for training in the code for that layer. So when einsum gradient is supported in tfjs-core in the future, training for that layer should just work. Before the gradient is supported, the layer can still support inference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I wasn't clear, I meant using classic fully connected layers to create the MultiheadAttention so we can train it and then refactor when einsum gradient will be available.

Then I understand that you still prefer to do everything with einsum from now on even if we cannot train it.

So I will change the code to use einsum, do you have any ETA on when the gradient will be available ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in response. But my take is to use the einsum implementation. This will involve less total amount of work. I don't have an ETA for the gradient of einsum yet. But einsum grad is not terribly hard to write.

Copy link
Author

@ierezell ierezell Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced the K.dot(a,b, bias) with add(eisum('eq', a,b)) as required.
However I don't know if there is an equation to do the dot and the sum at once.

To be honest I find harder to read einsum than plain methods like .add or .dot why is this better ?

I also modified things to be more close to the python code (still not the same and I would like your thoughts about it).

The only pain point is the layerNormalisation, I don't know how to get it and apply it in the tfjs-layers/core.ts as it's another layer...

Then just back and forth to ensure everything is correct and match your guidelines/API requirements.

Note : I don't know which formater you're using and I'm sorry that the linting is so different.

Thanks again for your time on this PR.

Have a great day.

@ierezell
Copy link
Author

Hello there,
Sorry to bump but is there any news on merging this PR to have MultiHead attention in JS ?

I'm willing to continue to work on it but I don't know what to change to make it "google standards compatible".

Thanks,
have a great day

@jaspermolgvits
Copy link

Hey @ierezell! Thank you for pushing this issue, I'm sure many people are waiting for this to be added to the API.
Would you happen to have a working fork of TFJS that people interested in transformers could use, while Google figures itself out? Maybe if people start using your fork, it might expedite their attention (layers), too.

haoyunfeix and others added 21 commits February 20, 2022 20:45
* [webgpu] Fix matmul_small_output program index issue

* remove util.assert from matmul small size program

* Change all a.shape and b.shape to a3dShape and b3dShape

* Correct the the shapes checking for useVec4
This image bases off of debian:10 instead of debian:11 because gcc-10 in debian:11 segfaults on BatchMatMul and FusedMatMul cc tests in the wasm backend. Debian:10 uses gcc-8 which does not have this problem.

This PR updates the dockerfile, but it does not update the cloudbuild files to use the release image because the image must be published first.
Make all cloudbuild CI steps use the release docker, as described in #5640.
This PR changes the name with WebGPU_ prefix to distinguish with webgl one.
And use 1000 as the default threshold to enable USE model. Otherwise,
it will complain readSync error.
PERF

The total time of ScatterNd becomes 22.68 ms from 126.76 ms in
USE-batchsize 30 model in benchmarks.
Make the default ts_library macro behave like pre-4.0 versions of rules_nodejs by automatically setting package_name to module_name if its not set.

Temporarily mark webgl tests as not testonly to work around bazel-contrib/rules_nodejs#2984 until the next release.

Add package_name to pkg_npm targets.
Downloading the TFLite Web API WASM module is now handled by Bazel.
PERF

With this change, the total time of scatterND in USE(batchSize=30) is
further reduced to less than 1ms, while the revious one is larger than
20ms.
FEATURE
INTERNAL
* add more bazel build files

* remove the integration test directory

* enumerate tests

* require tests files

* fix tests

* fixed spyOn test errors

* fix more tests

* keep function name for esbuild

* fix more tests

* allow random flag for jasmine test

* updated deps

* merged

* fixed bazel lint

* fixed layers snippet tests

* fixed tests

* fix test failure

* fix bazel tests

* addressed comments

* addressed comments

* fix test

* fix bazel failure
Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
Run the correct browsers in nightly and pass the correct parameters to them to match the original test-ci.sh file.

Pin Firefox to 90.

Fix square test comparison precision on Safari webgl1.
FEATURE
INTERNAL
* fix nightly layers tests

* fixed webgl1 failed tests

* fix failed webgl1 tests

* restrict gpu tests that failed for webgl1 to webgl2 only

* fix lint
dependabot bot and others added 23 commits February 20, 2022 20:45
)

Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.10.0 to 6.12.6.
- [Release notes](https://github.com/ajv-validator/ajv/releases)
- [Commits](ajv-validator/ajv@v6.10.0...v6.12.6)

---
updated-dependencies:
- dependency-name: ajv
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [handlebars](https://github.com/wycats/handlebars.js) from 4.1.2 to 4.7.7.
- [Release notes](https://github.com/wycats/handlebars.js/releases)
- [Changelog](https://github.com/handlebars-lang/handlebars.js/blob/master/release-notes.md)
- [Commits](handlebars-lang/handlebars.js@v4.1.2...v4.7.7)

---
updated-dependencies:
- dependency-name: handlebars
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…#6141)

Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.10.2 to 6.12.6.
- [Release notes](https://github.com/ajv-validator/ajv/releases)
- [Commits](ajv-validator/ajv@v6.10.2...v6.12.6)

---
updated-dependencies:
- dependency-name: ajv
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [karma](https://github.com/karma-runner/karma) from 6.3.2 to 6.3.14.
- [Release notes](https://github.com/karma-runner/karma/releases)
- [Changelog](https://github.com/karma-runner/karma/blob/master/CHANGELOG.md)
- [Commits](karma-runner/karma@v6.3.2...v6.3.14)

---
updated-dependencies:
- dependency-name: karma
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.12.1 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.12.1...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
Bumps [ua-parser-js](https://github.com/faisalman/ua-parser-js) from 0.7.20 to 0.7.31.
- [Release notes](https://github.com/faisalman/ua-parser-js/releases)
- [Commits](faisalman/ua-parser-js@0.7.20...0.7.31)

---
updated-dependencies:
- dependency-name: ua-parser-js
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.13.3 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.13.3...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [karma](https://github.com/karma-runner/karma) from 6.3.1 to 6.3.14.
- [Release notes](https://github.com/karma-runner/karma/releases)
- [Changelog](https://github.com/karma-runner/karma/blob/master/CHANGELOG.md)
- [Commits](karma-runner/karma@v6.3.1...v6.3.14)

---
updated-dependencies:
- dependency-name: karma
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [ajv](https://github.com/ajv-validator/ajv) from 6.3.0 to 6.12.3.
- [Release notes](https://github.com/ajv-validator/ajv/releases)
- [Commits](ajv-validator/ajv@v6.3.0...v6.12.3)

---
updated-dependencies:
- dependency-name: ajv
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…6148)

Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…6149)

Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.14.7 to 1.14.8.
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](follow-redirects/follow-redirects@v1.14.7...v1.14.8)

---
updated-dependencies:
- dependency-name: follow-redirects
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Upgrade xnnpack and emsdk to support WASM development on Apple silicon.
…6107)

* Update isnan implementation in WebGL backend to follow IEEE 754-1985

Previous implementation of isNaN in WebGL relies on the platform behaviour.
And it might generate incorrect results(See #5800).
This PR implements isnan based on the rules in IEEE 754-1985 to restrict
the rules.

Issue:#5800

* Only apply bit version isnan for WebGL2

* Fix 80-line exceeed issue
Wgsl removes inNaN from spec (gpuweb/gpuweb#2311)
This CL implement isnan based on the rules in IEEE 754-1985
@google-cla google-cla bot added cla: no and removed cla: yes labels Feb 21, 2022
@ierezell
Copy link
Author

ierezell commented Feb 21, 2022

Hi @jaspermolgvits, I did this work as I needed Transformers in JS a moment ago...so my fork was totally usable.

I just rebased with the latest master and I will re-check this code and provide an example here (or in my forked readme).
I guess many things have changed in almost one year.

Thanks to look at this PR after so long. Please tell me if anything isn't in accord with your guidelines.

Have a great week.

@ierezell
Copy link
Author

ierezell commented Mar 10, 2022

I'm sorry I had to recreate another PR as this one was too old and the history/changes were too big. It was easier for me just to fork again and create a new PR.

The new PR is here: #6212
it still needs a bit of love and insights from someone knowing the best practices of tfjs and I will be happy to put a few more hours to close it. (for example the best way to add missing ops and linting)

Have a great day.

This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.