diff --git a/caffe2/python/compatibility.py b/caffe2/python/compatibility.py new file mode 100644 index 000000000000000..9d615a30833371a --- /dev/null +++ b/caffe2/python/compatibility.py @@ -0,0 +1,8 @@ +from six import PY2, PY3 + +if PY2: + import collections + container_abcs = collections +elif PY3: + import collections.abc + container_abcs = collections.abc diff --git a/caffe2/python/onnx/backend.py b/caffe2/python/onnx/backend.py index 3d9239c8b5c92bd..7eacaf327ad2643 100644 --- a/caffe2/python/onnx/backend.py +++ b/caffe2/python/onnx/backend.py @@ -25,6 +25,7 @@ import caffe2 from caffe2.python import core, workspace, rnn_cell, gru_cell +from caffe2.python.compatibility import container_abcs from caffe2.python.model_helper import ModelHelper from caffe2.proto import caffe2_pb2 import caffe2.python.utils @@ -778,7 +779,7 @@ def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version) if isinstance(ops, Caffe2Ops): return ops - if not isinstance(ops, collections.Iterable): + if not isinstance(ops, container_abcs.Iterable): ops = [ops] return Caffe2Ops(ops, [], []) diff --git a/caffe2/python/onnx/frontend.py b/caffe2/python/onnx/frontend.py index 5fd470c932ac59c..379ef65af904a66 100644 --- a/caffe2/python/onnx/frontend.py +++ b/caffe2/python/onnx/frontend.py @@ -12,11 +12,11 @@ from __future__ import unicode_literals import itertools -import collections import logging import re from caffe2.python import core as caffe2_core +from caffe2.python.compatibility import container_abcs from caffe2.proto import caffe2_legacy_pb2 from enum import Enum from onnx import (defs, checker, helper, numpy_helper, mapping, @@ -156,7 +156,7 @@ def caffe2_op_to_onnx_node(cls, op_def, shapes): const_tensors = [] if isinstance(nodes, tuple): nodes, const_tensors = nodes - if not isinstance(nodes, collections.Iterable): + if not isinstance(nodes, container_abcs.Iterable): nodes = [nodes] return nodes, const_tensors diff --git a/caffe2/python/utils.py b/caffe2/python/utils.py index 75124add41cecd1..5e87df8058e0177 100644 --- a/caffe2/python/utils.py +++ b/caffe2/python/utils.py @@ -6,13 +6,13 @@ from __future__ import unicode_literals from caffe2.proto import caffe2_pb2 +from caffe2.python.compatibility import container_abcs from future.utils import viewitems from google.protobuf.message import DecodeError, Message from google.protobuf import text_format import sys import copy -import collections import functools import numpy as np from six import integer_types, binary_type, text_type, string_types @@ -120,7 +120,7 @@ def MakeArgument(key, value): """Makes an argument based on the value type.""" argument = caffe2_pb2.Argument() argument.name = key - iterable = isinstance(value, collections.Iterable) + iterable = isinstance(value, container_abcs.Iterable) # Fast tracking common use case where a float32 array of tensor parameters # needs to be serialized. The entire array is guaranteed to have the same diff --git a/test/test_legacy_nn.py b/test/test_legacy_nn.py index b446920c4fec650..f3a807a0a6d6409 100644 --- a/test/test_legacy_nn.py +++ b/test/test_legacy_nn.py @@ -1,10 +1,10 @@ import math import random import unittest -import collections from copy import deepcopy import torch +from torch._six import container_abcs import torch.legacy.nn as nn from common import to_gpu, freeze_rng_state, run_tests, skipIfRocm, TEST_WITH_ROCM from common_nn import NNTestCase, ModuleTest, CriterionTest, iter_tensors, \ @@ -701,7 +701,7 @@ def require_grad(input): input = input.detach() input.requires_grad = True return input - elif isinstance(input, collections.Iterable): + elif isinstance(input, container_abcs.Iterable): return type(input)(require_grad(e) for e in input) return input diff --git a/torch/_six.py b/torch/_six.py index 1d70df51830d5e3..84ba9a464891bb3 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -108,3 +108,10 @@ def exec_(_code_, _globs_=None, _locs_=None): else: def raise_from(value, from_value): raise value + +if PY2: + import collections + container_abcs = collections +elif PY3: + import collections.abc + container_abcs = collections.abc diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 2cc4ebbfacd4a34..26dc9daf4a73506 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1,5 +1,5 @@ import torch -from collections import Iterable +from torch._six import container_abcs import torch.testing import sys from itertools import product @@ -11,7 +11,7 @@ def zero_gradients(x): if x.grad is not None: x.grad.detach_() x.grad.data.zero_() - elif isinstance(x, Iterable): + elif isinstance(x, container_abcs.Iterable): for elem in x: zero_gradients(elem) @@ -23,7 +23,7 @@ def make_jacobian(input, num_out): if not input.requires_grad: return None return torch.zeros(input.nelement(), num_out, dtype=input.dtype) - elif isinstance(input, Iterable): + elif isinstance(input, container_abcs.Iterable): jacobians = list(filter( lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input))) if not jacobians: @@ -37,7 +37,7 @@ def iter_tensors(x, only_requiring_grad=False): if isinstance(x, torch.Tensor): if x.requires_grad or not only_requiring_grad: yield x - elif isinstance(x, Iterable): + elif isinstance(x, container_abcs.Iterable): for elem in x: for result in iter_tensors(elem, only_requiring_grad): yield result diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index b32edf715a6c004..e4543a6be12b87b 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -6,7 +6,7 @@ import torch.jit.annotations from torch._six import raise_from, with_metaclass import torch.testing -from collections import defaultdict, OrderedDict, namedtuple, Iterable +from collections import defaultdict, OrderedDict, namedtuple import sys import warnings import itertools diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index ef79b1ffb6a289d..01e12e621ba4ac3 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,5 +1,6 @@ import warnings -from collections import OrderedDict, Iterable, Mapping +from collections import OrderedDict +from torch._six import container_abcs from itertools import islice import operator @@ -178,7 +179,7 @@ def extend(self, modules): Arguments: modules (iterable): iterable of modules to append """ - if not isinstance(modules, Iterable): + if not isinstance(modules, container_abcs.Iterable): raise TypeError("ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__) offset = len(self) @@ -278,12 +279,12 @@ def update(self, modules): modules (iterable): a mapping (dictionary) of (string: :class:`~torch.nn.Module``) or an iterable of key/value pairs of type (string, :class:`~torch.nn.Module``) """ - if not isinstance(modules, Iterable): + if not isinstance(modules, container_abcs.Iterable): raise TypeError("ModuleDict.update should be called with an " "iterable of key/value pairs, but got " + type(modules).__name__) - if isinstance(modules, Mapping): + if isinstance(modules, container_abcs.Mapping): if isinstance(modules, OrderedDict): for key, module in modules.items(): self[key] = module @@ -292,7 +293,7 @@ def update(self, modules): self[key] = module else: for j, m in enumerate(modules): - if not isinstance(m, Iterable): + if not isinstance(m, container_abcs.Iterable): raise TypeError("ModuleDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(m).__name__) @@ -375,7 +376,7 @@ def extend(self, parameters): Arguments: parameters (iterable): iterable of parameters to append """ - if not isinstance(parameters, Iterable): + if not isinstance(parameters, container_abcs.Iterable): raise TypeError("ParameterList.extend should be called with an " "iterable, but got " + type(parameters).__name__) offset = len(self) @@ -483,12 +484,12 @@ def update(self, parameters): (string : :class:`~torch.nn.Parameter`) or an iterable of key/value pairs of type (string, :class:`~torch.nn.Parameter`) """ - if not isinstance(parameters, Iterable): + if not isinstance(parameters, container_abcs.Iterable): raise TypeError("ParametersDict.update should be called with an " "iterable of key/value pairs, but got " + type(parameters).__name__) - if isinstance(parameters, Mapping): + if isinstance(parameters, container_abcs.Mapping): if isinstance(parameters, OrderedDict): for key, parameter in parameters.items(): self[key] = parameter @@ -497,7 +498,7 @@ def update(self, parameters): self[key] = parameter else: for j, p in enumerate(parameters): - if not isinstance(p, Iterable): + if not isinstance(p, container_abcs.Iterable): raise TypeError("ParameterDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(p).__name__) diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3cff6a9e9ffba92..2b8ebd642b000a6 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,10 +1,10 @@ -import collections +from torch._six import container_abcs from itertools import repeat def _ntuple(n): def parse(x): - if isinstance(x, collections.Iterable): + if isinstance(x, container_abcs.Iterable): return x return tuple(repeat(x, n)) return parse diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 963e0bc95912556..d027267053052a2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -9,7 +9,7 @@ import torch.autograd import torch.serialization import re -import collections +from torch._six import container_abcs import contextlib import numbers import warnings @@ -354,7 +354,7 @@ def _run_symbolic_method(op_name, symbolic_fn, args): def _is_onnx_list(value): if not isinstance(value, string_classes) and \ not isinstance(value, torch.Tensor) and \ - isinstance(value, collections.Iterable): + isinstance(value, container_abcs.Iterable): return True return False diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 21f69bb82804d38..41c1e916f4d8d11 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -1,4 +1,5 @@ -from collections import defaultdict, Iterable +from collections import defaultdict +from torch._six import container_abcs import torch from copy import deepcopy @@ -123,7 +124,7 @@ def cast(param, value): return value elif isinstance(value, dict): return {k: cast(param, v) for k, v in value.items()} - elif isinstance(value, Iterable): + elif isinstance(value, container_abcs.Iterable): return type(value)(cast(param, v) for v in value) else: return value diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 9d69ab4daf0fc85..0874fd1185028d7 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -6,7 +6,7 @@ from . import SequentialSampler, RandomSampler, BatchSampler import signal import functools -import collections +from torch._six import container_abcs import re import sys import threading @@ -187,9 +187,9 @@ def default_collate(batch): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch - elif isinstance(batch[0], collections.Mapping): + elif isinstance(batch[0], container_abcs.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} - elif isinstance(batch[0], collections.Sequence): + elif isinstance(batch[0], container_abcs.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] @@ -201,9 +201,9 @@ def pin_memory_batch(batch): return batch.pin_memory() elif isinstance(batch, string_classes): return batch - elif isinstance(batch, collections.Mapping): + elif isinstance(batch, container_abcs.Mapping): return {k: pin_memory_batch(sample) for k, sample in batch.items()} - elif isinstance(batch, collections.Sequence): + elif isinstance(batch, container_abcs.Sequence): return [pin_memory_batch(sample) for sample in batch] else: return batch