Skip to content

Commit

Permalink
Improve shape analysis to cover all most commonly used ops (#11358)
Browse files Browse the repository at this point in the history
Summary:
[Here's a list](https://gist.github.com/apaszke/f0821840bdcc67a977832dc58acc1b85) of ops that are in `register_aten_ops.cpp`, but aren't supported in shape prop. Everything else should work now.
Pull Request resolved: #11358

Differential Revision: D9753693

Pulled By: apaszke

fbshipit-source-id: efeae0126ce16cb56b8797fc5246405588bcae3c
  • Loading branch information
apaszke authored and facebook-github-bot committed Sep 11, 2018
1 parent f84693e commit 0ddbe66
Show file tree
Hide file tree
Showing 9 changed files with 834 additions and 90 deletions.
58 changes: 50 additions & 8 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7059,6 +7059,19 @@ def forward(self, x, y):
'test_split_dim_neg0',
}

EXCLUDE_TYPE_CHECK = {
# slogdet tests use itemgetter to select its only differentiable output,
# but this happens outside of the graph we handle, so there are fewer
# reference outputs than graph outputs.
'test_slogdet_1x1_neg_det',
'test_slogdet_1x1_pos_det',
'test_slogdet_distinct_singular_values',
'test_slogdet_neg_det',
'test_slogdet_pos_det',
'test_slogdet_symmetric',
'test_slogdet_symmetric_pd',
}

# known to be failing in script
EXCLUDE_SCRIPT = {
# TODO: Fix var/std
Expand Down Expand Up @@ -7182,7 +7195,9 @@ def traced_fn(*inputs, **kwargs):
fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
traced = torch.jit.trace(fn_tensors, inputs_tensors)
self.assertExportImport(traced.graph, inputs_tensors)
return traced(*inputs_tensors)
output = traced(*inputs_tensors)
traced_fn.last_graph = traced.graph_for(*inputs_tensors)
return output
return traced_fn

script_template = '''
Expand Down Expand Up @@ -7222,12 +7237,30 @@ def script_fn(*args, **kwargs):
script = script_template.format(', '.join(formals), call)
CU = torch.jit.CompilationUnit(script)
self.assertExportImport(CU.the_method.graph, tensors)

return output_process_fn(CU.the_method(*tensors))
output = output_process_fn(CU.the_method(*tensors))
script_fn.last_graph = CU.the_method.graph_for(*tensors)
return output
return script_fn


def check_against_reference(self, func, reference_func, args, kwargs=None, allow_unused=True):
def check_output_types(self, func, ref_outputs, args, kwargs):
graph = getattr(func, 'last_graph', None)
if not isinstance(ref_outputs, tuple):
ref_outputs = (ref_outputs,)
types = [o.type() for o in graph.outputs()]
self.assertEqual(len(types), len(ref_outputs))
for i, (t, ref_out) in enumerate(zip(types, ref_outputs)):
if isinstance(ref_out, list):
assert len(ref_out) > 0
elem = ref_out[0]
assert isinstance(elem, torch.Tensor)
self.assertTrue(t.isSubtypeOf(torch._C.ListType.ofTensors()))
else:
ref_type = torch._C.Type.inferFrom(ref_out)
self.assertTrue(ref_type.isSubtypeOf(t))


def check_against_reference(self, func, reference_func, args, kwargs=None, allow_unused=True, check_types=True):
kwargs = kwargs if kwargs else {}

def allSum(vs):
Expand All @@ -7252,6 +7285,9 @@ def clone_inputs(requires_grad):
outputs_test = func(*nograd_inputs, **kwargs)
self.assertEqual(outputs, outputs_test)

if check_types:
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)

# test single grad case
outputs = reference_func(*recording_inputs, **kwargs)
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
Expand Down Expand Up @@ -7577,15 +7613,19 @@ def fn(*inputs, **kwargs):
output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
return output_process_fn(output)

check_types = test_name not in EXCLUDE_TYPE_CHECK

if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
if test_name not in EXCLUDE_TRACED:
check_against_reference(self, create_traced_fn(self, fn),
fn, (self_variable,) + args_variable, kwargs_variable)
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)

if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
check_against_reference(self,
create_script_fn(self, name, 'method', output_process_fn),
fn, (self_variable,) + args_variable, kwargs_variable)
fn, (self_variable,) + args_variable, kwargs_variable,
check_types=check_types)

# functional interface tests
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
Expand All @@ -7597,12 +7637,14 @@ def fn(*inputs, **kwargs):
f_args_tensor = (self_tensor,) + args_tensor

if not is_inplace and test_name not in EXCLUDE_TRACED:
check_against_reference(self, create_traced_fn(self, fn), fn, f_args_variable, kwargs_variable)
check_against_reference(self, create_traced_fn(self, fn), fn,
f_args_variable, kwargs_variable, check_types=check_types)

if not is_inplace and test_name not in EXCLUDE_SCRIPT:
check_against_reference(self,
create_script_fn(self, name, 'functional', output_process_fn),
fn, f_args_variable, kwargs_variable)
fn, f_args_variable, kwargs_variable,
check_types=check_types)

check(name)
inplace_name = name + '_'
Expand Down
6 changes: 3 additions & 3 deletions tools/jit/gen_jit_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def jit_type_of(arg):
# map from aten 'simple_type' to the function that will turn a tensor into
# that type
FROM_IVALUE = {
'Device': 'as_device({}.toIntList()->elements())',
'Device': '{}.to<at::Device>()',
'IntList': '{}.toIntList()->elements()',
'Layout': 'static_cast<at::Layout>({}.toInt())',
'Layout': '{}.to<at::Layout>()',
'Scalar': '{}.toScalar()',
'ScalarType': 'static_cast<at::ScalarType>({}.toInt())',
'ScalarType': '{}.to<at::ScalarType>()',
'Tensor': '{}.toTensor()',
'TensorList': '{}.toTensorList()->elements()',
'bool': 'bool({}.toInt())',
Expand Down
4 changes: 0 additions & 4 deletions tools/jit/templates/register_aten_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ std::array<bool, N> as_bool_array(at::ArrayRef<int64_t> vec) {
return res;
}

at::Device as_device(ArrayRef<int64_t> elements) {
return at::Device(static_cast<at::Device::Type>(elements[0]), elements[1]);
}

RegisterOperators reg({
${constructors}
});
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ struct GraphExecutorImpl {

// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape information
// anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
PropagateInputShapes(*opt_graph, spec);

// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites that
Expand Down Expand Up @@ -427,7 +431,6 @@ struct GraphExecutorImpl {
EliminateDeadCode(graph);
EliminateCommonSubexpression(graph);
UnrollLoops(graph);
ConstantPropagation(graph);
PeepholeOptimize(graph);
CheckInplace(graph);
BatchMM(graph);
Expand Down
29 changes: 29 additions & 0 deletions torch/csrc/jit/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,35 @@ DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)

#undef DEFINE_TO

#define DEFINE_TO_WITH_BODY(type, body) \
template<> \
inline type IValue::to<type>() && { \
body(std::move(*this)); \
} \
template<> \
inline type IValue::to<type>() const & { \
body((*this)); \
}

#define SCALAR_TYPE_BODY(this) return static_cast<at::ScalarType>(this.toInt());
#define LAYOUT_BODY(this) return static_cast<at::Layout>(this.toInt());
#define DEVICE_BODY(this) \
/* NB: const_list might be a move of the vector, so we need to */ \
/* assign it to prevent its deallocation. */ \
auto && const_list = this.toIntList(); \
const auto & elems = const_list->elements(); \
JIT_ASSERT(elems.size() == 2); \
return at::Device(static_cast<at::Device::Type>(elems[0]), elems[1]);

DEFINE_TO_WITH_BODY(at::ScalarType, SCALAR_TYPE_BODY)
DEFINE_TO_WITH_BODY(at::Layout, LAYOUT_BODY)
DEFINE_TO_WITH_BODY(at::Device, DEVICE_BODY)

#undef DEFINE_TO_WITH_BODY
#undef SCALAR_TYPE_BODY
#undef LAYOUT_BODY
#undef DEVICE_BODY

inline IValue::IValue(c10::intrusive_ptr<Tuple> v)
: tag(Tag::Tuple), is_intrusive_ptr(true) {
as_intrusive_ptr = v.release();
Expand Down
Loading

0 comments on commit 0ddbe66

Please sign in to comment.