read through "Attention is all you need" paper. Figure out details about multihead attention module.
- Finish
<YOUR CODE GOES HERE>
parts inmodel.py
; Read commends. They are helpful. - Use
einsum
andeinops
for this implementation Transformer-NMT-en-es.ipynb
notebook provides a workflow of NMT.train_tokenizer_en/es.py
scripts are for training Unigram Tokenizer from scratchtrain.py
script is for training the vanilla seq2seq transformer model.- Note: the
translate
function suffers performance issue / bug: try to identify and fix it. - Write down your learning and thoughts about Attention and Jax in a separate Markdown file.
submit your work via github:
- fork a "private" repo, named as "jax-transformer-CNetID", from this repo: How to do that?
- include a seperate Markdown file, document you thoughts, questions about attention and jax
- add the trained tokenizer.json in
vanilla-NMT
folder (train tokenizer and you will see it) - include ONE
ckpt/state-{timestamp}.pickle
file (train model and you will see it) - go to repo settings, and add collaborators: Oaklight
Deadline is Oct 10th, noon for this assignment. Timestamped by the collaborator invite email. I will respond to your invites ASAP.
- Attention is all you need: https://arxiv.org/abs/1706.03762
- Annotated transformer (Pytorch): https://nlp.seas.harvard.edu/2018/04/03/attention.html
- einsum:
- videos:
- Youtube: https://youtu.be/ULY6pncbRY8
- Bilibili: https://www.bilibili.com/video/BV1ee411g7Sv
- code snippets: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
- einops: https://github.com/arogozhnikov/einops/
- einsum: https://theaisummer.com/einsum-attention/
- videos:
- Jax:
- Jax: https://jax.readthedocs.io/en/latest/index.html
- Haiku: https://dm-haiku.readthedocs.io/en/latest/index.html
- Haiku101: Haiku库的基本使用逻辑 - 谷雨的文章 - 知乎 https://zhuanlan.zhihu.com/p/471892075
- reference implementations:
- haiku: this is cleaner but more functional & the transformer is not complete
- flax: this is more complex and easier to get lost
- elegy: perhaps use it for training loop
- https://nn.labml.ai/transformers/mha.html