Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.save also saves docstrings into pickle for some reason #21745

Open
ezyang opened this issue Jun 13, 2019 · 14 comments
Open

torch.save also saves docstrings into pickle for some reason #21745

ezyang opened this issue Jun 13, 2019 · 14 comments
Labels
module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Jun 13, 2019

Steps to reproduce:

  1. Run:
from torch import nn
import torch
torch.save(nn.Conv2d(1, 1, 1, 1), 'moo.pt')
  1. Open moo.pt in your editor

Expected result: a bunch of unintelligible binary gobbeldygook

Actual result: I see a docstring!!!

<80>^B<8a>
lü<9c>Fù j¨P^Y.<80>^BMé^C.<80>^B}q^@(X^P^@^@^@protocol_versionq^AMé^CX^M^@^@^@little_endianq^B<88>X
^@^@^@type_sizesq^C}q^D(X^E^@^@^@shortq^EK^BX^C^@^@^@intq^FK^DX^D^@^@^@longq^GK^Duu.<80>^B(X^F^@^@^@moduleq^@ctorch.nn.modules.conv
Conv2d
q^AX7^@^@^@/data/users/ezyang/pytorch-tmp/torch/nn/modules/conv.pyq^BX¸^[^@^@class Conv2d(_ConvNd):
    r"""Applies a 2D convolution over an input signal composed of several input
    planes.

    In the simplest case, the output value of the layer with input size
    :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
    can be precisely described as:

...
@ezyang
Copy link
Contributor Author

ezyang commented Jun 13, 2019

This makes our pickles big and exacerbates #21743

@gchanan gchanan added module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 13, 2019
@nmilosev
Copy link
Contributor

nmilosev commented Sep 14, 2019

@ezyang Hello, I started working on this and on #21743.

To clarify, for this issue, the docstring should not be saved at all? And for #21743 since you commented on that one as well, just need to handle the error and point the user to the docs? Although if we fix this, the other error should go away in the future, just not for older pickled files.

Thanks!

@ezyang
Copy link
Contributor Author

ezyang commented Sep 16, 2019

@nmilosev I'm actually not sure how to go about solving this particular issue (is pickle supposed to serialize docstrings? I'm not too familiar with pickling, so I don't know...). #21743 is simpler, just catch the error and then give some better information in this case.

@nmilosev
Copy link
Contributor

Thank you for the clarification. I will start with #21743 as it is easier to begin with.

For this issue what about explictly removing docstrings or adding a parameter to save?

Docstring should be stored to dunder doc so a simple:

def save(something, save_docstring=True):
    if save_docstring:
        pickle.dumps(something)
    else:        
        backup_docstring = something.__doc__
        something.__doc__ = None
        pickle.dumps(something)
        something.__doc__ = backup_docstring

should be enough.

It's a bit "hacky" but we can try it like that? What do you think?

@ezyang
Copy link
Contributor Author

ezyang commented Sep 16, 2019

It won't work. Pickle operates recursively and the code you wrote above would only strip __doc__ from the top level object.

@nmilosev
Copy link
Contributor

nmilosev commented Sep 16, 2019

Wow you're right, didn't think about that, sorry.

Another idea: per pickle doc if __getstate__ is present we can override default whole dict serialization. So if we add it to nn.Module perhaps, we should at least remove some of the docstrings from pickled files?

EDIT: Might not be a good idea after all. After looking at this example:

 def __getstate__(self):
        # Copy the object's state from self.__dict__ which contains
        # all our instance attributes. Always use the dict.copy()
        # method to avoid modifying the original state.
        state = self.__dict__.copy()
        # Remove the unpicklable entries.
        del state['file']
        return state

It is suggested to copy the original state which can be huge in our case. :(

@ezyang
Copy link
Contributor Author

ezyang commented Sep 17, 2019

Maybe we can just del state['__doc__'] here; this code might be why we are pickling docs in the first place.

@nmilosev
Copy link
Contributor

My thoughts exactly, however we have some issues with that also:

  1. Some modules e.g. linear quantized have their __getstate__ implemented
  2. if we do del state['__doc__'] we are modifying the state of the object which shouldn't be happening with save, and if we copy the state beforehand we are essentially doubling the memory usage

Seems more and more like a "won't fix" to me :(

@ezyang
Copy link
Contributor Author

ezyang commented Sep 18, 2019

(1) is an issue, but

if we do del state['doc'] we are modifying the state of the object which shouldn't be happening with save, and if we copy the state beforehand we are essentially doubling the memory usage

We were already doing a copy, no? So it doesn't seem like a big deal to try adding it there.

@nmilosev
Copy link
Contributor

Not sure about it, the documentation says:

If the __getstate__() method is absent, the instance’s __dict__ is pickled as usual.

Doesn't mention a copy, maybe it just reads directly from __dict__?

I skimmed through the source and I don't see an explicit copy operation. :/

@ezyang
Copy link
Contributor Author

ezyang commented Sep 19, 2019

Oh sorry, I misunderstood, and I thought you were quoting code in our codebase. OK, so this may not be easy to fix.

@driazati
Copy link
Contributor

Our serialization weirdly saves the entire module code (including the docstring) and verifies that it matches the loaded instance of the nn.Module on load, which is non-standard pickle behavior (usually just a reference to the qualified name of the type of the object is saved):

return ('module', obj, source_file, source)

I suppose that makes it easier to reproduce results exactly since the code being subtly different may affect the results, but it's not the normal pickle behavior.

Also FYI you can use pickletools to decode the binary

import pickletools
f = open('moo.pt', 'rb')
pickletools.dis(f)
pickletools.dis(f)
pickletools.dis(f)
pickletools.dis(f)
pickletools.dis(f)

Which looks like:

...
  139: (    MARK                                                                                                             
  140: X        BINUNICODE 'module'                                                                                          
  151: q        BINPUT     0                                                                                                 
  153: c        GLOBAL     'torch.nn.modules.conv Conv2d'                                                                    
  183: q        BINPUT     1                                                                                                 
  185: X        BINUNICODE '/scratch/driazati/dev/pytorch/torch/nn/modules/conv.py'                                      
  248: q        BINPUT     2                                                                                                 
  250: X        BINUNICODE 'class Conv2d(_ConvNd):\n    r"""Applies a 2D convolution over an input signal composed of several
 input\n    planes.\n\n    In the simplest case, the output value of the layer with input size\n    :math:`(N, C_{\\text{in}}
, H, W)` and output :math:`(N, C_{\\text{out}}, H_{\\text{out}}, W_{\\text{out}})`\n    can be precisely described as:\n\n   
 .. math::\n        \\text{out}(N_i, C_{\\text{out}_j}) = \\text{bias}(C_{\\text{out}_j}) +\n        \\sum_{k = 0}^{C_{\\text
{in}} - 1} \\text{weight}(C_{  
...                                                                                

@ezyang
Copy link
Contributor Author

ezyang commented Sep 20, 2019

@driazati So, we could fix this problem by changing that line, right? I wonder how old this code is lol. I guess there are BC concerns too.

@driazati
Copy link
Contributor

Yeah we could get rid of it, the code is 3 years old. Due to some other issues we were thinking of doing #26567, which could remove this functionality, and we could keep the old load in place as-is to preserve full BC.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: serialization Issues related to serialization (e.g., via pickle, or otherwise) of PyTorch objects triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants