Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent JIT from overspecializing to every single size configuration #10844

Closed
wants to merge 10 commits into from

Conversation

apaszke
Copy link
Contributor

@apaszke apaszke commented Aug 24, 2018

Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details.

Summary of changes:

  • Renamed TensorType to CompleteTensorType. Added a new TensorType which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use CompleteTensorType less, as most passes will only have limited information available. To make transition easier complete_type->cast<TensorType>() works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail.
  • Renamed ArgumentSpec to CompleteArgumentSpec. Added a new ArgumentSpec, which matches argument only at the level of the new TensorType.
  • Shape analysis can process graphs with both CompleteTensorType and TensorType.
  • Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in graph_fuser.cpp.

@zdevito @ezyang @mruberry @ngimel @csarofeen

%6 : Float(4!, 4) = aten::expand(%2, %3, %4)
%7 : Float(4, 4) = prim::FusionGroup_0[device=0](%6, %0, %5)
return (%7);
graph(%0 : Float(*, *)

This comment was marked as off-topic.

// - Associativity: A simple visual proof is that you can expand 3 tensors
// at the same time by stacking their sizes (with alignment to the right),
// just as you'd do in the case of 2 tensors, but with an intermediate
// (the algorithm ends up being pretty much the same).

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// Proof. A simple exercise for the reader :)
//
// Theorem. If all (pre-concat-)outputs have equal shapes, then we can push the expands to
// (pre-chunk-)inputs, and have all intermediates of the same shape

This comment was marked as off-topic.

// Lemma 4. Expands can be collapsed, i.e. E(E(x, s1), s2) = E(x, B(s1, s2)).
// Proof. A simple exercise for the reader :)
//
// Theorem. If all (pre-concat-)outputs have equal shapes, then we can push the expands to

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -183,22 +248,22 @@ struct TORCH_API TensorType : public Type {
}
static TypePtr fromNumberType(TypePtr typ);

static CompleteTensorTypePtr sliceSubtypes(const CompleteTensorTypePtr& type) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -63,39 +63,30 @@ IValue representativeValue(Value* v) {

void PropagateShapeOnBlock(Block * block, bool insert_expands=true);

// for each node in the schema with type Tensor, extract the TensorType
// for each node in the schema with type Tensor, extract the CompleteTensorType
// returns at::nullopt if any Tensor in the schema does not have a known shape
// ignores non-tensor in the list of inputs

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -117,57 +116,123 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
// This node represents a single Tensor value with a specific size

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -49,6 +55,82 @@ std::vector<bool> TensorDesc::findContiguous(
return cont;
}

// Descriptor for chunk-ing an input tensor into subtensors
// OR concat-ing an output tensor from subtensors
struct PartitionDesc {

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -28,7 +36,7 @@ struct TensorDesc {
: TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
TensorDesc(const at::Tensor& t)
: TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
TensorDesc(TensorTypePtr type)
TensorDesc(CompleteTensorTypePtr type)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// an output is actually a concatenation of
// many subtensors that the fusion group produces
std::vector<PartitionDesc> concat_desc;
struct FusedKernelCache {

This comment was marked as off-topic.

auto uses = input->uses();
if (uses.size() == 1) {
Node *user = uses[0].user;
if (user->kind() == prim::FusedChunk) {

This comment was marked as off-topic.

This comment was marked as off-topic.

chunk_desc.emplace_back();
flat_inputs.emplace_back(p, agraph.input_desc[input_index++]);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

outputs.clear();
outputs.reserve(outputDescriptors().size());
for(auto & od : outputDescriptors()) {
outputs.push_back(torch::getType(backend(),od.scalar_type).tensor());
outputs.push_back(ref_type.toScalarType(od.scalar_type).tensor());

This comment was marked as off-topic.

InterpreterState(fallback_code).runOneStage(stack);
}

void FusedKernelCache::expandArgs(std::vector<at::Tensor>& args, std::vector<int64_t>& map_size) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

at::optional<std::vector<int64_t>> FusedKernelCache::getMapSize(at::TensorList args, at::IntList arg_subset) {
// NB: we leave this uninitialized, because an empty size is trivially

This comment was marked as off-topic.

This comment was marked as off-topic.

if (chunk_desc.nSubtensors == 1) {
try {
map_size = at::infer_size(map_size, arg.sizes());
} catch (std::exception& e) {

This comment was marked as off-topic.

This comment was marked as off-topic.

int64_t num_chunks = chunk_desc.nSubtensors;
int64_t dim = chunk_desc.dim;
if (dim < 0) {
dim += arg.dim();

This comment was marked as off-topic.

This comment was marked as off-topic.

if (!arg.sizes().equals(map_size)) {
arg = arg.expand(map_size);
}
map_size.at(pdesc.dim) /= pdesc.nSubtensors;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

return std::all_of(tensors.begin(), tensors.end(), [&expected](Value *v) {
auto actual = v->type()->cast<TensorType>();
return actual && actual->sizes() == expected->sizes();
});

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -898,7 +805,7 @@ struct GraphFuser {
Node * chunked_op = block->owningGraph()->create(producer_for_chunk_node->kind());
chunked_op->copyAttributes(*producer_for_chunk_node);
// Invariant: mappable operators always produce contiguous output

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Aug 24, 2018

Question: Shouldn't TensorType (the non-complete version) also track contiguity? It seems that we can get pretty accurate information about it once we get inside the graph.

bool PropagateCompleteShapeOnNode(
Node * node, bool insert_expands, std::vector<CompleteTensorTypePtr> types);

void PropagateCatShape(Node * cat_node) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work.

@apaszke
Copy link
Contributor Author

apaszke commented Aug 24, 2018

We could track contiguity in TensorShape, but:

  1. Contiguity is very tightly bound to strides, and not having to think about strides makes it much easier to work with the code. Most ops only have a size contract, but don't have a clear stride contract.
  2. It doesn't seem to be necessary. I don't recall even a single place where it would enable us to do something we can't already do.

@ezyang
Copy link
Contributor

ezyang commented Aug 24, 2018

Without symbolic sizes, we can't really track stride contracts in a useful way, but ops definitely have contiguity contracts which users know about, because contiguity tells you if you can view() a tensor, e.g.

But you're right, let's add it if/when a pass actually desperately wants to know about contiguity.

@apaszke apaszke force-pushed the jit_type_unspecialize branch 2 times, most recently from 9b85a71 to f5efae2 Compare August 24, 2018 17:04
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good -- I have a bunch of individual but minor comments and questions.

Higher level notes for building on top of this:

  • This puts the FusionCompiler at the breaking point of complexity. Further functionality added there is going require refactor some of the Chunk and Concat logic into a separate phase, given how many times we are need to check nSubtensors, and do other things to derive correct sizes. The core of the fusion is simple, but this chunk/concat stuff is getting spread all over the place.
  • We may notice regressions from overhead in launching fused kernels. For reference, we know the time it takes to do chunk in the interpreter adds significant overhead, and that seems on the same order of magnitude as the extra checking added here. We will need to monitor this and optimize if necessary.

@@ -52,27 +52,21 @@ bool isDifferentiable(Node * n) {
return true;

if (n->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
return static_cast<bool>(n->input(1)->type()->cast<TensorType>());
return static_cast<bool>(n->input(1)->type()->cast<CompleteTensorType>());

This comment was marked as off-topic.

This comment was marked as off-topic.

} else if (kind == TypeKind::TensorType) {
type_proto->set_denotation("TensorType");
TensorTypePtr node_type = type->cast<TensorType>();
} else if (kind == TypeKind::CompleteTensorType) {

This comment was marked as off-topic.

@@ -225,6 +225,9 @@ TypePtr ModuleDecoder::buildType(const onnx::TypeProto& type_proto) {
} else if (kind == "TensorType") {
// TODO: Don't use DynamicType here
return DynamicType::get();
} else if (kind == "CompleteTensorType") {

This comment was marked as off-topic.

@@ -719,8 +719,8 @@ struct InterpreterStateImpl {
current_pc = pc;
current_stage++;
}
const TensorType & tensorTypeForInput(size_t i) const {
return *function->preprocess.stage_input_types.at(current_stage).at(i)->expect<TensorType>();
const CompleteTensorType & tensorTypeForInput(size_t i) const {

This comment was marked as off-topic.

@@ -150,8 +150,8 @@ void BatchMMBlock(Block* block) {
std::unordered_map<Node*, TreeToken> tokens;
for (auto node : block->nodes()) {
if (node->kind() == aten::mm &&
node->input(0)->type()->cast<TensorType>() &&
node->input(1)->type()->cast<TensorType>()) {
node->input(0)->type()->cast<CompleteTensorType>() &&

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

if (!arg.sizes().equals(map_size)) {
arg = arg.expand(map_size);
}
map_size.at(pdesc.dim) /= pdesc.nSubtensors;

This comment was marked as off-topic.

for(auto & i : inputs) {
agraph.input_desc.emplace_back(i);
agraph.input_desc = spec.descs();
at::optional<at::ScalarType> scalar_type;

This comment was marked as off-topic.

This comment was marked as off-topic.

throw std::runtime_error("cannot compile a CUDA fusion group, CUDA is not enabled.");
#endif
} else {
JIT_ASSERT(compiler.canCompileOnCPU());

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -252,6 +301,14 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) {
}
return;
}
// NB: We assume that all shapes are known within fused kernels

This comment was marked as off-topic.

This comment was marked as off-topic.

}
}

if (canPropagateShapeByRunningIt(node))

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor Author

apaszke commented Aug 25, 2018

The last commit forces slicing on types if the dynamic cast succeeds, but the kind doesn't match exactly. Normally we could simply use the copy constructor in this place, but because we're incorrectly comparing only addresses of types in many cases (because we assume they're used as singletons), the slicing needs some extra care. We really should either stop using shared pointers to types in most places (that would also come with a benefit of not having to incref/decref just to check the type), or have a subclass of shared pointer that uses equality on held elements in operator==.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

petrex added a commit to petrex/pytorch that referenced this pull request Aug 27, 2018
* upstream/master: (89 commits)
  move HeatmapMaxKeypointOp unittest to oss
  fix xfails involving literals (pytorch#10905)
  Bag of Distributions doc fixes (pytorch#10894)
  Remove FIXME_zerol() from test_jit.py (pytorch#10900)
  Increase BC for PackedSequence ctor (pytorch#9864)
  Remove ability of Scalars to hold Tensors.
  Begin a bestiary of MSVC/NVCC bugs. (pytorch#10883)
  Prevent JIT from overspecializing to every single size configuration (pytorch#10844)
  Handling failing test on ROCm.
  Update mobile predictor caller's interface
  Cache isContiguous and numel
  Create class constant for string literal 'blob_names'
  Conv BN fusion for 3D conv (pytorch#10239)
  Stop using symbolic override for tracing RNNs (pytorch#10638)
  Add registry to pybind_state (pytorch#10759)
  Remove the nanopb submodule
  Create at::linear (pytorch#10799)
  Refactor THCNumerics and add common math functions for at::Half (pytorch#10301)
  Remove Tensor constructor of Scalar. (pytorch#10852)
  Revert D9492561: [pytorch][PR] Moving the operator argument to the front for kernelPointwiseApply.
  ...
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
…ytorch#10844)

Summary:
Please review the expects carefully to make sure there are no regressions. I tried to go over them one by one when they changed, but it's sometimes easy to miss finer details.

Summary of changes:

- Renamed `TensorType` to `CompleteTensorType`. Added a new `TensorType` which records only the scalar type, number of dimensions, and device of a value. The argument behind the rename is to encourage people to use `CompleteTensorType` less, as most passes will only have limited information available. To make transition easier `complete_type->cast<TensorType>()` works, and makes our passes work with both kinds of specialization if they don't need extra the extra detail.
- Renamed `ArgumentSpec` to `CompleteArgumentSpec`. Added a new `ArgumentSpec`, which matches argument only at the level of the new `TensorType`.
- Shape analysis can process graphs with both `CompleteTensorType` and `TensorType`.
- Fuser was a part that heavily relied on full shape information being available. Now, we simply try to fuse the largest possible graphs, and have to do run-time checks to make sure they match the code we generate. If they don't, we fall back to regular interpretation. The shape checks are implementing using an optimized method exploiting algebraic properties of shapes with broadcasting, and the relations of broadcasting with pointwise ops. A full written proof of correctness of the shape checking algorithm is included in a comment in `graph_fuser.cpp`.

zdevito ezyang mruberry ngimel csarofeen
Pull Request resolved: pytorch#10844

Differential Revision: D9498705

Pulled By: apaszke

fbshipit-source-id: 0c53c2fcebd871cc2a29c260f8d012276479cc61
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants