Skip to content

Commit

Permalink
[3/n] loading meta to device
Browse files Browse the repository at this point in the history
Summary: Make it possible to `torch.jit.load(model, device)` to a device when `model` contains weights that are on device `meta`. Just leave the `meta` weights on `meta`, and load the weights that can be loaded to the target device.

Reviewed By: singlaiiit, RoshanPAN, sayitmemory

Differential Revision: D45099145

fbshipit-source-id: a736a8ed3052707e986f4efdd4301df32e74d088
  • Loading branch information
qxy11 authored and facebook-github-bot committed May 2, 2023
1 parent 7caac54 commit 5ad5f4c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
44 changes: 44 additions & 0 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,50 @@ def forward(self, x):
self.assertTrue(m_buffers["buffer"].is_meta)
self.assertTrue(m_loaded_buffers["buffer"].is_meta)

def test_save_load_meta_tensors_to_device(self):
"""
Check that when loading a module with meta tensors to device, the meta tensors
stay on meta, but non-meta tensors are set to the indicated device.
"""

class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)

def forward(self, x):
x = self.foo(x)
x = self.bar(x)
return x

m = Foo()

m_loaded = self.getExportImportCopy(torch.jit.script(m), map_location="cpu")
# Check submodules.
self.assertEqual(
len(list(m.named_modules())), len(list(m_loaded.named_modules()))
)
self.assertEqual(
{name for name, _ in m.named_modules()},
{name for name, _ in m_loaded.named_modules()},
)
# Check parameters.
m_params = dict(m.named_parameters())
m_loaded_params = dict(m_loaded.named_parameters())
self.assertEqual(len(m_params), len(m_loaded_params))
self.assertEqual(m_params, m_loaded_params)
# Check params and buffers that are/are not meta tensors
self.assertTrue(m_params["foo.weight"].is_meta)
self.assertTrue(m_loaded_params["foo.weight"].is_meta)
self.assertTrue(m_params["foo.bias"].is_meta)
self.assertTrue(m_loaded_params["foo.bias"].is_meta)
self.assertTrue(m_params["bar.weight"].is_cpu)
self.assertTrue(m_loaded_params["bar.weight"].is_cpu)
self.assertTrue(m_params["bar.bias"].is_cpu)
self.assertTrue(m_loaded_params["bar.bias"].is_cpu)


def test_save_load_with_saved_traced_inputs(self):
"""
Check that saving and loading with traced inputs works as expected
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/serialization/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ PickleOpCode Unpickler::readInstruction() {
const std::string& key = args.at(2).toStringRef();

at::Device device(args.at(3).toStringRef());
if (device_) {
// remap device location if it's not meta
if (device_ && !device.is_meta()) {
device = *device_;
}

Expand Down

0 comments on commit 5ad5f4c

Please sign in to comment.