From 961ebca2255f902477e9ea7060b8f28781e3c0cd Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 21 Oct 2022 01:09:50 +0000 Subject: [PATCH] Add `weights_only` option to `torch.load` (#86812) This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling. Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as `torch.Tensor` and `torch.Storage` variants. Defaults `weights_only` is set to `False`, but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable. To some extent, addresses https://github.com/pytorch/pytorch/issues/52596 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812 Approved by: https://github.com/ezyang --- test/test_serialization.py | 91 +++++++--- torch/_weights_only_unpickler.py | 288 +++++++++++++++++++++++++++++++ torch/serialization.py | 41 ++++- 3 files changed, 391 insertions(+), 29 deletions(-) create mode 100644 torch/_weights_only_unpickler.py diff --git a/test/test_serialization.py b/test/test_serialization.py index 68b59a5cb9a2e..d8cfd08aea084 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -148,17 +148,17 @@ def test(name_or_buffer): test(io.BytesIO()) - def test_serialization(self): + def _test_serialization(self, weights_only): # Test serialization with a real file b = self._test_serialization_data() with tempfile.NamedTemporaryFile() as f: torch.save(b, f) f.seek(0) - c = torch.load(f) + c = torch.load(f, weights_only=weights_only) self._test_serialization_assert(b, c) with TemporaryFileName() as fname: torch.save(b, fname) - c = torch.load(fname) + c = torch.load(fname, weights_only=weights_only) self._test_serialization_assert(b, c) # test non-ascii encoding of bytes arrays/strings # The following bytes are produced by serializing @@ -180,12 +180,18 @@ def test_serialization(self): buf = io.BytesIO(serialized) utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc' utf8_str = utf8_bytes.decode('utf-8') - loaded_utf8 = torch.load(buf, encoding='utf-8') + loaded_utf8 = torch.load(buf, weights_only=weights_only, encoding='utf-8') self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2]) buf.seek(0) - loaded_bytes = torch.load(buf, encoding='bytes') + loaded_bytes = torch.load(buf, weights_only=weights_only, encoding='bytes') self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2]) + def test_serialization(self): + self._test_serialization(False) + + def test_serialization_safe(self): + self._test_serialization(True) + def test_serialization_filelike(self): # Test serialization (load and save) with a filelike object b = self._test_serialization_data() @@ -279,7 +285,7 @@ def test_serialization_offset_gzip(self): self.assertTrue(torch.equal(a, b)) self.assertEqual(i, j) - def test_serialization_sparse(self): + def _test_serialization_sparse(self, weights_only): def _test_serialization(conversion): x = torch.zeros(3, 3) x[1][1] = 1 @@ -287,11 +293,17 @@ def _test_serialization(conversion): with tempfile.NamedTemporaryFile() as f: torch.save({"tensor": x}, f) f.seek(0) - y = torch.load(f) + y = torch.load(f, weights_only=weights_only) self.assertEqual(x, y["tensor"]) _test_serialization(lambda x: x.to_sparse()) _test_serialization(lambda x: x.to_sparse_csr()) + def test_serialization_sparse(self): + self._test_serialization(False) + + def test_serialization_sparse_safe(self): + self._test_serialization(True) + def test_serialization_sparse_invalid(self): x = torch.zeros(3, 3) x[1][1] = 1 @@ -358,13 +370,13 @@ def test_serialize_device(self): device_copied = copy.deepcopy(device) self.assertEqual(device, device_copied) - def test_serialization_backwards_compat(self): + def _test_serialization_backwards_compat(self, weights_only): a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] b = [a[i % 2] for i in range(4)] b += [a[0].storage()] b += [a[0].reshape(-1)[1:4].clone().storage()] path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') - c = torch.load(path) + c = torch.load(path, weights_only=weights_only) self.assertEqual(b, c, atol=0, rtol=0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) @@ -403,12 +415,17 @@ def __reduce__(self): old_x = old_cls(x) torch.save(old_x, f) f.seek(0) - load_x = torch.load(f) + load_x = torch.load(f, weights_only=weights_only) self.assertEqual(x.storage(), load_x.storage()) self.assertEqual(x.storage_offset(), load_x.storage_offset()) self.assertEqual(x.size(), load_x.size()) self.assertEqual(x.stride(), load_x.stride()) + def test_serialization_backwards_compat(self): + self._test_serialization_backwards_compat(False) + + def test_serialization_backwards_compat_safe(self): + self._test_serialization_backwards_compat(True) def test_serialization_save_warnings(self): with warnings.catch_warnings(record=True) as warns: @@ -680,25 +697,31 @@ def wrapper(*args, **kwargs): def __exit__(self, *args, **kwargs): torch.save = self.torch_save +@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") class TestBothSerialization(TestCase): - @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") - def test_serialization_new_format_old_format_compat(self, device): + def _test_serialization_new_format_old_format_compat(self, device, weights_only): x = [torch.ones(200, 200, device=device) for i in range(30)] def test(f_new, f_old): torch.save(x, f_new, _use_new_zipfile_serialization=True) f_new.seek(0) - x_new_load = torch.load(f_new) + x_new_load = torch.load(f_new, weights_only=weights_only) self.assertEqual(x, x_new_load) torch.save(x, f_old, _use_new_zipfile_serialization=False) f_old.seek(0) - x_old_load = torch.load(f_old) + x_old_load = torch.load(f_old, weights_only=weights_only) self.assertEqual(x_old_load, x_new_load) with tempfile.NamedTemporaryFile() as f_new, tempfile.NamedTemporaryFile() as f_old: test(f_new, f_old) + def test_serialization_new_format_old_format_compat(self, device): + self._test_serialization_new_format_old_format_compat(device, False) + + def test_serialization_new_format_old_format_compat_safe(self, device): + self._test_serialization_new_format_old_format_compat(device, True) + class TestOldSerialization(TestCase, SerializationMixin): # unique_key is necessary because on Python 2.7, if a warning passed to @@ -721,7 +744,7 @@ def import_module(name, filename): module = import_module(tmpmodule_name, fname) torch.save(module.Net(), checkpoint) - # First check that the checkpoint can be loaded without warnings + # First check that the checkpoint can be loaded without warning about unsafe loads checkpoint.seek(0) with warnings.catch_warnings(record=True) as w: loaded = torch.load(checkpoint) @@ -771,7 +794,8 @@ def test_serialization_offset(self): self.assertEqual(i, i_loaded) self.assertEqual(j, j_loaded) - def test_serialization_offset_filelike(self): + @parametrize('weights_only', (True, False)) + def test_serialization_offset_filelike(self, weights_only): a = torch.randn(5, 5) b = torch.randn(1024, 1024, 512, dtype=torch.float32) i, j = 41, 43 @@ -783,9 +807,9 @@ def test_serialization_offset_filelike(self): self.assertTrue(f.tell() > 2 * 1024 * 1024 * 1024) f.seek(0) i_loaded = pickle.load(f) - a_loaded = torch.load(f) + a_loaded = torch.load(f, weights_only=weights_only) j_loaded = pickle.load(f) - b_loaded = torch.load(f) + b_loaded = torch.load(f, weights_only=weights_only) self.assertTrue(torch.equal(a, a_loaded)) self.assertTrue(torch.equal(b, b_loaded)) self.assertEqual(i, i_loaded) @@ -797,7 +821,8 @@ def run(self, *args, **kwargs): class TestSerialization(TestCase, SerializationMixin): - def test_serialization_zipfile(self): + @parametrize('weights_only', (True, False)) + def test_serialization_zipfile(self, weights_only): data = self._test_serialization_data() def test(name_or_buffer): @@ -806,7 +831,7 @@ def test(name_or_buffer): if hasattr(name_or_buffer, 'seek'): name_or_buffer.seek(0) - result = torch.load(name_or_buffer) + result = torch.load(name_or_buffer, weights_only=weights_only) self.assertEqual(result, data) with tempfile.NamedTemporaryFile() as f: @@ -832,24 +857,40 @@ def test_serialization_2gb_file(self): f.seek(0) state = torch.load(f) - def test_pathlike_serialization(self): + @parametrize('weights_only', (True, False)) + def test_pathlike_serialization(self, weights_only): model = torch.nn.Conv2d(20, 3200, kernel_size=3) with TemporaryFileName() as fname: path = pathlib.Path(fname) torch.save(model.state_dict(), path) - torch.load(path) + torch.load(path, weights_only=weights_only) - def test_meta_serialization(self): + @parametrize('weights_only', (True, False)) + def test_meta_serialization(self, weights_only): big_model = torch.nn.Conv2d(20000, 320000, kernel_size=3, device='meta') with BytesIOContext() as f: torch.save(big_model.state_dict(), f) f.seek(0) - state = torch.load(f) + state = torch.load(f, weights_only=weights_only) self.assertEqual(state['weight'].size(), big_model.weight.size()) + def test_weights_only_assert(self): + class HelloWorld: + def __reduce__(self): + return (print, ("Hello World!",)) + + with BytesIOContext() as f: + torch.save(HelloWorld(), f) + f.seek(0) + # Unsafe load should work + self.assertIsNone(torch.load(f, weights_only=False)) + f.seek(0) + # Safe load should assert + with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"): + torch.load(f, weights_only=True) def run(self, *args, **kwargs): with serialization_method(use_zip=True): @@ -983,6 +1024,8 @@ def test_empty_class_serialization(self): instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_parametrized_tests(TestSubclassSerialization) +instantiate_parametrized_tests(TestOldSerialization) +instantiate_parametrized_tests(TestSerialization) if __name__ == '__main__': run_tests() diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py new file mode 100644 index 0000000000000..ee00db937fc3d --- /dev/null +++ b/torch/_weights_only_unpickler.py @@ -0,0 +1,288 @@ +# Unpickler restricted to loading only state dicts +# Restrict constructing types to a list defined in _get_allowed_globals() +# Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only +# Restrict APPEND/APPENDS to `list` +# In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary +# defined by `_get_allowed_globals()` method, that contains: +# - torch types (Storage, dtypes, Tensor, `torch.Size`), +# - `torch._utils._rebuild` functions. +# - `torch.nn.Parameter` +# - `collections.OrderedDict` + +# Based of https://github.com/python/cpython/blob/main/Lib/pickle.py +# Expected to be useful for loading PyTorch model weights +# For example: +# data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read() +# buf = io.BytesIO(data) +# weights = torch.load(buf, weights_only = True) + +import functools as _functools +from collections import OrderedDict +from pickle import ( + APPEND, + APPENDS, + BINGET, + BININT, + BININT1, + BININT2, + BINPERSID, + BINPUT, + BINUNICODE, + BUILD, + bytes_types, + decode_long, + EMPTY_DICT, + EMPTY_LIST, + EMPTY_SET, + EMPTY_TUPLE, + GLOBAL, + LONG1, + LONG_BINGET, + LONG_BINPUT, + MARK, + NEWFALSE, + NEWOBJ, + NEWTRUE, + NONE, + PROTO, + REDUCE, + SETITEM, + SETITEMS, + SHORT_BINSTRING, + STOP, + TUPLE, + TUPLE1, + TUPLE2, + TUPLE3, + UnpicklingError, +) +from struct import unpack +from sys import maxsize +from typing import Any, Dict, List + +import torch + + +# Unpickling machinery +@_functools.lru_cache(maxsize=1) +def _get_allowed_globals(): + rc: Dict[str, Any] = { + "collections.OrderedDict": OrderedDict, + "torch.nn.parameter.Parameter": torch.nn.Parameter, + "torch.serialization._get_layout": torch.serialization._get_layout, + "torch.Size": torch.Size, + "torch.Tensor": torch.Tensor, + } + # dtype + for t in [ + torch.complex32, + torch.complex64, + torch.complex128, + torch.float16, + torch.float32, + torch.float64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ]: + rc[str(t)] = t + # Tensor classes + for tt in torch._tensor_classes: + rc[f"{tt.__module__}.{tt.__name__}"] = tt + # Storage classes + for ts in torch._storage_classes: + rc[f"{ts.__module__}.{ts.__name__}"] = ts + # Rebuild functions + for f in [ + torch._utils._rebuild_parameter, + torch._utils._rebuild_tensor, + torch._utils._rebuild_tensor_v2, + torch._utils._rebuild_sparse_tensor, + torch._utils._rebuild_meta_tensor_no_storage, + torch._utils._rebuild_sparse_csr_tensor, + ]: + rc[f"torch._utils.{f.__name__}"] = f + return rc + + +class Unpickler: + def __init__(self, file, *, encoding: str = "bytes"): + self.encoding = encoding + self.readline = file.readline + self.read = file.read + self.memo: Dict[int, Any] = {} + + def load(self): + """Read a pickled object representation from the open file. + + Return the reconstituted object hierarchy specified in the file. + """ + self.metastack = [] + self.stack: List[Any] = [] + self.append = self.stack.append + read = self.read + readline = self.readline + while True: + key = read(1) + if not key: + raise EOFError + assert isinstance(key, bytes_types) + # Risky operators + if key[0] == GLOBAL[0]: + module = readline()[:-1].decode("utf-8") + name = readline()[:-1].decode("utf-8") + full_path = f"{module}.{name}" + if full_path in _get_allowed_globals(): + self.append(_get_allowed_globals()[full_path]) + else: + raise RuntimeError(f"Unsupported class {full_path}") + elif key[0] == NEWOBJ[0]: + args = self.stack.pop() + cls = self.stack.pop() + if cls is not torch.nn.Parameter: + raise RuntimeError(f"Trying to instantiate unsupported class {cls}") + self.append(torch.nn.Parameter(*args)) + elif key[0] == REDUCE[0]: + args = self.stack.pop() + func = self.stack[-1] + if func not in _get_allowed_globals().values(): + raise RuntimeError( + f"Trying to call reduce for unrecognized function {func}" + ) + self.stack[-1] = func(*args) + elif key[0] == BUILD[0]: + state = self.stack.pop() + inst = self.stack[-1] + if type(inst) is torch.Tensor: + # Legacy unpickling + inst.set_(*state) + elif type(inst) is torch.nn.Parameter: + inst.__setstate__(state) + elif type(inst) is OrderedDict: + inst.__dict__.update(state) + else: + raise RuntimeError( + f"Can only build Tensor, parameter or dict objects, but got {type(inst)}" + ) + # Stack manipulation + elif key[0] == APPEND[0]: + item = self.stack.pop() + list_obj = self.stack[-1] + if type(list_obj) is not list: + raise RuntimeError( + f"Can only append to lists, but got {type(list_obj)}" + ) + list_obj.append(item) + elif key[0] == APPENDS[0]: + items = self.pop_mark() + list_obj = self.stack[-1] + if type(list_obj) is not list: + raise RuntimeError( + f"Can only extend lists, but got {type(list_obj)}" + ) + list_obj.extend(items) + elif key[0] == SETITEM[0]: + (v, k) = (self.stack.pop(), self.stack.pop()) + self.stack[-1][k] = v + elif key[0] == SETITEMS[0]: + items = self.pop_mark() + for i in range(0, len(items), 2): + self.stack[-1][items[i]] = items[i + 1] + elif key[0] == MARK[0]: + self.metastack.append(self.stack) + self.stack = [] + self.append = self.stack.append + elif key[0] == TUPLE[0]: + items = self.pop_mark() + self.append(tuple(items)) + elif key[0] == TUPLE1[0]: + self.stack[-1] = (self.stack[-1],) + elif key[0] == TUPLE2[0]: + self.stack[-2:] = [(self.stack[-2], self.stack[-1])] + elif key[0] == TUPLE3[0]: + self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] + # Basic types construction + elif key[0] == NONE[0]: + self.append(None) + elif key[0] == NEWFALSE[0]: + self.append(False) + elif key[0] == NEWTRUE[0]: + self.append(True) + elif key[0] == EMPTY_TUPLE[0]: + self.append(()) + elif key[0] == EMPTY_LIST[0]: + self.append([]) + elif key[0] == EMPTY_DICT[0]: + self.append({}) + elif key[0] == EMPTY_SET[0]: + self.append(set()) + elif key[0] == BININT[0]: + self.append(unpack(" maxsize: + raise RuntimeError("String is too long") + strval = str(read(strlen), "utf-8", "surrogatepass") + self.append(strval) + elif key[0] == SHORT_BINSTRING[0]: + strlen = read(1)[0] + strdata = read(strlen) + if self.encoding != "bytes": + strdata = strdata.decode(self.encoding, "strict") + self.append(strdata) + elif key[0] == BINPERSID[0]: + pid = self.stack.pop() + # Only allow persistent load of storage + if type(pid) is not tuple and not type(pid) is not int: + raise RuntimeError( + f"persistent_load id must be tuple or int, but got {type(pid)}" + ) + if ( + type(pid) is tuple + and len(pid) > 0 + and torch.serialization._maybe_decode_ascii(pid[0]) != "storage" + ): + raise RuntimeError( + f"Only persistent_load of storage is allowed, but got {pid[0]}" + ) + self.append(self.persistent_load(pid)) + elif key[0] in [BINGET[0], LONG_BINGET[0]]: + idx = (read(1) if key[0] == BINGET[0] else unpack(" None: pickle_module: module used for pickling metadata and objects ''' - if pickle_module.__name__ == 'dill': + if pickle_module is not None and pickle_module.__name__ == 'dill': required_dill_version = (0, 3, 1) if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): raise ValueError(( @@ -652,7 +653,9 @@ def persistent_id(obj): def load( f: FILE_LIKE, map_location: MAP_LOCATION = None, - pickle_module: Any = pickle, + pickle_module: Any = None, + *, + weights_only: bool = False, **pickle_load_args: Any ) -> Any: # Reference: https://github.com/pytorch/pytorch/issues/54354 @@ -660,7 +663,7 @@ def load( # documentation. We need it so that Sphinx doesn't leak `pickle`s path from # the build environment (e.g. `>> torch.load('module.pt', encoding='ascii') """ + UNSAFE_MESSAGE = ( + "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" + " will likely succeed, but it can result in arbitrary code execution." + "Do it only if you get the file from a trusted source. WeightsUnpickler error: " + ) + # Add ability to force safe only weight loads via environment variable + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: + weights_only = True + + if weights_only: + if pickle_module is not None: + raise RuntimeError("Can not safely load weights when expiclit picke_module is specified") + else: + pickle_module = pickle + _check_dill_version(pickle_module) if 'encoding' not in pickle_load_args.keys(): @@ -760,7 +781,17 @@ def load( " silence this warning)", UserWarning) opened_file.seek(orig_position) return torch.jit.load(opened_file, map_location=map_location) + if weights_only: + try: + return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args) + except RuntimeError as e: + raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) + if weights_only: + try: + return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) + except RuntimeError as e: + raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)