# Proof of concept for zero-copy loading from `.h5` files

This notebook demonstrates loading the weights of PyTorch models directly from a memory-mapped file, without copying the weights into local process memory. Loading models this way should make loading much faster and allow for multiple processes to share the memory containing the weights.

In [1]:
# Uncomment if you need to install h5py
#!pip install h5py

In [2]:
# imports go here
import h5py
import io
import mmap
import os
import psutil
import numpy as np
import pickle
import torch
import transformers
import zipfile

from typing import Dict, Tuple, Union

## Serialize the model

We start by loading a copy of `bert-base-uncased`, then writing it to disk with PyTorch's built-in serialization. This operation creates a file `outputs/bert.pt`, which is actually a zip archive with a pickled graph of Python objects plus a single file per tensor containing tensor data.

In [3]:
bert = transformers.BertModel.\
          from_pretrained("bert-base-uncased")
torch.save(bert, "outputs/bert.pt")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Convert to HDF5 format

PyTorch's zipfile-based format is not amenable to zero-copy loading, because model weights are stored compressed. So we convert the zip file to [HDF5 format](https://www.hdfgroup.org/solutions/hdf5/). HDF5 stores binary data directly inside the file, uncompressed. 

In [4]:
h5_file_name = 'outputs/bert.h5'

if os.path.exists(h5_file_name):
    os.unlink(h5_file_name)

with zipfile.ZipFile('outputs/bert.pt', 'r') as zip_file:
    with h5py.File(h5_file_name, 'w') as h5_file:
        for info in zip_file.infolist():
            with zip_file.open(info.filename, 'r') as f:
                file_data = f.read()
                print(f'Copying {len(file_data)} bytes for "{info.filename}"')
                dataset = h5_file.create_dataset(
                    info.filename, data=np.frombuffer(file_data, dtype=np.byte))

Copying 60813 bytes for "archive/data.pkl"
Copying 4096 bytes for "archive/data/0"
Copying 4096 bytes for "archive/data/1"
Copying 3072 bytes for "archive/data/10"
Copying 3072 bytes for "archive/data/100"
Copying 3072 bytes for "archive/data/101"
Copying 3072 bytes for "archive/data/102"
Copying 2359296 bytes for "archive/data/103"
Copying 3072 bytes for "archive/data/104"
Copying 2359296 bytes for "archive/data/105"
Copying 3072 bytes for "archive/data/106"
Copying 2359296 bytes for "archive/data/107"
Copying 3072 bytes for "archive/data/108"
Copying 2359296 bytes for "archive/data/109"
Copying 2359296 bytes for "archive/data/11"
Copying 3072 bytes for "archive/data/110"
Copying 3072 bytes for "archive/data/111"
Copying 3072 bytes for "archive/data/112"
Copying 9437184 bytes for "archive/data/113"
Copying 12288 bytes for "archive/data/114"
Copying 9437184 bytes for "archive/data/115"
Copying 3072 bytes for "archive/data/116"
Copying 3072 bytes for "archive/data/117"
Copying 3072 byte

In [5]:
!ls -lh outputs/bert*

-rw-r--r--  1 freiss  staff   418M Sep  2 15:13 outputs/bert.h5
-rw-r--r--  1 freiss  staff   418M Sep  2 15:13 outputs/bert.pt


## Read tensors from HDF5 files without copying data

If you write data to an HDF5 file with the `h5py` library's default settings, the data for each dataset (i.e. multidimensional array) will end up in a contiguous range of the file on disk. If we later memory-map the file, we can access this data directly without copying.

Let's start with the process for loading a single tensor.
Some of the code here is inspired by [this gist](https://gist.github.com/maartenbreddels/09e1da79577151e5f7fec660c209f06e) from 
Maarten Breddels.

In [6]:
# Pick the first tensor in the file that isn't filled with zeros
tensor_data_loc = 'archive/data/2'
original_tensor_name = 'embeddings.word_embeddings.weight'
tensor_shape = bert.state_dict()[original_tensor_name].shape

# Read the offset and size of the array
with h5py.File(h5_file_name, 'r') as h5_file:
    dataset = h5_file[tensor_data_loc]
    offset = dataset.id.get_offset()
    length = dataset.id.shape[0]

# Memory-map the file once.
raw_file = open(h5_file_name,'rb')
mmap_buffer = mmap.mmap(raw_file.fileno(), 0, access=mmap.ACCESS_READ)

# Wrap a portion of the memory-mapped buffer in a Numpy array, without
# copying data.
raw_array = np.frombuffer(mmap_buffer, dtype=np.byte, count=length, 
                          offset=offset)

# Add shape and dtype information back in, without copying data
array_view = raw_array.view(np.float32).reshape(tensor_shape)

# Wrap the Numpy array in a PyTorch tensor, without copying data
tensor = torch.as_tensor(array_view)

# Compare the resulting data against the original
print('Original tensor data:')
print(bert.state_dict()[original_tensor_name])
print('Data after zero-copy loading:')
print(tensor)

Original tensor data:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])
Data after zero-copy loading:
tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])


  tensor = torch.as_tensor(array_view)


We can power up this approach to read all of the weights at once, without copying data.

But first we'll need to work around a limitation of the file format: Reading offset metadata for the datasets inside an HDF5 file from Python code takes a surprisingly long time. If we use the `h5py` library's Pythonic API...

In [7]:
%%time

with h5py.File(h5_file_name, 'r') as h5_file:
    # PyTorch stores all tensors under the prefix "archive/data"
    tensor_data_group = h5_file['archive/data']
    name_to_offset_and_len = {
        name: (data.id.get_offset(), data.id.shape[0])
        for name, data in tensor_data_group.items()
    }

CPU times: user 26.9 ms, sys: 3.63 ms, total: 30.5 ms
Wall time: 28.9 ms


...then reading this data takes 25-30 msec. That's much too long. 

If we drop down to `h5py`'s low-level API...

In [8]:
%%time 

# Measure how long it takes to read offsets with h5py's low-level API
name_to_offset_and_len = {}
file_id = h5py.h5f.open(h5_file_name.encode('utf8'), h5py.h5f.ACC_RDONLY)
group_id = h5py.h5g.open(file_id, b'archive/data')
for group_name in group_id:
    dataset_id = h5py.h5d.open(group_id, group_name)
    offset = dataset_id.get_offset()
    length = dataset_id.shape[0]
    name_to_offset_and_len[group_name.decode('utf8')] = \
        (offset, length)

file_id.close()

CPU times: user 10.1 ms, sys: 2.22 ms, total: 12.3 ms
Wall time: 11.4 ms


...then reading the metadata takes 8-10 msec, which is still too long.

So instead we'll stuff all the lengths and offsets into an HDF5 dataset and store that dataset in the file.

In [9]:
if os.path.exists(h5_file_name):
    os.unlink(h5_file_name)
    
offset_info = []  # type: Tuple[int, int, int]

with zipfile.ZipFile('outputs/bert.pt', 'r') as zip_file:
    with h5py.File(h5_file_name, 'w') as h5_file:
        for info in zip_file.infolist():
            with zip_file.open(info.filename, 'r') as f:
                file_data = f.read()
                dataset = h5_file.create_dataset(
                    info.filename, data=np.frombuffer(file_data, dtype=np.byte))
                if info.filename.startswith('archive/data/'):
                    # Tensor storage data file. Remember ID, offset, and length
                    # Conveniently, all the IDs are integers, so we can store them
                    # as such.
                    storage_id_str = info.filename.split('/')[-1]
                    storage_id = int(storage_id_str)
                    offset = dataset.id.get_offset()
                    length = dataset.id.shape[0]
                    offset_info.append((storage_id, offset, length))
        # Write table of storage offsets
        _ = h5_file.create_dataset('offsets_table',
                                   data=np.array(offset_info))



Now we can recover the offset information more quickly by reading the table.

In [10]:
%%time

with h5py.File(h5_file_name, 'r') as h5_file:
    offset_info_dataset = h5_file['offsets_table']

    # Dump into a Numpy array because that method is much faster than 
    # iterating over the Python object.
    offset_info_array = np.zeros(offset_info_dataset.shape, dtype='int64')
    offset_info_dataset.read_direct(offset_info_array)

name_to_offset_and_len = {
    str(row[0]): (row[1], row[2]) for row in offset_info_array
}

CPU times: user 1.43 ms, sys: 621 µs, total: 2.05 ms
Wall time: 1.62 ms


Reading the offsets info in this way takes 1.5-2 msec, which is good enough for now.

Now we can read the all weights into Numpy arrays without copying data.

In [11]:
%%time

name_to_array = {
    name: 
    np.frombuffer(
        # Memory map buffer from the previous cell
        mmap_buffer, 
        dtype=np.byte, count=tup[1], offset=tup[0])
    for name, tup in name_to_offset_and_len.items()
}

CPU times: user 321 µs, sys: 14 µs, total: 335 µs
Wall time: 340 µs


Reading all of the weights takes less than a millisecond. Let's verify that the data comes back correct.

In [12]:
array_view_2 = name_to_array['2'].view(np.float32).reshape(tensor_shape)
torch.as_tensor(array_view_2)

tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])

We can do the same to directly create PyTorch tensors without copying data.

In [13]:
%%time

name_to_tensor = {
    name: 
    torch.frombuffer(
        # Memory map buffer from the previous cell
        mmap_buffer, 
        dtype=torch.int8, count=tup[1], offset=tup[0])
    for name, tup in name_to_offset_and_len.items()
}

CPU times: user 612 µs, sys: 73 µs, total: 685 µs
Wall time: 649 µs




In [14]:
name_to_tensor['2'].view(torch.float32).reshape(tensor_shape)

tensor([[-0.0102, -0.0615, -0.0265,  ..., -0.0199, -0.0372, -0.0098],
        [-0.0117, -0.0600, -0.0323,  ..., -0.0168, -0.0401, -0.0107],
        [-0.0198, -0.0627, -0.0326,  ..., -0.0165, -0.0420, -0.0032],
        ...,
        [-0.0218, -0.0556, -0.0135,  ..., -0.0043, -0.0151, -0.0249],
        [-0.0462, -0.0565, -0.0019,  ...,  0.0157, -0.0139, -0.0095],
        [ 0.0015, -0.0821, -0.0160,  ..., -0.0081, -0.0475,  0.0753]])

## Deserialize entire models

Now that we know how load weights directly from an HDF5 file, let's try and redo the `torch.load()` function's internals to use this mechanism. `torch.load()` doesn't actually restore tensors. The function leaves restoration of the `Tensor` objects to `pickle` and focuses on restoring the `_TypedStorage` objects where the tensors' data are kept.

Let's start by redoing the loading code above so that, instead of loading up tensors, it creates storage objects for the tensors. The obvious way to do this would be with the built-in `torch.<dtype>Storage.from_buffer()` class methods, but unfortunately those methods copy data, as evidenced by the amount of time the following cell takes:

In [15]:
%%time

# This doesn't work; copies data 
name_to_storage = {
    name: 
    torch.ByteStorage.from_buffer(
        mmap_buffer, 
        count=tup[1], offset=tup[0])
    for name, tup in name_to_offset_and_len.items()
}

CPU times: user 285 ms, sys: 149 ms, total: 434 ms
Wall time: 432 ms


Falling back to the internal `_UntypedStorage` class also doesn't work; the `from_buffer()` method of that class also copies data.

In [16]:
%%time

# This doesn't work; copies data
name_to_storage = {
    name: 
    torch._UntypedStorage.from_buffer(
        mmap_buffer,
        dtype=torch.int8,
        count=tup[1], offset=tup[0])
    for name, tup in name_to_offset_and_len.items()
}

CPU times: user 195 ms, sys: 126 ms, total: 321 ms
Wall time: 319 ms


What we can do instead is to use `torch.frombuffer()` to create `Tensor` objects, then drill down to their storage objects.

In [17]:
def make_name_to_storage():
    return {
        name: 
        torch.frombuffer(
            mmap_buffer, 
            dtype=torch.int8, count=tup[1], offset=tup[0]).storage()
        for name, tup in name_to_offset_and_len.items()
    }

# The %%time magic is unreliable on this code for some reason, so use %timeit
%timeit make_name_to_storage()

912 µs ± 4.84 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Now we're ready to modify a copy of `torch.serialization._load()` (the primary implementation of `torch.load()`) such that it loads from an HDF5 file and doesn't copy any tensor data.

In [18]:
# Function copied from torch/serialization.py
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
    if isinstance(bytes_str, bytes):
        return bytes_str.decode('ascii')
    return bytes_str


def zero_copy_load(h5_file, pickle_module=pickle, pickle_file='data.pkl', **pickle_load_args):
    """
    Modified version of :func:`torch.serialization._load()` with the following changes:
    
    * Loads from an HDF5 file instead of a zip file
    * Loads tensors without copying data
    * Doesn't support restoring data to non-CPU devices
    
    :param h5_file: Location of a serialized PyTorch model written with 
                    :func:`torch.save()` and then directly converted from zip to 
                    HDF5 format.
    :param pickle_module: Implementation of the Pickle prototol that was used in
                          the original call to :func:`torch.save()`
    :param pickle_load_args: Additional arguments to pass to the serialization library                      
    """
    loaded_storages = {}
    
    # Extract the pickled model and the byte offsets of tensor data from the HDF5 file
    with h5py.File(h5_file_name, 'r') as h5_file:
        offset_info_dataset = h5_file['offsets_table']
        offset_info_array = np.zeros(offset_info_dataset.shape, dtype='int64')
        offset_info_dataset.read_direct(offset_info_array)
        
        pickled_data_dataset = h5_file[f'archive/{pickle_file}']
        pickled_data_array = np.zeros(pickled_data_dataset.shape, dtype='byte')
        pickled_data_dataset.read_direct(pickled_data_array)

    key_to_offset = {
        str(row[0]): row[1] for row in offset_info_array
    }
    
    # Memory-map the entire file in read-only mode
    raw_file = open(h5_file_name, 'rb')
    mmap_buffer = mmap.mmap(raw_file.fileno(), 0, 
                            access=mmap.ACCESS_READ)

    # Define callbacks for deserialization of tensor storage
    def load_tensor(dtype, numel, key, location):
        # In spite of its name (retained from original PyTorch code), this callback
        # doesn't load the Tensor object but instead loads the tensor's backing 
        # TypedStorage object.
        offset = key_to_offset[key]
        loaded_storages[key] = (
            torch.frombuffer(mmap_buffer, dtype=dtype, count=numel, offset=offset)
            .storage())

    def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        assert typename == 'storage', \
            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
        storage_type, key, location, numel = data
        if storage_type is torch._UntypedStorage:
            dtype = torch.uint8
        else:
            dtype = storage_type.dtype

        if key not in loaded_storages:
            # Original PyTorch code passed in the number of bytes for the `numel`
            # argument here for some reason.
            load_tensor(dtype, numel, key, _maybe_decode_ascii(location))

        return loaded_storages[key]

    load_module_mapping: Dict[str, str] = {
        # See https://github.com/pytorch/pytorch/pull/51633
        'torch.tensor': 'torch._tensor'
    }

    # Need to subclass Unpickler instead of directly monkey-patching the find_class method
    # because it's marked readonly in pickle.
    # The type: ignore is because mypy can't statically determine the type of this class.
    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
        # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
        # Lets us override the imports that pickle uses when unpickling an object.
        # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
        def find_class(self, mod_name, name):
            if type(name) is str and 'Storage' in name:
                try:
                    return torch.serialization.StorageType(name)
                except KeyError:
                    pass
            mod_name = load_module_mapping.get(mod_name, mod_name)
            return super().find_class(mod_name, name)

    # Load the data (which may in turn use `persistent_load` to load tensors)
    data_file = io.BytesIO(pickled_data_array.tobytes())

    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
    unpickler.persistent_load = persistent_load
    result = unpickler.load()

    torch._utils._validate_loaded_sparse_tensors()

    return result


We can use this `zero_copy_load()` function to load a second instance of BERT:

In [19]:
bert_2 = zero_copy_load(h5_file_name)

This load operation takes about 30-40 msec

In [20]:
%timeit zero_copy_load(h5_file_name)

21.6 ms ± 6.86 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


Most of that wall-clock time goes to memory-mapping the HDF5 file:

In [21]:
%%time
raw_file = open(h5_file_name, 'rb')
_ = mmap.mmap(raw_file.fileno(), 0, access=mmap.ACCESS_COPY)

CPU times: user 1.25 ms, sys: 10.4 ms, total: 11.6 ms
Wall time: 28.8 ms


This speed compares quite favorably to loading with `torch.load()`:

In [22]:
%timeit torch.load('outputs/bert.pt')

148 ms ± 3.89 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


We can verify that the model loaded with `zero_copy_load()` produces the same answer as the original:

In [23]:
test_text = 'This is an example input sentence.'
tokenizer = transformers.BertTokenizerFast.from_pretrained(
    'bert-base-uncased')
test_tokens = tokenizer(test_text, return_tensors="pt")

# Run the original model and the copy that we just loaded
with torch.inference_mode():
    print("Original model's output:")
    print(bert(**test_tokens).last_hidden_state)
    print("\nModel output after zero-copy model loading:")
    print(bert_2(**test_tokens).last_hidden_state)

Original model's output:
tensor([[[-0.4056, -0.3897, -0.3079,  ..., -0.5104,  0.2806,  0.7274],
         [-0.6331, -0.5107, -0.5133,  ..., -0.5693,  0.9497,  0.1506],
         [-0.3115, -0.7275,  0.1119,  ..., -0.3893,  0.6364,  0.6212],
         ...,
         [ 0.0106, -0.1030,  0.0220,  ..., -0.3178, -0.0998,  0.0148],
         [ 0.5781,  0.1131, -0.6878,  ...,  0.3539, -0.4097, -0.2768],
         [ 0.0636, -0.2789, -0.3596,  ...,  0.5591, -0.9456,  0.1195]]])

Model output after zero-copy model loading:
tensor([[[-0.4056, -0.3897, -0.3079,  ..., -0.5104,  0.2806,  0.7274],
         [-0.6331, -0.5107, -0.5133,  ..., -0.5693,  0.9497,  0.1506],
         [-0.3115, -0.7275,  0.1119,  ..., -0.3893,  0.6364,  0.6212],
         ...,
         [ 0.0106, -0.1030,  0.0220,  ..., -0.3178, -0.0998,  0.0148],
         [ 0.5781,  0.1131, -0.6878,  ...,  0.3539, -0.4097, -0.2768],
         [ 0.0636, -0.2789, -0.3596,  ...,  0.5591, -0.9456,  0.1195]]])


There is one key diference, though: All the weights in the copy that we've loaded with `zero_copy_load()` are in a single memory-mapped region. 
That memory-mapped region is shared across processes, so if this process and others on the same machine load many copies of the model, the model's weights will only be stored once in memory across the entire machine.

Watch what happens to the Python process's heap size when we load 1000 copies of `bert-base-uncased` using `zero_copy_load()`:

In [24]:
def memory_in_mb() -> int:
    return psutil.Process(os.getpid()).memory_info().rss / 2 / 1048576

mb_before = memory_in_mb()
print(f'Memory usage before loading 1000 models: {mb_before} MB')
many_berts = [
    zero_copy_load(h5_file_name) for i in range(1000)
]
mb_after = memory_in_mb()
print(f'Memory usage after loading 1000 models: {mb_after} MB')

Memory usage before loading 1000 models: 1619.505859375 MB
Memory usage after loading 1000 models: 1739.984375 MB


With the weights living in shared memory, each copy of BERT requires only a small slice of process memory to hold its Python objects.

In [25]:
print(f'Megabytes of process memory per copy of bert-base-uncased: '
      f'{(mb_after - mb_before)/1000:0.3f}')

Megabytes of process memory per copy of bert-base-uncased: 0.120
