Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorchFileRecorder memory usage #1270

Closed
laggui opened this issue Feb 6, 2024 · 7 comments
Closed

PyTorchFileRecorder memory usage #1270

laggui opened this issue Feb 6, 2024 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@laggui
Copy link
Member

laggui commented Feb 6, 2024

Updated

Describe the bug
When importing weights from a saved PyTorch model, incorrectly mapped keys will lead to increasing memory usage and eventually the process will get killed.

To Reproduce
Check out this simple MRE: https://github.com/Nikaidou-Shinku/burn-1270-mre

Expected behavior
The user experience could be improved. Hopefully an incorrect key mapping doesn't lead to OOM and we can catch and report the error.

@laggui laggui added the bug Something isn't working label Feb 6, 2024
@antimora
Copy link
Collaborator

antimora commented Feb 7, 2024

yurzhang has reported the following on the discord channel:

[7:51 PM]yurzhang: When I tried to port a PyTorch module to Burn before, I used a workaround to implement Sequential, so all the keys in the module in the form of conv.0 were renamed to conv0.
But when I try to read the pth file, key_remap is not working properly, so then my model needs the key named conv0, but the pth file does not have these keys, instead there are some named conv.0 in it.
In this case, I tried to read a pth file about 120MiB in size, and then used the debugger to trace it into the burn_import::pytorch::reader::from_file function. When candle_core::pickle::read_all ends, the program occupies less than 200MiB of memory, and then when burn::record::serde::data::unflatten ends, the program occupies less than 2GiB of memory, and finally when the program enters D::deserialize, it used up all my 8GiB of memory and 16GiB of swap space until it was killed by the OOM killer.
Later, the key_remap problem was fixed, so I can read the pth file normally, but I think there may still be some problem here.
[7:55 PM]yurzhang: If a minimal reproducible example would be helpful, I might try to create one.
[8:01 PM]yurzhang: Oh I think #1270 may have encountered a similar problem to me, the model parameters I loaded were larger than resnet-18 was, but after I fixed the key inconsistency problem, I can load it with a normal amount of memory
GitHub
PyTorchFileRecorder memory usage · Issue #1270 · tracel-ai/burn
Describe the bug Trying to load resnet-18 pre-trained weights from torchvision, the memory usage rapidly increases to over 32GB which causes the process to get killed on my Ubuntu machine. To Repro...
PyTorchFileRecorder memory usage · Issue #1270 · tracel-ai/burn
[8:04 PM]yurzhang: I think we can first check whether the key after remap and module required are the same in the issue 1270.

@Nikaidou-Shinku
Copy link
Contributor

I made a minimal reproducible example, I hope it helps.
https://github.com/Nikaidou-Shinku/burn-1270-mre

@antimora
Copy link
Collaborator

antimora commented Feb 7, 2024

I made a minimal reproducible example, I hope it helps.

https://github.com/Nikaidou-Shinku/burn-1270-mre

Thank you. I will check it out.

@laggui
Copy link
Member Author

laggui commented Feb 7, 2024

Based on the information provided by @Nikaidou-Shinku on discord and the MRE I found that I had a pattern that did not cover all layers. My fix for the pattern chain in #1269 did not reuse the new name, so some names would still not match and thus I'd run into the OOM issue we described. I'll open another PR for that.

I'm still running into an issue though but it seems to be related to something else. pickle::read_all is not returning all the tensors actually saved in the .pth file. I'm only getting the batch norm layers, no convolutional layers weights seem to be parsed correctly. Will investigate further.


/edit: below are my findings

Turns out the TensorInfo is not returned in into_tensor_info() . I'm not exactly super familiar with the internal format for tensors saved with pickle but looks like the class name is not matched here.

When loading the torchvision pre-trained weights for resnet-18, all my missing parameters are of type Class { module_name: "torch._utils", class_name: "_rebuild_parameter" } and the code explicitly checks for class_name == "_rebuild_tensor_v2". So it looks like it cannot load any nn.Parameter types somehow?

@antimora
Copy link
Collaborator

antimora commented Feb 8, 2024

I will be investing the root cause but now that @Nikaidou-Shinku pointed to another problem (wrong keys), I suspect the NestedValue creation is blown up because there is a loop or something like this during unflattening. I will try fixing the two issues: memory doubling and recursion.

antimora added a commit to antimora/burn that referenced this issue Feb 9, 2024
@antimora
Copy link
Collaborator

antimora commented Feb 9, 2024

I have figured out the root cause of the issue. It was because a struct field in a model was not present in pt file. This coincided with incorrect renaming fields. Thanks to @Nikaidou-Shinku for catching it.

#1286 PR fixes it. Please see the changes section on how I fixed it.

nathanielsimard pushed a commit that referenced this issue Feb 12, 2024
@laggui
Copy link
Member Author

laggui commented Feb 29, 2024

Fixed with PR #1286

@laggui laggui closed this as completed Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants