Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
# Add files needed from jit folders
LIST(APPEND ATen_CORE_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_location.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/function_schema_parser.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.h
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.h
Expand All @@ -23,6 +22,7 @@ LIST(APPEND ATen_CORE_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/lexer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/strtod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/script/schema_type_parser.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../../torch/csrc/jit/source_range.cpp
)

# Pass to parent
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ c10::optional<Value*> tryInsertConstant(
return c10::nullopt;
}
if (loc)
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
n->setSourceRange(*loc);
if (scope)
n->setScope(*scope);
if (result_type) {
Expand Down
14 changes: 3 additions & 11 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,7 @@ namespace onnx = ::ONNX_NAMESPACE;
class ScriptModuleSerializer;

std::string getNodeStackTraceString(const Node* n) {
std::stringstream ss;
if (n->getSourceLocation()) {
n->getSourceLocation()->highlight(ss);
} else {
ss << "<unknown location>";
}
return ss.str();
return n->sourceRange().str();
}

void validateBlock(
Expand Down Expand Up @@ -258,10 +252,8 @@ void EncoderBase::EncodeBlock(
continue;
}
auto p_n = graph_proto->add_node();
if (node->getSourceLocation() && !strip_doc_) {
std::stringstream ss;
node->getSourceLocation()->highlight(ss);
p_n->set_doc_string(ss.str());
if (!strip_doc_) {
p_n->set_doc_string(node->sourceRange().str());
}
for (auto input : node->inputs()) {
if (input->node()->mustBeNone() && !is_raw_export) {
Expand Down
22 changes: 9 additions & 13 deletions torch/csrc/jit/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ struct Instruction {
UseList inputs;
ListHandle<int> outputs;
Symbol debug_name; // used in dump to understand the generated code
std::shared_ptr<SourceLocation> debug_location; // for error reporting
c10::optional<SourceRange> debug_location; // for error reporting
};

int relativeJump(int from_inst, int to_inst) {
Expand Down Expand Up @@ -377,7 +377,7 @@ struct CodeImpl {

void insertNodesFromBlock(Block* block) {
for (auto node : block->nodes()) {
const auto& source_location = node->getSourceLocation();
SourceRange source_location = node->sourceRange();
switch (node->kind()) {
case prim::If: {
// x = if c:
Expand Down Expand Up @@ -481,7 +481,7 @@ struct CodeImpl {
size_t insertInstruction(Node* n) {
auto inst = insertInstruction(
n->kind(),
n->getSourceLocation(),
n->sourceRange(),
n->inputs(),
moveFlags(n),
n->outputs());
Expand All @@ -490,7 +490,7 @@ struct CodeImpl {
}
size_t insertInstruction(
Symbol sym,
std::shared_ptr<SourceLocation> debug_location,
const SourceRange& debug_location,
ArrayRef<Value*> inputs,
ArrayRef<uint8_t> move_flags,
ArrayRef<Value*> outputs) {
Expand Down Expand Up @@ -520,7 +520,7 @@ struct CodeImpl {
}

size_t insertAssign(
std::shared_ptr<SourceLocation> debug_location,
const SourceRange& debug_location,
ArrayRef<Value*> inputs,
ArrayRef<uint8_t> move_flags,
ArrayRef<Value*> outputs) {
Expand Down Expand Up @@ -713,14 +713,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
} catch (std::exception& e) {
// Error from the current thread
bool is_jit_exception = dynamic_cast<JITException*>(&e);
if (instructions[pc].debug_location) {
handleError(
instructions[pc].debug_location->wrapException(
e, "operation failed in interpreter"),
is_jit_exception);
} else {
handleError(e.what(), is_jit_exception);
}
handleError(
instructions[pc].debug_location->wrapException(
e, "operation failed in interpreter"),
is_jit_exception);
return false;
}
}
Expand Down
14 changes: 13 additions & 1 deletion torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
out << "]";
}

SourceRange Node::sourceRange() const {
if(source_range_) {
return *source_range_;
}
std::stringstream ss;
return SourceRange(ss.str());
}

static std::ostream& indent(std::ostream& out, size_t level) {
for (size_t i = 0; i < level; ++i) {
out << " ";
Expand All @@ -224,8 +232,10 @@ std::ostream& Node::print(
if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
printAttributes(out, /*ignore_subgraph=*/true);
}

groups->push_back(this);
} else {

out << kind().toQualString();
if (hasAttributes()) {
printAttributes(out);
Expand All @@ -241,6 +251,7 @@ std::ostream& Node::print(
out << ", ";
out << "scope: " << scName << "\n";
}

for (size_t i = 0; i < blocks().size(); ++i) {
auto b = blocks()[i];
indent(out, level + 1) << "block" << i << "("
Expand All @@ -251,6 +262,7 @@ std::ostream& Node::print(
}
indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
}

return out;
}

Expand Down Expand Up @@ -973,7 +985,7 @@ void Node::destroy() {
}

void Node::cloneFrom(Node* s) {
setSourceLocation(s->getSourceLocation());
s->source_range_ = s->source_range_;
if (s->scope_ && !s->scope_->isBlank()) {
scope_ = s->scope_;
}
Expand Down
11 changes: 5 additions & 6 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct TORCH_API Node {
std::vector<Block*> blocks_;
Graph* graph_;
Block* owning_block_;
std::shared_ptr<SourceLocation> source_location_;
c10::optional<SourceRange> source_range_;
ScopePtr scope_;
// Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
// This field is effective a cache that's populated on attribute lookups and
Expand Down Expand Up @@ -287,13 +287,12 @@ struct TORCH_API Node {
NodeKind kind() const {
return kind_;
}
Node* setSourceLocation(std::shared_ptr<SourceLocation> sl) {
source_location_ = std::move(sl);
Node* setSourceRange(SourceRange r) {
source_range_ = std::move(r);
return this;
}
std::shared_ptr<SourceLocation> getSourceLocation() const {
return source_location_;
}
SourceRange sourceRange() const;

Graph* owningGraph() {
return graph_;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ const Operator& getOperatorFor(const Node* node) {
if (op)
return *op;

auto er = script::ErrorReport(node->getSourceLocation());
auto er = script::ErrorReport(node->sourceRange());
er << "Schema not found for node. File a bug report.\n";
er << "Node: " << *node << "\n";
er << "Input types:";
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void AliasDb::analyzeImpl(Node* node) {
// We don't have alias info for this node. Either schematize it, or
// add it an analyze* method for it.
if (hasMutableOutputs) {
throw script::ErrorReport(node->getSourceLocation())
throw script::ErrorReport(node->sourceRange())
<< "Alias information not found for node. File a bug report.\n"
<< "Node: " << *node << "\n";
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/lower_tuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) {
auto maybe_int = constant_as<int64_t>(idx);
if (!maybe_int) {
if (must_remove_tuples) {
AT_ERROR(n->getSourceLocation(), "tuple index with non-constant index");
AT_ERROR(n->sourceRange(), "tuple index with non-constant index");
}
return;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/passes/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void BlockToONNX(
outputs[i]->setType(old->type());
// Copy over source location and scope information to all nodes
// created by the symbolic
outputs[i]->node()->setSourceLocation(node->getSourceLocation());
outputs[i]->node()->setSourceRange(node->sourceRange());
outputs[i]->node()->setScope(node->scope());
env[old] = outputs[i];
} else {
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ struct PythonPrintPass {
// this must be a while loop, but check that there isn't _also_ a trip
// count
if (trip_count_is_specified) {
throw script::ErrorReport(stmt.node()->getSourceLocation())
throw script::ErrorReport(stmt.node()->sourceRange())
<< "loop cannot be printed as python "
<< "because it has gone through an optimization "
<< "that combined while and for loops. File a bug.";
Expand Down Expand Up @@ -678,7 +678,7 @@ struct PythonPrintPass {
switch (node->kind()) {
case prim::Return:
if (enforce_importable_ && node->inputs().size() != 1) {
throw script::ErrorReport(node->getSourceLocation())
throw script::ErrorReport(node->sourceRange())
<< "Exportable methods must have a single return value. "
<< "Normal use of ScriptMethods should enforce this.";
}
Expand Down Expand Up @@ -733,7 +733,7 @@ struct PythonPrintPass {
} break;
case prim::Function: {
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
throw script::ErrorReport(node->sourceRange())
<< "closures are not exportable";
}
assignValuesToTheirUniqueNames(node->outputs());
Expand Down Expand Up @@ -850,7 +850,7 @@ struct PythonPrintPass {
case prim::PythonOp: {
auto value = static_cast<const PythonOp*>(node);
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
throw script::ErrorReport(node->sourceRange())
<< "could not export python function call " << value->name()
<< ". Remove calls to Python functions before export. "
<< "Did you forget add @script or @script_method annotation? "
Expand Down
6 changes: 1 addition & 5 deletions torch/csrc/jit/passes/shape_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ class ShapePropagator {
} catch (propagation_error& e) {
setUnshapedType(node);
} catch (std::exception& e) {
if (auto sl = node->getSourceLocation()) {
sl->wrapAndRethrowException(e, "operation failed shape propagation");
} else {
throw;
}
node->sourceRange().wrapAndRethrowException(e, "operation failed shape propagation");
}
}
}
Expand Down
12 changes: 3 additions & 9 deletions torch/csrc/jit/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,9 @@ void initPythonIRBindings(PyObject* module_) {
return ss.str();
})
.def(
"getSourceLocation",
[](Node& n) -> py::object {
std::stringstream ss;
if (auto sl = n.getSourceLocation()) {
sl->highlight(ss);
return py::str(ss.str());
} else {
return py::none();
}
"sourceRange",
[](Node& n) {
return n.sourceRange().str();
})
.def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
.def("outputsSize", [](Node& n) { return n.outputs().size(); })
Expand Down
4 changes: 1 addition & 3 deletions torch/csrc/jit/python_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ Node* preRecordPythonTrace(
}

void pythonRecordSourceLocation(Node* n) {
auto sl =
std::make_shared<StringSourceLocation>(getPythonInterpreterStackTrace());
n->setSourceLocation(sl);
n->setSourceRange(SourceRange(getPythonInterpreterStackTrace()));
}

void pythonWarn(const std::string& reason) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2257,7 +2257,7 @@ void checkSortSchema(const Node* node, const c10::TypePtr& list_element_type) {
<< ", got list of " << list_element_type->python_str() << "\n";
}

auto error_msg = script::ErrorReport(node->getSourceLocation());
auto error_msg = script::ErrorReport(node->sourceRange());
error_msg << error_str.str();
throw error_msg;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/register_special_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace {
void checkListInputType(const c10::TypePtr& elem_type, const Node* node) {
if (!elem_type->isSubtypeOf(NumberType::get()) &&
elem_type != BoolType::get()) {
auto error = script::ErrorReport(node->getSourceLocation());
auto error = script::ErrorReport(node->sourceRange());
error << "Input list to torch.tensor must be of ints, floats, or bools, "
<< "got " << elem_type->str();
// special case empty list torch.tensor([])
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ struct to_ir {

Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
return graph->create(kind, n_outputs)
->setSourceLocation(std::make_shared<SourceRange>(loc));
->setSourceRange(loc);
}

Value* emitTernaryIf(const TernaryIf& expr) {
Expand Down Expand Up @@ -2379,7 +2379,7 @@ struct to_ir {
auto fork_node =
method.graph()
->insertNode(method.graph()->create(prim::fork, 1))
->setSourceLocation(std::make_shared<SourceRange>(loc));
->setSourceRange(loc);
auto body_block = fork_node->addBlock();

// Build a template of the graph to be executed
Expand Down
13 changes: 6 additions & 7 deletions torch/csrc/jit/script/error_report.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <torch/csrc/jit/script/tree.h>
#include <c10/util/Optional.h>

namespace torch {
namespace jit {
Expand All @@ -10,17 +11,15 @@ struct ErrorReport : public std::exception {
ErrorReport(const ErrorReport& e)
: ss(e.ss.str()), context(e.context), the_message(e.the_message) {}

ErrorReport() : context(nullptr) {}
explicit ErrorReport(const SourceRange& r)
: context(std::make_shared<SourceRange>(r)) {}
explicit ErrorReport(std::shared_ptr<SourceLocation> loc)
: context(std::move(loc)) {}
ErrorReport() : context(c10::nullopt) {}
explicit ErrorReport(SourceRange r)
: context(std::move(r)) {}
explicit ErrorReport(const TreeRef& tree) : ErrorReport(tree->range()) {}
explicit ErrorReport(const Token& tok) : ErrorReport(tok.range) {}
const char* what() const noexcept override {
std::stringstream msg;
msg << "\n" << ss.str();
if (context != nullptr) {
if (context) {
msg << ":\n";
context->highlight(msg);
} else {
Expand All @@ -35,7 +34,7 @@ struct ErrorReport : public std::exception {
friend const ErrorReport& operator<<(const ErrorReport& e, const T& t);

mutable std::stringstream ss;
std::shared_ptr<SourceLocation> context;
c10::optional<SourceRange> context;
mutable std::string the_message;
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
continue;
}
if (e.n->kind() != prim::GetAttr) {
throw ErrorReport(e.n->getSourceLocation())
throw ErrorReport(e.n->sourceRange())
<< "temporary: the only valid use of a module is looking up an attribute";
}
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
Expand Down
Loading