From 5f875965c6d3c042c6bd8039135e8fe76cf9267b Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Tue, 5 Jan 2021 18:46:06 -0800 Subject: [PATCH 01/20] Fix doc for vmap levels (#50099) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50099 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D25783257 Pulled By: ejguan fbshipit-source-id: 7d2c7614f87e1c8adc8aefe3fe312b6c98ff6788 --- aten/src/ATen/VmapTransforms.h | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index 5063beeb08b0..8fa085245459 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap; // The levels bitset specifies which vmap levels correspond to the batch // dimensions at the front of the tensor. In particular, the number of set bits // corresponds to the number of batch dimensions on `tensor` and the rightmost -// bit of `levels` specifies the minimum number of nested vmaps we are in at +// bit of `levels` specifies the maximum number of nested vmaps we are in at // this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 struct TORCH_API VmapPhysicalView { VmapPhysicalView(Tensor&& tensor, std::bitset levels) : levels_(levels), tensor_(tensor) { From 574a15b6ccac1f25945926804facb0e677543fce Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 5 Jan 2021 18:57:15 -0800 Subject: [PATCH 02/20] [PyTorch] Reapply D25544731: Avoid extra Tensor refcounting in _cat_out_cpu (#49760) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49760 This was reverted because it landed in a stack together with D25542799 (https://github.com/pytorch/pytorch/commit/9ce1df079f6ea90dd4b7f9aa12a1a78d51a8b204), which really was broken. ghstack-source-id: 119361028 Test Plan: CI Reviewed By: bwasti Differential Revision: D25685789 fbshipit-source-id: 41e5abb4ff30acaa6f33f9c806acd652a6dd9646 --- aten/src/ATen/native/TensorShape.cpp | 34 ++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 09d50356abd9..6de79b2b8b53 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -106,15 +106,17 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor } } +static bool should_skip(const Tensor& t) { + return t.numel() == 0 && t.dim() == 1; +} + Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific // size (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - bool allSkipped = true; + bool allContiguous = true; - Tensor notSkippedTensor; // Inputs cannot alias the output tensor for (int64_t i = 0; i < tensors.size(); i++) { @@ -126,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { } at::assert_no_internal_overlap(result); - auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; }; - for (auto const &tensor : tensors) { - if (should_skip(tensor)) { - continue; + const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* { + for (auto const &tensor : tensors) { + if (should_skip(tensor)) { + continue; + } + // we've found a non-empty tensor + return &tensor; } - // we've found a non-empty tensor - allSkipped = false; - notSkippedTensor = tensor; - break; - } - if (allSkipped) { + return nullptr; + }(tensors); + + if (!pnotSkippedTensor) { + // FIXME: warn if this is the case -- see comment about skipped + // tensors at top of function. return result; } + const Tensor& notSkippedTensor = *pnotSkippedTensor; TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors"); TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range"); @@ -196,7 +202,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (reuse_iterator && result.is_contiguous(first_tensor_mem_format) && no_type_promotion) { - auto source_slice = notSkippedTensor; + const auto& source_slice = notSkippedTensor; auto slice_dim_size = source_slice.size(dim); auto result_slice = result.narrow(dim, 0, slice_dim_size); auto result_slice_data = result_slice.data_ptr(); From 75028f28e1be4ee1b259218a48ebe38a4e22adca Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 5 Jan 2021 18:57:15 -0800 Subject: [PATCH 03/20] [PyTorch] Reapply D25545777: Use .sizes() instead of .size() in _cat_out_cpu (#49761) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49761 This was reverted because it landed in a stack together with D25542799 (https://github.com/pytorch/pytorch/commit/9ce1df079f6ea90dd4b7f9aa12a1a78d51a8b204), which really was broken. ghstack-source-id: 119361027 Test Plan: CI Reviewed By: bwasti Differential Revision: D25685855 fbshipit-source-id: b51f67ebe667199d15bfc6f8f131a6f1ab1b0352 --- aten/src/ATen/native/TensorShape.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6de79b2b8b53..d1fadd58d38d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -98,8 +98,8 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor if (dim == dimension) { continue; } - int64_t first_dim_size = first.size(dim); - int64_t second_dim_size = second.size(dim); + int64_t first_dim_size = first.sizes()[dim]; + int64_t second_dim_size = second.sizes()[dim]; TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ", dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim, " (The offending index is ", index, ")"); @@ -167,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { continue; } check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i); - cat_dim_size += tensor.size(dim); + cat_dim_size += tensor.sizes()[dim]; if (!tensor.is_contiguous(first_tensor_mem_format)) { allContiguous = false; @@ -203,7 +203,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { result.is_contiguous(first_tensor_mem_format) && no_type_promotion) { const auto& source_slice = notSkippedTensor; - auto slice_dim_size = source_slice.size(dim); + auto slice_dim_size = source_slice.sizes()[dim]; auto result_slice = result.narrow(dim, 0, slice_dim_size); auto result_slice_data = result_slice.data_ptr(); auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type()); @@ -232,7 +232,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) { if (should_skip(tensor)) { continue; } - auto slice_dim_size = tensor.size(dim); + auto slice_dim_size = tensor.sizes()[dim]; auto result_slice = result.narrow(dim, offset, slice_dim_size); auto iter = TensorIteratorConfig() From d80d38cf87a527fded154293038d252d595c7100 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 5 Jan 2021 19:02:38 -0800 Subject: [PATCH 04/20] Clean up type annotations in caffe2/torch/nn/modules (#49957) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49957 Test Plan: Sandcastle Reviewed By: xush6528 Differential Revision: D25729745 fbshipit-source-id: 85810e2c18ca6856480bef81217da1359b63d8a3 --- torch/nn/modules/activation.py | 5 ++--- torch/nn/modules/conv.py | 5 +++-- torch/nn/modules/utils.py | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 073c95c28619..837ecca6fe9d 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -922,9 +922,8 @@ def __setstate__(self, state): super(MultiheadAttention, self).__setstate__(state) - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query, key, value: map a query and a set of key-value pairs to an output. diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index f22c35fa39ff..6a9c4dcd2ef6 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -530,8 +530,9 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, # dilation being an optional parameter is for backwards # compatibility - def _output_padding(self, input, output_size, stride, padding, kernel_size, dilation=None): - # type: (Tensor, Optional[List[int]], List[int], List[int], List[int], Optional[List[int]]) -> List[int] + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + dilation: Optional[List[int]] = None) -> List[int]: if output_size is None: ret = _single(self.output_padding) # converting to list if was not already else: diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 3e0b93c7afc0..97e4195619cb 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -26,8 +26,7 @@ def _reverse_repeat_tuple(t, n): return tuple(x for x in reversed(t) for _ in range(n)) -def _list_with_default(out_size, defaults): - # type: (List[int], List[int]) -> List[int] +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): From 9b7f3fa146d350628b295ab9b794d64173f17da1 Mon Sep 17 00:00:00 2001 From: Meghan Lele Date: Tue, 5 Jan 2021 19:35:12 -0800 Subject: [PATCH 05/20] [fx] Add matrix multiplication fusion pass (#50120) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50120 This commit adds a graph transformation pass that merges several matrix multiplications that use the same RHS operand into one large matrix multiplication. The LHS operands from all of the smaller matrix multiplications are concatenated together and used as an input in the large matrix multiply, and the result is split in order to obtain the same products as the original set of matrix multiplications. Test Plan: This commit adds a simple unit test with two matrix multiplications that share the same RHS operand. `buck test //caffe2/test:fx_experimental` Reviewed By: jamesr66a Differential Revision: D25239967 fbshipit-source-id: fb99ad25b7d83ff876da6d19dc4abd112d13001e --- test/test_fx_experimental.py | 123 +++++++++++++++ torch/fx/experimental/merge_matmul.py | 215 ++++++++++++++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 torch/fx/experimental/merge_matmul.py diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6e9c877b8de6..ac71d6037591 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -21,6 +21,7 @@ PartitionMode ) from torch.fx.experimental.fuser import fuse +from torch.fx.experimental import merge_matmul try: from torchvision.models import resnet18 @@ -844,6 +845,128 @@ def forward(self, a): for p_name in para_list: assert p_name in node.attrs_for_lowering + def test_merge_matmuls(self): + """ + A collection of test cases for torch.fx.experimental.merge_matmul, + a graph transformation that merges matrix multiplication operations. + """ + # Utility function for counting matmuls for test assertions. + def _count_matmuls(mod): + gm = torch.fx.symbolic_trace(mod) + + num_matmuls = 0 + for node in gm.graph.nodes: + if node.target == torch.matmul: + num_matmuls += 1 + + return num_matmuls + + # Simple test case in which there are two matmuls of the same size to merge. + class SimpleMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x, y): + a = torch.matmul(x, self.rhs) + b = torch.matmul(y, self.rhs) + return a + b + + # Initialize inputs. + a = torch.randn(3, 3) + b = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct SimpleMergeMatmulModule and call merge_matmul on it. + module = SimpleMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a, b) + after = opt_module(a, b) + before.allclose(after) + + # Basic graph structure check; original module should have 2 matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Test case in which there are multiple matmuls of different sizes to merge. + class FiveMergeMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, a, b, c, d, e): + s = torch.Tensor((0)) + matmuls = [] + + # For some reason using a list comprehension or for-loop for this + # doesn't work. + matmuls.append(torch.matmul(a, self.rhs)) + matmuls.append(torch.matmul(b, self.rhs)) + matmuls.append(torch.matmul(c, self.rhs)) + matmuls.append(torch.matmul(d, self.rhs)) + matmuls.append(torch.matmul(e, self.rhs)) + + for m in matmuls: + s += torch.sum(m) + + return s + + # Initialize inputs. + inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] + + # Initialize RHS. + rhs = torch.randn(5, 4) + + # Construct FiveMergeMatmulModule and call merge_matmul on it. + module = FiveMergeMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(*inputs) + after = opt_module(*inputs) + before.allclose(after) + + # Basic graph structure check; original module should have len(inputs) matmuls + # and optimized module should have 1. + self.assertEqual(_count_matmuls(module), len(inputs)) + self.assertEqual(_count_matmuls(opt_module), 1) + + # Simple test case in which two matmuls cannot be merged due to a data dependency between + # the LHS operands. + class UnmergeableMatmulModule(torch.nn.Module): + def __init__(self, rhs): + super().__init__() + self.rhs = rhs + + def forward(self, x): + a = torch.matmul(x, self.rhs) + a_abs = torch.abs(a) + b = torch.matmul(a_abs.transpose(1, 0), self.rhs) + return b + + # Initialize inputs. + a = torch.randn(3, 3) + + # Initialize RHS for matmuls. + rhs = torch.randn(3, 4) + + # Construct UnmergeableMatmulModule and call merge_matmul on it. + module = UnmergeableMatmulModule(rhs) + opt_module = merge_matmul.merge_matmul(module) + + # Numerical correctness check. + before = module(a) + after = opt_module(a) + before.allclose(after) + + # Basic graph structure check; the number of matrix multiplcations should not have changed. + self.assertEqual(_count_matmuls(module), 2) + self.assertEqual(_count_matmuls(opt_module), 2) if __name__ == "__main__": run_tests() diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py new file mode 100644 index 000000000000..a5bd24c84c12 --- /dev/null +++ b/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,215 @@ +import torch + +import itertools +import operator + +from typing import List + + +def get_first_dim(t: torch.Tensor) -> int: + """ + A free function primarily for use in the merge_matmul graph transformation below + that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape + is an attribute (and cannot be the target of a call_function node) and also helps save + a getitem op in the graph. + + Arguments: + t: The tensor to get the first dimension of. + + Returns: + The first dimension of t. + """ + return t.shape[0] + + +def legalize_graph(gm: torch.fx.GraphModule): + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + """ + # Build an adjacency list representation of node dependencies in the graph. This also + # serves as a list of nodes that still need to be inserted into the new, topologically + # sorted graph. + dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} + + # Construct a new graph that will contain all nodes in topologically sorted order. + new_graph = torch.fx.Graph() + value_remap = {} + + # Copy over all nodes with no dependencies. + for node, deps in dependencies.items(): + if not deps: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + # Remove the copied over nodes from the adjacency list. + for copied_node in value_remap.keys(): + del dependencies[copied_node] + + # While there are still nodes to insert into the new graph: + while dependencies: + copied_this_round = [] + + # Copy over all nodes whose dependencies already exist in the new graph. + for node, deps in dependencies.items(): + all_deps_copied = True + for dep in deps: + if dep not in value_remap: + all_deps_copied = False + + if all_deps_copied: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + copied_this_round.append(node) + + # Delete all nodes copied over in this iteration from dependencies. + for copied_node in copied_this_round: + del dependencies[copied_node] + + # Replace the old graph with the new, topologically sorted one. + gm.graph = new_graph + + +def may_depend_on(a: torch.fx.Node, b: torch.fx.Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[torch.fx.Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = torch.fx.symbolic_trace(in_mod) + + rhs_users = {} + lhs_users = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_sizes = [ + gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs + ] + merge_mm_split = gm.graph.call_function( + torch.split, (merge_mm, merge_mm_sizes), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint(in_mod) + return gm From def8aa5499065ae554d2c4d692f272f868c9b42b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 5 Jan 2021 19:36:56 -0800 Subject: [PATCH 06/20] Remove cpu half and dead code from multinomial (#50063) Summary: Based on ngimel's (Thank you!) feedback, cpu half was only accidental, so I'm removing it. This lets us ditch the old codepath for without replacement in favour of the new, better one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50063 Reviewed By: mruberry Differential Revision: D25772449 Pulled By: ngimel fbshipit-source-id: 608729c32237de4ee6d1acf7e316a6e878dac7f0 --- aten/src/ATen/native/Distributions.cpp | 17 +++---- aten/src/ATen/native/UnaryOps.h | 4 +- .../src/ATen/native/cpu/MultinomialKernel.cpp | 50 ++++++------------- .../src/ATen/native/cuda/MultinomialKernel.cu | 13 +++-- test/test_torch.py | 3 +- 5 files changed, 35 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index ef0c2e2509c1..413ea32acdef 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -118,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub); DEFINE_DISPATCH(bernoulli_scalar_stub); DEFINE_DISPATCH(cauchy_stub); DEFINE_DISPATCH(exponential_stub); -DEFINE_DISPATCH(multinomial_stub); +DEFINE_DISPATCH(multinomial_with_replacement_stub); DEFINE_DISPATCH(geometric_stub); DEFINE_DISPATCH(log_normal_stub); DEFINE_DISPATCH(uniform_stub); @@ -497,8 +497,10 @@ Tensor& multinomial_out( // Reference: // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 // Half is not supported on CPU. - if (!with_replacement && - !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) { + TORCH_CHECK( + !(self.device().is_cpu() && self.scalar_type() == ScalarType::Half), + "multinomial is not implemented for half on CPU"); + if (!with_replacement) { // Sanity checks on `self`. auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); TORCH_CHECK( @@ -537,13 +539,8 @@ Tensor& multinomial_out( return result; } - multinomial_stub( - result.device().type(), - result, - self, - n_sample, - with_replacement, - gen); + multinomial_with_replacement_stub( + result.device().type(), result, self, n_sample, gen); return result; } diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index f732cb9a0141..d92864e6fb2a 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -77,7 +77,9 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_full DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional), random_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub); -DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional), multinomial_stub); +DECLARE_DISPATCH( + void (*)(Tensor&, const Tensor&, int64_t, c10::optional), + multinomial_with_replacement_stub); DECLARE_DISPATCH( void (*)( TensorIterator&, diff --git a/aten/src/ATen/native/cpu/MultinomialKernel.cpp b/aten/src/ATen/native/cpu/MultinomialKernel.cpp index 1f4a52084962..62f1d7b879ac 100644 --- a/aten/src/ATen/native/cpu/MultinomialKernel.cpp +++ b/aten/src/ATen/native/cpu/MultinomialKernel.cpp @@ -11,8 +11,12 @@ namespace at { namespace native { namespace { -template -void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +template +void multinomial_with_replacement_apply( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); @@ -61,8 +65,6 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl } TORCH_CHECK(sum > 0, "invalid multinomial distribution (sum of probabilities <= 0)"); - TORCH_CHECK(with_replacement || (n_categories - n_zeros >= n_sample), - "invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)"); /* normalize cumulative probability distribution so that last val is 1 i.e. doesn't assume original self row sums to one */ @@ -100,45 +102,23 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl /* store in result tensor (will be incremented for lua compat by wrapper) */ result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx; - - /* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */ - if (!with_replacement && j < n_sample - 1) { - /* update cumulative distribution so that sample cannot be drawn again */ - scalar_t diff; - scalar_t new_val = 0; - scalar_t sum; - - if (sample_idx != 0) { - new_val = cum_dist_ptr[(sample_idx - 1) * cum_dist_stride_0]; - } - /* marginal cumulative mass (i.e. original probability) of sample */ - diff = cum_dist_ptr[sample_idx * cum_dist_stride_0] - new_val; - /* new sum of marginals is not one anymore... */ - sum = 1.0 - diff; - for (int64_t k = 0; k < n_categories; k++) { - new_val = cum_dist_ptr[k * cum_dist_stride_0]; - if (k >= sample_idx) { - /* remove sampled probability mass from later cumulative probabilities */ - new_val -= diff; - } - /* make total marginals sum to one */ - new_val /= sum; - cum_dist_ptr[k * cum_dist_stride_0] = new_val; - } - } } } } -static void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional gen) { +static void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional gen) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "multinomial", [&] { - multinomial_apply(result, self, n_sample, with_replacement, gen); + multinomial_with_replacement_apply(result, self, n_sample, gen); }); } - } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); } } diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index 3d59617903b4..cc74848b632a 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -300,7 +300,11 @@ sampleMultinomialOnce(int64_t* dest, } } -void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional generator) { +void multinomial_with_replacement_kernel_impl( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { auto gen = get_generator_or_default(generator, cuda::detail::getDefaultCUDAGenerator()); int inputSize = self.dim(); @@ -371,7 +375,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n PhiloxCudaState rng_engine_inputs; - if (with_replacement) { // Binary search is warp divergent (so effectively we're running // with just a single thread), but for better utilization, // we need each block to have at least 4 warps. @@ -402,7 +405,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n prefixSum.data_ptr(), normDist.data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); - } } }); @@ -412,6 +414,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n } } -REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); - +REGISTER_DISPATCH( + multinomial_with_replacement_stub, + &multinomial_with_replacement_kernel_impl); }} diff --git a/test/test_torch.py b/test/test_torch.py index 1f85ed2fff54..72fa853e2e7c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5689,7 +5689,8 @@ def test_storage_multigpu(self, devices): x = torch.tensor([], device=device) self.assertEqual(x.dtype, x.storage().dtype) - @dtypes(torch.float, torch.double, torch.half) + @dtypesIfCUDA(torch.float, torch.double, torch.half) + @dtypes(torch.float, torch.double) def test_multinomial(self, device, dtype): def make_prob_dist(shape, is_contiguous): if is_contiguous: From 05358332b3c2cb04dfaa9152e69c9e8401a89c53 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Tue, 5 Jan 2021 19:49:08 -0800 Subject: [PATCH 07/20] Fix mypy typing check for test_dataset (#50108) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50108 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D25789184 Pulled By: ejguan fbshipit-source-id: 0eeeeeda62533e7137d56f313b7bf11406b32611 --- test/test_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_dataset.py b/test/test_dataset.py index 2caa1a248435..a72b87cca555 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -90,7 +90,7 @@ def _collate_fn(batch): y = next(ds_iter) self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float)) - collate_ds_nolen = CollateIterableDataset(ds_nolen) + collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore with self.assertRaises(NotImplementedError): len(collate_ds_nolen) ds_nolen_iter = iter(ds_nolen) @@ -144,7 +144,7 @@ def test_sampler_dataset(self): arrs = range(10) ds = IterDatasetWithLen(arrs) # Default SequentialSampler - sampled_ds = SamplerIterableDataset(ds) + sampled_ds = SamplerIterableDataset(ds) # type: ignore self.assertEqual(len(sampled_ds), 10) i = 0 for x in sampled_ds: @@ -152,7 +152,7 @@ def test_sampler_dataset(self): i += 1 # RandomSampler - random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) + random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore # Requires `__len__` to build SamplerDataset ds_nolen = IterDatasetWithoutLen(arrs) From f6f0fde8411882af712d53fc7f7c0bdffeb47683 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 5 Jan 2021 20:25:56 -0800 Subject: [PATCH 08/20] [reland][quant][graphmode][fx] Standalone module support {input/output}_quantized_idxs (#49754) (#50058) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50058 This PR adds the support for {input/output}_quantized_idxs for standalone module. if input_quantized_idxs = [] and output_quantized_idxs = [], the standalone module will be expecting float input and produce float output, and will quantize the input and dequantize output internally if input_quantized_idxs = [0] and otuput_qiuantized_idxs = [0], the standalone module will be expecting quantized input and produce quantized output, the input will be quantized in the parent module, and output will be dequantized in the parent module as well, this is similar to current quantized modules like nn.quantized.Conv2d For more details, please see the test case Test Plan: python test/test_quantization.py TestQuantizeFx.test_standalone_module Imported from OSS Imported from OSS Reviewed By: vkuzo Differential Revision: D25768910 fbshipit-source-id: 96c21a3456cf192c8f1400afa4e86273ee69197b --- test/quantization/test_quantize_fx.py | 126 ++++++++++++---- torch/quantization/fx/fuse.py | 11 +- torch/quantization/fx/fusion_patterns.py | 23 ++- torch/quantization/fx/observed_module.py | 10 +- .../quantization/fx/quantization_patterns.py | 4 +- torch/quantization/fx/quantize.py | 138 +++++++++++++----- torch/quantization/fx/utils.py | 6 +- torch/quantization/quantize_fx.py | 23 ++- 8 files changed, 253 insertions(+), 88 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index d014bd31f02e..7965b3cc88a4 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -573,7 +573,16 @@ def forward(self, x): m = convert_fx(m) m(tensor_input) - def test_standalone_module(self): + def _test_standalone_module( + self, + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check): + """ Test standalone module with different quantized input/quantized output + configurations + """ class StandaloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -613,45 +622,32 @@ def forward(self, x): original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) - qconfig_dict = {"": default_qconfig} - config_name = {"standalone_module_name": [("standalone", None, None)]} - config_class = {"standalone_module_class": [(StandaloneModule, None, None)]} - for prepare_config in [config_name, config_class]: + for is_name in [True, False]: + if is_name: + prepare_config = { + "standalone_module_name": [("standalone", None, interface_config)] + } + else: + prepare_config = { + "standalone_module_class": [(StandaloneModule, None, interface_config)] + } + original_m_copy = copy.deepcopy(original_m) original_ref_m_copy = copy.deepcopy(original_ref_m) + + qconfig_dict = {"": default_qconfig} # check prepared model m = prepare_fx( original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config) # calibration m(data) - # input and output of first conv, observer for standalone module - # will be inserted in the standalone module itself - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - # for input and output of conv in the standalone module - count_check = { - ns.call_module(torch.quantization.MinMaxObserver): 2 - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) # check converted/quantized model m = convert_fx(m) - count_check = { - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d) : 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) - count_check = { - # standalone module will take float as input and output - # so we'll see quantize and dequantize in the modoule - ns.call_function(torch.quantize_per_tensor) : 1, - ns.call_module(nnq.Conv2d): 1, - ns.call_method('dequantize') : 1, - } - self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=count_check) + self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) + self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) res = m(data) # quantize the reference model @@ -661,6 +657,76 @@ def forward(self, x): ref_res = ref_m(data) self.assertEqual(res, ref_res) + def test_standalone_module_float_interface(self): + float_interface_config = { + "input_quantized_idxs": [], # float input + "output_quantized_idxs": [], # float output + } + interface_config = float_interface_config + # input and output of first conv, observer for standalone module + # will be inserted in the standalone module itself + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for input and output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + convert_count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # standalone module will take float as input and output + # so we'll see quantize and dequantize in the modoule + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d): 1, + ns.call_method("dequantize") : 1, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + + def test_standalone_module_quantized_interface(self): + quantized_interface_config = { + "input_quantized_idxs": [0], # quantized input + "output_quantized_idxs": [0], # quantized output + } + interface_config = quantized_interface_config + # observer for input and output of first conv + prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + # for output of conv in the standalone module + standalone_prepare_count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + convert_count_check = { + # quantizing input for conv + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + # dequantizing output of standalone module + ns.call_method("dequantize") : 1, + } + standalone_convert_count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + ns.call_module(nnq.Conv2d): 1, + # dequantization for output happens in parent module + ns.call_method("dequantize") : 0, + } + self._test_standalone_module( + interface_config, + prepare_count_check, + standalone_prepare_count_check, + convert_count_check, + standalone_convert_count_check) + @skipIfNoFBGEMM def test_qconfig_none(self): class M(torch.nn.Module): diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 5aabbd66c4b1..59e3851dcd57 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -21,7 +21,7 @@ from .quantization_types import Pattern -from typing import Callable, Tuple, Optional +from typing import Callable, Tuple class Fuser: @@ -59,11 +59,12 @@ def load_arg(a): model = GraphModule(input_root, self.fused_graph) return model - def _find_matches(self, root: GraphModule, graph: Graph, - patterns: Dict[Pattern, Callable] - ) -> Dict[str, Tuple[Node, Optional[Any]]]: + def _find_matches( + self, root: GraphModule, graph: Graph, + patterns: Dict[Pattern, Callable] + ) -> Dict[str, Tuple[Node, FuseHandler]]: modules = dict(root.named_modules()) - match_map = {} # node name -> (root_node, match_value?) + match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value) def apply_match(pattern, node, match): if isinstance(pattern, tuple): diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index b7af6008b3f3..1749484fccec 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -6,12 +6,25 @@ from .utils import _parent_name from .quantization_types import QuantizerCls from ..fuser_method_mappings import get_fuser_method +from abc import ABC, abstractmethod from typing import Any, Callable, Dict # --------------------- -# Fusion Patterns +# Fusion Pattern Registrations # --------------------- +# Base Pattern Handler +class FuseHandler(ABC): + """ Base handler class for the fusion patterns + """ + def __init__(self, quantizer: QuantizerCls, node: Node): + pass + + @abstractmethod + def fuse(self, quantizer: QuantizerCls, load_arg: Callable, + fuse_custom_config_dict: Dict[str, Any] = None) -> Node: + pass + @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d)) @@ -27,9 +40,9 @@ @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) @register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d))) -class ConvBNReLUFusion(): +class ConvBNReLUFusion(FuseHandler): def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__() + super().__init__(quantizer, node) self.relu_node = None self.bn_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ @@ -94,9 +107,9 @@ def fuse(self, quantizer: QuantizerCls, load_arg: Callable, @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d)) @register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d)) @register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d)) -class ModuleReLUFusion(): +class ModuleReLUFusion(FuseHandler): def __init__(self, quantizer: QuantizerCls, node: Node): - super().__init__() + super().__init__(quantizer, node) self.relu_node = node assert isinstance(node.args[0], Node) node = node.args[0] diff --git a/torch/quantization/fx/observed_module.py b/torch/quantization/fx/observed_module.py index a95bc184fa10..808a3b36fb4a 100644 --- a/torch/quantization/fx/observed_module.py +++ b/torch/quantization/fx/observed_module.py @@ -2,11 +2,11 @@ import copy from torch.fx import GraphModule # type: ignore from torch.fx.graph import Graph -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List class ObservedGraphModule(GraphModule): - def get_preserved_attr_names(self): + def get_preserved_attr_names(self) -> List[str]: return ['_activation_post_process_map', '_patterns', '_qconfig_map', @@ -35,6 +35,12 @@ def is_observed_module(module: Any) -> bool: return isinstance(module, ObservedGraphModule) class ObservedStandaloneGraphModule(ObservedGraphModule): + def get_preserved_attr_names(self) -> List[str] : + return super().get_preserved_attr_names() + [ + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs" + ] + def __deepcopy__(self, memo): fake_mod = torch.nn.Module() fake_mod.__dict__ = copy.deepcopy(self.__dict__) diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 46fbed74bdc8..fb5bef0bd0ad 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -755,10 +755,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, qconfig = quantizer.qconfig_map[node.name] convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore observed_standalone_module = quantizer.modules[node.target] + input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist() quantized_standalone_module = convert(observed_standalone_module, debug=debug) parent_name, name = _parent_name(node.target) # update the modules dict setattr(quantizer.modules[parent_name], name, quantized_standalone_module) quantizer.modules[node.target] = quantized_standalone_module - # standalone module takes float input - return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False)) + return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index af9496a66a63..318295270b61 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -102,14 +102,15 @@ def insert_observer( 'call_module', observer_name, (load_arg(node),), {}) observed_node_names_set.add(node.name) -def insert_observer_for_special_module( +def maybe_insert_observer_for_special_module( quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module], - prepare_custom_config_dict: Any, qconfig: Any, node: Node): + prepare_custom_config_dict: Any, qconfig: Any, node: Node) -> Optional[List[int]]: """ Insert observer for custom module and standalone module Returns: standalone_module_input_idxs: the indexs for inputs that needs to be observed by parent module """ assert modules is not None + standalone_module_input_idxs = None if isinstance(quantize_handler, CustomModuleQuantizeHandler): custom_module = modules[node.target] # type: ignore custom_module_class_mapping = prepare_custom_config_dict.get( @@ -129,19 +130,22 @@ def insert_observer_for_special_module( class_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_class_configs} name_config_map = {x[0]: (x[1], x[2]) for x in standalone_module_name_configs} config = class_config_map.get(type(standalone_module), (None, None)) - config = name_config_map.get(node.target, (None, None)) - standalone_module_qconfig_dict = {"": qconfig} if config[0] is None else config[0] - standalone_prepare_config_dict = {} if config[1] is None else config[1] + config = name_config_map.get(node.target, config) + sm_qconfig_dict = {"": qconfig} if config[0] is None else config[0] + sm_prepare_config_dict = {} if config[1] is None else config[1] prepare = \ torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore observed_standalone_module = \ - prepare(standalone_module, standalone_module_qconfig_dict, standalone_prepare_config_dict) + prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict) + standalone_module_input_idxs = observed_standalone_module.\ + _standalone_module_input_quantized_idxs.int().tolist() observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(modules[parent_name], name, observed_standalone_module) modules[node.target] = observed_standalone_module # type: ignore + return standalone_module_input_idxs def insert_observer_for_output_of_the_node( node: Node, @@ -155,7 +159,8 @@ def insert_observer_for_output_of_the_node( observed_graph: Graph, load_arg: Callable, observed_node_names_set: Set[str], - matched_nodes: Optional[List[Node]]): + matched_nodes: Optional[List[Node]], + standalone_module_input_idxs: Optional[List[int]]): """ Insert observer/fake_quantize module for output of the observed module if needed """ @@ -215,8 +220,13 @@ def input_is_observed(arg): observed_node_names_set.add(node.name) elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler): - # output is observed in the standalone module - return + assert node.op == "call_module" + assert isinstance(node.target, str) + sm_out_qidxs = modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + output_is_quantized = 0 in sm_out_qidxs + + if output_is_quantized: + observed_node_names_set.add(node.name) elif (quantize_handler.all_node_args and input_output_observed(quantize_handler)): # observer for outputs @@ -226,6 +236,16 @@ def input_is_observed(arg): activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set) + # insert observer for input of standalone module + if standalone_module_input_idxs is not None: + for idx in standalone_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: # type: ignore + new_observer = qconfig.activation() + insert_observer( + node, new_observer, model, + activation_post_process_map, env, observed_graph, + load_arg, observed_node_names_set) + def insert_observer_for_input_arg_of_observed_node( node: Node, observed_node_names_set: Set[str], quants: Dict[str, Tuple[DefaultQuantizeHandler, Callable]], @@ -373,10 +393,19 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - When we are preparing a standalone module: - both input and output are observed in prepared standalone module + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module Returns: model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ if prepare_custom_config_dict is None: prepare_custom_config_dict = {} @@ -430,8 +459,6 @@ def _prepare(self, model: GraphModule, qconfig_dict: Any, def load_arg(a): return map_arg(a, lambda node: env[node.name]) - # indexes for the inputs that needs to be observed - standalone_module_observed_input_idxs: List[int] = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': @@ -487,14 +514,15 @@ def load_arg(a): # parent if qconfig is not None: assert obj is not None - insert_observer_for_special_module( - obj, self.modules, prepare_custom_config_dict, qconfig, - node) + standalone_module_input_idxs = \ + maybe_insert_observer_for_special_module( + obj, self.modules, prepare_custom_config_dict, qconfig, + node) insert_observer_for_output_of_the_node( node, obj, qconfig, self.modules, model, pattern, self.activation_post_process_map, env, observed_graph, load_arg, observed_node_names_set, - matched_nodes) + matched_nodes, standalone_module_input_idxs) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -516,6 +544,21 @@ def load_arg(a): model = GraphModule(model, observed_graph) self.save_state(model) model = mark_observed_module(model) + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), \ + "standalone module only supports returning simple value currently"\ + "(not tuple, dict etc.)" + # indicator for whether output is observed or not. + # This used for correctly quantize standalone modules + output_is_observed = \ + result_node.args[0].name in observed_node_names_set + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + model._standalone_module_input_quantized_idxs = \ + torch.Tensor(input_quantized_idxs) + model._standalone_module_output_quantized_idxs = torch.Tensor(output_quantized_idxs) return model def save_state(self, observed: GraphModule) -> None: @@ -569,8 +612,10 @@ def _convert(self, model: GraphModule, debug: bool = False, """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Returns a quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ if convert_custom_config_dict is None: convert_custom_config_dict = {} @@ -627,36 +672,50 @@ def load_x(n: Node) -> Node: else: return env[n.name] - def load_arg(quantized: Optional[Union[List[Any], bool, Tuple[Any, ...]]] + def load_arg(quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] ) -> Callable[[Node], Argument]: """ Input: quantized, which can be None, list, boolean or tuple - - if quantized is a list or tuple, then arg should be a list and - the args with corresponding indexes will be quantized - - if quantized is a boolean, then all args will be - quantized/not quantized - if quantized is None, then we'll load the node as long as it exists + - if quantized is a boolean, then all args will be + quantized/not quantized + - if quantized is an empty list or tuple, then it is the same as load_arg(quantized=False) + - if quantized is a list or tuple, then arg should be a list and + the args with corresponding indexes will be quantized Output: fn which takes arg_or_args, and loads them from the corresponding environment depending on the value of quantized. """ assert quantized is None or \ isinstance(quantized, (tuple, list, bool)), type(quantized) + if isinstance(quantized, (tuple, list)) and len(quantized) == 0: + # empty tuple or list means nothing is quantized + quantized = False def load_arg_impl(arg_or_args): - if quantized is None: + # we'll update the format of `quantized` + # to better match arg_or_args + updated_quantized: Optional[Union[List[int], bool, Tuple[int, ...]]] = quantized + + if isinstance(quantized, (tuple, list)) and \ + len(quantized) == 1 and isinstance(arg_or_args, Node): + # when argument is one Node instead of tuple, we just need to check + # 0 is in the quantized list + updated_quantized = 0 in quantized + + if updated_quantized is None: return map_arg(arg_or_args, load_x) - if isinstance(quantized, bool): + if isinstance(updated_quantized, bool): return map_arg( arg_or_args, - load_quantized if quantized else load_non_quantized) - elif isinstance(quantized, (tuple, list)): + load_quantized if updated_quantized else load_non_quantized) + elif isinstance(updated_quantized, (tuple, list)): assert isinstance(arg_or_args, (tuple, list)), arg_or_args loaded_args = [] # for now, we only support quantizing positional arguments for i, a in enumerate(arg_or_args): - if i in quantized: + if i in updated_quantized: loaded_args.append(map_arg(a, load_quantized)) else: loaded_args.append(map_arg(a, load_non_quantized)) @@ -690,10 +749,10 @@ def node_arg_is_quantized(node_arg: Any) -> bool: def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool: """ Check if output node is quantized or not """ assert self.modules is not None - # by default the output is expected to be quantized + # by default the output for a quantizable node is expected to be quantized quantized = True - # Need to get correct quantized/non-quantized state for the output + # Need to get correct quantized/non-quantized state forn the output # of CopyNode if type(obj) in [ CopyNode, @@ -750,7 +809,7 @@ def insert_quantize_node(node: Node) -> None: "output_quantized_idxs", []) for node in model.graph.nodes: - if node.op == 'output': + if node.op == "output": cur_output_node_idx = output_node_seen_cnt output_node_seen_cnt += 1 if cur_output_node_idx in output_quantized_idxs: @@ -775,12 +834,19 @@ def insert_quantize_node(node: Node) -> None: quantized = False else: assert obj is not None + # We will get whether the output is quantized or not before + # convert for standalone module and after convert + # for non-standalone module, since _standalone_module_output_quantized_idxs + # is only available in observed standalone module + if is_observed_standalone_module_node: + out_quant_idxs = self.modules[node.target]._standalone_module_output_quantized_idxs.tolist() # type: ignore + assert len(out_quant_idxs) <= 1, "Currently standalone only support one output" + quantized = 0 in out_quant_idxs + result = obj.convert( self, node, load_arg, debug=debug, convert_custom_config_dict=convert_custom_config_dict) - if is_observed_standalone_module_node: - quantized = False - else: + if not is_observed_standalone_module_node: quantized = is_output_quantized(node, obj) if quantized: @@ -929,7 +995,7 @@ def _find_matches( standalone_module_names = [] match_map: Dict[str, MatchResult] = {} - all_matched = set() + all_matched : Set[str] = set() def record_match(pattern, node, matched): if isinstance(pattern, tuple): diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index c1f849803342..8285e204b1ed 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -9,7 +9,7 @@ Node, ) -from typing import Callable, Optional, List, Dict, Any +from typing import Callable, Optional, List, Dict, Any, Set # turn foo.bar -> ['foo', 'bar'] def _parent_name(target): @@ -140,7 +140,7 @@ def get_next_qparams_idx(module, qparams): inputs.append(graph.create_node('get_attr', qparam_full_path)) return graph.create_node('call_function', quantize_op, tuple(inputs), {}) -def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): +def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key) -> List[Any]: r""" Get all the unique custom module keys in the custom config dict e.g. Input: @@ -163,7 +163,7 @@ def get_custom_module_class_keys(custom_config_dict, custom_config_dict_key): [CustomModule1, CustomModule2, CustomModule3] """ # using set to dedup - float_custom_module_classes = set() + float_custom_module_classes : Set[Any] = set() custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) for quant_mode in ["static", "dynamic", "weight_only"]: quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index cba104b8f783..89ba877ffe78 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -107,8 +107,20 @@ def _prepare_standalone_module_fx( standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. - Both input and output of the module are observed in the - standalone module. + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + model(GraphModule): prepared standalone module + attributes: + _standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + _standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module """ return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, is_standalone_module=True) @@ -378,8 +390,9 @@ def _convert_standalone_module_fx( r""" [Internal use only] Convert a model produced by :func:`~torch.quantization.prepare_standalone_module_fx` and convert it to a quantized model - Return: - A quantized standalone module which accepts float input - and produces float output. + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config_dict, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details """ return _convert_fx(graph_module, debug, convert_custom_config_dict, is_standalone_module=True) From 57d489e43a5b915cdb4bd8a16112ac68eb792581 Mon Sep 17 00:00:00 2001 From: Michael Carilli Date: Tue, 5 Jan 2021 22:34:19 -0800 Subject: [PATCH 09/20] Fix for possible RNG offset calculation bug in cuda vectorized dropout with VEC=2 (#50110) Summary: The [offset calculation](https://github.com/pytorch/pytorch/blob/e3c56ddde67ca1a49159ffa886d889b6e65c7033/aten/src/ATen/native/cuda/Dropout.cu#L328) (which gives an estimated ceiling on the most 32-bit values in the philox sequence any thread in the launch will use) uses the hardcoded UNROLL value of 4, and assumes the hungriest threads can use every value (.x, .y, .z, and .w) their curand_uniform4 calls provide. However, the way fused_dropout_kernel_vec is currently written, that assumption isn't true in the VEC=2 case: Each iteration of the `grid x VEC` stride loop, each thread calls curand_uniform4 once, uses rand.x and rand.y, and discards rand.z and rand.w. This means (I _think_) curand_uniform4 may be called twice as many times per thread in the VEC=2 case as for the VEC=4 case or the fully unrolled code path, which means the offset calculation (which is a good estimate for the latter two cases) is probably wrong for the `fused_dropout_kernel_vec<..., /*VEC=*/2>` code path. The present PR inserts some value-reuse in fused_dropout_kernel_vec to align the number of times curand_uniform4 is called for launches with the same totalElements in the VEC=2 and VEC=4 cases. The diff should - make the offset calculation valid for all code paths - provide a very small perf boost by reducing the number of curand_uniform4 calls in the VEC=2 path - ~~make results bitwise accurate for all code paths~~ nvm, tensor elements are assigned to threads differently in the unrolled, VEC 2 and VEC 4 cases, so we're screwed here no matter what. ngimel what do you think? Pull Request resolved: https://github.com/pytorch/pytorch/pull/50110 Reviewed By: smessmer Differential Revision: D25790121 Pulled By: ngimel fbshipit-source-id: f8f533ad997268c6f323cf4d225de547144247a8 --- aten/src/ATen/native/cuda/Dropout.cu | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index 67adbaabbb84..c3e456d97056 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -57,6 +57,12 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, accscalar_t pinv = accscalar_t(1)/p; + // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements + // in the vec=2 and vec=4 cases. + bool gridxvec_loop_state = 0; + + float4 rand; + // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time for (IndexType linearIndex = idx * VEC; linearIndex < totalElements; @@ -69,12 +75,21 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4) // sets of rand. - float4 rand = curand_uniform4(&state); + if ((VEC == 4) || (gridxvec_loop_state == 0)) { + rand = curand_uniform4(&state); + } else { + // sets up the last two values we generated last iteration to be used this iteration. + rand.x = rand.z; + rand.y = rand.w; + gridxvec_loop_state ^= 1; + } rand.x = rand.x < p; rand.y = rand.y < p; - rand.z = rand.z < p; - rand.w = rand.w < p; + if (VEC == 4) { + rand.z = rand.z < p; + rand.w = rand.w < p; + } // Note: We explicitly check for is_contiguous() before launching the vectorized kernel // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other) From 282552dde2415d3cb3e4b1f0b18356810cf1ecd4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 5 Jan 2021 22:57:12 -0800 Subject: [PATCH 10/20] [PyTorch] Reapply D25546409: Use .sizes() isntead of .size() in cat_serial_kernel_impl (#49762) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49762 This was reverted because it landed in a stack together with D25542799 (https://github.com/pytorch/pytorch/commit/9ce1df079f6ea90dd4b7f9aa12a1a78d51a8b204), which really was broken. ghstack-source-id: 119326870 Test Plan: CI Reviewed By: maratsubkhankulov Differential Revision: D25685905 fbshipit-source-id: f4ec9e114993f988d4af380677331c72dfe41c44 --- aten/src/ATen/native/cpu/CatKernel.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cpu/CatKernel.cpp b/aten/src/ATen/native/cpu/CatKernel.cpp index 299850407da3..f86adb8e6318 100644 --- a/aten/src/ATen/native/cpu/CatKernel.cpp +++ b/aten/src/ATen/native/cpu/CatKernel.cpp @@ -15,18 +15,20 @@ struct InputMeta { InputMeta(const Tensor& t, int64_t dim, int64_t inner) : data_ptr(t.data_ptr()) - , inner_size(t.size(dim) * inner) {} + , inner_size(t.sizes()[dim] * inner) {} }; template void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) { - int64_t outer = result.numel() / (result.size(dim) * result.stride(dim)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl"); + int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]); scalar_t* result_data = result.data_ptr(); int64_t ninputs = tensors.size(); std::vector inputs; inputs.reserve(ninputs); for (auto const &tensor : tensors) { - inputs.emplace_back(tensor, dim, result.stride(dim)); + inputs.emplace_back(tensor, dim, result.strides()[dim]); } using Vec = vec256::Vec256; From ad7d208ba5f2c5614679a7999918b75ae74530e9 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 5 Jan 2021 23:20:42 -0800 Subject: [PATCH 11/20] Revert D25239967: [fx] Add matrix multiplication fusion pass Test Plan: revert-hammer Differential Revision: D25239967 (https://github.com/pytorch/pytorch/commit/9b7f3fa146d350628b295ab9b794d64173f17da1) Original commit changeset: fb99ad25b7d8 fbshipit-source-id: 370167b5ade8bf2b3a6cccdf4290ea07b8347c79 --- test/test_fx_experimental.py | 123 --------------- torch/fx/experimental/merge_matmul.py | 215 -------------------------- 2 files changed, 338 deletions(-) delete mode 100644 torch/fx/experimental/merge_matmul.py diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index ac71d6037591..6e9c877b8de6 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -21,7 +21,6 @@ PartitionMode ) from torch.fx.experimental.fuser import fuse -from torch.fx.experimental import merge_matmul try: from torchvision.models import resnet18 @@ -845,128 +844,6 @@ def forward(self, a): for p_name in para_list: assert p_name in node.attrs_for_lowering - def test_merge_matmuls(self): - """ - A collection of test cases for torch.fx.experimental.merge_matmul, - a graph transformation that merges matrix multiplication operations. - """ - # Utility function for counting matmuls for test assertions. - def _count_matmuls(mod): - gm = torch.fx.symbolic_trace(mod) - - num_matmuls = 0 - for node in gm.graph.nodes: - if node.target == torch.matmul: - num_matmuls += 1 - - return num_matmuls - - # Simple test case in which there are two matmuls of the same size to merge. - class SimpleMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x, y): - a = torch.matmul(x, self.rhs) - b = torch.matmul(y, self.rhs) - return a + b - - # Initialize inputs. - a = torch.randn(3, 3) - b = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct SimpleMergeMatmulModule and call merge_matmul on it. - module = SimpleMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a, b) - after = opt_module(a, b) - before.allclose(after) - - # Basic graph structure check; original module should have 2 matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Test case in which there are multiple matmuls of different sizes to merge. - class FiveMergeMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, a, b, c, d, e): - s = torch.Tensor((0)) - matmuls = [] - - # For some reason using a list comprehension or for-loop for this - # doesn't work. - matmuls.append(torch.matmul(a, self.rhs)) - matmuls.append(torch.matmul(b, self.rhs)) - matmuls.append(torch.matmul(c, self.rhs)) - matmuls.append(torch.matmul(d, self.rhs)) - matmuls.append(torch.matmul(e, self.rhs)) - - for m in matmuls: - s += torch.sum(m) - - return s - - # Initialize inputs. - inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] - - # Initialize RHS. - rhs = torch.randn(5, 4) - - # Construct FiveMergeMatmulModule and call merge_matmul on it. - module = FiveMergeMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(*inputs) - after = opt_module(*inputs) - before.allclose(after) - - # Basic graph structure check; original module should have len(inputs) matmuls - # and optimized module should have 1. - self.assertEqual(_count_matmuls(module), len(inputs)) - self.assertEqual(_count_matmuls(opt_module), 1) - - # Simple test case in which two matmuls cannot be merged due to a data dependency between - # the LHS operands. - class UnmergeableMatmulModule(torch.nn.Module): - def __init__(self, rhs): - super().__init__() - self.rhs = rhs - - def forward(self, x): - a = torch.matmul(x, self.rhs) - a_abs = torch.abs(a) - b = torch.matmul(a_abs.transpose(1, 0), self.rhs) - return b - - # Initialize inputs. - a = torch.randn(3, 3) - - # Initialize RHS for matmuls. - rhs = torch.randn(3, 4) - - # Construct UnmergeableMatmulModule and call merge_matmul on it. - module = UnmergeableMatmulModule(rhs) - opt_module = merge_matmul.merge_matmul(module) - - # Numerical correctness check. - before = module(a) - after = opt_module(a) - before.allclose(after) - - # Basic graph structure check; the number of matrix multiplcations should not have changed. - self.assertEqual(_count_matmuls(module), 2) - self.assertEqual(_count_matmuls(opt_module), 2) if __name__ == "__main__": run_tests() diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py deleted file mode 100644 index a5bd24c84c12..000000000000 --- a/torch/fx/experimental/merge_matmul.py +++ /dev/null @@ -1,215 +0,0 @@ -import torch - -import itertools -import operator - -from typing import List - - -def get_first_dim(t: torch.Tensor) -> int: - """ - A free function primarily for use in the merge_matmul graph transformation below - that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape - is an attribute (and cannot be the target of a call_function node) and also helps save - a getitem op in the graph. - - Arguments: - t: The tensor to get the first dimension of. - - Returns: - The first dimension of t. - """ - return t.shape[0] - - -def legalize_graph(gm: torch.fx.GraphModule): - """ - Replace the graph of the given GraphModule with one that contains the same nodes as the - original, but in topologically sorted order. - - This is used by the merge_matmul transformation below, which disturbs the topologically sorted - order of its input GraphModule, so that this order is restored before further transformation. - - Arguments: - gm: The graph module to topologically sort. It is modified in-place. - - """ - # Build an adjacency list representation of node dependencies in the graph. This also - # serves as a list of nodes that still need to be inserted into the new, topologically - # sorted graph. - dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes} - - # Construct a new graph that will contain all nodes in topologically sorted order. - new_graph = torch.fx.Graph() - value_remap = {} - - # Copy over all nodes with no dependencies. - for node, deps in dependencies.items(): - if not deps: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - - # Remove the copied over nodes from the adjacency list. - for copied_node in value_remap.keys(): - del dependencies[copied_node] - - # While there are still nodes to insert into the new graph: - while dependencies: - copied_this_round = [] - - # Copy over all nodes whose dependencies already exist in the new graph. - for node, deps in dependencies.items(): - all_deps_copied = True - for dep in deps: - if dep not in value_remap: - all_deps_copied = False - - if all_deps_copied: - value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) - copied_this_round.append(node) - - # Delete all nodes copied over in this iteration from dependencies. - for copied_node in copied_this_round: - del dependencies[copied_node] - - # Replace the old graph with the new, topologically sorted one. - gm.graph = new_graph - - -def may_depend_on(a: torch.fx.Node, b: torch.fx.Node, search_depth: int = 6): - """ - Determine if one node depends on another in a torch.fx.Graph. - - Arguments: - a: The node that may have a dependency on b. - b: The node that a may have a dependency on. - search_depth: In the case of an indirect dependency, this function - searches upto this many nodes away in search of a - data dependency. If none is found, the function - makes the conservative assumption that there is a - dependency. - - Returns: - True if a may depend on b, False if it definitely does not. - """ - # Equivalence is defined as dependence. - if a == b: - return True - - # If a has no inputs, it cannot depend on b. - if len(a.all_input_nodes) == 0: - return False - - # If the search depth has been exhausted and no conclusion has been - # reached, assume that there is a data dependency. - if search_depth == 0: - return True - - # Recursively check all inputs of a. - for inp in a.all_input_nodes: - if may_depend_on(inp, b, search_depth - 1): - return True - - return False - - -def are_nodes_independent(nodes: List[torch.fx.Node]): - """ - Check if all of the given nodes are pairwise-data independent. - - Arguments: - nodes: The nodes to check for data dependencies. - - Returns: - True if any pair in nodes has a data dependency. - """ - # For each pair in nodes: - for i, j in itertools.combinations(nodes, 2): - if may_depend_on(i, j) or may_depend_on(j, i): - return False - - return True - - -def merge_matmul(in_mod: torch.nn.Module): - """ - A graph transformation that merges matrix multiplication operations that share the same right-hand - side operand into one large matrix multiplication. - ____ _________ _________ - ---- | | | | M| A * C | - M| A | T| B | * K| C | = |---------| - ---- , | | | | T| B * C | - K ---- --------- --------- - K R R - """ - gm = torch.fx.symbolic_trace(in_mod) - - rhs_users = {} - lhs_users = {} - - # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to - # the matmul of which they are the LHS/RHS. - for node in gm.graph.nodes: - if node.op != "call_function" or node.target is not torch.matmul: - continue - - lhs, rhs = node.args - - # TODO: Properly handle aliasing caused by get_attr. For now, - # use the attribute name as the operand if the node is a - # get_attr. - lhs = lhs.target if lhs.op == "get_attr" else lhs - rhs = rhs.target if rhs.op == "get_attr" else rhs - - lhs_users.setdefault(lhs, []).append(node) - rhs_users.setdefault(rhs, []).append(node) - - for rhs, mms in rhs_users.items(): - # There must be at least matmuls for a merge to make sense. - if len(mms) < 2: - continue - - # All matmuls must not depend on each other directly or indirectly - # in order for the merge to be possible. - if not are_nodes_independent(mms): - continue - - lhs_vals = [mm.args[0] for mm in mms] - - # Merge the matmul. - # Collect a list of LHS operands and the single RHS operand. - lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] - rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs - - # Concatenate all the LHS operands. - merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) - - # Multiply the concatenated LHS operands with the one RHS. This will produce - # the same results as all the individual matmuls involving rhs in the original graph, - # but they will all be concatenated together. - merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) - - # Split the result of the merged matmul using the shapes of the LHS operands - # to ascertain how large each chunk should be. - merge_mm_sizes = [ - gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs - ] - merge_mm_split = gm.graph.call_function( - torch.split, (merge_mm, merge_mm_sizes), {} - ) - merge_mm_res = [ - gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) - for out in range(len(lhs)) - ] - - # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. - for old, new in zip(mms, merge_mm_res): - old.replace_all_uses_with(new) - gm.graph.erase_node(old) - - # All of the new nodes created above were inserted at the end, so we need to sort - # the nodes topologically to make sure all definitions precede uses. - legalize_graph(gm) - - gm.recompile() - gm.graph.lint(in_mod) - return gm From 0ad6f066843537d6cf86e57910f4bbf8faa60f9e Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 6 Jan 2021 06:50:56 -0800 Subject: [PATCH 12/20] drop a unneeded comma from cmakelist.txt (#50091) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50091 Reviewed By: smessmer Differential Revision: D25782083 Pulled By: ezyang fbshipit-source-id: f90f57c6c9fc0c1e68ab30dd3b56dfe971798df2 --- aten/src/ATen/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index fd3c95f2573b..6fedef185b21 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -72,7 +72,7 @@ file(GLOB metal_h "metal/*.h") file(GLOB metal_cpp "metal/*.cpp") file(GLOB_RECURSE native_metal_h "native/metal/*.h") file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm") -file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp") +file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm" "native/metal/*.cpp") EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs}) file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h") file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp") From 45ec35827ed73c27c114ba0444517baa5b3cdbee Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Wed, 6 Jan 2021 06:55:10 -0800 Subject: [PATCH 13/20] Set USE_RCCL cmake option (dependent on USE_NCCL) [REDUX] (#34683) Summary: Refiled duplicate of https://github.com/pytorch/pytorch/issues/31341 which was reverted in commit 63964175b52197a75e03b73c59bd2573df66b398. This PR enables RCCL support when building Gloo as part of PyTorch for ROCm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34683 Reviewed By: glaringlee Differential Revision: D25540578 Pulled By: ezyang fbshipit-source-id: fcb02e5745d62e1b7d2e02048160e9e7a4b4df2d --- CMakeLists.txt | 2 ++ tools/amd_build/build_amd.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index e346087c0cdb..3df73f8a3041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option( USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_RCCL "Use RCCL" ON + USE_NCCL OFF) cmake_dependent_option( USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 026293a9281a..9d4fa54c93b3 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -131,6 +131,20 @@ def is_hip_clang(): sources.write(line) print("%s updated" % gloo_cmake_file) +gloo_cmake_file = "third_party/gloo/cmake/Modules/Findrccl.cmake" +if os.path.exists(gloo_cmake_file): + do_write = False + with open(gloo_cmake_file, "r") as sources: + lines = sources.readlines() + newlines = [line.replace('RCCL_LIBRARY', 'RCCL_LIBRARY_PATH') for line in lines] + if lines == newlines: + print("%s skipped" % gloo_cmake_file) + else: + with open(gloo_cmake_file, "w") as sources: + for line in newlines: + sources.write(line) + print("%s updated" % gloo_cmake_file) + hipify_python.hipify( project_directory=proj_dir, output_directory=out_dir, From 2ac180a5dddf04178068dba7cbced33df250eb60 Mon Sep 17 00:00:00 2001 From: Chester Liu Date: Wed, 6 Jan 2021 07:08:16 -0800 Subject: [PATCH 14/20] Fix cl.exe detection in cpu/fused_kernel.cpp (#50085) Summary: The command used here is essentially `where cl.exe`. By using `system()` we will not be able to find cl.exe unless we are using VS Developer Prompt, which makes `activate()` meaningless. Change `system()` to `run()` fixes this. Found during https://github.com/pytorch/pytorch/issues/49781. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50085 Reviewed By: smessmer Differential Revision: D25782054 Pulled By: ezyang fbshipit-source-id: e8e3cac903a73f3bd78def667ebe0e93201814c8 --- torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index 4e76dc23e55d..4f4aa0d1536b 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -45,11 +45,17 @@ constexpr int so_suffix_len = 3; constexpr int cpp_suffix_len = 4; #endif +intptr_t run(const std::string& cmd); + static bool programExists(const std::string& program) { TemplateEnv env; env.s("program", program); std::string cmd = format(check_exists_string, env); +#ifdef _MSC_VER + return (run(cmd.c_str()) == 0); +#else return (system(cmd.c_str()) == 0); +#endif } #ifdef _MSC_VER From c517e15d79b8ae672ee2a94581fc57fa62155adf Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Wed, 6 Jan 2021 07:36:12 -0800 Subject: [PATCH 15/20] Add support for converting sparse bool tensors to dense (#50019) Summary: Fixes https://github.com/pytorch/pytorch/issues/49977 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50019 Reviewed By: smessmer Differential Revision: D25782045 Pulled By: ezyang fbshipit-source-id: a8389cbecb7e79099292a423a6fd8ac28631905b --- aten/src/ATen/native/sparse/SparseTensorMath.cpp | 2 +- aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu | 4 ++-- test/test_sparse.py | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 9bb679beb3d0..6c3298b72e75 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -650,7 +650,7 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen dstBuffer.add_(srcBuffer, value); } } else { - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, commonDtype, "add_dense_sparse", [&] { add_dense_sparse_worker_cpu(resultBuffer, value, sparse, indices, valuesBuffer); }); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index c8366f71618e..fce3446816e7 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -338,8 +338,8 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT if (sparse.dense_dim() == 0) { TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> <<>>( TensorCAddOp(value.to()), diff --git a/test/test_sparse.py b/test/test_sparse.py index 4e982b8333d9..228c66aa403e 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -356,6 +356,11 @@ def test_to_sparse(self): sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3]) self.assertRaises(RuntimeError, lambda: sp.to_sparse()) + def test_sparse_bool(self): + a = self.value_tensor([True, False]).to(torch.bool) + b = a.to_sparse().to_dense() + self.assertEqual(a, b) + def test_scalar(self): # tensor with value a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1), 12.3, []) From 5f2ec6293d6a443b8acca1d3ff7d57f9121afcc7 Mon Sep 17 00:00:00 2001 From: Alex Henrie Date: Wed, 6 Jan 2021 08:15:08 -0800 Subject: [PATCH 16/20] Unused variables in neural net classes and functions (#50100) Summary: These unused variables were identified by [pyflakes](https://pypi.org/project/pyflakes/). They can be safely removed to simplify the code and possibly improve performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50100 Reviewed By: ezyang Differential Revision: D25797764 Pulled By: smessmer fbshipit-source-id: ced341aee692f429d2dcc3a4ef5c46c8ee99cabb --- torch/nn/modules/module.py | 1 - torch/nn/parallel/replicate.py | 1 - torch/nn/quantized/dynamic/modules/rnn.py | 2 -- torch/nn/quantized/modules/embedding_ops.py | 1 - torch/nn/quantized/modules/normalization.py | 5 ----- torch/nn/utils/prune.py | 1 - 6 files changed, 11 deletions(-) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 297a4edf15bf..f054590da66a 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -843,7 +843,6 @@ def _slow_forward(self, *input, **kwargs): if recording_scopes: name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None if name: - cur_scope_name = tracing_state.current_scope() tracing_state.push_scope(name) else: recording_scopes = False diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index a069c6c6f939..8effeece5908 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -108,7 +108,6 @@ def replicate(network, devices, detach=False): modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} - scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"} for i, module in enumerate(modules): module_indices[module] = i diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index df88169471ca..59c0195d7858 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -239,8 +239,6 @@ def from_float(cls, mod): _all_weight_values = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): - layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions - suffix = '_reverse' if direction == 1 else '' def retrieve_weight_bias(ihhh): diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index d16748b3baf7..e41d55347741 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -52,7 +52,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) self.dtype = state_dict[prefix + 'dtype'] state_dict.pop(prefix + 'dtype') diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index 4664120ec8b5..c12f74374863 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -29,7 +29,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.normalized_shape, mod.weight, mod.bias, float(scale), @@ -63,7 +62,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), @@ -98,7 +96,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -133,7 +130,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), @@ -168,7 +164,6 @@ def _get_name(self): @classmethod def from_float(cls, mod): - activation_post_process = mod.activation_post_process scale, zero_point = mod.activation_post_process.calculate_qparams() new_mod = cls( mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 84fa30021ed1..851a551da0d8 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -587,7 +587,6 @@ def compute_mask(self, t, default_mask): # Compute number of units to prune: amount if int, # else amount * tensor_size nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) - nparams_tokeep = tensor_size - nparams_toprune # This should raise an error if the number of units to prune is larger # than the number of units in the tensor _validate_pruning_amount(nparams_toprune, tensor_size) From 688992c775e2eeef53f3184b2e3428ef2f3a2967 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 6 Jan 2021 08:33:26 -0800 Subject: [PATCH 17/20] [PyTorch] Additional IValue tests (#49718) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49718 Improving test coverage in preparation for updating the implementation of IValue. ghstack-source-id: 119327373 Test Plan: ivalue_test Reviewed By: hlu1 Differential Revision: D25674605 fbshipit-source-id: 37a82bb135f75ec52d2d8bd929c4329e8dcc4d25 --- aten/src/ATen/test/ivalue_test.cpp | 217 +++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 14e75205aa66..a0e2648758ff 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -51,6 +51,91 @@ TEST(IValueTest, Basic) { ASSERT_EQ(tv.use_count(), 2); } +static std::array makeSampleIValues() { + return { at::rand({3, 4}), "hello", 42, true, 1.5 }; +} + +static std::array makeMoreSampleIValues() { + return { at::rand({3, 4}), "goodbye", 23, false, 0.5 }; +} + +// IValue::operator== doesn't seem to work on Tensors. +#define EXPECT_IVALUE_EQ(a, b) \ + EXPECT_EQ((a).isTensor(), (b).isTensor()); \ + if ((a).isTensor()) { \ + EXPECT_TRUE(a.toTensor().equal(b.toTensor())); \ + } else { \ + EXPECT_EQ(a, b); \ + } + +TEST(IValueTest, Swap) { + // swap() has the following 3 cases: tensor, intrusive_ptr, or + // neither. Exercise all pairs of the three. + + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + for (const auto& input: sampleInputs) { + for (const auto& target: sampleTargets) { + IValue a(input); + IValue b(target); + EXPECT_IVALUE_EQ(a, input); + EXPECT_IVALUE_EQ(b, target); + a.swap(b); + EXPECT_IVALUE_EQ(a, target); + EXPECT_IVALUE_EQ(b, input); + } + } +} + +TEST(IValueTest, CopyConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue copy(v); + EXPECT_IVALUE_EQ(copy, v); + } +} + +TEST(IValueTest, MoveConstruct) { + auto sampleInputs = makeSampleIValues(); + for (const IValue& v: sampleInputs) { + IValue source(v); + IValue target(std::move(source)); + EXPECT_IVALUE_EQ(target, v); + EXPECT_TRUE(source.isNone()); + } +} + +TEST(IValueTest, CopyAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue copyTo(target); + IValue copyFrom(input); + copyTo = copyFrom; + EXPECT_IVALUE_EQ(copyTo, input); + EXPECT_IVALUE_EQ(copyFrom, input); + EXPECT_IVALUE_EQ(copyTo, copyFrom); + } + } +} + +TEST(IValueTest, MoveAssign) { + auto sampleInputs = makeSampleIValues(); + auto sampleTargets = makeMoreSampleIValues(); + + for (const IValue& input: sampleInputs) { + for (const IValue& target: sampleTargets) { + IValue moveTo(target); + IValue moveFrom(input); + moveTo = std::move(moveFrom); + EXPECT_IVALUE_EQ(moveTo, input); + EXPECT_TRUE(moveFrom.isNone()); + } + } +} + TEST(IValueTest, Tuple) { std::tuple t = std::make_tuple(123, at::randn({1})); auto iv = IValue(t); @@ -318,5 +403,137 @@ TEST(IValueTest, EnumEquality) { ); } +TEST(IValueTest, isPtrType) { + IValue tensor(at::rand({3, 4})); + IValue undefinedTensor((at::Tensor())); + IValue integer(42); + IValue str("hello"); + + EXPECT_TRUE(tensor.isPtrType()); + EXPECT_FALSE(undefinedTensor.isPtrType()); + EXPECT_FALSE(integer.isPtrType()); + EXPECT_TRUE(str.isPtrType()); +} + +TEST(IValueTest, isAliasOf) { + auto sampleIValues = makeSampleIValues(); + for (auto& iv: sampleIValues) { + for (auto& iv2: sampleIValues) { + if (&iv == &iv2 && iv.isPtrType()) { + EXPECT_TRUE(iv.isAliasOf(iv2)); + } else { + EXPECT_FALSE(iv.isAliasOf(iv2)); + } + } + } +} + +TEST(IValueTest, internalToPointer) { + IValue tensor(at::rand({3, 4})); + IValue str("hello"); + + EXPECT_EQ(tensor.internalToPointer(), tensor.unsafeToTensorImpl()); + EXPECT_NE(str.internalToPointer(), nullptr); + + IValue nullStr((c10::intrusive_ptr())); + ASSERT_TRUE(nullStr.isString()); + EXPECT_EQ(nullStr.internalToPointer(), nullptr); +} + +TEST(IValueTest, IdentityComparisonAndHashing) { + at::Tensor t1 = at::rand({3, 4}); + at::Tensor t2 = at::rand({3, 4}); + IValue tv1(t1), tv2(t2); + IValue tv1b(t1); + + EXPECT_EQ(tv1.hash(), tv1b.hash()); + EXPECT_NE(tv1.hash(), tv2.hash()); + + EXPECT_TRUE(tv1.is(tv1)); + EXPECT_TRUE(tv1.is(tv1b)); + EXPECT_TRUE(tv1b.is(tv1)); + EXPECT_TRUE(tv2.is(tv2)); + + EXPECT_FALSE(tv1.is(tv2)); + EXPECT_FALSE(tv2.is(tv1)); + + IValue none; + IValue undefinedTensor((at::Tensor())); + + EXPECT_TRUE(none.is(undefinedTensor)); + EXPECT_TRUE(undefinedTensor.is(none)); + + // Is this a bug? We should probably have a is b => a.hash() == b.hash() + EXPECT_NE(none.hash(), undefinedTensor.hash()); + + auto sampleIValues = makeSampleIValues(); + auto sampleIValues2 = makeSampleIValues(); + auto moreSampleIValues = makeMoreSampleIValues(); + + ASSERT_EQ(sampleIValues.size(), moreSampleIValues.size()); + for (int ii = 0; ii < sampleIValues.size(); ++ii) { + // Constant strings will have the same pointer value. + if (sampleIValues[ii].isPtrType() && !sampleIValues[ii].isString()) { + EXPECT_NE(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } else { + EXPECT_EQ(sampleIValues[ii].hash(), sampleIValues2[ii].hash()); + } + EXPECT_NE(sampleIValues[ii].hash(), moreSampleIValues[ii].hash()); + } +} + +TEST(IValueTest, getSubValues) { + // Scalars have no subvalues. + IValue integer(42), float_(1.5); + + IValue::HashAliasedIValues subvalues; + + integer.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + float_.getSubValues(subvalues); + EXPECT_TRUE(subvalues.empty()); + + subvalues.clear(); + + at::Tensor t1(at::rand({3, 4})), t2(at::rand({3, 4})); + IValue tv1(t1), tv2(t2); + IValue list(std::vector{t1, t2}); + IValue tuple(ivalue::Tuple::create({tv1, tv2})); + + std::unordered_map m; + m[1] = t1; + m[2] = t2; + + IValue dict(std::move(m)); + + auto objType = ClassType::create(nullopt, {}); + objType->addAttribute("t1", tv1.type()); + objType->addAttribute("t2", tv2.type()); + + auto o = ivalue::Object::create(StrongTypePtr(nullptr, objType), 2); + o->setSlot(0, tv1); + o->setSlot(1, tv2); + + IValue object(o); + tv1.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + + subvalues.clear(); + + for (auto& container: {list, tuple, dict, object}) { + container.getSubValues(subvalues); + EXPECT_EQ(subvalues.size(), 3); + EXPECT_EQ(subvalues.count(container), 1); + EXPECT_EQ(subvalues.count(tv1), 1); + EXPECT_EQ(subvalues.count(tv2), 1); + + subvalues.clear(); + } +} + // TODO(gmagogsfm): Add type conversion test? } // namespace c10 From 1b31e1353903eb52140aedef04c6edff5bb7b7e6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 6 Jan 2021 08:33:26 -0800 Subject: [PATCH 18/20] [PyTorch] Store Tensor explicitly in IValue (#48824) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48824 Enables following diff, which will make toTensor() return `const Tensor&` and allow callers to avoid refcounting overhead. ghstack-source-id: 119327370 Test Plan: ivalue_test Internal benchmark to ensure perf parity. Some interesting steps during the debugging process: - First version was about a 5% regression - Directly implementing move construction instead of using swap lowered the regression to 2-3% - Directly implementing move assign was maybe an 0.5% improvement - Adding C10_ALWAYS_INLINE on move assign got our regression to negligible - Fixing toTensor() to actually be correct regressed us again, but omitting the explicit dtor call as exhaustively spelled out in a comment fixed it. Reviewed By: bwasti Differential Revision: D25324617 fbshipit-source-id: 7518c1c67f6f2661f151b43310aaddf4fb6e511a --- aten/src/ATen/core/ivalue.cpp | 12 +- aten/src/ATen/core/ivalue.h | 279 +++++++++++++++++++++++--------- aten/src/ATen/core/ivalue_inl.h | 95 +++++++---- aten/src/ATen/core/jit_type.h | 8 +- c10/util/intrusive_ptr.h | 4 +- 5 files changed, 275 insertions(+), 123 deletions(-) diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 320fa6294638..1223577c59c6 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -265,7 +265,7 @@ bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) { TORCH_INTERNAL_ASSERT(lhs.is_intrusive_ptr); TORCH_INTERNAL_ASSERT(rhs.is_intrusive_ptr); return lhs.tag == rhs.tag && - lhs.payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } IValue IValue::equals(const IValue& rhs) const { @@ -325,17 +325,17 @@ size_t IValue::hash(const IValue& v) { case Tag::None: return 0; case Tag::Bool: - return c10::get_hash(v.payload.as_bool); + return c10::get_hash(v.payload.u.as_bool); case Tag::Double: - return c10::get_hash(v.payload.as_double); + return c10::get_hash(v.payload.u.as_double); case Tag::Tensor: // Tensor __hash__ is equivalent to `id()`, so take the pointer value of // the tensor to emulate it - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl()); case Tag::Storage: - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.u.as_int); case Tag::Int: - return c10::get_hash(v.payload.as_int); + return c10::get_hash(v.payload.u.as_int); case Tag::String: return c10::get_hash(v.toStringRef()); case Tag::Tuple: diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 4a7e15c4008b..5370294b2f2c 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -131,10 +131,15 @@ struct Capsule { // they are marked `@private`, which hides them on the doxygen documentation for // this page. -/// IValue (Interpreter Value) is a tagged union over the types supported by the -/// TorchScript interpreter. IValues contain their values as an -/// `IValue::Payload`, which holds primitive types (`int64_t`, `bool`, `double`, -/// `Device`), as values and all other types as a `c10::intrusive_ptr`. +/// IValue (Interpreter Value) is a tagged union over the types +/// supported by the TorchScript interpreter. IValues contain their +/// values as an `IValue::Payload`, which holds primitive types +/// (`int64_t`, `bool`, `double`, `Device`) and `Tensor` as values, +/// and all other types as a `c10::intrusive_ptr`. In order to +/// optimize performance of the destructor and related operations by +/// making the `Tensor` and `c10::intrusive_ptr` paths generate the +/// same code, we represent a null `c10::intrusive_ptr` as +/// `UndefinedTensorImpl::singleton()`, *not* `nullptr`. /// /// IValues are used as inputs to and outputs from the TorchScript interpreter. /// To retrieve the value contained within an IValue, use the `.toX()` methods, @@ -160,27 +165,35 @@ struct Capsule { struct TORCH_API IValue final { IValue(const IValue& rhs) : IValue(rhs.payload, rhs.tag, rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr); + if (is_intrusive_ptr && payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr); } } - IValue(IValue&& rhs) noexcept : IValue() { - swap(rhs); + + IValue(IValue&& rhs) noexcept : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + moveFrom(std::move(rhs)); } + /// @private [doxygen private] ~IValue() { - if (is_intrusive_ptr) { - c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr); - } + destroy(); } - IValue& operator=(IValue&& rhs) & noexcept { - IValue(std::move(rhs)).swap(*this); // this also sets rhs to None + + C10_ALWAYS_INLINE IValue& operator=(IValue&& rhs) & noexcept { + if (&rhs == this) { + return *this; + } + + destroy(); + moveFrom(std::move(rhs)); return *this; } + IValue& operator=(IValue const& rhs) & { IValue(rhs).swap(*this); return *this; } + void dump() const; /** @@ -260,13 +273,6 @@ struct TORCH_API IValue final { return false; } - if (!this->is_intrusive_ptr) { - // Primitive types don't alias anything - return false; - } - - AT_ASSERT(rhs.is_intrusive_ptr); - // Tensors should be compared based on internal storage if (this->isTensor()) { const auto thisTensor = this->toTensor(); @@ -274,22 +280,56 @@ struct TORCH_API IValue final { return thisTensor.is_alias_of(rhsTensor); } + if (!this->is_intrusive_ptr) { + // Primitive types don't alias anything + return false; + } + + AT_ASSERT(rhs.is_intrusive_ptr); + // Other types can be compared by their ptr value - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } /// @private [doxygen private] size_t use_count() const noexcept { + if (isTensor()) { + return payload.as_tensor.use_count(); + } + if (!is_intrusive_ptr) { return 1; } - return c10::raw::intrusive_ptr::use_count(payload.as_intrusive_ptr); + if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) { + return 0; + } + return c10::raw::intrusive_ptr::use_count(payload.u.as_intrusive_ptr); } /// @private [doxygen private] void swap(IValue& rhs) noexcept { - std::swap(payload, rhs.payload); + if (isTensor() && rhs.isTensor()) { + std::swap(payload.as_tensor, rhs.payload.as_tensor); + } else if (isTensor()) { + at::Tensor t = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + payload.u = rhs.payload.u; + new (&rhs.payload.as_tensor) at::Tensor(std::move(t)); + } else if (rhs.isTensor()) { + rhs.swap(*this); + return; + } else { + std::swap(payload.u, rhs.payload.u); + } std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr); std::swap(tag, rhs.tag); } @@ -298,13 +338,8 @@ struct TORCH_API IValue final { // While some of these accessors could be generated through templates, // we prefer to write them manually for clarity - IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) { - // Note: the undefined tensor is not refcounted, so while it - // is tagged as a tensor, is_intrusive_ptr is set to false. - // This is not an optional optimization: our incref call - // *will not* do the right thing when called on an - // undefined tensor. - payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl(); + IValue(at::Tensor t) : tag(Tag::Tensor), is_intrusive_ptr(false) { + new (&payload.as_tensor) at::Tensor(std::move(t)); } bool isTensor() const { return Tag::Tensor == tag; @@ -312,7 +347,7 @@ struct TORCH_API IValue final { at::Tensor toTensor() &&; at::Tensor toTensor() const&; at::TensorImpl* unsafeToTensorImpl() const { - return static_cast(payload.as_intrusive_ptr); + return payload.as_tensor.unsafeGetTensorImpl(); } IValue(at::Storage s) : tag(Tag::Storage), is_intrusive_ptr(static_cast(s)) { @@ -321,7 +356,7 @@ struct TORCH_API IValue final { // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined tensor. - payload.as_intrusive_ptr = s.unsafeReleaseStorageImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(s.unsafeReleaseStorageImpl()); } bool isStorage() const { return Tag::Storage == tag; @@ -341,7 +376,7 @@ struct TORCH_API IValue final { : tag(Tag::Blob), is_intrusive_ptr(true) { // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract // and store it as a Tensor instead. - payload.as_intrusive_ptr = blob.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); } /// @private [doxygen private] @@ -397,14 +432,14 @@ struct TORCH_API IValue final { // Double IValue(double d) : tag(Tag::Double), is_intrusive_ptr(false) { - payload.as_double = d; + payload.u.as_double = d; } bool isDouble() const { return Tag::Double == tag; } double toDouble() const { AT_ASSERT(isDouble()); - return payload.as_double; + return payload.u.as_double; } // Future @@ -433,7 +468,7 @@ struct TORCH_API IValue final { // Int IValue(int64_t i) : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = i; + payload.u.as_int = i; } // allow you to pass literals (3, 4) without ambiguity @@ -445,7 +480,7 @@ struct TORCH_API IValue final { int64_t toInt() const { AT_ASSERT(isInt()); - return payload.as_int; + return payload.u.as_int; } // Bool @@ -454,9 +489,9 @@ struct TORCH_API IValue final { // Initializing entire payload stops valgrind's from reporting // "jump or move depends on uninitialised value" in IValue copy constructor // See https://github.com/pytorch/pytorch/issues/37117 - payload.as_int = b; + payload.u.as_int = b; #else - payload.as_bool = b; + payload.u.as_bool = b; #endif } bool isBool() const { @@ -464,7 +499,7 @@ struct TORCH_API IValue final { } bool toBool() const { AT_ASSERT(isBool()); - return payload.as_bool; + return payload.u.as_bool; } // IntList @@ -580,7 +615,7 @@ struct TORCH_API IValue final { c10::intrusive_ptr toEnumHolder() const&; // None - IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {} + IValue() : tag(Tag::None), is_intrusive_ptr(false) {} bool isNone() const { return Tag::None == tag; } @@ -616,21 +651,21 @@ struct TORCH_API IValue final { // Device IValue(c10::Device d) : tag(Tag::Device), is_intrusive_ptr(false) { - payload.as_device.type = d.type(); - payload.as_device.index = d.index(); + payload.u.as_device.type = d.type(); + payload.u.as_device.index = d.index(); } bool isDevice() const { return Tag::Device == tag; } c10::Device toDevice() const { AT_ASSERT(isDevice()); - return c10::Device(payload.as_device.type, payload.as_device.index); + return c10::Device(payload.u.as_device.type, payload.u.as_device.index); } //Stream IValue(c10::Stream stream) : tag(Tag::Stream), is_intrusive_ptr(false) { - payload.as_int = stream.pack(); + payload.u.as_int = stream.pack(); } c10::Stream toStream() &&; c10::Stream toStream() const &; @@ -659,7 +694,7 @@ struct TORCH_API IValue final { // QScheme IValue(at::QScheme qscheme) : tag(Tag::Int), is_intrusive_ptr(false) { - payload.as_int = static_cast(qscheme); + payload.u.as_int = static_cast(qscheme); } at::QScheme toQScheme() const { @@ -680,7 +715,7 @@ struct TORCH_API IValue final { // This is not an optional optimization: our incref call // *will not* do the right thing when called on an // undefined generator. - payload.as_intrusive_ptr = g.unsafeReleaseGeneratorImpl(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(g.unsafeReleaseGeneratorImpl()); } bool isGenerator() const { return Tag::Generator == tag; @@ -749,14 +784,19 @@ struct TORCH_API IValue final { const IValue& v); bool isPtrType() const { - return is_intrusive_ptr; + return (isTensor() && payload.as_tensor.defined()) || is_intrusive_ptr; } /// @private [doxygen private] const void* internalToPointer() const { TORCH_INTERNAL_ASSERT( isPtrType(), "Can only call internalToPointer() for pointer types"); - return payload.as_intrusive_ptr; + if (isTensor()) { + return payload.as_tensor.unsafeGetTensorImpl(); + } else { + return payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton() + ? payload.u.as_intrusive_ptr : nullptr; + } } TypePtr type() const; @@ -770,7 +810,7 @@ struct TORCH_API IValue final { } // If it is not a Tensor, then two mutable IValues alias each other only // if they are the same pointer. - return val.payload.as_int; + return val.payload.u.as_int; } }; @@ -800,6 +840,10 @@ struct TORCH_API IValue final { IValue deepcopy(HashAliasedIValueMap& memo) const; private: + static c10::intrusive_ptr_target* null_to_undefined_tensor(c10::intrusive_ptr_target* p) { + return p ? p : static_cast(c10::UndefinedTensorImpl::singleton()); + } + static bool ptrEqual(const IValue& lhs, const IValue& rhs); // NOTE: IValue tags are intentionally private. In the future we may encode // this value different (e.g. using NaN boxing), and this would make it more @@ -822,24 +866,77 @@ struct TORCH_API IValue final { class NullType = c10::detail::intrusive_target_default_null_type> c10::intrusive_ptr toIntrusivePtr() const; - void clearToNone() { - payload.as_int = 0; + void destroy() { + // We carefully construct this call to both 1) avoid UB by using + // the "wrong" one of as_tensor and as_intrusive_ptr and 2) enable + // the compiler to generate the same code for each case. It is + // surprisingly difficult to get this right. + if (isTensor() || is_intrusive_ptr) { + c10::intrusive_ptr_target* p = isTensor() ? payload.as_tensor.unsafeGetTensorImpl() : payload.u.as_intrusive_ptr; + c10::intrusive_ptr::reclaim(p); + // No need to make this destructor call! + // payload.as_tensor.~Tensor(); + } + } + + C10_ALWAYS_INLINE void moveFrom(IValue&& rhs) noexcept { + if (rhs.isTensor()) { + new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor)); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // rhs.payload.as_tensor.~Tensor(); + } else { + payload.u = rhs.payload.u; + } + tag = rhs.tag; + is_intrusive_ptr = rhs.is_intrusive_ptr; + rhs.clearToNone(); + } + + void clearToNone() noexcept { + payload.u.as_int = 0; tag = Tag::None; is_intrusive_ptr = false; } union Payload { - int64_t as_int; - double as_double; - bool as_bool; - c10::intrusive_ptr_target* as_intrusive_ptr; - struct { - DeviceType type; - DeviceIndex index; - } as_device; + // We use a nested union here so that we can make the copy easy + // and efficient in the non-tensor (i.e., trivially copyable) + // case. Specifically, we do not have to do a switch-on-tag to + // figure out which union member to assign; we can just use + // TriviallyCopyablePayload::operator=. + union TriviallyCopyablePayload { + TriviallyCopyablePayload() : as_int(0) {} + int64_t as_int; + double as_double; + bool as_bool; + // Invariant: never nullptr; null state is represented as + // c10::UndefinedTensorImpl::singleton() for consistency of + // representation with Tensor. + c10::intrusive_ptr_target* as_intrusive_ptr; + struct { + DeviceType type; + DeviceIndex index; + } as_device; + } u; + at::Tensor as_tensor; + Payload() : u() {} + ~Payload() {} }; - IValue(Payload p, Tag t, bool i) : payload(p), tag(t), is_intrusive_ptr(i) {} + IValue(const Payload& p, Tag t, bool i) : tag(t), is_intrusive_ptr(i) { + if (isTensor()) { + new (&payload.as_tensor) at::Tensor(p.as_tensor); + } else { + payload.u = p.u; + } + } Payload payload; Tag tag; @@ -848,29 +945,36 @@ struct TORCH_API IValue final { }; struct TORCH_API WeakIValue final { - WeakIValue() : payload{0}, tag(IValue::Tag::None), is_intrusive_ptr(false) {} + WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {} WeakIValue(const WeakIValue& rhs) : payload(rhs.payload), tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); } } WeakIValue(const IValue& rhs) - : payload(rhs.payload), - tag(rhs.tag), + : tag(rhs.tag), is_intrusive_ptr(rhs.is_intrusive_ptr) { + if (rhs.isTensor()) { + payload.as_intrusive_ptr = rhs.unsafeToTensorImpl(); + is_intrusive_ptr = true; + } else { + payload = rhs.payload.u; + } if (is_intrusive_ptr) { - c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + if (payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { + c10::raw::weak_intrusive_ptr::incref(payload.as_intrusive_ptr); + } } } WeakIValue(WeakIValue&& rhs) noexcept : WeakIValue() { swap(rhs); } ~WeakIValue() { - if (is_intrusive_ptr) { + if (is_intrusive_ptr && payload.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton()) { c10::raw::weak_intrusive_ptr::decref(payload.as_intrusive_ptr); } } @@ -895,17 +999,33 @@ struct TORCH_API WeakIValue final { IValue lock() const { if (!is_intrusive_ptr) { - return IValue(payload, tag, false); + IValue::Payload newPayload; + newPayload.u = payload; + return IValue(newPayload, tag, false); } - auto temp = c10::weak_intrusive_ptr::reclaim( - payload.as_intrusive_ptr); - IValue::Payload pl; - pl.as_intrusive_ptr = temp.lock().release(); - temp.release(); - if (!pl.as_intrusive_ptr) { - return IValue(); + if (IValue::Tag::Tensor == tag) { + auto temp = c10::weak_intrusive_ptr::reclaim( + static_cast(payload.as_intrusive_ptr)); + c10::intrusive_ptr ip(temp.lock()); + temp.release(); + if (!ip) { + return IValue(); + } else { + return IValue(at::Tensor(std::move(ip))); + } } else { - return IValue(pl, tag, true); + auto temp = c10::weak_intrusive_ptr::reclaim( + payload.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? nullptr + : payload.as_intrusive_ptr); + IValue::Payload pl; + pl.u.as_intrusive_ptr = temp.lock().release(); + temp.release(); + if (!pl.u.as_intrusive_ptr) { + return IValue(); + } else { + return IValue(pl, tag, true); + } } } @@ -913,7 +1033,7 @@ struct TORCH_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.use_count(); temp.release(); @@ -924,7 +1044,7 @@ struct TORCH_API WeakIValue final { if (!is_intrusive_ptr) { return 1; } - auto temp = c10::weak_intrusive_ptr::reclaim( + auto temp = c10::weak_intrusive_ptr::reclaim( payload.as_intrusive_ptr); size_t result = temp.weak_use_count(); temp.release(); @@ -935,7 +1055,8 @@ struct TORCH_API WeakIValue final { } private: - IValue::Payload payload; + using Payload = IValue::Payload::TriviallyCopyablePayload; + Payload payload; IValue::Tag tag; bool is_intrusive_ptr; }; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 89c8e669c138..fe55d783e780 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -48,14 +48,18 @@ struct tagged_capsule { template c10::intrusive_ptr IValue::moveToIntrusivePtr() { auto t = c10::intrusive_ptr::reclaim( - static_cast(payload.as_intrusive_ptr)); + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); clearToNone(); return t; } template c10::intrusive_ptr IValue::toIntrusivePtr() const { auto r = c10::intrusive_ptr::reclaim( - static_cast(payload.as_intrusive_ptr)); + payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton() + ? NullType::singleton() + : static_cast(payload.u.as_intrusive_ptr)); auto p = r; r.release(); return p; @@ -131,12 +135,22 @@ inline c10::intrusive_ptr IValue::toEnumHolder() const& { } inline at::Tensor IValue::toTensor() && { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor( - moveToIntrusivePtr()); + auto result = std::move(payload.as_tensor); + // As far as I can tell, omitting the usual explicit destructor call + // is not UB in and of itself, and it's a slight perf win. The + // destructor is a no-op, because the moved-from Tensor is + // effectively an intrusive_ptr in the null state, so we don't need + // the behavior for correctness reasons either. Leaving this + // explanatory comment, including commented-out destructor call, to + // make this abundantly clear. + // + // payload.as_tensor.~Tensor(); + clearToNone(); + return result; } inline at::Tensor IValue::toTensor() const& { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); - return at::Tensor(toIntrusivePtr()); + return payload.as_tensor; } inline c10::Storage IValue::toStorage() && { AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind()); @@ -148,10 +162,10 @@ inline c10::Storage IValue::toStorage() const& { return c10::Storage(toIntrusivePtr()); } inline c10::Stream IValue::toStream() && { - return c10::Stream::unpack(payload.as_int); + return c10::Stream::unpack(payload.u.as_int); } inline c10::Stream IValue::toStream() const& { - return c10::Stream::unpack(payload.as_int); + return c10::Stream::unpack(payload.u.as_int); } inline c10::intrusive_ptr IValue::toBlob() && { AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind()); @@ -713,7 +727,8 @@ using _guarded_unsigned_long = std::conditional_t< inline const ivalue::Object& IValue::toObjectRef() const { AT_ASSERT(isObject(), "Expected Object but got ", tagKind()); - return *static_cast(payload.as_intrusive_ptr); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference"); + return *static_cast(payload.u.as_intrusive_ptr); } // note: when adding a DEFINE_TO case here you should also add a @@ -980,8 +995,11 @@ inline c10::List IValue::toIntList() const& { } inline std::vector IValue::toIntVector() const { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toIntVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toDoubleList() && { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); @@ -993,8 +1011,11 @@ inline c10::List IValue::toDoubleList() const& { } inline std::vector IValue::toDoubleVector() const { AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toDoubleVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toBoolList() && { AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind()); @@ -1014,8 +1035,11 @@ inline c10::List IValue::toTensorList() const& { } inline std::vector IValue::toTensorVector() const { AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toTensorVector on null intrusive_ptr IValue"); return createVectorFromList( - static_cast(payload.as_intrusive_ptr)); + static_cast(payload.u.as_intrusive_ptr)); } inline c10::List IValue::toList() && { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); @@ -1027,7 +1051,10 @@ inline c10::List IValue::toList() const& { } inline c10::ArrayRef IValue::toListRef() const { AT_ASSERT(isList(), "Expected GenericList but got ", tagKind()); - return static_cast(payload.as_intrusive_ptr) + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toListRef on null intrusive_ptr IValue"); + return static_cast(payload.u.as_intrusive_ptr) ->list; } inline c10::Dict IValue::toGenericDict() && { @@ -1049,7 +1076,7 @@ inline c10::intrusive_ptr IValue::toTuple() const& { inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Tuple), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } template < typename... Args, @@ -1065,14 +1092,14 @@ inline IValue::IValue(const std::tuple& t) inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::String), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(std::string v) : IValue(ivalue::ConstantString::create(std::move(v))) {} inline IValue::IValue(c10::impl::GenericList v) : tag(Tag::GenericList), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template > @@ -1104,7 +1131,7 @@ inline IValue::IValue(std::array v) : IValue(c10::List()) { inline IValue::IValue(c10::impl::GenericDict v) : tag(Tag::GenericDict), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.impl_.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release()); } template inline IValue::IValue(c10::Dict v) @@ -1131,17 +1158,17 @@ inline IValue::IValue(c10::nullopt_t) : IValue() {} inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Object), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::PyObject), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Enum), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue IValue::make_capsule( @@ -1149,7 +1176,7 @@ inline IValue IValue::make_capsule( IValue iv; iv.tag = Tag::Capsule; iv.is_intrusive_ptr = true; - iv.payload.as_intrusive_ptr = blob.release(); + iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release()); return iv; } @@ -1170,30 +1197,33 @@ IValue::IValue(c10::intrusive_ptr custom_class) { auto ivalue_obj = c10::ivalue::Object::create( c10::StrongTypePtr(nullptr, classType), /*num_slots=*/1); ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class))); - payload.as_intrusive_ptr = ivalue_obj.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release()); tag = Tag::Object; is_intrusive_ptr = true; } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Future), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::RRef), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline IValue::IValue(c10::intrusive_ptr v) : tag(Tag::Quantizer), is_intrusive_ptr(true) { - payload.as_intrusive_ptr = v.release(); + payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release()); } inline const std::string& IValue::toStringRef() const { AT_ASSERT(isString(), "Expected String but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toStringRef on null intrusive_ptr IValue"); return static_cast( - payload.as_intrusive_ptr) + payload.u.as_intrusive_ptr) ->string(); } inline c10::optional> IValue:: @@ -1202,8 +1232,11 @@ inline c10::optional> IValue:: return c10::nullopt; } AT_ASSERT(isString(), "Expected optional but got ", tagKind()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), + "called toOptionalStringRef on null intrusive_ptr IValue"); return std::reference_wrapper( - static_cast(payload.as_intrusive_ptr) + static_cast(payload.u.as_intrusive_ptr) ->string()); } @@ -1241,15 +1274,13 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { // for bool type, do equality check return this->toBool() == rhs.toBool(); } else if (this->isTensor() && rhs.isTensor()) { - // for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr - // is false for undefined tensor - return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + return this->payload.as_tensor.is_same(rhs.payload.as_tensor); } else if (this->isTensor() && rhs.isNone()) { // special case: undefined tensor and None are the same identity - return !this->is_intrusive_ptr; + return !this->payload.as_tensor.defined(); } else if (this->isNone() && rhs.isTensor()) { // special case: undefined tensor and None are the same identity - return !rhs.is_intrusive_ptr; + return !rhs.payload.as_tensor.defined(); } else if (this->isInt() && rhs.isInt()) { return this->toInt() == rhs.toInt(); } else if (this->isDouble() && rhs.isDouble()) { @@ -1260,7 +1291,7 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const { // for objects holding in IValue, do shallow compare on pointer address to // testify the identity return this->is_intrusive_ptr && rhs.is_intrusive_ptr && - this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr; + this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr; } } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index a3ae813616e0..7d3890f582b8 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -2370,19 +2370,19 @@ struct TORCH_API AnyClassType : public Type { inline bool IValue::isDoubleList() const { // note: avoids calling type() to avoid extra referencing counting for the returned type. - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == FloatType::Kind; } inline bool IValue::isTensorList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == TensorType::Kind; } inline bool IValue::isIntList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == IntType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == IntType::Kind; } inline bool IValue::isBoolList() const { - return isList() && static_cast(payload.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; + return isList() && static_cast(payload.u.as_intrusive_ptr)->elementType->kind() == BoolType::Kind; } template<> diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 637db95991f2..790d97ee3994 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -206,7 +206,7 @@ class intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; @@ -509,7 +509,7 @@ class weak_intrusive_ptr final { "NullType must have a constexpr singleton() method"); #endif static_assert( - std::is_same::value, + std::is_base_of::type>::value, "NullType::singleton() must return a element_type* pointer"); TTarget* target_; From 480a756194f27580753a63d908393dfda3baeb25 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 6 Jan 2021 08:33:26 -0800 Subject: [PATCH 19/20] [PyTorch] IValue::toTensor can now return const Tensor& (#48868) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48868 Building on the previous diff, we can make `toTensor()` return a `const Tensor&`, which should make it easier to avoid reference counting. ghstack-source-id: 119327372 Test Plan: internal benchmarks. Reviewed By: bwasti Differential Revision: D25325379 fbshipit-source-id: ca699632901691bcee432f595f75b0a4416d55dd --- aten/src/ATen/core/ivalue.h | 7 +- aten/src/ATen/core/ivalue_inl.h | 7 +- torch/csrc/jit/codegen/cuda/kernel_cache.cpp | 2 +- torch/csrc/jit/frontend/tracer.cpp | 6 +- torch/csrc/jit/passes/freeze_module.cpp | 8 +- torch/csrc/jit/runtime/argument_spec.h | 2 +- torch/csrc/jit/runtime/interpreter.cpp | 4 +- torch/csrc/jit/runtime/profiling_record.cpp | 2 +- torch/csrc/jit/runtime/static/ops.cpp | 82 +++++++++---------- torch/csrc/jit/serialization/pickler.cpp | 2 +- torch/csrc/jit/serialization/python_print.cpp | 4 +- torch/csrc/jit/serialization/unpickler.cpp | 2 +- 12 files changed, 67 insertions(+), 61 deletions(-) diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 5370294b2f2c..ca68a8df46e1 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -275,8 +275,8 @@ struct TORCH_API IValue final { // Tensors should be compared based on internal storage if (this->isTensor()) { - const auto thisTensor = this->toTensor(); - const auto rhsTensor = rhs.toTensor(); + const auto& thisTensor = this->toTensor(); + const auto& rhsTensor = rhs.toTensor(); return thisTensor.is_alias_of(rhsTensor); } @@ -345,7 +345,8 @@ struct TORCH_API IValue final { return Tag::Tensor == tag; } at::Tensor toTensor() &&; - at::Tensor toTensor() const&; + at::Tensor& toTensor() &; + const at::Tensor& toTensor() const&; at::TensorImpl* unsafeToTensorImpl() const { return payload.as_tensor.unsafeGetTensorImpl(); } diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index fe55d783e780..b96f4b834989 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -148,7 +148,11 @@ inline at::Tensor IValue::toTensor() && { clearToNone(); return result; } -inline at::Tensor IValue::toTensor() const& { +inline at::Tensor& IValue::toTensor() & { + AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); + return payload.as_tensor; +} +inline const at::Tensor& IValue::toTensor() const& { AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind()); return payload.as_tensor; } @@ -744,6 +748,7 @@ inline const ivalue::Object& IValue::toObjectRef() const { inline type IValue::to() const& { \ return this->method_name(); \ } + DEFINE_TO(at::Tensor, toTensor) DEFINE_TO(at::Storage, toStorage) DEFINE_TO(c10::Stream, toStream) diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index f1a0a634727a..5bddc510fe56 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -209,7 +209,7 @@ InputsIdLookup::IdLookupReturn InputsIdLookup::lookupId( std::stringstream encoded_inputs; for (const auto& input : inputs) { if (input.isTensor()) { - auto input_tensor = input.toTensor(); + auto& input_tensor = input.toTensor(); encoded_inputs << ";"; auto sep = ""; diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 1bab391bd393..0c88371399de 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -137,7 +137,7 @@ Value* TracingState::getValue(const IValue& var) { return graph->insertNode(dict_node)->output(); } if (var.isTensor()) { - auto ten = var.toTensor(); + auto& ten = var.toTensor(); if (!ten.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -237,7 +237,7 @@ bool TracingState::hasValue(const IValue& var) const { Value* TracingState::getOutput(const IValue& iv, size_t i) { bool tracing_mode_strict = getTracingState()->strict; if (iv.isTensor()) { - at::Tensor var = iv.toTensor(); + const at::Tensor& var = iv.toTensor(); if (!var.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); @@ -506,7 +506,7 @@ void setValueTrace(const IValue& v, Value* value) { } void TracingState::setValue(const IValue& v, Value* value) { if (v.isTensor()) { - auto var = v.toTensor(); + auto& var = v.toTensor(); AT_ASSERT(var.defined()); env_stack.back()[v] = value; } else if (v.isTensorList()) { diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 2778c7712f23..f66f54eeb567 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -289,11 +289,11 @@ class AttributePropagator { IValue overrideGradient(IValue attr) { if (attr.isTensor()) { - auto t = attr.toTensor(); + auto& t = attr.toTensor(); if (t.requires_grad()) { - t = t.detach(); - t.set_requires_grad(false); - attr = IValue(t); + auto detached = t.detach(); + detached.set_requires_grad(false); + attr = IValue(std::move(detached)); } } else if (attr.isTuple()) { auto tuple = std::move(attr).toTuple(); diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 401933c6d67e..a0e60e879146 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -237,7 +237,7 @@ struct CompleteArgumentSpec { for (int32_t i = 0; i < num_inputs; i++) { if (!inputs[i].isTensor()) continue; - auto tensor = inputs[i].toTensor(); + auto& tensor = inputs[i].toTensor(); all_dims += tensor.defined() ? tensor.ndimension() : 0; } // allocate enough room for all TensorPODs and dimensions diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 24ca9dbf9793..ce4718becaf7 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1418,7 +1418,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // Check every input's shape against profiled (expected) shape. for (i = 0; i < num_inputs; i++) { auto& input = peek(stack, i, num_inputs); - auto t = input.toTensor(); + auto& t = input.toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto expected_type = expected->cast(); if (t.defined() && !expected_type->matchTensor(t)) { @@ -1439,7 +1439,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { // so it's safe to pass this guard check push(stack, true); } else { - auto t = stack.back().toTensor(); + auto& t = stack.back().toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X]; auto expected_type = expected->cast(); if (t.defined() && diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index 8d276dd58b50..d233f089f187 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -165,7 +165,7 @@ void ProfilingRecord::insertShapeProfile(Node* n, size_t offset) { if (v.isTensor()) { std::lock_guard lock(this->mutex_); auto& profiled_types = profiled_types_per_frame_[frame_id]; - auto t = v.toTensor(); + auto& t = v.toTensor(); if (t.defined()) { auto pttp = tensorTypeInCurrentExecutionContext(t); GRAPH_DEBUG( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 5c118f513565..89519d3765b5 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -79,13 +79,13 @@ struct static_add final : public at::native::structured_add_out { REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); auto in2_s = p_node->Input(2, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); static_add op{out_t}; op.meta(in0_t, in1_t, in2_s); op.impl(in0_t, in1_t, in2_s, out_t); @@ -94,12 +94,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::mul_out(out_t, in0_t, in1_t); }; @@ -107,15 +107,15 @@ REGISTER_OPERATOR_FUNCTOR(aten::mul, aten_mul, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); - auto in2_t = p_node->Input(2, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); + auto& in2_t = p_node->Input(2, reg).toTensor(); auto in3_s = p_node->Input(3, reg).toScalar(); auto in4_s = p_node->Input(4, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::addmm_cpu_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); }; @@ -123,13 +123,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::addmm, aten_addmm, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_s = p_node->Input(1, reg).toScalar(); auto in2_s = p_node->Input(2, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::clamp_out(out_t, in0_t, in1_s, in2_s); }; @@ -137,12 +137,12 @@ REGISTER_OPERATOR_FUNCTOR(aten::clamp, aten_clamp, [](Node* n) -> SROperator { REGISTER_OPERATOR_FUNCTOR(aten::bmm, aten_bmm, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); - auto in1_t = p_node->Input(1, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); + auto& in1_t = p_node->Input(1, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::bmm_out_cpu(out_t, in0_t, in1_t); }; @@ -154,7 +154,7 @@ REGISTER_OPERATOR_FUNCTOR( [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { auto input_size = p_node->input_regs().size(); - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); double in1_d = input_size > 1 ? p_node->Input(1, reg).toDouble() : 0; double in2_d = input_size > 2 ? p_node->Input(2, reg).toDouble() : std::numeric_limits::infinity(); @@ -164,7 +164,7 @@ REGISTER_OPERATOR_FUNCTOR( if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::nan_to_num_out(out_t, in0_t, in1_d, in2_d, in3_d); }; @@ -176,18 +176,18 @@ REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator { if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_tl[0]); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::_cat_out_cpu(out_t, in0_tl, in1_i); }; }); REGISTER_OPERATOR_FUNCTOR(aten::tanh, aten_tanh, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::tanh_out(out_t, in0_t); }; @@ -217,7 +217,7 @@ SROperator aten_stack(Node* n) { for (auto i = 0; i < inputs.size(); i++) { inputs[i] = inputs[i].unsqueeze(dim); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::_cat_out_cpu(out_t, inputs, dim); }; @@ -230,11 +230,11 @@ REGISTER_OPERATOR_FUNCTOR( aten_sigmoid, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::sigmoid_out(out_t, in0_t); }; @@ -247,57 +247,57 @@ REGISTER_OPERATOR_FUNCTOR( if (in1) { auto in1_s = in1->toScalar(); return [=](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::leaky_relu_out(out_t, in0_t, in1_s); }; } else { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_s = p_node->Input(1, reg).toScalar(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::leaky_relu_out(out_t, in0_t, in1_s); }; } }); REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::threshold_out(out_t, in0_t, 0, 0); }; }); REGISTER_OPERATOR_FUNCTOR(aten::logit, aten_logit, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); double in1_d = p_node->input_regs().size() > 1 ? p_node->Input(1, reg).toDouble() : -1.0; if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); out_t.resize_({0}); at::native::logit_out(out_t, in0_t, in1_d); }; }); REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); if (p_node->Output(0, reg).isNone()) { p_node->Output(0, reg) = create_empty_from(in0_t); } - auto out_t = p_node->Output(0, reg).toTensor(); + auto& out_t = p_node->Output(0, reg).toTensor(); at::native::resize_as_(out_t, in0_t, c10::nullopt); at::native::copy_(out_t, in0_t, false); }; @@ -317,14 +317,14 @@ std::function&)> getNativeOperation(Node* n) { if (n->kind() == c10::Symbol::fromQualString("aten::transpose")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); p_node->Output(0, reg) = at::native::transpose(in0_t, in1_i, in2_i); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::flatten")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); p_node->Output(0, reg) = at::native::flatten(in0_t, in1_i, in2_i); @@ -386,19 +386,19 @@ getNativeOperation(Node* n) { }; } else if (n->kind() == c10::Symbol::fromQualString("aten::permute")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_iv = p_node->Input(1, reg).toIntVector(); p_node->Output(0, reg) = at::native::permute(in0_t, in1_iv); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::reshape")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_iv = p_node->Input(1, reg).toIntVector(); p_node->Output(0, reg) = at::native::reshape(in0_t, in1_iv); }; } else if (n->kind() == c10::Symbol::fromQualString("aten::slice")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toInt(); auto in2_i = p_node->Input(2, reg).toInt(); auto in3_i = p_node->Input(3, reg).toInt(); @@ -408,13 +408,13 @@ getNativeOperation(Node* n) { }; } else if (n->kind() == c10::Symbol::fromQualString("aten::narrow")) { return [](const ProcessedNode* p_node, std::vector& reg) { - auto self = p_node->Input(0, reg).toTensor(); // self + auto& self = p_node->Input(0, reg).toTensor(); // self auto dim = p_node->Input(1, reg).toInt(); // dim int64_t start = 0; if (p_node->Input(2, reg).isScalar()) { start = p_node->Input(2, reg).toInt(); } else { - auto t = p_node->Input(2, reg).toTensor(); + auto& t = p_node->Input(2, reg).toTensor(); start = t.item(); } auto length = p_node->Input(3, reg).toInt(); // length @@ -440,7 +440,7 @@ getNativeOperation(Node* n) { } else if (n->kind() == c10::Symbol::fromQualString("aten::to")) { return [](const ProcessedNode* p_node, std::vector& reg) { DCHECK(p_node->input_regs().size() == 5); - auto in0_t = p_node->Input(0, reg).toTensor(); + auto& in0_t = p_node->Input(0, reg).toTensor(); auto in1_i = p_node->Input(1, reg).toScalarType(); auto in2_i = p_node->Input(2, reg).toBool(); auto in3_i = p_node->Input(3, reg).toBool(); diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 6e5c3b927c38..811569485888 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -354,7 +354,7 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // // The format here is the same one used by `torch.save()`. The code for the // format can be found in `torch/serialization.py`. - auto tensor = ivalue.toTensor(); + auto& tensor = ivalue.toTensor(); bool quantized = tensor.is_quantized(); // The arguments to this function are: // storage, storage_offset, size, stride, requires_grad, backward_hooks diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index c86cbc460c9c..18d656c98f32 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -309,12 +309,12 @@ struct PythonPrintImpl { // because it doesn't hash any information about the tensors. // We will probably need to optimize this at some point using hashing. if (val.isTensor()) { - auto t = val.toTensor(); + auto& t = val.toTensor(); for (size_t i = 0; i < constant_table_.size(); ++i) { if (!constant_table_[i].isTensor()) { continue; } - auto t2 = constant_table_[i].toTensor(); + auto& t2 = constant_table_[i].toTensor(); if (t.options().type_equal(t2.options()) && t.equal(t2)) { return i; } diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 3ff5da29fe1f..841e87592be9 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -632,7 +632,7 @@ void Unpickler::rebuildTensor(bool quantized) { auto tup = pop(stack_).toTuple(); const auto& elements = tup->elements(); size_t idx = 0; - auto storage_tensor = elements.at(idx++).toTensor(); + auto& storage_tensor = elements.at(idx++).toTensor(); int64_t storage_offset = elements.at(idx++).toInt(); std::vector size = tupleToIntList(elements.at(idx++)); std::vector stride = tupleToIntList(elements.at(idx++)); From 68a6e4637903dba279c60daae5cff24e191ff9b4 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 6 Jan 2021 08:39:11 -0800 Subject: [PATCH 20/20] Push anonymous namespace into codegen, not template (#49498) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49498 In the near future, I want to code generate some functions that are visible externally to this compilation unit. I cannot easily do this if all the codegen code is wrapped in a global anonymous namespace, so push the namespace in. Registration has to stay in an anonymous namespace to avoid name conflicts. This could also have been solved by making the wrapper functions have more unique names but I didn't do this in the end. Signed-off-by: Edward Z. Yang Test Plan: Imported from OSS Reviewed By: albanD, smessmer Differential Revision: D25616104 Pulled By: ezyang fbshipit-source-id: 323c0dda05a081502aab702f359a08dfac8c41a4 --- aten/src/ATen/templates/RegisterDispatchKey.cpp | 7 +++++-- tools/codegen/gen.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index e923f6d73bd0..ed4359c6883e 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -37,10 +37,13 @@ namespace at { -namespace { - ${dispatch_definitions} +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { + TORCH_LIBRARY_IMPL(aten, ${DispatchKey}, m) { ${dispatch_registrations} } diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 8f521e6651bc..4768670b6f26 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -435,6 +435,8 @@ def gen_one(f: NativeFunction) -> Optional[str]: # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ +namespace {{ + {self.gen_structured_class( f, k, class_name=class_name, @@ -448,6 +450,8 @@ def gen_one(f: NativeFunction) -> Optional[str]: {impl_call} return {ret_expr}; }} + +}} // anonymous namespace """ elif self.target is Target.REGISTRATION: @@ -540,9 +544,13 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: """ return f"""\ +namespace {{ + {returns_type} {name}({args_str}) {{ {cuda_guard}{return_kw}{impl_name}({args_exprs_str}); }} + +}} // anonymous namespace """ elif self.target is Target.REGISTRATION: