Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure each warnings.warn only executes once inside TorchScript. #45382

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/core/interned_strings.h
Expand Up @@ -360,7 +360,8 @@ namespace c10 {
_(attr, scope) \
_(attr, keepdims) \
_(attr, cache_id) \
_(attr, new_axis)
_(attr, new_axis) \
_(attr, warn_id)
#else
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
Expand Down
165 changes: 165 additions & 0 deletions test/jit/test_warn.py
@@ -0,0 +1,165 @@
import os
import sys
import io

import torch
import warnings
from contextlib import redirect_stderr
from torch.testing import FileCheck

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")


class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")

f = io.StringIO()
with redirect_stderr(f):
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())

def test_warn_only_once(self):
@torch.jit.script
def fn():
for _ in range(10):
warnings.warn("I am warning you")

f = io.StringIO()
with redirect_stderr(f):
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())

def test_warn_only_once_in_loop_func(self):
def w():
warnings.warn("I am warning you")

@torch.jit.script
def fn():
for _ in range(10):
w()

f = io.StringIO()
with redirect_stderr(f):
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())

def test_warn_once_per_func(self):
def w1():
warnings.warn("I am warning you")

def w2():
warnings.warn("I am warning you")

@torch.jit.script
def fn():
w1()
w2()

f = io.StringIO()
with redirect_stderr(f):
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())

def test_warn_once_per_func_in_loop(self):
def w1():
warnings.warn("I am warning you")

def w2():
warnings.warn("I am warning you")

@torch.jit.script
def fn():
for _ in range(10):
w1()
w2()

f = io.StringIO()
with redirect_stderr(f):
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())

def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")

f = io.StringIO()
with redirect_stderr(f):
fn()
fn()

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so in this case it should be 1, not 2 - that's what Python is doing

exactly=True) \
.run(f.getvalue())

def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
warnings.warn("I am warning you from " + caller)

@torch.jit.script
def foo():
warn("foo")

@torch.jit.script
def bar():
warn("bar")

f = io.StringIO()
with redirect_stderr(f):
foo()
bar()

FileCheck() \
.check_count(
str="UserWarning: I am warning you from foo",
count=1,
exactly=True) \
.check_count(
str="UserWarning: I am warning you from bar",
count=1,
exactly=True) \
.run(f.getvalue())
1 change: 1 addition & 0 deletions test/test_jit.py
Expand Up @@ -32,6 +32,7 @@
from jit.test_enum import TestEnum # noqa: F401
from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401

# Torch
from torch import Tensor
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Expand Up @@ -148,6 +148,7 @@ core_sources_full = [
"torch/csrc/jit/ir/scope.cpp",
"torch/csrc/jit/ir/subgraph_matcher.cpp",
"torch/csrc/jit/jit_log.cpp",
"torch/csrc/jit/passes/annotate_warns.cpp",
"torch/csrc/jit/passes/bailout_graph.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",
"torch/csrc/jit/passes/canonicalize.cpp",
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Expand Up @@ -8,6 +8,7 @@
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/annotate_warns.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
Expand Down Expand Up @@ -4090,6 +4091,10 @@ void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {

// For jitter
CanonicalizeOutputs(to_clean);

// Annotate aten::warns so that each has its unique ID. This enables us to
// mimic Python behavior of only emitting each warning only once.
AnnotateWarns(to_clean);
}

// we consider _N where N is a number, to be a non-meaningful name
Expand Down
29 changes: 29 additions & 0 deletions torch/csrc/jit/passes/annotate_warns.cpp
@@ -0,0 +1,29 @@
#include <torch/csrc/jit/passes/annotate_warns.h>

#include <atomic>

namespace torch {
namespace jit {

void AnnotateWarns(Block* b) {
static std::atomic<int64_t> idx(0);
for (Node* n : b->nodes()) {
for (Block* child_b : n->blocks()) {
AnnotateWarns(child_b);
}

if (n->kind() != aten::warn) {
continue;
}

n->i_(attr::warn_id, idx);
idx++;
}
}

void AnnotateWarns(const std::shared_ptr<Graph>& graph) {
AnnotateWarns(graph->block());
}

} // namespace jit
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/annotate_warns.h
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

TORCH_API void AnnotateWarns(const std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
2 changes: 1 addition & 1 deletion torch/csrc/jit/runtime/instruction.h
Expand Up @@ -52,7 +52,7 @@ namespace jit {
_(ISINSTANCE, "TI") /* check object is one of types[X:X+N] */ \
_(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */ \
_(FORK, "CN") /* launch a thread to run code entry x with N inputs */ \
_(WARN, "") /* emit a warning with line information */ \
_(WARN, "I") /* emit a warning with line information */ \
_(ENTER, "EN") /* enter scope of a contextmanager */ \
_(EXIT, "EX") /* exit the last entered contextmanager */

Expand Down
73 changes: 60 additions & 13 deletions torch/csrc/jit/runtime/interpreter.cpp
Expand Up @@ -428,6 +428,16 @@ struct TLSCurrentInterpreterGuard {
InterpreterStateImpl* prev_state_;
};

template <class Ttarget, class Tsource>
Ttarget safe_narrow_cast(Tsource v) {
Ttarget res = static_cast<Ttarget>(v);
// Casting it back to check whether it overflew.
if (static_cast<Tsource>(res) != v) {
throw std::runtime_error("safe_narrow_cast<>() failed due to overflow");
}
return res;
}

struct CodeImpl {
friend struct InterpreterState;
std::vector<Instruction> instructions_;
Expand Down Expand Up @@ -535,7 +545,10 @@ struct CodeImpl {
}

void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
instructions_.emplace_back(op, X, N);
instructions_.emplace_back(
op,
safe_narrow_cast<int32_t, int64_t>(X),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, won't it throw if there are > int32 distinct warnings? it actually can happen because you keep increasing the same counter across all models, so it might overflow eventually. you probably want explicit static_cast in emitWarn i n this case as the truncation is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this logic since index is no longer needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added static_cast at time of emitting warn op.

safe_narrow_cast<int16_t, int64_t>(N));
instructions_source_.emplace_back(current_node_);

// check that we didn't accidentally emit nodes out of topological order
Expand Down Expand Up @@ -873,7 +886,11 @@ struct CodeImpl {

void emitWarn(Node* node) {
emitLoadInputs(node->inputs());
insertInstruction(WARN);
int32_t idx = -1;
if (node->hasAttribute(attr::warn_id)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious - why do we need to have separate indices? would using just Node* as a key here be sufficient? or are there some considerations with inlining that might change that? (I'm not TS expert, so @suo probably has a better idea)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed to using Node* as key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New turn of events.

Using Node* as key doesn't work for ProflingExecutor, which unrolls loops and thus creating many copies of same node, resulting in many calls to warnings. Switched back to using an unique ID attached to aten::warn.

idx = static_cast<int32_t>(node->i(attr::warn_id));
}
insertInstruction(WARN, idx);
}

void emitEnter(Node* node) {
Expand Down Expand Up @@ -1017,6 +1034,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
}

private:
struct WarnedNodes {
public:
// Inserts idx into warned_nodes_, returns a boolean indicates whether
// insertion actually happened (idx wasn't originally in the set).
bool insert(int32_t idx) {
std::unique_lock<std::mutex> lock(mutex_);
return warned_nodes_.insert(idx).second;
}

private:
std::mutex mutex_;
std::unordered_set<int32_t> warned_nodes_;
};

WarnedNodes warned_nodes_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not we want - multiple invocations of the same function should still warn only once (like python does). Sorry for missing it in the review. You probably need to move it to Code level or something like this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another suggestion - instead of unordered_set, just create a vector<atomic> inside Code and use that to modify the index directly. I think this way you don't need locking. Or you can still have unordered_map, but pre-populate it with all possible indices beforehand so that you don't need the lock

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I wonder if what python does is actually desired. Think about following use case:

@torch.jit.script
def issue_warning_wrong_dtype(dtype: str):
    warnings.warn("Incorrect data type " + dtype)

It may be called in various callsites, and they are all useful. IMHO, it is slightly emit warning for each unique callsite. Number of warnings would still not be too spammy because there likely won't be that many callsites to the extend that they feel spammy to users. What do you think?

I am still investigating the issue reported internally, which I feel is a multi-threading issues similar to #46684. Will update here once finding out more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, ideally we should respect stacklevel argument, like Python does. I.e. in you issue_warning_wrong_dtype example, it'd be stacklevel=2 or something. However, that's much harder to implement and frankly I don't think it's worth the effort as warnings in TorchScript are kind of a niche use case in general.

Fixing the spammy warning is important though, because the typical case is inference, where model is called many times in a loop and each of them creates independent InterpreterState. So the current implementation sadly will log on each invocation spamming the logs. (that's pretty much the internal issue you're referring to).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this problem a lot but couldn't think of a clean way to implement the "warn once" behavior, which requires maintaining a global state that is above ScriptModule. This sort of goes against TorchScript's philosophy of keep states of ScriptModule local to itself.

Here is an alternative solution: Given what we talked about, it sounds like it is OK (or even preferred) to have zero warnings emitted during inference. If that's the case, we can expose a control knob that predictor can toggle to disable Warnings entirely inside TorchScript module. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can have a global knob :)

In general, I don't think that having this state across ScriptModule invocation is necessarily that bad. It's a very narrow case of state and it matches python's warnings module behavior.

requires maintaining a global state that is above ScriptModule

I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.

Based on my rough understanding, that bit per Warning node would have to be in predictor, which runs inference calls in a for loop. In that case, implementation detail of TorchScript kind of leaked. Let me know what you think.


// if we need to suspend, where do we reset the stack?
// answer: to where it was when we were called, not
// including any inputs to this function
Expand Down Expand Up @@ -1487,21 +1520,35 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
++frame.pc;
} break;
case WARN: {
Node* node = frame.function->instructions_source_.at(frame.pc);
// Keeps track of which WARN instruction has been executed before,
// we only want to execute each WARN once to match default Python
// warning behavior.
bool need_warn = true;
if (inst.X != -1) {
need_warn = warned_nodes_.insert(inst.X);
}

Node* node =
frames.back().function->instructions_source_.at(frame.pc);
auto range = node->sourceRange().source();
if (range->filename()) {
auto line = range->starting_line_no() +
range->lineno_for_offset(node->sourceRange().start());
drop(stack, 1);
c10::SourceLocation location{
"", range->filename()->c_str(), uint32_t(line)};
// Sends the warning to the warning handler with the
// "verbatim" flag. This flag ensures the warning handler
// will print the exception as configured.
c10::Warning::warn(
location, pop(stack).toStringRef(), /*verbatim=*/true);
const auto msg = pop(stack).toStringRef();
if (need_warn) {
auto line = range->starting_line_no() +
range->lineno_for_offset(node->sourceRange().start());
c10::SourceLocation location{
"", range->filename()->c_str(), uint32_t(line)};
// Sends the warning to the warning handler with the
// "verbatim" flag. This flag ensures the warning handler
// will print the exception as configured.
c10::Warning::warn(location, msg, /*verbatim=*/true);
}
} else {
TORCH_WARN(pop(stack).toStringRef());
const auto msg = pop(stack).toStringRef();
if (need_warn) {
TORCH_WARN(msg);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was the case with this code even earlier, but I find it suspicious that this branch didn't have drop(stack,1). Do we actually have unittests that verify it? It might be a bug because I think that additional argument (i.e. stacklevel) shouldn't depend on presence or absence of range information. Safer fix might be to record # of node.inputs in inst.N and use it to drop from the stack

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding an extra stack causes interpreter error. It kind of makes sense since lack of file info means stack_level is not meaningful. Anyway, I will do some more investigation and address this issue in a later PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info).

}
++frame.pc;
} break;
Expand Down