Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: multi-part torch.load/torch.save to support huge models and/or low CPU memory #64327

Open
stas00 opened this issue Aug 31, 2021 · 29 comments
Labels
module: nn Related to torch.nn module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects onnx-needs-info needs information from the author / reporter before ONNX team can take action onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@stas00
Copy link
Contributor

stas00 commented Aug 31, 2021

馃殌 Feature

The need

Models are getting bigger and there are times when loading all the params from external storage into CPU memory at once is either not possible or calls for some extra craftiness to make it work.

Examples:

  • users wanting to use large models that can't fit into 12GB free colab, yet which fit just fine into the provided GPU (this issue is definitely solvable with current tools, but it can be made easier.)

  • I encounter this issue quite often when working with multi-node training, or when converting multi-gpu partitioned param checkpoints into normal checkpoints like what Deepspeed does. Here I had to reconstitute params partitioned across many state dicts the main SD in CPU memory before I could save it, and the input can easily be 100GB or bigger, calling for quite a lot of RAM to complete. And needing 2x RAM. This is an example of needing a leaner torch.save.

    Another immediate use-case is converting from one framework to another. e.g. at BigScience we will shortly need to be able to convert a 200B param model from Megatron-Deepspeed to Megatron-LM to HF Transformers, so currently such conversion will require at least 800GB CPU RAM for fp16, and 1600GB CPU RAM for fp32.

    In the same context reshaping a huge checkpoint to a different TP/PP layout too may require a huge amount of CPU memory.

Proposal

The current storage facility is pickled data, which is a zip file with multiple separate files internally, but it's currently not possible to load these parts selectively - it is all-in-one. Same for saving data, one has to first build up a full state dict and only then can save it (important when params are spread out across gpus, e.g. in model parallelism).

This RFC proposes switching to an optional different format, while supporting the existing one, where there is a higher granularity than all tensors in the saved state dict, which reduces CPU memory requirements.

The proposal is to add support for tied to filesystem dicts, so that the dict is never fully loaded into memory, but instead gets loaded on demand, one key at a time. Same with saving, one can "grow" the saved state dict one key at a time.

One possible solution is using dbm because it's a python built-in and it provides the required functionality. So a working prototype follows.

Possible Implementation

You can play with the code directly using this notebook: https://github.com/stas00/porting/blob/master/pytorch/torch-save-load-multi-part-dbm.ipynb


import dbm
import pickle
import torch
from torch import nn
from collections.abc import MutableMapping

class DBMStateDict(MutableMapping):
    def __init__(self, path):
        self.path = path
        self.db = dbm.open(path, 'c')
        
    def __len__(self):
        return len(self.db.keys())
    
    def __getitem__(self, key):
        return pickle.loads(self.db[key])
    
    def __setitem__(self, key, value):
        self.db[key] = pickle.dumps(value)
        # it looks like dbm syncs immediately
    
    def __delitem__(self, key):
        return self.db.pop(key)
    
    def keys(self):
        return [k.decode() for k in self.db.keys()]
    
    def __iter__(self):    
        return iter(self.db)
            
    def copy(self):
        return DBMStateDict(self.path)
    
    def __del__(self):
        self.db.close()
                            
def save_new(sd_dict, path):
    sd = DBMStateDict(path)
    for k,v in sd_dict.items():
        sd[k] = v
        
def load_new(path):
    # this doesn't load the whole sd into memory!
    return DBMStateDict(path)

class SubNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(1, 1)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = SubNet()

# original
m = Net()
path = "model1.pt"
sd_dict = m.state_dict()
torch.save(sd_dict, path)
sd_dict = torch.load(path)
m.load_state_dict(sd_dict)

# same but loading / saving one key at a time   
m = Net()
path = "model1.dbm"
sd_dict = m.state_dict()
save_new(sd_dict, path)
sd_new = load_new(path)
m.load_state_dict(sd_new)

This is a rough prototype and doesn't pretend to be complete.

And of course, it'll have to be made back-compatible. But the point here is to demonstrate the feature and not to provide a complete new version.

Problems

  • I can see this is not OrderedDict which could be a problem in some circumstances (but good enough as a prototype).

Other approaches

Credits

The main class has been inspired by SplitCheckpoint

Thank you!

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@driazati
Copy link
Contributor

driazati commented Sep 1, 2021

Just for those interested (haven't read too much into the proposal yet): but the offending code is likely here:

https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/init.cpp#L1109-L1129

where we load all saved Tensors in as CPU tensors, then move them to their proper device later on:

def load_tensor(data_type, size, key, location):
name = f'data/{key}'
dtype = data_type(0).dtype
storage = zip_file.get_storage_from_record(name, size, dtype).storage()
loaded_storages[key] = restore_location(storage, location)

@t-vi
Copy link
Collaborator

t-vi commented Sep 1, 2021

One of the bits jumping to my mind from reading the the above is "why can we not use the zip file", because the file format seems to allow extracting parts / updating the archive selectively. It would be neat to have a brief mention of why this doesn't work as well or that whether could be an alternative to be investigated.
Also, I think one of the bits that is often asked for (and was a pain with the pickle format) is how to load from C++, so eventually some consideration for that would be good, too.

@stas00
Copy link
Contributor Author

stas00 commented Sep 1, 2021

One of the bits jumping to my mind from reading the the above is "why can we not use the zip file", because the file format seems to allow extracting parts / updating the archive selectively. It would be neat to have a brief mention of why this doesn't work as well or that whether could be an alternative to be investigated.

That would be great!

If I remember correctly from investigating the current format, that pickle zip file usually contains separate files for large tensors, but combines small tensors into a single file. But perhaps there is an API to untangle that and not work in the piecemeal fashion!

Also, I think one of the bits that is often asked for (and was a pain with the pickle format) is how to load from C++, so eventually some consideration for that would be good, too.

Following my proposal we could save the values in a non-python specific way, so that C++ which obviously has an API to read dbm files can read these too. After all, we have the data storage and the metadata about the tensor itself.

@zou3519 zou3519 added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 1, 2021
@albanD albanD added the module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects label Sep 1, 2021
@driazati
Copy link
Contributor

driazati commented Sep 1, 2021

Flagging a related issue: #63720

I haven't dug to deep into this code in a while but I don't think we need to change the format we use at all since we can determine from the zip format where/how long tensors are. We also already have a method to read torch.saveed items in C++ from the zip format (though the APIs are still undocumented). It seems to me that if we made the reader aware of this issue and more careful about loading data in smaller chunks when necessary, we wouldn't need to mess with the format.

If I remember correctly from investigating the current format, that pickle zip file usually contains separate files for large tensors, but combines small tensors into a single file. But perhaps there is an API to untangle that and not work in the piecemeal fashion!

Do you have any pointers for where this is happening?

@stas00
Copy link
Contributor Author

stas00 commented Sep 2, 2021

If I remember correctly from investigating the current format, that pickle zip file usually contains separate files for large tensors, but combines small tensors into a single file. But perhaps there is an API to untangle that and not work in the piecemeal fashion!

Do you have any pointers for where this is happening?

It's very possible that my memory is faulty, or perhaps I happened to run into some custom checkpoint formats. But I just did an experiment with bare pytorch-1.9.0 creating a model with tiny 1 element tensors, and torch.saveing them and each tensor had its own file in the archive/data subdir, each file was 4 bytes - which matches fp32.

Therefore I stand corrected and it appears that each tensor has its own file.

Flagging a related issue: #63720

Indeed! Both load and save need an option to do one tensor at a time.

I haven't dug to deep into this code in a while but I don't think we need to change the format we use at all since we can determine from the zip format where/how long tensors are. We also already have a method to read torch.saveed items in C++ from the zip format (though the APIs are still undocumented). It seems to me that if we made the reader aware of this issue and more careful about loading data in smaller chunks when necessary, we wouldn't need to mess with the format.

Excellent!

Now that we cleared that the zip structure used by torch save/load has each tensor as a separate file, by all means, let's go with your approach, @driazati!

I will update the OP to indicate your suggestion if others are in agreement that this is the path to take.

@ShadenSmith
Copy link

ShadenSmith commented Sep 2, 2021

Big +1 to this functionality. Checkpointing is a pain point for DeepSpeed for the reasons above.

In addition to per-tensor granularity, when working with partitioned tensors (e.g., model parallelism or ZeRO) it could be useful to be able to work at the granularity of tensor slices.

I have been wanting to try out HDF5, which is used in the scientific computing community and designed for managing huge tensor data. Maybe someone already has tried it out with PyTorch? HDF can be interfaced with both C++ and Python. I don't have personal experience with it though.

@mruberry
Copy link
Collaborator

mruberry commented Sep 2, 2021

Thanks for the detailed issue, @stas00! It would be nice to improve on serialization.

@stas00
Copy link
Contributor Author

stas00 commented Sep 8, 2021

Any practical suggestions to how we can move forward with tapping into the pickled zip structure to make each value loaded only when its key is accessed?

@stas00
Copy link
Contributor Author

stas00 commented Sep 8, 2021

If I remember correctly from investigating the current format, that pickle zip file usually contains separate files for large tensors, but combines small tensors into a single file. But perhaps there is an API to untangle that and not work in the piecemeal fashion!

Do you have any pointers for where this is happening?

It's very possible that my memory is faulty, or perhaps I happened to run into some custom checkpoint formats. But I just did an experiment with bare pytorch-1.9.0 creating a model with tiny 1 element tensors, and torch.saveing them and each tensor had its own file in the archive/data subdir, each file was 4 bytes - which matches fp32.

Therefore I stand corrected and it appears that each tensor has its own file.

I did some more testing and now I can see why I remembered it didn't always have a dedicated file per state dict entry.

This time I created a mixed types sd:

sd = dict(
    a=torch.ones(1), 
    b=dict(c=torch.ones(2), d=torch.ones(3), e=torch.ones(4)),
    f=dict(g=5,h=8,i=5,k=6)
    )
torch.save(sd, "/tmp/test.zip")

The resulting archive has only 4 entries - corresponding to each tensor.:

$ unzip test.zip 
$ find archive -type f
archive/data.pkl
archive/version
archive/data/0
archive/data/2
archive/data/1
archive/data/3

It did not create separate entries for non-tensor data, which I can see stored in data.pkl.

therefore we can't make an assumption that each key will have a corresponding file with the data for the value.

@driazati
Copy link
Contributor

driazati commented Sep 8, 2021

Yeah the only thing that's separate is the tensors (everything in data/ is a tensor storage as a binary blob) since you can torch.save anything pickle-able (e.g. an int, list, or class), so there's not always a set of "keys" available. So if someone stores a giant Python list or dictionary we can't really do much with the pickle format (which you can see with python -m pickletools archive/data.pkl), but I think a lazy loading mode for tensors would still be worthwhile and pretty do-able. When loading, we could have a flag (this should be opt-in since it would change the timing) that makes tensors loaded lazily and when actually unpickling them we insert some placeholder. Later on when that placeholder is used we can replace the data with the actual tensor by reading data/. Would that still be worthwhile?

Another part to consider is maybe we could stream tensors directly to their corresponding device when loading, so GPU tensors are cudaMalloc'd in one step and then copied over in chunks that the CPU can support.

@stas00
Copy link
Contributor Author

stas00 commented Sep 8, 2021

[...] When loading, we could have a flag (this should be opt-in since it would change the timing) that makes tensors loaded lazily and when actually unpickling them we insert some placeholder. Later on when that placeholder is used we can replace the data with the actual tensor by reading data/. Would that still be worthwhile?

Yes, that'd address part of the problem we are trying to solve.

But it'd still require something like the proposal here: #64601 to avoid multiple copies of the tensors. i.e. if the core functions end up reading the whole dict into memory as a fully copy, nothing has changed.

Another part to consider is maybe we could stream tensors directly to their corresponding device when loading, so GPU tensors are cudaMalloc'd in one step and then copied over in chunks that the CPU can support.

I think that would be a great addition too, but not the default.


And then there was torch.save which needs a symmetrical functionality to torch.load returned dict.

When manipulating huge models, e.g. converting them from one format to another currently at least 2x model size is required. But ideally we should be able to start saving and freeing params as soon as they are ready, leading to the total memory requirement of 1x model size + largest param size. So torch.save ideally should be able to save a key/value pair at a time, and not as the whole dict as it's done now. The tied dict proposal in OP provides this functionality. It removes the need for torch.save as the dict gets updated and flushed to the filesystem as soon as it's assigned to any of its keys.

This feature would be badly needed in a few months when we will need to convert 100B and 200B checkpoints from Megatron-Deepspeed to HF transformers format. Which currently would require 800GB for fp16 and 1.6TB CPU RAM for fp32 for the 200B param checkpoint.

Another use is reformatting degrees of parallelism in the checkpoint, e.g. Megatron-LM reshaper currently reads the existing checkpoint and then reformats it in another process and only then saves it, which again requires a lot more RAM than most of us have.

@stas00
Copy link
Contributor Author

stas00 commented Sep 13, 2021

Further, there is an emerging need to reduce the size of the largest file in the checkpoint. As past 20GB file size (5B param checkpoint) all kinds of issues emerge - e.g. Cloudfront won't cache such files. I was just trying to release a 26GB checkpoint and discovered there is no CDN caching.

I opened another RFC specifically on the HF transformers side: huggingface/transformers#13548 as we increasingly start to see checkpoints that are in 10s of GBs and growing. But perhaps the need can be seen for the larger pytorch community and addressed in the core instead. The advantage of doing it in the core is that it'll create a standard, rather than each framework concocting their own format and then we run into incompatibilities, which having this as a core feature would prevent.

The reason I'm posting it here is because if one has a dedicated separate file for each param, then this whole issue becomes moot, as we can then save/load each param separately and not incur the problem of having the whole state_dict copy in the memory.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Sep 17, 2021

Just in case it's missing from the approach options, but another possibility is to compress the parameters when the model is on CPU and decompress during transfer to GPU. This depends on how compressible these weights are. Maybe via something like zstd.

@stas00
Copy link
Contributor Author

stas00 commented Sep 17, 2021

not by much, e.g.:

$ zstd pytorch_model.bin
$ ls -sh1 py*
2.7G pytorch_model.bin
2.3G pytorch_model.bin.zst

Granted I haven't tried non-default's higher compression levels or its DICT training.

I did try compressing before and the best I got was 1/3 shaved off - this of course would depend on the model.

Unless this would perform better on per-tensor and I shouldn't have tried to compress the whole zip structure.

So I unzipped the file and tried to compress individual tensors and got only ~92% - so worse than the total, I repeated the experiment on much larger checkpoint's individual tensors (0.5G single tensor) with the same results - only 8% shaved off.

But wow, this library is fast!

@stas00
Copy link
Contributor Author

stas00 commented Mar 17, 2022

Does the pytorch team want to participate in the design of a multi-part model save/load functionality for LLMs? I was trying to have this conversation in this thread, but I have been mainly talking to myself.

We are discussing how we should approach the handling of 329GB bf16 weights for the 176B model which is not hypothetical anymore and needs to be dealt with shortly.

Ideally we should have a single solution that is adopted by pytorch and not re-invent the wheel.

The current discussion starts here: huggingface/transformers#13548 (comment)

Thank you for your interest.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 16, 2022

+1 for evaluating hdf5. maybe it can support appending tensors or even tensor chunks to an existing file cc @albanD

@kumpera
Copy link
Contributor

kumpera commented Jun 16, 2022

I'm working on a solution under torch.distributed._shard.checkpoint that can handle arbitrarily large models in a SPMD fashion including transparently handling changes in distribution topology such distributed training -> single GPU inference or going from X to X*2 nodes.

@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Jun 14, 2023

Following up here to note that an mmap option has been recently added to torch.load in #102549 that helps with

there are times when loading all the params from external storage into CPU memory at once is either not possible or calls for some extra craftiness to make it work

This is BC-compatible (i.e. checkpoints torch.save-ed with the _use_new_zipfile_serialization added since version 1.6) can be torch.load(mmap=True)-ed

@mackamann
Copy link

mackamann commented Jun 14, 2023

This is /awesome/, thank you @mikaylagawarecki! I am always about 1G shy of system RAM and have resorted to all kinds of tricks to get my model to load.

I just pulled nightly and tried to use the new named param (mmap) for torch.load() and I got an error saying mmap was an invalid keyword. Is this not in nightly (yet)? I saw it was merged, but not sure where it was merged to.

@mikaylagawarecki
Copy link
Contributor

@mackamann It should be in the nightly, could I confirm which nightly version you are using?

@mackamann
Copy link

It's working! Well I need to re-save my model w/ the correct flags first, but the named keyword is working. Silly me didn't put in the pip uninstall torch -y before trying to install a new version in colab. It pulled:

https://download.pytorch.org/whl/nightly/cu121/torch-2.1.0.dev20230614%2Bcu121-cp310-cp310-linux_x86_64.whl

@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Jun 14, 2023

Ah great! Just curious, why did you need to re-save your model? It should be compatible with checkpoints that use the new zipfile serialization (which is the default post torch 1.6 iiuc)

@mackamann
Copy link

mackamann commented Jun 14, 2023

@mikaylagawarecki I got this error:

905         if mmap:
--> 906             raise RuntimeError("mmap can only be used with files saved with ",
    907                                "`torch.save(_use_new_zipfile_serialization=True), "
    908                                "please torch.save your checkpoint with this option in order to use mmap."

So, yeah that's weird as I'm pretty sure I trained and saved the model using 2.0.1. EDIT checked and it was definitely saved with 2.0.1.

pip3 show torch
Name: torch
Version: 2.0.1

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 14, 2023

@mikaylagawarecki although this mmap loading a single huge zip formatted checkpoint probably still does not address the original concerns of needing to handle multi-part checkpoint files (file-per-parameter or multiple-chunk-files-per-parameter) as outlined by @stas00 needed by the huge language models and fitting for reliable/cached downloads

@mackamann
Copy link

mackamann commented Jun 15, 2023

@mikaylagawarecki any idea what might be going on or how I could troubleshoot the issue? It seems I do have a model saved as a zip file w/ 2.0.1:

file checkpoint_latest.pth.tar
checkpoint_latest.pth.tar: Zip archive data, at least v0.0 to extract, compression method=store

I also loaded and explicitly saved with that option and get the same error when loading w/ mmap named arg = True.

@albanD
Copy link
Collaborator

albanD commented Jun 15, 2023

@vadimkantorov it might be best to open a new issue to discuss handling multi-file checkpoints. I'm sure there will be a distributed axis to that discussion as well.

I think we should close this one (once the discussion with @mackamann is over).

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 15, 2023

I think that originally (or quite fast #64327 (comment)) this issue was exactly about multi-part and multi-file checkpoints :)

@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Jun 15, 2023

@mackamann do you have a minimal repro of the problem (perhaps a google colab would be super helpful!)

@mackamann
Copy link

mackamann commented Jun 16, 2023

@mikaylagawarecki I created a small repro'able colab, and that helped me to troubleshoot, the issue was that my vqvae was using the same loading code, and that was saved long ago, my actual model was loading fine.

Everything with mmap is loading now, my issue now is running out of VRAM, I'll probably have to do some moving stuff around to get the model to load OK.

thanks again!

@thiagocrepaldi thiagocrepaldi added onnx-triaged triaged by ONNX team onnx-needs-info needs information from the author / reporter before ONNX team can take action labels Jan 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects onnx-needs-info needs information from the author / reporter before ONNX team can take action onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests