Skip to content

shahashka/jax-transformer-release

 
 

Repository files navigation

TODO

read through "Attention is all you need" paper. Figure out details about multihead attention module.

  • Finish <YOUR CODE GOES HERE> parts in model.py; Read commends. They are helpful.
  • Use einsum and einops for this implementation
  • Transformer-NMT-en-es.ipynb notebook provides a workflow of NMT.
  • train_tokenizer_en/es.py scripts are for training Unigram Tokenizer from scratch
  • train.py script is for training the vanilla seq2seq transformer model.
  • Note: the translate function suffers performance issue / bug: try to identify and fix it.
  • Write down your learning and thoughts about Attention and Jax in a separate Markdown file.

What/How to submit

submit your work via github:

  • fork a "private" repo, named as "jax-transformer-CNetID", from this repo: How to do that?
  • include a seperate Markdown file, document you thoughts, questions about attention and jax
  • add the trained tokenizer.json in vanilla-NMT folder (train tokenizer and you will see it)
  • include ONE ckpt/state-{timestamp}.pickle file (train model and you will see it)
  • go to repo settings, and add collaborators: Oaklight

When to submit

Deadline is Oct 10th, noon for this assignment. Timestamped by the collaborator invite email. I will respond to your invites ASAP.

Resources about transformer in jax:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 79.5%
  • Python 20.5%