# __Saving and loading a general checkpoint in pytorch__

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html


## ___Intro___

To save multiple checkpoints, you must
- organize them in a dictionary
- use `torch.save()` to serialize the dictionary

A common PyTorch convention is to:
- save these checkpoints using the `.tar`` file extension.
- load the items by:
  - initialize the model and optimizer
  - load the dictionary locally using `torch.load()`
  - access the saved items by simply querying the dictionary as you would expect.

## ___Steps___

1. Import all necessary libraries for loading our data
1. Define and initialize the neural network
1. Initialize the optimizer
1. Save the general checkpoint
1. Load the general checkpoint

### Import necessary libraries for loading our data

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

### Define and initialize the neural network

In [2]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net = Net()
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


### Initialize the optimizer

In [3]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

### Save the general checkpoint

In [4]:
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

### Load the general checkpoint

In [5]:
# Intialize model and optimizer first
model     = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Load checkpoint
checkpoint = torch.load(PATH)

# Load model state dict, optimizer, epoch, and loss
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [6]:
model.train()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

## ___Play with my own model___

### Import necessary libraries for loading our data

In [18]:
import yaml
from pathlib import Path
from transformers import BertTokenizerFast, BertConfig, BertForMaskedLM, \
                         DataCollatorForLanguageModeling, TrainingArguments, \
                         Trainer
from tokenizers import BertWordPieceTokenizer
from torchinfo import summary

### Three steps skipped

- Define and initialize the neural network
- Initialize the optimizer
- Save the general checkpoint

### Load the general checkpoint

- [how to continue training from a checkpoint with Trainer?](https://github.com/huggingface/transformers/issues/7198)
  - Ok, none of the following is necessary.
  - Just do: `trainer.train(checkpoint-dir)`

In [19]:
work_dir  = Path("/home/shius/projects/plantbert/")
model_dir = work_dir / "models"
ckpt_dir  = model_dir / "checkpoint-11500" 

config_file = "./config.yaml"

with open(config_file, 'r') as f:
  config = yaml.safe_load(f)

In [24]:
checkpoint = torch.load(ckpt_dir)

IsADirectoryError: [Errno 21] Is a directory: '/home/shius/projects/plantbert/models/checkpoint-11500'

In [20]:
# initialize the model with the config
vocab_size   = config['tokenize']['vocab_size']
max_length   = config['tokenize']['max_length']
model_config = BertConfig(vocab_size=vocab_size, 
                          max_position_embeddings=max_length)

model = BertForMaskedLM(config=model_config)

In [27]:
optimizer = torch.load(ckpt_dir / "optimizer.pt")
optimizer.keys()

dict_keys(['state', 'param_groups'])

: 

In [None]:
# Intialize model and optimizer first
model     = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Load checkpoint
checkpoint = torch.load(PATH)

# Load model state dict, optimizer, epoch, and loss
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
model.train()

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)