New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sure each warnings.warn only executes once inside TorchScript. #45382
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import os | ||
import sys | ||
import io | ||
|
||
import torch | ||
import warnings | ||
from contextlib import redirect_stderr | ||
from torch.testing import FileCheck | ||
|
||
# Make the helper files in test/ importable | ||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||
sys.path.append(pytorch_test_dir) | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
|
||
if __name__ == '__main__': | ||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" | ||
"\tpython test/test_jit.py TESTNAME\n\n" | ||
"instead.") | ||
|
||
|
||
class TestWarn(JitTestCase): | ||
def test_warn(self): | ||
@torch.jit.script | ||
def fn(): | ||
warnings.warn("I am warning you") | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=1, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_only_once(self): | ||
@torch.jit.script | ||
def fn(): | ||
for _ in range(10): | ||
warnings.warn("I am warning you") | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=1, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_only_once_in_loop_func(self): | ||
def w(): | ||
warnings.warn("I am warning you") | ||
|
||
@torch.jit.script | ||
def fn(): | ||
for _ in range(10): | ||
w() | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=1, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_once_per_func(self): | ||
def w1(): | ||
warnings.warn("I am warning you") | ||
|
||
def w2(): | ||
warnings.warn("I am warning you") | ||
|
||
@torch.jit.script | ||
def fn(): | ||
w1() | ||
w2() | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=2, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_once_per_func_in_loop(self): | ||
def w1(): | ||
warnings.warn("I am warning you") | ||
|
||
def w2(): | ||
warnings.warn("I am warning you") | ||
|
||
@torch.jit.script | ||
def fn(): | ||
for _ in range(10): | ||
w1() | ||
w2() | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=2, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_multiple_calls_multiple_warnings(self): | ||
@torch.jit.script | ||
def fn(): | ||
warnings.warn("I am warning you") | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
fn() | ||
fn() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=2, | ||
exactly=True) \ | ||
.run(f.getvalue()) | ||
|
||
def test_warn_multiple_calls_same_func_diff_stack(self): | ||
def warn(caller: str): | ||
warnings.warn("I am warning you from " + caller) | ||
|
||
@torch.jit.script | ||
def foo(): | ||
warn("foo") | ||
|
||
@torch.jit.script | ||
def bar(): | ||
warn("bar") | ||
|
||
f = io.StringIO() | ||
with redirect_stderr(f): | ||
foo() | ||
bar() | ||
|
||
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you from foo", | ||
count=1, | ||
exactly=True) \ | ||
.check_count( | ||
str="UserWarning: I am warning you from bar", | ||
count=1, | ||
exactly=True) \ | ||
.run(f.getvalue()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include <torch/csrc/jit/passes/annotate_warns.h> | ||
|
||
#include <atomic> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
void AnnotateWarns(Block* b) { | ||
static std::atomic<int64_t> idx(0); | ||
for (Node* n : b->nodes()) { | ||
for (Block* child_b : n->blocks()) { | ||
AnnotateWarns(child_b); | ||
} | ||
|
||
if (n->kind() != aten::warn) { | ||
continue; | ||
} | ||
|
||
n->i_(attr::warn_id, idx); | ||
idx++; | ||
} | ||
} | ||
|
||
void AnnotateWarns(const std::shared_ptr<Graph>& graph) { | ||
AnnotateWarns(graph->block()); | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/ir/ir.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
TORCH_API void AnnotateWarns(const std::shared_ptr<Graph>& graph); | ||
|
||
} // namespace jit | ||
} // namespace torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -428,6 +428,16 @@ struct TLSCurrentInterpreterGuard { | |
InterpreterStateImpl* prev_state_; | ||
}; | ||
|
||
template <class Ttarget, class Tsource> | ||
Ttarget safe_narrow_cast(Tsource v) { | ||
Ttarget res = static_cast<Ttarget>(v); | ||
// Casting it back to check whether it overflew. | ||
if (static_cast<Tsource>(res) != v) { | ||
throw std::runtime_error("safe_narrow_cast<>() failed due to overflow"); | ||
} | ||
return res; | ||
} | ||
|
||
struct CodeImpl { | ||
friend struct InterpreterState; | ||
std::vector<Instruction> instructions_; | ||
|
@@ -535,7 +545,10 @@ struct CodeImpl { | |
} | ||
|
||
void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) { | ||
instructions_.emplace_back(op, X, N); | ||
instructions_.emplace_back( | ||
op, | ||
safe_narrow_cast<int32_t, int64_t>(X), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait, won't it throw if there are > int32 distinct warnings? it actually can happen because you keep increasing the same counter across all models, so it might overflow eventually. you probably want explicit static_cast in emitWarn i n this case as the truncation is ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed this logic since index is no longer needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
safe_narrow_cast<int16_t, int64_t>(N)); | ||
instructions_source_.emplace_back(current_node_); | ||
|
||
// check that we didn't accidentally emit nodes out of topological order | ||
|
@@ -873,7 +886,11 @@ struct CodeImpl { | |
|
||
void emitWarn(Node* node) { | ||
emitLoadInputs(node->inputs()); | ||
insertInstruction(WARN); | ||
int32_t idx = -1; | ||
if (node->hasAttribute(attr::warn_id)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious - why do we need to have separate indices? would using just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Changed to using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New turn of events. Using |
||
idx = static_cast<int32_t>(node->i(attr::warn_id)); | ||
} | ||
insertInstruction(WARN, idx); | ||
} | ||
|
||
void emitEnter(Node* node) { | ||
|
@@ -1017,6 +1034,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { | |
} | ||
|
||
private: | ||
struct WarnedNodes { | ||
public: | ||
// Inserts idx into warned_nodes_, returns a boolean indicates whether | ||
// insertion actually happened (idx wasn't originally in the set). | ||
bool insert(int32_t idx) { | ||
std::unique_lock<std::mutex> lock(mutex_); | ||
return warned_nodes_.insert(idx).second; | ||
} | ||
|
||
private: | ||
std::mutex mutex_; | ||
std::unordered_set<int32_t> warned_nodes_; | ||
}; | ||
|
||
WarnedNodes warned_nodes_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually not we want - multiple invocations of the same function should still warn only once (like python does). Sorry for missing it in the review. You probably need to move it to Code level or something like this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another suggestion - instead of unordered_set, just create a vector<atomic> inside Code and use that to modify the index directly. I think this way you don't need locking. Or you can still have unordered_map, but pre-populate it with all possible indices beforehand so that you don't need the lock There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion, I wonder if what python does is actually desired. Think about following use case:
It may be called in various callsites, and they are all useful. IMHO, it is slightly emit warning for each unique callsite. Number of warnings would still not be too spammy because there likely won't be that many callsites to the extend that they feel spammy to users. What do you think? I am still investigating the issue reported internally, which I feel is a multi-threading issues similar to #46684. Will update here once finding out more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, ideally we should respect Fixing the spammy warning is important though, because the typical case is inference, where model is called many times in a loop and each of them creates independent InterpreterState. So the current implementation sadly will log on each invocation spamming the logs. (that's pretty much the internal issue you're referring to). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about this problem a lot but couldn't think of a clean way to implement the "warn once" behavior, which requires maintaining a global state that is above ScriptModule. This sort of goes against TorchScript's philosophy of keep states of ScriptModule local to itself. Here is an alternative solution: Given what we talked about, it sounds like it is OK (or even preferred) to have zero warnings emitted during inference. If that's the case, we can expose a control knob that predictor can toggle to disable Warnings entirely inside TorchScript module. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we can have a global knob :) In general, I don't think that having this state across ScriptModule invocation is necessarily that bad. It's a very narrow case of state and it matches python's
I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Based on my rough understanding, that bit per Warning node would have to be in predictor, which runs inference calls in a for loop. In that case, implementation detail of TorchScript kind of leaked. Let me know what you think. |
||
|
||
// if we need to suspend, where do we reset the stack? | ||
// answer: to where it was when we were called, not | ||
// including any inputs to this function | ||
|
@@ -1487,21 +1520,35 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { | |
++frame.pc; | ||
} break; | ||
case WARN: { | ||
Node* node = frame.function->instructions_source_.at(frame.pc); | ||
// Keeps track of which WARN instruction has been executed before, | ||
// we only want to execute each WARN once to match default Python | ||
// warning behavior. | ||
bool need_warn = true; | ||
if (inst.X != -1) { | ||
need_warn = warned_nodes_.insert(inst.X); | ||
} | ||
|
||
Node* node = | ||
frames.back().function->instructions_source_.at(frame.pc); | ||
auto range = node->sourceRange().source(); | ||
if (range->filename()) { | ||
auto line = range->starting_line_no() + | ||
range->lineno_for_offset(node->sourceRange().start()); | ||
drop(stack, 1); | ||
c10::SourceLocation location{ | ||
"", range->filename()->c_str(), uint32_t(line)}; | ||
// Sends the warning to the warning handler with the | ||
// "verbatim" flag. This flag ensures the warning handler | ||
// will print the exception as configured. | ||
c10::Warning::warn( | ||
location, pop(stack).toStringRef(), /*verbatim=*/true); | ||
const auto msg = pop(stack).toStringRef(); | ||
if (need_warn) { | ||
auto line = range->starting_line_no() + | ||
range->lineno_for_offset(node->sourceRange().start()); | ||
c10::SourceLocation location{ | ||
"", range->filename()->c_str(), uint32_t(line)}; | ||
// Sends the warning to the warning handler with the | ||
// "verbatim" flag. This flag ensures the warning handler | ||
// will print the exception as configured. | ||
c10::Warning::warn(location, msg, /*verbatim=*/true); | ||
} | ||
} else { | ||
TORCH_WARN(pop(stack).toStringRef()); | ||
const auto msg = pop(stack).toStringRef(); | ||
if (need_warn) { | ||
TORCH_WARN(msg); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was the case with this code even earlier, but I find it suspicious that this branch didn't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think adding an extra stack causes interpreter error. It kind of makes sense since lack of file info means There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info). |
||
} | ||
++frame.pc; | ||
} break; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, so in this case it should be
1
, not2
- that's what Python is doing