Skip to content

Commit

Permalink
Update on "Rewrite implementation of faithful cpp signatures"
Browse files Browse the repository at this point in the history
This rewrite is as per my comments at #44087 (comment)
I did the rewrite by reverting #44087 and then reimplementing it on top.
You may find it easier to review by diffing against master with only #44087
reverted.

There are two main ideas.

First, we now factor cpp argument processing into two phases operating
on three representations of data:

1. `FunctionSchema` - this is the source from native_functions.yaml
2. `Union[Argument, ThisArgument, TensorOptionsArgument]` - this is
   the arguments after doing some basic semantic analysis to group
   them (for TensorOptions) or identify the this argument (if this
   is a method).  There is only ever one of these per functions.
3. `Union[CppArgument, CppThisArgument, CppTensorOptionsArgument]` -
   this is the arguments after we've elaborated them to C++.  There
   may be multiple of these per actual C++ signature.

You can think of (2) as common processing, whereas (3) bakes in specific
assumptions about whether or not you have a faithful or non-faithful
signature.

Second, we now have CppSignature and CppSignatureGroup representing
the *total* public C++ API signature.  So those dataclasses are what
know how to render definitions/declarations, and you no longer have
to manually type it out in the Functions/TensorMethods codegen.

Here is an exhaustive accounting of the changes.

tools.codegen.api.types

- CppSignature and CppSignatureGroup got moved to tools.codegen.api.types
- Add new CppThisArgument and CppTensorOptionsArguments (modeled off
  of ThisArgument and TensorOptionsArguments) so that we can retain
  high level semantic structure even after elaborating terms with C++
  API information.  Once this is done, we can refine
  CppArgument.argument to no longer contain a ThisArgument (ThisArgument
  is always translated to CppThisArgument.  Note that this doesn't
  apply to TensorOptionsArguments, as those may be expanded or not
  expanded, and so you could get a single CppArgument for 'options')
- Add no_default() functional mutator to easily remove default arguments
  from CppArgument and friends
- Add an explicit_arguments() method to CppArgument and friends to
  extract (flat) argument list that must be explicitly written in the signature.
  This is everything except (Cpp)ThisArgument, and is also convenient
  when you don't care about the extra structure of
  CppTensorOptionsArguments

tools.codegen.api.cpp

- group_arguments is back, and it doesn't send things directly to a
  CppSignatureGroup; instead, it moves us from representation (1) to (2)
  (perhaps it should live in model).  Here I changed my mind from my
  PR comment; I discovered it was not necessary to do classification at
  grouping time, and it was simpler and easier to do it later.
- argument got split into argument_not_this/argument/argument_faithful.
  argument and argument_faithful are obvious enough what they do,
  and I needed argument_not_this as a more refined version of argument
  so that I could get the types to work out on TensorOptionsArguments

tools.codegen.api.dispatcher

- Here we start seeing the payoff.  The old version of this code had a
  "scatter" mode and a "gather" mode.  We don't need that anymore:
  cppargument_exprs is 100% type-directed via the passed in cpp
  arguments.  I am able to write the functions without any reference
  to use_c10_dispatcher

tools.codegen.gen

- Instead of having exprs_str and types_str functions, I moved these to
  live directly on CppSignature, since it seemed pretty logical.
- The actual codegen for TensorMethods/Functions is greatly simplified,
  since (1) all of the heavy lifting is now happening in
  CppSignature(Group) construction, and (2) I don't need to proxy one
  way or another, the new dispatcher translation code is able to handle
  both cases no problem.  There is a little faffing about with ordering
  to reduce the old and new diff which could be removed afterwards.

Here are codegen diffs.  For use_c10_dispatcher: full:

```
+// aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
 Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, const TensorOptions & options) {
-    return _cudnn_init_dropout_state(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::_cudnn_init_dropout_state", "")
+        .typed<Tensor (double, bool, int64_t, c10::optional<ScalarType>, c10::optional<Layout>, c10::optional<Device>, c10::optional<bool>)>();
+    return op.call(dropout, train, dropout_seed, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
 }
```

Otherwise:

```
+// aten::empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
 Tensor empty_meta(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> memory_format) {
-    return empty_meta(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
+    static auto op = c10::Dispatcher::singleton()
+        .findSchemaOrThrow("aten::empty_meta", "")
+        .typed<Tensor (IntArrayRef, const TensorOptions &, c10::optional<MemoryFormat>)>();
+    return op.call(size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), memory_format);
 }
```

Things that I probably did not get right:

- The Union[Argument, TensorOptionsArguments, ThisArgument] and
  the Cpp variants are starting to get a little unwieldy.  Not sure if
  this means I should add a supertype (or at the very least an
  alias); in some cases I do purposely omit one of these from the Union
- Code may not necessarily live in the most logical files.  There isn't
  very much rhyme or reason to it.
- The fields on CppSignature.  They're not very well constrained and
  it will be better if people don't use them directly.
- Disambiguation.  We should do this properly in #44087 and we don't
  need special logic for deleting defaulting for faithful signatures;
  there is a more general story here.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D24144035](https://our.internmc.facebook.com/intern/diff/D24144035)

[ghstack-poisoned]
  • Loading branch information
ezyang committed Oct 9, 2020
2 parents 0d34b1e + a5c0dbc commit 99e115a
Show file tree
Hide file tree
Showing 51 changed files with 1,175 additions and 381 deletions.
26 changes: 18 additions & 8 deletions .circleci/cimodel/data/pytorch_build_data.py
Expand Up @@ -18,7 +18,11 @@
("clang", [
("5", [
("3.6", [
("asan", [XImportant(True)]),
("asan", [
(True, [
("shard_test", [XImportant(True)]),
]),
]),
]),
]),
("7", [
Expand All @@ -45,14 +49,14 @@
]),
("10.2", [
("3.6", [
("important", [X(True)]),
("shard_test", [XImportant(True)]),
("libtorch", [X(True)]),
]),
]),
("11.0", [
("3.8", [
X(True),
("libtorch", [XImportant(True)])
("libtorch", [XImportant(True)]),
]),
]),
]),
Expand Down Expand Up @@ -158,6 +162,7 @@ def child_constructor(self):
"libtorch": LibTorchConfigNode,
"important": ImportantConfigNode,
"build_only": BuildOnlyConfigNode,
"shard_test": ShardTestConfigNode,
"cuda_gcc_override": CudaGccOverrideConfigNode,
"coverage": CoverageConfigNode,
"pure_torch": PureTorchConfigNode,
Expand Down Expand Up @@ -195,7 +200,7 @@ def init2(self, node_name):
self.props["is_asan"] = node_name

def child_constructor(self):
return ImportantConfigNode
return ExperimentalFeatureConfigNode


class ONNXConfigNode(TreeConfigNode):
Expand Down Expand Up @@ -260,17 +265,24 @@ def init2(self, node_name):
def child_constructor(self):
return ExperimentalFeatureConfigNode

class BuildOnlyConfigNode(TreeConfigNode):

class BuildOnlyConfigNode(TreeConfigNode):
def init2(self, node_name):
self.props["build_only"] = node_name

def child_constructor(self):
return ExperimentalFeatureConfigNode


class CoverageConfigNode(TreeConfigNode):
class ShardTestConfigNode(TreeConfigNode):
def init2(self, node_name):
self.props["shard_test"] = node_name

def child_constructor(self):
return ImportantConfigNode


class CoverageConfigNode(TreeConfigNode):
def init2(self, node_name):
self.props["is_coverage"] = node_name

Expand All @@ -290,7 +302,6 @@ def get_children(self):


class XenialCompilerConfigNode(TreeConfigNode):

def modify_label(self, label):
return label or "<unspecified>"

Expand All @@ -304,7 +315,6 @@ def child_constructor(self):


class BionicCompilerConfigNode(TreeConfigNode):

def modify_label(self, label):
return label or "<unspecified>"

Expand Down
6 changes: 4 additions & 2 deletions .circleci/cimodel/data/pytorch_build_definitions.py
Expand Up @@ -288,7 +288,6 @@ def instantiate_configs():
rocm_version = None
if compiler_name == "cuda":
cuda_version = fc.find_prop("compiler_version")
restrict_phases = ["build", "test1", "test2"]

elif compiler_name == "rocm":
rocm_version = fc.find_prop("compiler_version")
Expand All @@ -311,7 +310,6 @@ def instantiate_configs():
parms_list.append("asan")
python_version = fc.find_prop("pyver")
parms_list[0] = fc.find_prop("abbreviated_pyver")
restrict_phases = ["build", "test1", "test2"]

if is_onnx:
parms_list.append("onnx")
Expand All @@ -328,7 +326,11 @@ def instantiate_configs():
parallel_backend = fc.find_prop("parallel_backend") or None
build_only = fc.find_prop("build_only") or False
is_coverage = fc.find_prop("is_coverage") or False
shard_test = fc.find_prop("shard_test") or False
# TODO: fix pure_torch python test packaging issue.
if shard_test:
restrict_phases = ["build"] if restrict_phases is None else restrict_phases
restrict_phases.extend(["test1", "test2"])
if build_only or is_pure_torch:
restrict_phases = ["build"]
if is_coverage and restrict_phases is None:
Expand Down
66 changes: 8 additions & 58 deletions .circleci/config.yml
Expand Up @@ -6668,7 +6668,7 @@ workflows:
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test1
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test
requires:
- pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build
filters:
Expand All @@ -6677,21 +6677,7 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test2
requires:
- pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test2"
build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand Down Expand Up @@ -6802,21 +6788,7 @@ workflows:
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test1
requires:
- pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test2
name: pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test
requires:
- pytorch_libtorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
filters:
Expand All @@ -6825,7 +6797,7 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test2"
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand All @@ -6842,7 +6814,7 @@ workflows:
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test
requires:
- pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
filters:
Expand All @@ -6851,21 +6823,7 @@ workflows:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2
requires:
- pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2"
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand All @@ -6876,18 +6834,10 @@ workflows:
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test1
requires:
- pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test1"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_test:
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test2
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test
requires:
- pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test2"
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
Expand Down
42 changes: 23 additions & 19 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Expand Up @@ -16,7 +16,7 @@ at::Tensor embedding_bag_4bit_helper(
const at::Tensor& weight,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
Expand All @@ -38,10 +38,10 @@ at::Tensor embedding_bag_4bit_helper(
auto weight_contig = weight.contiguous();
uint8_t* input_data = weight_contig.data_ptr<uint8_t>();

// Get compressed indices for sparse op.
// Get compressed indices for pruned_weights op.
int32_t* compressed_indices_mapping_data = nullptr;
int compressed_index_size = 0;
if (sparse) {
if (pruned_weights) {
compressed_index_size = compressed_indices_mapping.value().numel();
compressed_indices_mapping_data =
compressed_indices_mapping.value().data_ptr<int32_t>();
Expand Down Expand Up @@ -77,7 +77,7 @@ at::Tensor embedding_bag_4bit_helper(
const int index_size = indices.numel();
constexpr int prefetch_distance = 16;
#ifdef USE_FBGEMM
if (!sparse) {
if (!pruned_weights) {
// Generate the fbgemm kernel
auto kernel_64_ = fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
/*bit rate=*/4,
Expand Down Expand Up @@ -158,7 +158,7 @@ at::Tensor embedding_bag_4bit_helper(

for (int i = 0; i < lengths_data[m]; ++i, ++current) {
int64_t idx;
if (!sparse) {
if (!pruned_weights) {
idx = indices_data[current];
TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data");
} else {
Expand Down Expand Up @@ -201,7 +201,7 @@ at::Tensor embedding_bag_byte_helper(
const at::Tensor& packed_w,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
bool include_last_offset) {
TORCH_CHECK(
Expand Down Expand Up @@ -301,30 +301,30 @@ at::Tensor embedding_bag_byte_helper(
at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
bool include_last_offset) {
return embedding_bag_byte_helper(
packed_w,
indices,
offsets_in,
sparse,
pruned_weights,
per_sample_weights_,
include_last_offset);
}

at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
bool pruned_weights,
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
return embedding_bag_4bit_helper(
packed_w,
indices,
offsets_in,
sparse,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
Expand All @@ -340,7 +340,7 @@ Tensor embedding_bag_byte_rowwise_offsets(
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool sparse,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
bool include_last_offset) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
Expand All @@ -349,7 +349,7 @@ Tensor embedding_bag_byte_rowwise_offsets(
weight,
indices,
offsets_in,
sparse,
pruned_weights,
per_sample_weights_,
include_last_offset);
}
Expand All @@ -360,15 +360,15 @@ Tensor embedding_bag_4bit_rowwise_offsets(
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool sparse,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
return embedding_bag_4bit_helper(
weight,
indices,
offsets_in,
sparse,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
Expand All @@ -383,18 +383,22 @@ class QEmbeddingBag final {
const c10::optional<Tensor>& offsets,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool sparse,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
indices, offsets, sparse, per_sample_weights_, include_last_offset);
indices,
offsets,
pruned_weights,
per_sample_weights_,
include_last_offset);
} else if (bit_rate == 4) {
return packed_weight->embeddingbag_4bit(
indices,
offsets,
sparse,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
Expand All @@ -411,13 +415,13 @@ class QEmbedding final {
static at::Tensor run(
const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
const Tensor& indices,
bool sparse) {
bool pruned_weights) {
const auto offsets_size = indices.numel();
at::Tensor offsets = at::arange(0, offsets_size, at::kLong);
at::Tensor output;
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
indices, offsets, sparse, c10::nullopt, false);
indices, offsets, pruned_weights, c10::nullopt, false);
} else {
TORCH_INTERNAL_ASSERT(
"Currently only support 8-bit embedding quantization");
Expand Down

0 comments on commit 99e115a

Please sign in to comment.