Skip to content

Commit

Permalink
[TorchScript] Support user defined classes as constants (#5062)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/glow#5062

Pull Request resolved: #45556

User defined classes can be used as constants.  This is useful when freezing and removing the module from the graph.

Test Plan: waitforsadcastle

Reviewed By: eellison

Differential Revision: D23994974

fbshipit-source-id: 6c494269e222b7e1b5ecddf0c460ae8c09ac0556
  • Loading branch information
bwasti authored and facebook-github-bot committed Nov 16, 2020
1 parent 95ea778 commit b71b970
Show file tree
Hide file tree
Showing 17 changed files with 144 additions and 26 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Expand Up @@ -508,6 +508,9 @@ std::ostream& IValue::repr(
return out << enum_holder->qualifiedClassName() << "." <<
enum_holder->name();
}
case IValue::Tag::Object: {
TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?");
}
default:
TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
}
Expand Down
43 changes: 43 additions & 0 deletions test/cpp/jit/test_custom_class.cpp
@@ -1,6 +1,7 @@
#include <gtest/gtest.h>

#include <test/cpp/jit/test_custom_class_registrations.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/custom_class.h>
#include <torch/script.h>

Expand Down Expand Up @@ -86,5 +87,47 @@ TEST(CustomClassTest, TestDocString) {
method_doc_string);
}

TEST(CustomClassTest, Serialization) {
script::Module m("m");

// test make_custom_class API
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.register_attribute(
"s",
custom_class_obj.type(),
custom_class_obj,
/*is_parameter=*/false);
m.define(R"(
def forward(self):
return self.s.return_a_tuple()
)");

auto test_with_obj = [](script::Module& mod) {
auto res = mod.run_method("forward");
auto tup = res.toTuple();
AT_ASSERT(tup->elements().size() == 2);
auto i = tup->elements()[1].toInt();
AT_ASSERT(i == 123);
};

auto frozen_m = torch::jit::freeze_module(m.clone());

test_with_obj(m);
test_with_obj(frozen_m);

std::ostringstream oss;
m.save(oss);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
auto loaded_module = torch::jit::load(iss, torch::kCPU);

std::ostringstream oss_frozen;
frozen_m.save(oss_frozen);
std::istringstream iss_frozen(oss_frozen.str());
caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
}

} // namespace jit
} // namespace torch
6 changes: 3 additions & 3 deletions test/quantization/test_quantize_jit.py
Expand Up @@ -2674,6 +2674,7 @@ def forward(self, x):
num_quantize_per_tensor = 1 # for output
for num_quant, num_op in num_op_by_num_quant.items():
num_quantize_per_tensor += num_op * num_quant
num_quantize_per_tensor -= 4 # constant propagation removes some prepacks
FileCheck().check_count("aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True) \
.run(m1.graph)

Expand Down Expand Up @@ -2997,8 +2998,7 @@ def forward(self, x):
m = torch.jit.script(M())
m = quantize_dynamic_jit(m, {'': float16_dynamic_qconfig})

FileCheck().check("quantized::linear_prepack_fp16") \
.check_next("quantized::linear_dynamic_fp16") \
FileCheck().check("quantized::linear_dynamic_fp16") \
.check_not("aten::linear") \
.check_not("aten::dequantize") \
.check_not("aten::quantize") \
Expand Down Expand Up @@ -3076,7 +3076,7 @@ def forward(self, indices1, offsets1, indices2, offsets2):
m = prepare_jit(m, {'embedding1' : int4_qconfig, 'embedding2' : int8_qconfig})
m = convert_jit(m)
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
.check("quantized::embedding_bag_byte_rowwise_offsets") \
.run(m.graph)
m(*dummy_inputs)

Expand Down
4 changes: 2 additions & 2 deletions test/test_mobile_optimizer.py
Expand Up @@ -340,7 +340,7 @@ def _quant_script_and_optimize(model):

m, m_optim = _quant_script_and_optimize(Standalone())
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
.run(m_optim.graph)
self.assertFalse(hasattr(m_optim, "conv1"))
self.assertFalse(hasattr(m_optim, "conv2"))
Expand All @@ -354,7 +354,7 @@ def _quant_script_and_optimize(model):

m, m_optim = _quant_script_and_optimize(Parent())
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
.run(m_optim.graph)
self.assertFalse(hasattr(m_optim, "conv1"))
self.assertFalse(hasattr(m_optim, "child"))
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -179,6 +179,7 @@ def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_inline(Graph) -> None: ...
def _jit_pass_constant_propagation(Graph) -> None: ...
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
def _jit_can_fuse_on_cpu() -> _bool: ...
def _jit_can_fuse_on_gpu() -> _bool: ...
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/ir/constants.cpp
Expand Up @@ -124,10 +124,9 @@ c10::optional<Value*> tryInsertConstant(
n->destroy();
return c10::nullopt;
};
} else if (val.isGenericDict() && insertableIValue(val)) {
n->ival_(attr::value, val);
n->output()->setType(val.type());
} else if (val.isEnum()) {
} else if (
(val.isGenericDict() && insertableIValue(val)) || (val.isEnum()) ||
(val.isObject() && !val.toObjectRef().type()->is_module())) {
n->ival_(attr::value, val);
n->output()->setType(val.type());
} else {
Expand Down Expand Up @@ -191,6 +190,9 @@ c10::optional<IValue> toIValue(const Value* v) {
} else if (type->cast<EnumType>()) {
const auto& enum_val = node->ival(attr::value);
return enum_val;
} else if (type->cast<ClassType>() && !type->is_module()) {
const auto& class_val = node->ival(attr::value);
return class_val;
} else {
std::stringstream ss;
ss << "constant literal not supported for: " << type->str();
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/ir/ir.cpp
Expand Up @@ -133,6 +133,9 @@ static void printAttribute(std::ostream& out, const IValue& ival) {
} else if (input.isTensorList()) {
ss << "[<Tensors>]";
return true;
} else if (input.isObject() && !input.type()->is_module()) {
ss << "object(" << &input.toObjectRef() << ")";
return true;
}
return false;
};
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/ir/node_hashing.cpp
Expand Up @@ -126,6 +126,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
if (a1.isEnum()) {
return a1.toEnumHolder() == a2.toEnumHolder();
}
if (a1.isObject()) {
return &a1.toObjectRef() == &a2.toObjectRef();
}
TORCH_INTERNAL_ASSERT(false);
}

Expand Down
36 changes: 28 additions & 8 deletions torch/csrc/jit/passes/constant_propagation.cpp
Expand Up @@ -16,7 +16,9 @@
namespace torch {
namespace jit {

c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
const Node* n,
bool ignore_custom_classes) {
Stack stack;
for (auto input : n->inputs()) {
if (auto ival = toIValue(input)) {
Expand All @@ -25,6 +27,7 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
return c10::nullopt;
}
}

switch (n->kind()) {
case prim::ListUnpack: {
if (stack.back().toList().size() != n->outputs().size()) {
Expand Down Expand Up @@ -80,6 +83,12 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
return c10::nullopt;
}
}
// Weak form of const propagation
if (ignore_custom_classes) {
if (v.isCustomClass()) {
return c10::nullopt;
}
}
}
return stack;
}
Expand All @@ -104,33 +113,40 @@ std::unordered_set<Symbol> skip_list = {
struct ConstantPropagator {
// Runs constant propagation with an aliasing db and checks if inputs or
// outputs might be mutated in the graph
static ConstantPropagator WithAliasDb(std::shared_ptr<Graph> graph) {
return ConstantPropagator(graph, true);
static ConstantPropagator WithAliasDb(
std::shared_ptr<Graph> graph,
bool ignore_custom_classes) {
return ConstantPropagator(std::move(graph), true, ignore_custom_classes);
}

// Runs constant propagation only on ops that clearly do not have aliased
// inputs or outputs without computing aliasing information
static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) {
return ConstantPropagator(graph, false);
return ConstantPropagator(std::move(graph), false, false);
}

void run() {
ConstantPropagation(graph_->block());
}

private:
ConstantPropagator(std::shared_ptr<Graph> graph, bool aliasing_types)
ConstantPropagator(
std::shared_ptr<Graph> graph,
bool aliasing_types,
bool ignore_custom_classes)
: graph_(std::move(graph)) {
if (aliasing_types) {
aliasDb_ = torch::make_unique<AliasDb>(graph_);
} else {
aliasDb_ = nullptr;
}
ignore_custom_classes_ = ignore_custom_classes;
}

void propagateNode(Node* n) {
std::vector<IValue> outputs;
if (auto outputs_opt = runNodeIfInputsAreConstant(n)) {
if (auto outputs_opt =
runNodeIfInputsAreConstant(n, ignore_custom_classes_)) {
outputs = std::move(outputs_opt.value());
} else {
// The op failed to run, so we cannot continue constant-prop for it.
Expand Down Expand Up @@ -353,11 +369,15 @@ struct ConstantPropagator {

std::shared_ptr<Graph> graph_;
std::unique_ptr<AliasDb> aliasDb_;
bool ignore_custom_classes_;
};
} // anonymous namespace

void ConstantPropagation(std::shared_ptr<Graph>& graph) {
ConstantPropagator cp = ConstantPropagator::WithAliasDb(graph);
void ConstantPropagation(
std::shared_ptr<Graph>& graph,
bool ignore_custom_classes) {
ConstantPropagator cp =
ConstantPropagator::WithAliasDb(graph, ignore_custom_classes);
cp.run();
EliminateDeadCode(graph);
GRAPH_DUMP("After ConstantPropagation: ", graph);
Expand Down
16 changes: 13 additions & 3 deletions torch/csrc/jit/passes/constant_propagation.h
Expand Up @@ -5,15 +5,25 @@
namespace torch {
namespace jit {

TORCH_API void ConstantPropagation(std::shared_ptr<Graph>& graph);
// Runs constant propagation on all objects unless ignore_custom_classes is
// specified as true, in which case user defined classes are skipped. This is
// useful to prevent early fusion of packing operations, which end up lowering
// away information about their constructors (e.g. packed::linear_clamp_prepack
// and prepacked::conv2d_clamp_prepack)
TORCH_API void ConstantPropagation(
std::shared_ptr<Graph>& graph,
bool ignore_custom_classes = false);

// runs constant propagation only on ops that have non-aliasing inputs & outputs
TORCH_API void ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph);

// Runs the node if its inputs are constants. Callers of this function must
// make their own determination if constant prop is appropriate - for example
// non-deterministic ops or ops with side effects
TORCH_API c10::optional<Stack> runNodeIfInputsAreConstant(const Node* node);
// non-deterministic ops or ops with side effects. If ignore_custom_classes is
// specified, nodes that output user defined classes are not run.
TORCH_API c10::optional<Stack> runNodeIfInputsAreConstant(
const Node* node,
bool ignore_custom_classes = false);

} // namespace jit
} // namespace torch
12 changes: 11 additions & 1 deletion torch/csrc/jit/passes/freeze_module.cpp
Expand Up @@ -82,7 +82,8 @@ class AttributePropagator {
ClearProfilingInformation(subgraph);
};
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
runOptimization(subgraph, /* unroll? */ false);
runOptimization(
subgraph, /* unroll? */ false, /* const_prop_user_classes? */ false);
};

for (auto function : preservedMethods_) {
Expand Down Expand Up @@ -315,6 +316,15 @@ class AttributePropagator {
val = overrideGradient(val);
}
attr = std::move(dict);
} else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
auto obj_type = attr.type()->expect<ClassType>();
auto obj_value = std::move(attr).toObject();
auto sub_attributes = obj_type->getAttributes();
for (const auto& sub_attr : sub_attributes) {
auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
sub_attr_val = overrideGradient(sub_attr_val);
}
return obj_value;
}

return attr;
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/xnnpack_rewrite.cpp
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
Expand Down Expand Up @@ -334,6 +335,9 @@ void fusePrePackedLinearConvWithClamp(script::Module& module) {
auto graph = module.get_method("forward").graph();
fuseReluWithPackedOps(graph);
fuseHardtanhWithPackedOps(graph);

// Ignore user defined classes for later passes
ConstantPropagation(graph, true);
}

void FoldPrePackingOps(script::Module& m) {
Expand All @@ -348,6 +352,9 @@ void FoldPrePackingOps(script::Module& m) {
"prepacked::conv2d_transpose_clamp_prepack"));
};
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
auto graph = m.get_method("forward").graph();
// Folding requires a const propagation through user defined classes
ConstantPropagation(graph, false);
}

script::Module optimizeForMobile(
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/python/init.cpp
Expand Up @@ -436,7 +436,8 @@ void initJITBindings(PyObject* module) {
})
.def(
"_jit_pass_constant_propagation",
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
py::arg("graph"))
.def("_jit_pass_erase_shape_information", EraseShapeInformation)
.def(
"_jit_pass_create_autodiff_subgraphs",
Expand Down
13 changes: 11 additions & 2 deletions torch/csrc/jit/runtime/graph_executor.cpp
Expand Up @@ -865,7 +865,10 @@ void runNondiffOptimization(
"After customPostPassses (end of runNondiffOptimization)\n", *graph);
}

void runOptimization(std::shared_ptr<Graph>& graph, bool unroll) {
void runOptimization(
std::shared_ptr<Graph>& graph,
bool unroll,
bool const_prop_user_classes) {
// Basic graph preprocessing to eliminate noise.
GRAPH_DEBUG(
"Before EliminateDeadCode (beginning of runOptimization)\n", *graph);
Expand All @@ -878,8 +881,14 @@ void runOptimization(std::shared_ptr<Graph>& graph, bool unroll) {

PeepholeOptimize(graph);
GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
ConstantPropagation(graph);

if (const_prop_user_classes) {
ConstantPropagation(graph);
} else {
ConstantPropagation(graph, true);
}
GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);

ConstantPooling(graph);
GRAPH_DEBUG("After ConstantPooling\n", *graph);

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/runtime/graph_executor_impl.h
Expand Up @@ -31,7 +31,10 @@ namespace jit {

void packGradient(const Gradient& gradient, Node* dnode);
bool needsGradient(const std::shared_ptr<const Graph>& graph);
void runOptimization(std::shared_ptr<Graph>& graph, bool unroll = true);
void runOptimization(
std::shared_ptr<Graph>& graph,
bool unroll = true,
bool const_prop_user_classes = true);
void runNondiffOptimization(
std::shared_ptr<Graph>& graph,
bool strict_fuser_check = false);
Expand Down

0 comments on commit b71b970

Please sign in to comment.