## torch.no_grad() vs. param.requires_grad == False
- torch.no_grad() is a context manager used to prevent calculating gradient.
  - Not store gradient at all,
  - Likely use it to inference, not training.
  - This doesn't change requires_grad.

- requires_grad is to freeze part of model and train the rest,
    - disable parts of network
    - used more on layer or module

In [2]:
from transformers import BertModel
model_name = 'bert-base-uncased'
bert = BertModel.from_pretrained(model_name)

In [3]:
def count_learnable_params(model):

    total_learnable_params = 0
    for name, params in model.named_parameters():
        
        if params.requires_grad:
            total_learnable_params +=params.numel()
            
    return total_learnable_params
    

In [4]:
count_learnable_params(bert)

109482240

In [6]:
import torch

with torch.no_grad():
    print(count_learnable_params(bert))

109482240


In [9]:
for name, params in bert.named_parameters():
        
        if params.requires_grad:
            params.requires_grad = False

count_learnable_params(bert)

0