Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model checkpointed using torch.save() unable to be loaded using torch.load() #12042

Closed
deepakn94 opened this issue Sep 25, 2018 · 42 comments
Closed

Comments

@deepakn94
Copy link

I have created a PyTorch model checkpoint using torch.save; however, I'm unable to load this model using torch.load. I run into the following error:

>>> torch.load('model_best.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/serialization.py", line 358, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/envs/pytorch_source/lib/python3.7/site-packages/torch/serialization.py", line 549, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_should_read_directly)
RuntimeError: storage has wrong size: expected -7659745797817883467 got 512

The model was saved using code like this:

def save_checkpoint(epoch, model, best_top5, optimizer, is_best=False, filename='checkpoint.pth.tar'):
    state = {
        'epoch': epoch+1, 'state_dict': model.state_dict(),
        'best_top5': best_top5, 'optimizer' : optimizer.state_dict(),
    }
    torch.save(state, filename)

if args.local_rank == 0:
    if is_best: save_checkpoint(epoch, model, best_top5, optimizer, is_best=True, filename='model_best.pth.tar')

The model was trained across multiple p3.16xlarge instances.

@deepakn94
Copy link
Author

PyTorch version:

>>> print(torch.__version__)
0.5.0a0+6993e4a

Python version:

>>> python --version
Python 3.7.0

@ssnl
Copy link
Collaborator

ssnl commented Sep 25, 2018

cc @ezyang

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2018

Would it be possible to upload the checkpoint file somewhere, so we can look at it? (Or, if you can provide a script which generates the checkpoint file, that would work too.)

@ddkang
Copy link

ddkang commented Sep 25, 2018

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2018

Thanks! Looking into it.

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2018

(/home/ezyang/Dev/pytorch-tmp-env) [ezyang@devgpu005.ash6 ~/Dev/pytorch-tmp] tar tf model_best.pth.tar 
tar: This does not look like a tar archive
tar: Skipping to next header
tar: Exiting with failure status due to previous errors
(/home/ezyang/Dev/pytorch-tmp-env) [ezyang@devgpu005.ash6 ~/Dev/pytorch-tmp] sha1sum model_best.pth.tar
ca1d315ffddd014ceb3895a919394369dbb8e076  model_best.pth.tar

Also, looks like you're training imagenet; can you make the code available to repro, if possible?

@ddkang
Copy link

ddkang commented Sep 25, 2018

@deepakn94
Copy link
Author

An easier to use version of the code is here: https://github.com/stanford-futuredata/pytorch-distributed/blob/master/train.py

You can reproduce using python train.py --machines 16. There's some additional setup needed to get this working on an EC2 account, and I can walk you through that if needed.

Also, to answer your previous question, you're right that this isn't a .tar file -- it was just named with a .tar extension, for whatever reason

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

I can reproduce the failure on load. Still investigating.

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

One data point: the file descriptor at the time of error is misaligned:

(gdb) p/x (int)lseek(file, 0, 1)                                                                     
$7 = 0xa0099f

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

@deepakn94 Does this repro if you run it on only one node? Basically, I want the model to be as small as possible while still reproducing the error.

@ddkang
Copy link

ddkang commented Sep 26, 2018

The model serializes and deserializes fine when run on one node.

@deepakn94
Copy link
Author

Here's another datapoint: serialization and deserialization seems to work fine for 4 nodes when using PyTorch 0.4.0

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

OK, I'm reading the serialization code, and I think I see an incorrect use of the write() function. Posting patch soon...

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

Please recompile PyTorch with the following patch, which will fix write-time corruption.

diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index 2299cce24..1e5889b15 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -35,7 +35,7 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
       throw std::system_error(result, std::system_category());
   } else {
     int64_t buffer_size = std::min(size, (int64_t)5000);
-    std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
+    std::unique_ptr<char[]> le_buffer(new char[buffer_size * sizeof(scalar_t)]);
     for (int64_t i = 0; i < size; i += buffer_size) {
       size_t to_convert = std::min(size - i, buffer_size);
       if (sizeof(scalar_t) == 2) {
@@ -54,7 +54,19 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
             THPByteOrder::THP_LITTLE_ENDIAN,
             to_convert);
       }
-      SYSCHECK(doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t)));
+      int64_t remaining = buffer_size * sizeof(scalar_t);
+      char *bytes = le_buffer.get();
+      while (remaining > 0) {
+        ssize_t result = doWrite(fd, bytes, to_convert * sizeof(scalar_t));
+        if (result < 0) {
+          throw std::system_error(result, std::system_category());
+        }
+        bytes += result;
+        remaining -= result;
+      }
+      if (remaining != 0) {
+        throw std::system_error(result, std::system_category());
+      }
     }
   }
 }

I am not 100% sure this will fix the problem, I need to audit the rest of the sites now.

@deepakn94
Copy link
Author

Okay, thanks.

There's no way to salvage the existing checkpoints, right?

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2018

If the patch above fixes the problem, no, they're irretrievably corrupted.

@deepakn94
Copy link
Author

Okay. I'll a little busy for the next day or two, but will check this patch over the weekend.

@deepakn94
Copy link
Author

Should I apply this patch to current master? Or to the old commit we were using?

Also, seems like PyTorch 0.4.0 on 16 machines doesn't work.

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2018

I authored this patch on master, but it should backport to older versions too. Perhaps it would be better to backport to the old commit to get a cleaner test.

@deepakn94
Copy link
Author

That unfortunately didn't work. I applied the patch to the old commit (6993e4a).

>>> torch.load('model_best.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected 2930331299881099915 got 128
ubuntu@ip-172-31-93-108:~/pytorch$ git diff
diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index 42dff61..0496311 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -35,7 +35,7 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
       throw std::system_error(result, std::system_category());
   } else {
     int64_t buffer_size = std::min(size, (int64_t)5000);
-    std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
+    std::unique_ptr<char[]> le_buffer(new char[buffer_size * sizeof(real)]);
     for (int64_t i = 0; i < size; i += buffer_size) {
       size_t to_convert = std::min(size - i, buffer_size);
       if (sizeof(real) == 2) {
@@ -54,7 +54,19 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
             THPByteOrder::THP_LITTLE_ENDIAN,
             to_convert);
       }
-      SYSCHECK(doWrite(fd, le_buffer.get(), to_convert * sizeof(real)));
+      int64_t remaining = buffer_size * sizeof(real);
+      char *bytes = le_buffer.get();
+      while (remaining > 0) {
+        ssize_t result = doWrite(fd, bytes, to_convert * sizeof(real));
+        if (result < 0) {
+          throw std::system_error(result, std::system_category());
+        }
+        bytes += result;
+        remaining -= result;
+      }
+      if (remaining != 0) {
+        throw std::system_error(result, std::system_category());
+      }
     }
   }
 }

Uploaded checkpoint here: https://s3.amazonaws.com/distributed-pytorch-imagenet-runs/imagenet-16-new/run1/model_best.pth.tar

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2018

I can't read the updated checkpoint (no permissions). I have a more complete patch which also fixes an underrun on reads, but it doesn't catch any more write side errors, so it must be a different bug.

ezyang added a commit to ezyang/pytorch that referenced this issue Sep 27, 2018
… cases.

Previously, doRead/doWrite were functions that could return partial reads/writes,
and we checked for this case inconsistently in the call sites of serialization.cpp.
Now, these functions do NOT return the amount of bytes read/written, and instead
handle the necessary checking loop themselves.

Fixes pytorch#12042.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@deepakn94
Copy link
Author

I updated the permissions on the checkpoint.

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2018

Thanks. Confirmed that it still seems to be a write side bug. I guess I'll have to figure something else out...

Any luck minimizing the repro?

@deepakn94
Copy link
Author

This shouldn't be closed, right?

I'm working on a smaller repro -- I suspect that running this on any distributed setup causes this, but will confirm sometime over the weekend.

@soumith soumith reopened this Sep 28, 2018
@soumith
Copy link
Member

soumith commented Sep 28, 2018

the closing was an accident, sorry about that. reopened the issue

ezyang added a commit to ezyang/pytorch that referenced this issue Sep 28, 2018
… cases.

Previously, doRead/doWrite were functions that could return partial reads/writes,
and we checked for this case inconsistently in the call sites of serialization.cpp.
Now, these functions do NOT return the amount of bytes read/written, and instead
handle the necessary checking loop themselves.

Fixes pytorch#12042.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang
Copy link
Contributor

ezyang commented Sep 28, 2018

So... I tried saving and loading on the small distributed setup we have in our test suite, and I got a very similar looking error: "EOFError: Ran out of input". I'll work on debugging this case. Branch I'm testing off of is https://github.com/ezyang/pytorch/tree/test/semicharmed-kind-of-life using python test/run_test.py -i distributed -b nccl

EDIT: Never mind! I forgot to seek the file back to the beginning before reading it out again.

@apaszke
Copy link
Contributor

apaszke commented Sep 28, 2018

If the problem only appears in the distributed setting... are you sure that all processes aren't writing to the same file at the same time? That would corrupt it for sure.

@ezyang
Copy link
Contributor

ezyang commented Sep 28, 2018

The script seems to only write from local_rank == 0: https://github.com/diux-dev/imagenet18/blob/59a8f25171fb8cede51db9187a32fc8f802384a0/training/train_imagenet_nv.py#L150 so unless I misunderstand how rank works it should be OK.

@deepakn94
Copy link
Author

Yup, I don't think that's the problem -- only the "master" worker should write the checkpoint. The bug seems to be non-deterministic, because I do have a single 4-machine run that succeeded (along with perhaps 20 failures).

@ezyang
Copy link
Contributor

ezyang commented Sep 28, 2018

@deepakn94 I haven't tried to get the script to run for me, but another thing to try: when you get to the save point, save the model multiple times; like, 8 times. We can then compare them and see if they're all corrupted identically, or some of them are ok, etc.

@deepakn94
Copy link
Author

Links are of the form https://s3.amazonaws.com/distributed-pytorch-imagenet-runs/multi-checkpoint/model_best.0.pth.tar (replace 0 with numbers from 0 to 7)

@deepakn94
Copy link
Author

This is actually interesting; looks like some of the checkpoints are corrupted identically, but most are different (and one of the eight checkpoints is not corrupted).

>>> torch.load('model_best.3.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected -4669315570785868528 got 512
>>> torch.load('model_best.4.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected 3219564007566745640 got 256
>>> torch.load('model_best.5.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected 1618383146375255311 got 256
>>> torch.load('model_best.6.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected 3219564007566745640 got 256
>>> torch.load('model_best.7.pth.tar')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 476, in _load
    deserialized_objects[key]._set_from_file(f, offset, f_is_real_file)
RuntimeError: storage has wrong size: expected 2579182944752902985 got 128

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2018

I'm not going to get around to looking at this until the work week, but my plan is to do a binary comparison on the checkpoints and see where they diverge, and what kind of corruption is happening, and then check which particular part of the serialization code was writing out that part of the file.

@deepakn94
Copy link
Author

Sounds good. Let me know if there's anything else you need from my side! (is the easier-to-produce test case still useful?)

@deepakn94
Copy link
Author

Sigh, I think I found the problem. It seems like local_rank is actually the ID within a worker; multiple workers have a local_rank of 0, so they're probably trampling each other's checkpoints.

@ezyang
Copy link
Contributor

ezyang commented Oct 3, 2018

Aw man, that sounds like a good one for the docs. Very happy you figured it out :)

@deepakn94
Copy link
Author

Verified that this is indeed the case -- closing this. Thanks for all the help!

@ezyang
Copy link
Contributor

ezyang commented Oct 5, 2018

@deepakn94 If you don't mind me asking, what change did you make to solve the problem? IIUC, you were writing to a network filesystem for the checkpoints; did you just make them stop writing to NFS?

@deepakn94
Copy link
Author

I added a --global_rank command line argument as well. Full commit here: stanford-futuredata/pytorch-distributed@23990ca

@bermanmaxim
Copy link
Contributor

bermanmaxim commented Oct 26, 2018

Note that the launch utility torch.distributed.launch sets up a RANK environment variable which can be used to detect if you are on the master process (with os.environ['RANK'] == '0' from python).
EDIT: actually it is even simpler than that, you can use torch.distributed.get_rank() to get the global rank.

@djstrong
Copy link

djstrong commented Apr 8, 2019

I have similar error when I only load pretrained model. The problem does not occur if only one process is loading the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants