Skip to content

Commit

Permalink
migrating deprecated calls without abc module for containers (#11515)
Browse files Browse the repository at this point in the history
Summary:
Implementing #10540.
Pull Request resolved: #11515

Reviewed By: apaszke

Differential Revision: D9771045

Pulled By: jeffreyksmithjr

fbshipit-source-id: 85ea39abaa9b465805a969f122b626b11fc85ef6
  • Loading branch information
jeffreyksmithjr authored and facebook-github-bot committed Sep 13, 2018
1 parent 29e29ca commit 05e06f7
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 32 deletions.
8 changes: 8 additions & 0 deletions caffe2/python/compatibility.py
@@ -0,0 +1,8 @@
from six import PY2, PY3

if PY2:
import collections
container_abcs = collections
elif PY3:
import collections.abc
container_abcs = collections.abc
3 changes: 2 additions & 1 deletion caffe2/python/onnx/backend.py
Expand Up @@ -25,6 +25,7 @@

import caffe2
from caffe2.python import core, workspace, rnn_cell, gru_cell
from caffe2.python.compatibility import container_abcs
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
import caffe2.python.utils
Expand Down Expand Up @@ -778,7 +779,7 @@ def _onnx_node_to_caffe2_op(cls, init_model, pred_model, node_def, opset_version
ops = translator(init_model, pred_model, OnnxNode(node_def), opset_version)
if isinstance(ops, Caffe2Ops):
return ops
if not isinstance(ops, collections.Iterable):
if not isinstance(ops, container_abcs.Iterable):
ops = [ops]
return Caffe2Ops(ops, [], [])

Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/onnx/frontend.py
Expand Up @@ -12,11 +12,11 @@
from __future__ import unicode_literals

import itertools
import collections
import logging
import re

from caffe2.python import core as caffe2_core
from caffe2.python.compatibility import container_abcs
from caffe2.proto import caffe2_legacy_pb2
from enum import Enum
from onnx import (defs, checker, helper, numpy_helper, mapping,
Expand Down Expand Up @@ -156,7 +156,7 @@ def caffe2_op_to_onnx_node(cls, op_def, shapes):
const_tensors = []
if isinstance(nodes, tuple):
nodes, const_tensors = nodes
if not isinstance(nodes, collections.Iterable):
if not isinstance(nodes, container_abcs.Iterable):
nodes = [nodes]
return nodes, const_tensors

Expand Down
4 changes: 2 additions & 2 deletions caffe2/python/utils.py
Expand Up @@ -6,13 +6,13 @@
from __future__ import unicode_literals

from caffe2.proto import caffe2_pb2
from caffe2.python.compatibility import container_abcs
from future.utils import viewitems
from google.protobuf.message import DecodeError, Message
from google.protobuf import text_format

import sys
import copy
import collections
import functools
import numpy as np
from six import integer_types, binary_type, text_type, string_types
Expand Down Expand Up @@ -120,7 +120,7 @@ def MakeArgument(key, value):
"""Makes an argument based on the value type."""
argument = caffe2_pb2.Argument()
argument.name = key
iterable = isinstance(value, collections.Iterable)
iterable = isinstance(value, container_abcs.Iterable)

# Fast tracking common use case where a float32 array of tensor parameters
# needs to be serialized. The entire array is guaranteed to have the same
Expand Down
4 changes: 2 additions & 2 deletions test/test_legacy_nn.py
@@ -1,10 +1,10 @@
import math
import random
import unittest
import collections
from copy import deepcopy

import torch
from torch._six import container_abcs
import torch.legacy.nn as nn
from common import to_gpu, freeze_rng_state, run_tests, skipIfRocm, TEST_WITH_ROCM
from common_nn import NNTestCase, ModuleTest, CriterionTest, iter_tensors, \
Expand Down Expand Up @@ -701,7 +701,7 @@ def require_grad(input):
input = input.detach()
input.requires_grad = True
return input
elif isinstance(input, collections.Iterable):
elif isinstance(input, container_abcs.Iterable):
return type(input)(require_grad(e) for e in input)
return input

Expand Down
7 changes: 7 additions & 0 deletions torch/_six.py
Expand Up @@ -108,3 +108,10 @@ def exec_(_code_, _globs_=None, _locs_=None):
else:
def raise_from(value, from_value):
raise value

if PY2:
import collections
container_abcs = collections
elif PY3:
import collections.abc
container_abcs = collections.abc
8 changes: 4 additions & 4 deletions torch/autograd/gradcheck.py
@@ -1,5 +1,5 @@
import torch
from collections import Iterable
from torch._six import container_abcs
import torch.testing
import sys
from itertools import product
Expand All @@ -11,7 +11,7 @@ def zero_gradients(x):
if x.grad is not None:
x.grad.detach_()
x.grad.data.zero_()
elif isinstance(x, Iterable):
elif isinstance(x, container_abcs.Iterable):
for elem in x:
zero_gradients(elem)

Expand All @@ -23,7 +23,7 @@ def make_jacobian(input, num_out):
if not input.requires_grad:
return None
return torch.zeros(input.nelement(), num_out, dtype=input.dtype)
elif isinstance(input, Iterable):
elif isinstance(input, container_abcs.Iterable):
jacobians = list(filter(
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
if not jacobians:
Expand All @@ -37,7 +37,7 @@ def iter_tensors(x, only_requiring_grad=False):
if isinstance(x, torch.Tensor):
if x.requires_grad or not only_requiring_grad:
yield x
elif isinstance(x, Iterable):
elif isinstance(x, container_abcs.Iterable):
for elem in x:
for result in iter_tensors(elem, only_requiring_grad):
yield result
Expand Down
2 changes: 1 addition & 1 deletion torch/jit/__init__.py
Expand Up @@ -6,7 +6,7 @@
import torch.jit.annotations
from torch._six import raise_from, with_metaclass
import torch.testing
from collections import defaultdict, OrderedDict, namedtuple, Iterable
from collections import defaultdict, OrderedDict, namedtuple
import sys
import warnings
import itertools
Expand Down
19 changes: 10 additions & 9 deletions torch/nn/modules/container.py
@@ -1,5 +1,6 @@
import warnings
from collections import OrderedDict, Iterable, Mapping
from collections import OrderedDict
from torch._six import container_abcs
from itertools import islice
import operator

Expand Down Expand Up @@ -178,7 +179,7 @@ def extend(self, modules):
Arguments:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, Iterable):
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__)
offset = len(self)
Expand Down Expand Up @@ -278,12 +279,12 @@ def update(self, modules):
modules (iterable): a mapping (dictionary) of (string: :class:`~torch.nn.Module``) or
an iterable of key/value pairs of type (string, :class:`~torch.nn.Module``)
"""
if not isinstance(modules, Iterable):
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(modules).__name__)

if isinstance(modules, Mapping):
if isinstance(modules, container_abcs.Mapping):
if isinstance(modules, OrderedDict):
for key, module in modules.items():
self[key] = module
Expand All @@ -292,7 +293,7 @@ def update(self, modules):
self[key] = module
else:
for j, m in enumerate(modules):
if not isinstance(m, Iterable):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(m).__name__)
Expand Down Expand Up @@ -375,7 +376,7 @@ def extend(self, parameters):
Arguments:
parameters (iterable): iterable of parameters to append
"""
if not isinstance(parameters, Iterable):
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(parameters).__name__)
offset = len(self)
Expand Down Expand Up @@ -483,12 +484,12 @@ def update(self, parameters):
(string : :class:`~torch.nn.Parameter`) or an iterable of
key/value pairs of type (string, :class:`~torch.nn.Parameter`)
"""
if not isinstance(parameters, Iterable):
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParametersDict.update should be called with an "
"iterable of key/value pairs, but got " +
type(parameters).__name__)

if isinstance(parameters, Mapping):
if isinstance(parameters, container_abcs.Mapping):
if isinstance(parameters, OrderedDict):
for key, parameter in parameters.items():
self[key] = parameter
Expand All @@ -497,7 +498,7 @@ def update(self, parameters):
self[key] = parameter
else:
for j, p in enumerate(parameters):
if not isinstance(p, Iterable):
if not isinstance(p, container_abcs.Iterable):
raise TypeError("ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" +
type(p).__name__)
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/utils.py
@@ -1,10 +1,10 @@
import collections
from torch._six import container_abcs
from itertools import repeat


def _ntuple(n):
def parse(x):
if isinstance(x, collections.Iterable):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/utils.py
Expand Up @@ -9,7 +9,7 @@
import torch.autograd
import torch.serialization
import re
import collections
from torch._six import container_abcs
import contextlib
import numbers
import warnings
Expand Down Expand Up @@ -354,7 +354,7 @@ def _run_symbolic_method(op_name, symbolic_fn, args):
def _is_onnx_list(value):
if not isinstance(value, string_classes) and \
not isinstance(value, torch.Tensor) and \
isinstance(value, collections.Iterable):
isinstance(value, container_abcs.Iterable):
return True
return False

Expand Down
5 changes: 3 additions & 2 deletions torch/optim/optimizer.py
@@ -1,4 +1,5 @@
from collections import defaultdict, Iterable
from collections import defaultdict
from torch._six import container_abcs

import torch
from copy import deepcopy
Expand Down Expand Up @@ -123,7 +124,7 @@ def cast(param, value):
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, Iterable):
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
Expand Down
10 changes: 5 additions & 5 deletions torch/utils/data/dataloader.py
Expand Up @@ -6,7 +6,7 @@
from . import SequentialSampler, RandomSampler, BatchSampler
import signal
import functools
import collections
from torch._six import container_abcs
import re
import sys
import threading
Expand Down Expand Up @@ -187,9 +187,9 @@ def default_collate(batch):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
elif isinstance(batch[0], container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

Expand All @@ -201,9 +201,9 @@ def pin_memory_batch(batch):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
elif isinstance(batch, container_abcs.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
elif isinstance(batch, container_abcs.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch
Expand Down

0 comments on commit 05e06f7

Please sign in to comment.