# Welcome!

### Purpose:
Implement a research paper!

[Memorizing Transformers](https://arxiv.org/pdf/2203.08913.pdf) from Google by Wu, Rabe, Hutchins, Szegedy (seh-guh-dee)
- [Google implementation
](https://github.com/google-research/meliad/tree/main/transformer)
- [Pytorch implementation - lucidrains](https://github.com/lucidrains/memorizing-transformers-pytorch)


### Goal:


"What I cannot create, I do not understand." - Richard Feynman

- Teach you to implement this paper.
- Give you a **process/framework** that enables you to implement new research papers and, more importantly, your own ideas
- Lots of ML concepts along the way
- PyTorch practice

### Why learn to implement papers?
- Put your own ideas into the world (research, tool, startup)
- Relevant right now (2023): open source LLMs like LLaMA, quantized models, parameter efficient fine-tuning methods...You can do a lot of interesting things without needing to spend millions of dollars on GPUs (would not have said so just a year or two ago...)
- Valuable (career) skill to have


### Why this video series?
- There's nowhere else to go! Very few resources cover this specific skill despite it being so valuable (Karpathy videos are by far the best resource I've seen)
- Implementing a research paper can be a huge pain, and there's no manual for this stuff: it's either sit next to someone who's done it before, your advisor/coworker, or DIY.
- I learned how to do this just diving in by myself. I want to share some tips and stupid mistakes I made to help save you time and make your life easier
- Lots of resources on: how to apply a model using a high-level library, how do deep learning models work, how to build a toy model...









### Format (Karpathy insipred!)
- Build piece by piece and explain what we do along the way
- To present the problems and dead ends we run into instead of just presenting the finished solution. So lots of fumbling and looking at documentation and testing things; **seeing the finished code doesn't teach you nearly as much as seeing the process**
- **For you: pause and try to implement each piece before you watch me do it. MAKE IT CHALLENGING FOR YOURSELF**

### Prerequisites
- I will explain as we go, but this is a fairly advanced series
- Familiar with Python / PyTorch
- Familiar with deep learning (Transformers)
- If you understood and enjoyed [Andrej Karpathy's YouTube lectures](https://www.youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ) then this is a great place to continue







# Framework for implementing research paper, step by step
### High level plan

1. Pick a paper
2. Identify paper idea, how they will implement it, how they will test it
3. Identify components of the paper
4. Have a (rough) pseudocode understanding of how things work and fit together
5. Build the individual components
6. Assemble them into a model
7. Run and test the model




### Tips for picking a paper
- Paper selection is important!
- Pick something you're interested in: this will probably take a while!
- You have to hunt for details: details are often missing from a paper: they are forgotten, or it's assumed you know, or they're somewhere in the code, or you'll have to look at reference papers to see the details. (A model that tries something new on top of the an old established architecture won't elaborate the details of that architecture - you'll have to go look them up.)
- If possible **choose something that has an existing implementation** so that you can cross-reference your work and find hidden details.
    - Obviously, if an implementation already exists there's little need to write a new one. However, right now, for the purposes of practicing and developing your skills, having something to work against is extremely helpful. This is your "labeled" data :)





Be aware of the limitations of reading the source code:
- The code can contain lots of performance tricks that will obscure the main ideas.
- Or it will be part of an existing library with a lot of weight, making it very hard to see/understand the core ideas. (e.g. try reading huggingface code)
- Not guaranteed to have good / any comments or documentation




# How to read a paper effectively
([Video presentation of Memorizing Transformers by author Yuhuai Wu](https://www.youtube.com/watch?v=5AoOpFFjW28&t=2973s))

### Multiple passes approach
Read the paper multiple times with increasing levels of detail. Aim for a very brief, high-level first pass. With each successive reading, fill in more details. This is easier and more efficient than going line by line, top to bottom, trying to digest every detail and equation. You can usually get the main point of the paper in a few minutes (abstract, diagrams, results). In later passes you can read more of the introduction, architecture details, experiments, related works, etc. and deepen your understanding.

### Checklist questions to make sure you understand
- What's the idea in a sentence or two?
- What's the motivation as framed by the authors? What's the problem that they are solving?
- How do they attempt to solve it?
- What is the main contribution of the paper?
- How do they measure success?
- Were they successful?
- Keep a running list of questions and knowledge gaps for yourself




**MEMORIZING TRANSFORMERS**

- **What's the idea in a sentence or two?**

Language models would benefit from a long-term memory mechanism. Approximate kNN lookup into non-differentiable memory of recent key-value pairs improves language modeling.

- **What's the motivation as framed by the authors? What's the problem they are solving?**

Attending to far-away tokens is important in many situations. However, transformer performance is limited by context window size: transformer is quadratic (O(n^2)) time and space complexity, so doubling the context window quadruples the compute requirements. This is why we cannot just naively make the context window bigger. There are many solutions that try to resolve this problem and create long-range attention.

- **How do they attempt to solve it?**

Store key-value pairs at one layer of the network. On future passes through the model, perform approximate kNN search into the stored key-value pairs by using the current query projection. Perform regular QKV attention at this layer, and also perform QKV attention a second time using the current Q and the stored/retrieved KV. Then combine these two QKVs and pass them into the next layer as usual.  

- **What is the main contribution of the paper?**

New approach: 1) we retrieve the exact historical key-value matrices rather than averaged or summarized versions, 2) gradients do not flow back into the kNN memory, making it fast and scalable.

Improvement to language modeling: this approach allows for better performance with fewer parameters than vanilla transformers (on long-range datasets). This approach can also be integrated into existing architecture and even existing pre-trained models (via fine-tuning)

- **How will they measure success?**

Does this improve the performance on datasets designed to test long-range capability (long documents)? -> perplexity / performance on long range datasets.

- **Were they successful?**

Main result: a memorizing transformer with 1 billion parameters has the same perplexity (on arXiv-math dataset) as a vanilla transformer with 8 billion parameters.

Other: increasing the memory size provided consistent benefits up to 262,000 tokens

Other: a vanilla transformer that is fine-tuned with external memory performs as well as a memorizing transform that is pretrained from scratch - instead of needing to pretrain a memorizing transformer you can just fine-tune an existing vanilla one with memory




- **What are the question / knowledge gaps for me?**

Question: did they outperform other models on long document datasets? Is this approach better? No comparison to other long-range models

Question: all metrics are perplexity: can we measure new capability that memory would unlock? E.g. a classification accuracy on a task that requires answering questions about a novel, very long document for which regular transformers would score very very low.

Question: how does external memory compare to expanding the context window? In terms of accuracy and in terms of compute.

Question: how does this compare to other models for "regular" length documents?


### Components to build
- Vanilla decoder model
- XL attention layer
- KNN augmented layer
- KNN memory (add/remove functionality)
- T5 relative position embedding scheme
-


### Tips for building effectively
- Try to build the simplest thing possible, then add in complexity / features
- Make things work first, focus on performance last
- How will you test it? Work out the metric and goal before you start anything else.
- Make testing and benchmarking repeatable and fast so that you can quickly measure progress

# Papers to code resources
(Feel free to submit additional resources)

### Main resources
- fast.ai course
- Andrej Karpathy YouTube series (highly recommended)
- https://nn.labml.ai/ - simple, readable, implementations
- https://github.com/lucidrains - hundreds of high quality PyTorch implementations (highly recommended)

### Blogs / forums with advice
- https://jsatml.blogspot.com/2014/10/beginner-advice-on-learning-to.html
- https://blog.briankitano.com/llama-from-scratch/
- https://www.reddit.com/r/MachineLearning/comments/2h94uj/comment/ckqrn1t/
- https://machinelearningmastery.com/dont-start-with-open-source-code-when-implementing-machine-learning-algorithms/
- https://www.reddit.com/r/MachineLearning/comments/ilqa9a/d_how_do_you_approach_implementing_research_papers/
-https://www.reddit.com/r/deeplearning/comments/i86y6v/how_to_start_implementing_papers/
- https://www.reddit.com/r/MachineLearning/comments/y0dk5c/d_recent_ml_papers_to_implement_from_scratch/
- https://news.ycombinator.com/item?id=34503362