diff --git a/caffe2/proto/torch.proto b/caffe2/proto/torch.proto index a140bb82cb5c9..dd5524287974b 100644 --- a/caffe2/proto/torch.proto +++ b/caffe2/proto/torch.proto @@ -72,6 +72,8 @@ message ModuleDef { // Used for retrieving module state from the pickled IValues table optional int64 get_state_attribute_id = 10; + + optional RecordRef torchscript_debug_arena = 11; } // Represents all non-module code that the model depends on. diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 578ef729b93c1..63fef413efcfe 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -659,8 +660,11 @@ void ScriptModuleSerializer::convertClass( std::vector class_deps; std::ostringstream class_stream; + // TODO: serialize for classes + SourceRangeRecords source_ranges; PythonPrint( class_stream, + source_ranges, class_type, tensor_table_, class_deps, @@ -905,9 +909,11 @@ void ScriptModuleSerializer::convertModule( if (module.class_compilation_unit()->get_functions().size() > 0) { std::ostringstream methods; + SourceRangeRecords source_ranges; methods << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n"; PythonPrint( methods, + source_ranges, *module.class_compilation_unit(), /*is_method=*/true, tensor_table_, @@ -921,6 +927,26 @@ void ScriptModuleSerializer::convertModule( writer_.writeRecord( filename.str(), methods_str.c_str(), methods_str.size()); record->set_key(filename.str()); + + // Write out debug records + torch::RecordRef* debug_record = + module_def->mutable_torchscript_debug_arena(); + Pickler p; + SourceRangeSerializer srs; + p.start(); + p.startTuple(); + for (const auto& range : source_ranges) { + std::vector row_elems{(int64_t)range.bytes, + srs.serialize(range.range)}; + p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems))); + } + p.endTuple(); + p.finish(); + std::stringstream debug_filename; + debug_filename << "debug/" << module_name.str() << ".pkl"; + writer_.writeRecord( + debug_filename.str(), p.stack().data(), p.stack().size()); + debug_record->set_key(debug_filename.str()); } for (script::Slot s : module.get_module_slots()) { diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 288c03ed73f71..000e32ef67a07 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -167,7 +167,79 @@ const static std::unordered_set reserved_names = { }; struct PythonPrintPass { - std::ostringstream body_; + using SourceRangeStack = std::vector; + SourceRangeStack source_range_stack_ = {SourceRange("")}; + + struct WithSourceRange { + explicit WithSourceRange(SourceRangeStack* stack, SourceRange sr) + : stack(stack) { + TORCH_INTERNAL_ASSERT(stack); + stack->push_back(std::move(sr)); + } + + ~WithSourceRange() { + stack->pop_back(); + } + + SourceRangeStack* stack; + }; + + class TaggedStringStream { + public: + TaggedStringStream(const SourceRangeStack* srs) : srs_(srs) {} + TaggedStringStream(TaggedStringStream&& rhs) = default; + + TaggedStringStream& operator<<(const std::string& s) { + if (!ranges_.size() || ranges_.back().range != srs_->back()) { + ranges_.emplace_back((size_t)oss_.tellp(), srs_->back()); + } + oss_ << s; + return *this; + } + + TaggedStringStream& operator<<(const TaggedStringStream& rhs) { + for (const auto& range : rhs.ranges_) { + if (!ranges_.size() || ranges_.back().range != range.range) { + ranges_.emplace_back((size_t)oss_.tellp() + range.bytes, range.range); + } + } + oss_ << rhs.oss_.str(); + return *this; + } + + // This overload is here to prevent people from shooting themselves in the + // foot. I would be highly surprised if someone actually wanted to write out + // the address of a TaggedStringStream in the pretty print. + TaggedStringStream& operator<<( + const std::shared_ptr& rhs) { + (*this) << *rhs; + return *this; + } + + template + TaggedStringStream& operator<<(const T& t) { + if (!ranges_.size() || ranges_.back().range != srs_->back()) { + ranges_.emplace_back((size_t)oss_.tellp(), srs_->back()); + } + oss_ << t; + return *this; + } + + std::string str() const { + return oss_.str(); + } + + const std::vector& ranges() const { + return ranges_; + } + + private: + std::ostringstream oss_; + std::vector ranges_; + const SourceRangeStack* srs_; + }; + + TaggedStringStream body_; // constants are written to this table, and given then named CONSTANTS.cN // where N is the index into this table. @@ -402,13 +474,35 @@ struct PythonPrintPass { } // map from Value to how it should be printed at each use - std::unordered_map value_names_; - - std::string useOf(Value* v) const { - return value_names_.at(v); + std::unordered_map> expr_table_; + std::unordered_map ident_refs_; + + // NB: we MUST pass around the shared pointers to these streams by value. + // There is an interaction in splitLongInlines where the string value for + // both the RHS and the LHS of an expression are live at the same time, + // however the value for the RHS is overwritten in the table. + std::shared_ptr useOf(Value* v) const { + // Ident refs take precedent over expression refs, since presence in + // the ident ref table indicates we have already emitted a statement + // assigning the given value. + if (ident_refs_.count(v)) { + auto rv = std::make_shared(&source_range_stack_); + (*rv) << ident_refs_.at(v); + return rv; + } + if (expr_table_.count(v)) { + return expr_table_.at(v); + } + TORCH_INTERNAL_ASSERT( + false, + "Value was not present in either expressions" + " table or ident refs table"); } void assignValue(Value* v, const std::string& s) { - value_names_[v] = s; + ident_refs_[v] = s; + } + void assignValue(Value* v, std::shared_ptr s) { + expr_table_[v] = std::move(s); } void assignValue(Value* v, Value* w) { assignValue(v, useOf(w)); @@ -421,7 +515,7 @@ struct PythonPrintPass { size_t level = 0; // indent to the current indent level - std::ostream& indent() { + TaggedStringStream& indent() { for (size_t i = 0; i < level; ++i) { body_ << " "; } @@ -449,7 +543,7 @@ struct PythonPrintPass { } void printValueList( - std::ostream& stmt, + TaggedStringStream& stmt, at::ArrayRef list, const char* begin = "", const char* end = "") { @@ -464,7 +558,7 @@ struct PythonPrintPass { } void printDict( - std::ostream& stmt, + TaggedStringStream& stmt, at::ArrayRef key_value_pairs, const char* begin = "{", const char* end = "}") { @@ -595,7 +689,8 @@ struct PythonPrintPass { } bool isLongInline(Node* node) { - return output_inline_.count(node) && isLongLine(useOf(node->output())); + return output_inline_.count(node) && + isLongLine(useOf(node->output())->str()); } bool isNonConstantInline(Value* input) { @@ -629,12 +724,13 @@ struct PythonPrintPass { // first place for (size_t i = 0; i < long_inline_slice; ++i) { if (isNonConstantInline(inputs[i])) { - printOutputDefinition(inputs[i]->node(), useOf(inputs[i])); + printOutputDefinition(inputs[i]->node(), *useOf(inputs[i])); } } } - void printOutputDefinition(Node* node, const std::string& str) { + template + void printOutputDefinition(Node* node, const T& expr) { assignValuesToTheirUniqueNames(node->outputs()); indent(); // Print outputs @@ -642,7 +738,7 @@ struct PythonPrintPass { printValueList(body_, node->outputs()); body_ << " = "; } - body_ << str << "\n"; + body_ << expr << "\n"; } // Recursively check contained types for any class dependencies @@ -660,6 +756,7 @@ struct PythonPrintPass { } void printNode(Node* node, bool print_const) { + WithSourceRange guard(&source_range_stack_, node->sourceRange()); // Check for class dependencies. If this node inputs or outputs a class // type, we need to add it to our table of dependencies. for (const auto input : node->inputs()) { @@ -734,7 +831,7 @@ struct PythonPrintPass { << "closures are not exportable"; } assignValuesToTheirUniqueNames(node->outputs()); - auto name = useOf(node->output()); + auto name = useOf(node->output())->str(); std::shared_ptr graph = node->g(attr::Subgraph); indent(); body_ << "def " << name << "("; @@ -750,19 +847,19 @@ struct PythonPrintPass { printBody(graph->block()); } break; default: - std::stringstream ss; - printRHS(ss, node); + auto ss = std::make_shared(&source_range_stack_); + printRHS(*ss, node); // we prevent long constants from inlining here. // it is not safe to do the same thing for non-constants here // because of [reordering of inlines] if (output_inline_.count(node) == 0 || - (node->kind() == prim::Constant && isLongLine(ss.str()))) { - printOutputDefinition(node, ss.str()); + (node->kind() == prim::Constant && isLongLine(ss->str()))) { + printOutputDefinition(node, *ss); } else { // this node is safe to inline, so assign the output value // to that expression directly - assignValue(node->output(), ss.str()); + assignValue(node->output(), ss); } } } @@ -779,39 +876,40 @@ struct PythonPrintPass { } } - void printConstant(std::ostream& stmt, const IValue& v) { + void printConstant(TaggedStringStream& stmt, const IValue& v) { + std::stringstream ss; if (v.isTensor()) { - stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor()); + ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor()); } else if (v.isString()) { - printQuotedString(stmt, v.toStringRef()); + printQuotedString(ss, v.toStringRef()); } else if (v.isDevice()) { - std::stringstream ss; - ss << v.toDevice(); - stmt << "torch.device("; - printQuotedString(stmt, ss.str()); - stmt << ")"; + std::stringstream device_stream; + device_stream << v.toDevice(); + ss << "torch.device("; + printQuotedString(ss, device_stream.str()); + ss << ")"; } else if (v.isTensorList()) { - stmt << "["; + ss << "["; const char* delim = ""; for (const at::Tensor& t : v.toTensorListRef()) { - stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t); + ss << delim << "CONSTANTS.c" << getOrAddTensorConstant(t); delim = ", "; } - stmt << "]"; + ss << "]"; } else if (v.isBoolList()) { - printMaybeAnnotatedConstantList( - stmt, "bool", v.toBoolList().size(), v); + printMaybeAnnotatedConstantList(ss, "bool", v.toBoolList().size(), v); } else if (v.isIntList()) { - printMaybeAnnotatedConstantList(stmt, "int", v.toIntListRef().size(), v); + printMaybeAnnotatedConstantList(ss, "int", v.toIntListRef().size(), v); } else if (v.isDoubleList()) { printMaybeAnnotatedConstantList( - stmt, "float", v.toDoubleListRef().size(), v); + ss, "float", v.toDoubleListRef().size(), v); } else { - stmt << v; + ss << v; } + stmt << ss.str(); } - void printNone(std::ostream& stmt, const Node* node) { + void printNone(TaggedStringStream& stmt, const Node* node) { if (node->output()->type()->isSubtypeOf(NoneType::get())) { stmt << "None"; return; @@ -842,7 +940,7 @@ struct PythonPrintPass { } // Prints the RHS value of a Node, e.g. `aten.add(x, y)` - void printRHS(std::ostream& stmt, Node* node) { + void printRHS(TaggedStringStream& stmt, Node* node) { switch (node->kind()) { case prim::PythonOp: { auto value = static_cast(node); @@ -857,8 +955,10 @@ struct PythonPrintPass { if (value->ignore_on_export) { stmt << "ops.prim.IgnoredPythonOp"; } else { + std::stringstream scalars_stream; stmt << "^" << value->name(); - value->writeScalars(stmt); + value->writeScalars(scalars_stream); + stmt << scalars_stream.str(); } printValueList(stmt, node->inputs(), "(", ")"); } break; @@ -951,8 +1051,9 @@ struct PythonPrintPass { stmt << useOf(obj) << "." << field; } else { stmt << "getattr(" << useOf(obj) << ", "; - printQuotedString(stmt, field); - stmt << ")"; + std::stringstream field_stream; + printQuotedString(field_stream, field); + stmt << field_stream.str() << ")"; } } break; default: { @@ -982,14 +1083,14 @@ struct PythonPrintPass { // vararg functions like format can have extra arguments AT_ASSERT(schema.is_vararg()); } - stmt << v; + stmt << *v; } stmt << ")"; } break; } } - std::ostream& printBlock(Block* root, bool block_has_other_statements) { + TaggedStringStream& printBlock(Block* root, bool block_has_other_statements) { // pythons weird 'pass' syntax creates a bunch of places where we have to // check if this block would be empty. But not everything in a block is a // node. Sometimes if, loop, and return statements will follow this block @@ -1007,7 +1108,7 @@ struct PythonPrintPass { void printDefaultValue( const TypePtr& typ, - std::ostream& stmt, + TaggedStringStream& stmt, const IValue& value) { // xxx - many weak script modules store default values for broadcasting // lists that are not actually the same type as the argument. We can only @@ -1048,6 +1149,9 @@ struct PythonPrintPass { Graph& graph = *func.graph(); used_names_.clear(); // each graph can reuse local names + WithSourceRange guard( + &source_range_stack_, graph.param_node()->sourceRange()); + indent(); body_ << "def " << func.name() << "("; auto param_it = graph.inputs().begin(); @@ -1091,7 +1195,8 @@ struct PythonPrintPass { std::vector& class_table, bool enforce_importable, bool is_method) - : tensor_table_(tensor_table), + : body_(&source_range_stack_), + tensor_table_(tensor_table), class_table_(class_table), enforce_importable_(enforce_importable), is_method_(is_method) {} @@ -1144,13 +1249,15 @@ struct PythonPrintPass { class_deps_.end()); } - void print(std::ostream& out) { + void print(std::ostream& out, SourceRangeRecords& source_ranges_out) { out << getImports() << body_.str(); + source_ranges_out = body_.ranges(); } }; void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const Function& func, bool is_method, std::vector& tensor_table, @@ -1158,11 +1265,12 @@ void PythonPrint( bool enforce_importable) { PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method); pp.printFunction(func); - pp.print(out); + pp.print(out, source_ranges_out); } void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const script::CompilationUnit& cu, bool is_method, std::vector& tensor_table, @@ -1170,18 +1278,19 @@ void PythonPrint( bool enforce_importable) { PythonPrintPass pp(tensor_table, class_table, enforce_importable, is_method); pp.printCompilationUnit(cu); - pp.print(out); + pp.print(out, source_ranges_out); } void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const c10::NamedTypePtr& classType, std::vector& tensor_table, std::vector& class_table, bool enforce_importable) { PythonPrintPass pp(tensor_table, class_table, enforce_importable, true); pp.printClass(classType); - pp.print(out); + pp.print(out, source_ranges_out); } bool printerHasSpecialCaseFor(Symbol sym) { diff --git a/torch/csrc/jit/passes/python_print.h b/torch/csrc/jit/passes/python_print.h index 4d13760acf101..9ee0d9fffc94f 100644 --- a/torch/csrc/jit/passes/python_print.h +++ b/torch/csrc/jit/passes/python_print.h @@ -12,8 +12,19 @@ struct Method; struct Module; } // namespace script +// A pair of (byte offset, SourceRange) describing a specific segment +// of the output stream +struct TaggedRange { + TaggedRange(size_t bytes, SourceRange range) + : bytes(bytes), range(std::move(range)) {} + size_t bytes; + SourceRange range; +}; +using SourceRangeRecords = std::vector; + TORCH_API void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const Function& callee, bool is_method, std::vector& tensor_table, @@ -22,6 +33,7 @@ TORCH_API void PythonPrint( TORCH_API void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const script::CompilationUnit& cu, bool is_method, std::vector& tensor_table, @@ -30,6 +42,7 @@ TORCH_API void PythonPrint( TORCH_API void PythonPrint( std::ostream& out, + SourceRangeRecords& source_ranges_out, const c10::NamedTypePtr& classType, std::vector& tensor_table, std::vector& class_table, diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1c7e4c634d444..2c5af74ab93f1 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -530,8 +530,10 @@ void initJitScriptBindings(PyObject* module) { std::ostringstream ss; std::vector tensors; std::vector classes; + SourceRangeRecords source_ranges; PythonPrint( ss, + source_ranges, *self.class_compilation_unit(), true, tensors, @@ -606,7 +608,9 @@ void initJitScriptBindings(PyObject* module) { std::ostringstream ss; std::vector tensors; std::vector classes; - PythonPrint(ss, self, false, tensors, classes, false); + SourceRangeRecords source_ranges; + PythonPrint( + ss, source_ranges, self, false, tensors, classes, false); return ss.str(); }) .def( @@ -632,7 +636,9 @@ void initJitScriptBindings(PyObject* module) { std::ostringstream ss; std::vector tensors; std::vector classes; - PythonPrint(ss, self.function(), true, tensors, classes, false); + SourceRangeRecords source_ranges; + PythonPrint( + ss, source_ranges, self.function(), true, tensors, classes, false); return ss.str(); }); m.def( @@ -759,14 +765,22 @@ void initJitScriptBindings(PyObject* module) { std::ostringstream ss; std::vector constants; std::vector classes; + SourceRangeRecords source_ranges; if (auto self = as_module(obj)) { PythonPrint( - ss, *self->class_compilation_unit(), true, constants, classes, true); + ss, + source_ranges, + *self->class_compilation_unit(), + true, + constants, + classes, + true); } else if (auto self = as_function(obj)) { - PythonPrint(ss, *self, false, constants, classes, true); + PythonPrint(ss, source_ranges, *self, false, constants, classes, true); } else { auto& m = py::cast(obj); - PythonPrint(ss, m.function(), true, constants, classes, true); + PythonPrint( + ss, source_ranges, m.function(), true, constants, classes, true); } return std::make_pair(ss.str(), std::move(constants)); }); diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h index 8786cd70629b4..5362457df6286 100644 --- a/torch/csrc/jit/source_range.h +++ b/torch/csrc/jit/source_range.h @@ -131,6 +131,15 @@ struct CAFFE2_API SourceRange { (size_t)col_offset); } + bool operator==(const SourceRange& rhs) const { + return start() == rhs.start() && end() == rhs.end() && + source() == rhs.source(); + } + + bool operator!=(const SourceRange& rhs) const { + return !(*this == rhs); + } + private: std::shared_ptr source_; size_t start_;