# Multi-part torch params save/loading

This proposal suggests to replace the currently used approach to saving/loading model param by requiring all of the data to be present in CPU memory, to instead loading or saving just one param at a time, allowing much smaller memory usage when CPU memory is tight.

I'm proposing to use DBM which is a python built-in module and requires no extra handling. But this can be any other simple db interface that is built in.

This is a rough prototype and doesn't pretend to be complete.

Alternative solutions: 
- [multi-file checkpoint splitting](https://github.com/finetuneanon/transformers/#checkpoint-splitting)

credits: The main class has been inspired by [SplitCheckpoint](https://github.com/finetuneanon/transformers/blob/ca5d90ac1965982db122a649c2c9c902bde74a03/src/transformers/modeling_utils.py#L417-L443)

Here is the corresponding [Pytorch RFC](https://github.com/pytorch/pytorch/issues/64327).

In [7]:
import dbm
import pickle
import torch
from torch import nn
from collections.abc import MutableMapping

In [2]:
class DBMStateDict(MutableMapping):
    def __init__(self, path):
        self.path = path
        self.db = dbm.open(path, 'c')
        
    def __len__(self):
        return len(self.db.keys())
    
    def __getitem__(self, key):
        return pickle.loads(self.db[key])
    
    def __setitem__(self, key, value):
        self.db[key] = pickle.dumps(value)
        # it looks like dbm syncs immediately
    
    def __delitem__(self, key):
        return self.db.pop(key)
    
    def keys(self):
        return [k.decode() for k in self.db.keys()]
    
    def __iter__(self):    
        return iter(self.db)
            
    def copy(self):
        return DBMStateDict(self.path)
    
    def __del__(self):
        self.db.close()
                            

In [3]:
def save_new(sd_dict, path):
    sd = DBMStateDict(path)
    for k,v in sd_dict.items():
        sd[k] = v

def load_new(path):
    # this doesn't load the whole sd into memory!
    return DBMStateDict(path)

In [None]:
class SubNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(1, 1)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = SubNet()

In [4]:
m = Net()
m

Net(
  (net): SubNet(
    (fc1): Linear(in_features=1, out_features=1, bias=True)
    (fc2): Linear(in_features=1, out_features=1, bias=True)
  )
)

In [5]:
# original
m = Net()
path = "model1.pt"
sd_dict = m.state_dict()
torch.save(sd_dict, path)
sd_dict = torch.load(path)
m.load_state_dict(sd_dict)

<All keys matched successfully>

In [6]:
# same but loading / saving one key at a time   
m = Net()
path = "model1.dbm"
sd_dict = m.state_dict()
save_new(sd_dict, path)
sd_new = load_new(path)
m.load_state_dict(sd_new)

<All keys matched successfully>