---
Paper: https://arxiv.org/pdf/2005.14165

---

---
# Section3: Hyperparameters, AdamW, gradient clipping
---

- Up until now, we have made all the changes in our NN to beter utilize our GPU. Now, we are going to make algorithm changes and improvement of actual optimization itself. 
- And, to do this we like to follow the hyper-parameters that are mentioned in GPT2 or GPT3 paper. There is not much to look in the GPT2 paper or the code that is relesed by GPT2. 
- So, we will look the appendix of GPT3 paper. 
    - B. Details of model training. 

**Change1**
- we change our code as per hyperparameter given in the paper. 

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    
    will changeed to 

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)




**Change2: Gradient Norm Clipping** : 
- we clip the global norm of the gradient at 1.0. 
- This is referening to once we have calculated the gradient after loss.backward()
- we have gradients at all the parameter tensors and what people like to do is basically clip them to have some kind of maximum norm 


norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

- norm: is the norm of the gradient. 
- what this function is doing is calculating the global norm of the parameters. so every single gradient on all the parameters you square it, add it all up and you take a big square root of that and that's the norm of parameter vector basically. It's bascially the length of it. And, we are basically making sure the length of that vector is not more than 1.0. we are going to clip it. 
- And, the reason people like to use is that sometimes you can get unlucky during the optimization may be it's a bad data batch or something like that and if you get very unlucky in a batch you might get really high loss and really high loss can lead to really light gradient and this could basically shock your model and shock the optimizations. 
- So, people like to use gradient norm clipping to prevent the model from getting too big of a shock in term of gradient magnitude and it's upper bounded in this way. 
- It's like hacky solution but people still do it quite frequently. 
- if the value of norm is well behaved then things are good and if it's climbing then things are bad and they are destabilizing during training and sometimes you can get a spike in the norm and means there is some kind of issue or an instability 





In [None]:
# step 0 | loss: 10.933823585510254 | norm: 29.4195 | dt: 13843.83ms | tok/sec: 591.7439477574981
# step 49 | loss: 5.918039321899414 | norm: 1.2791 | dt: 145.53ms | tok/sec: 56290.43590688744

--- 

**Change3: Learning rate Scheduler**

**Cosine Decay with warmup**


- They don't use a fixed learning rate like we are using up until now. 
- They are using a cosine decay for learning rate 
- And, they way this looks (look at the graph at https://scorrea92.medium.com/cosine-learning-rate-decay-e8b50aa455b)
- Learning rate starts at zero, linearly ramps up for some amount of time and comes down with this cosine sort of form to a minimum lr that is up to you. 
- In paper, they said "we use cosine decay for learning rate down to 10% of its value, over 260 billion tokens (after 260
billion tokens, training continues at 10% of the original learning rate). There is a linear LR warmup over the first 375
million tokens"
 - check the commit for code changes. 
 - There is cosine learning rate scheduler in pytorch as well but like the to write it's own code for this so that he fully understand it and it's just a couple of lines of code. 
 - so we learly warm up to max lr rate and then we start to decay it. 
 - One thing we are not following what they did is that, there training horizon is 300 Billion tokens and they come down to initial lr at 260 billion and they train after 260 at 10%. Bascially, there decay time is less than the max steps time whereas for us it's exactly equal. 

- What learning rate you use is totally up to you. Cosine lr has been popularized by GPT2 and GPT3 and this is active area of research.



In [None]:
# using device: cuda
# loaded 338025 tokens
# 1 epoch = 41 batches
# step    0 | loss: 10.933824 | lr 6.0000e-05 | norm: 29.4194 | dt: 13769564.39ms | tok/sec: 594.94
# step    1 | loss: 9.647558 | lr 1.2000e-04 | norm: 9.8480 | dt: 143231.39ms | tok/sec: 57194.17
# step    2 | loss: 8.992529 | lr 1.8000e-04 | norm: 5.8678 | dt: 142774.11ms | tok/sec: 57377.35
# step    3 | loss: 9.544296 | lr 2.4000e-04 | norm: 7.4843 | dt: 142538.31ms | tok/sec: 57472.27
# step    4 | loss: 8.959856 | lr 3.0000e-04 | norm: 4.2276 | dt: 143009.90ms | tok/sec: 57282.75
# step    5 | loss: 8.671993 | lr 3.6000e-04 | norm: 3.0261 | dt: 143142.22ms | tok/sec: 57229.79
# step    6 | loss: 8.600800 | lr 4.2000e-04 | norm: 3.4741 | dt: 143462.66ms | tok/sec: 57101.97
# step    7 | loss: 8.174708 | lr 4.8000e-04 | norm: 2.5680 | dt: 143914.94ms | tok/sec: 56922.51
# step    8 | loss: 7.772866 | lr 5.4000e-04 | norm: 2.7331 | dt: 149877.79ms | tok/sec: 54657.87
# step    9 | loss: 7.570984 | lr 6.0000e-04 | norm: 2.2930 | dt: 143419.27ms | tok/sec: 57119.24
# step   10 | loss: 7.380367 | lr 6.0000e-04 | norm: 2.0155 | dt: 143720.15ms | tok/sec: 56999.66
# step   11 | loss: 7.104572 | lr 5.9917e-04 | norm: 1.5674 | dt: 143507.00ms | tok/sec: 57084.32
# step   12 | loss: 6.973597 | lr 5.9668e-04 | norm: 1.1982 | dt: 143562.56ms | tok/sec: 57062.23
# step   13 | loss: 6.736616 | lr 5.9254e-04 | norm: 1.2940 | dt: 143640.52ms | tok/sec: 57031.26
# step   14 | loss: 6.641829 | lr 5.8679e-04 | norm: 0.8845 | dt: 143579.24ms | tok/sec: 57055.60
# step   15 | loss: 6.466947 | lr 5.7945e-04 | norm: 2.0647 | dt: 143687.96ms | tok/sec: 57012.43
# step   16 | loss: 6.580324 | lr 5.7057e-04 | norm: 1.2491 | dt: 146642.45ms | tok/sec: 55863.77
# step   17 | loss: 6.646888 | lr 5.6021e-04 | norm: 1.4116 | dt: 146014.45ms | tok/sec: 56104.04
# step   18 | loss: 6.574047 | lr 5.4843e-04 | norm: 1.4114 | dt: 144019.37ms | tok/sec: 56881.24
# step   19 | loss: 6.360060 | lr 5.3531e-04 | norm: 1.3833 | dt: 145360.95ms | tok/sec: 56356.26
# step   20 | loss: 6.478446 | lr 5.2092e-04 | norm: 1.7494 | dt: 147783.52ms | tok/sec: 55432.43
# step   21 | loss: 6.282598 | lr 5.0535e-04 | norm: 1.5478 | dt: 144555.81ms | tok/sec: 56670.16
# step   22 | loss: 6.409733 | lr 4.8870e-04 | norm: 1.1385 | dt: 148765.33ms | tok/sec: 55066.60
# step   23 | loss: 6.232314 | lr 4.7107e-04 | norm: 1.1030 | dt: 146510.36ms | tok/sec: 55914.13
# step   24 | loss: 6.251258 | lr 4.5258e-04 | norm: 1.2467 | dt: 144739.15ms | tok/sec: 56598.37
# step   25 | loss: 6.281051 | lr 4.3332e-04 | norm: 0.9518 | dt: 143994.33ms | tok/sec: 56891.13
# step   26 | loss: 6.606372 | lr 4.1343e-04 | norm: 1.1651 | dt: 145787.72ms | tok/sec: 56191.29
# step   27 | loss: 6.461005 | lr 3.9303e-04 | norm: 1.2547 | dt: 148341.42ms | tok/sec: 55223.96
# step   28 | loss: 6.730590 | lr 3.7224e-04 | norm: 1.1376 | dt: 144857.65ms | tok/sec: 56552.07
# step   29 | loss: 6.465397 | lr 3.5118e-04 | norm: 1.0399 | dt: 144737.48ms | tok/sec: 56599.02
# step   30 | loss: 6.425256 | lr 3.3000e-04 | norm: 0.9473 | dt: 145284.18ms | tok/sec: 56386.04
# step   31 | loss: 6.396163 | lr 3.0882e-04 | norm: 1.0962 | dt: 144153.36ms | tok/sec: 56828.37
# step   32 | loss: 6.242393 | lr 2.8776e-04 | norm: 1.1938 | dt: 144170.52ms | tok/sec: 56821.60
# step   33 | loss: 6.582739 | lr 2.6697e-04 | norm: 1.9381 | dt: 144218.21ms | tok/sec: 56802.81
# step   34 | loss: 6.419995 | lr 2.4657e-04 | norm: 1.4231 | dt: 145292.04ms | tok/sec: 56382.99
# step   35 | loss: 6.328524 | lr 2.2668e-04 | norm: 0.9130 | dt: 144602.54ms | tok/sec: 56651.84
# step   36 | loss: 6.373841 | lr 2.0742e-04 | norm: 1.0457 | dt: 146884.20ms | tok/sec: 55771.82
# step   37 | loss: 6.393122 | lr 1.8893e-04 | norm: 0.9558 | dt: 146669.39ms | tok/sec: 55853.51
# step   38 | loss: 6.166686 | lr 1.7130e-04 | norm: 0.9296 | dt: 144530.30ms | tok/sec: 56680.16
# step   39 | loss: 6.256642 | lr 1.5465e-04 | norm: 1.0896 | dt: 145495.41ms | tok/sec: 56304.18
# step   40 | loss: 6.440091 | lr 1.3908e-04 | norm: 1.1819 | dt: 144564.63ms | tok/sec: 56666.70
# step   41 | loss: 6.261865 | lr 1.2469e-04 | norm: 0.9343 | dt: 144588.95ms | tok/sec: 56657.17
# step   42 | loss: 6.312222 | lr 1.1157e-04 | norm: 1.1466 | dt: 146482.94ms | tok/sec: 55924.60
# step   43 | loss: 6.033404 | lr 9.9787e-05 | norm: 1.2428 | dt: 145900.96ms | tok/sec: 56147.68
# step   44 | loss: 5.987703 | lr 8.9428e-05 | norm: 0.9310 | dt: 144826.17ms | tok/sec: 56564.36
# step   45 | loss: 6.059750 | lr 8.0553e-05 | norm: 0.8940 | dt: 144172.91ms | tok/sec: 56820.66
# step   46 | loss: 6.136729 | lr 7.3215e-05 | norm: 0.7327 | dt: 145359.75ms | tok/sec: 56356.73
# step   47 | loss: 6.001136 | lr 6.7460e-05 | norm: 1.1086 | dt: 144611.84ms | tok/sec: 56648.20
# step   48 | loss: 5.896204 | lr 6.3324e-05 | norm: 0.8215 | dt: 146751.64ms | tok/sec: 55822.20
# step   49 | loss: 5.787439 | lr 6.0832e-05 | norm: 0.8931 | dt: 145449.88ms | tok/sec: 56321.81

--- 
Next: Gradual Batch size increase

--- 

- From the paper "We also gradually increase the batch size linearly from a small value (32k tokens) to the full value over
the first 4-12 billion tokens of training, depending on the model size"
- we are going to skip this. 
- Reason: It complicates a lot of the arithmetic because you are changing the number of tokens that you are processing at every single step of the optimization and Andrej like to keep that math very very simple. 
- Also, this is not like a alogorithm improvement. It's more of system and speed improvement. 

--- 

From paper: "Data are sampled without replacement during
training (until an epoch boundary is reached) to minimize overfitting. "

- we are already doing that, in our dataloader once the data has been draw it's is not eligible to drawn again until the next epoch. 
--- 



--- 
### Next: From Paper: "All models use weight decay of 0.1 to provide a small amount of regularization"


code changes from 
- addded a new function to configure_optimizers to class GPT. 

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), eps=1e-8)


to 

optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)




- We have create a new function in the GPT class. 
- It's common to not weight decay biases or any other sort of 1-dim tensors. like layernorm, weight decay, baises. 
- You want to weight decay the weights that participate in matric multiplications and embeddings.
- why does it make sense to decay the weights?
- You can think of this as an regularization. Because when you pulling down all the weights.You are forcing the optimization to use more of weights, and you are not allowing any one of the weight individually to be way too large. 
- You are forcing the keyword to kinda like distributing the work across more channels because they are like the pull of gravity on the weights themselves.  

---
## Next: Fused AdamW 

---

In [None]:
# 'fused' in inspect.signature(torch.optim.AdamW).parameters

- As previous version of AdamW didn't not had fused implementation therefore we are gurading it. 
- What is fused meaning here? 
- So, instead of iterating over all the parameters tensors in for loop and updating them. As, that would launch a lot of kernels.
- Fused means that all those kernels are fused into single kernel, you got rid of the a lot of overhead and you single time on all the parameters call a kernel that updates them. 
- Basically, it like the kernel fusion for AdamW optimizer.

In [None]:
#using device: cuda
# loaded 338025 tokens
# 1 epoch = 41 batches
# num decayed parameter tensors: 50, with 124,354,560 parameters
# num non-decayed parameter tensors: 98, with 121,344 parameters
# using fused AdamW: True
# step   49 | loss: 5.786849 | lr 6.0832e-05 | norm: 0.9179 | dt: 142815.59ms | tok/sec: 57360.68