New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sure each warnings.warn only executes once inside TorchScript. #45382
Conversation
💊 CI failures summary and remediationsAs of commit c214d52 (more details on the Dr. CI page): Commit c214d52 was recently pushed. Waiting for builds... This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 65 times. |
There is a still a doctest failing, but I wanted to get some early feedbacks on the approach, thus sending out for review now. |
namespace jit { | ||
|
||
void AnnotateWarns(Block* b) { | ||
static int64_t idx = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably need to make it atomic as compilation might happen concurrently from different threads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ended up removing the idx per your suggestion.
@@ -857,7 +870,11 @@ struct CodeImpl { | |||
|
|||
void emitWarn(Node* node) { | |||
emitLoadInputs(node->inputs()); | |||
insertInstruction(WARN); | |||
int64_t idx = -1; | |||
if (node->hasAttribute(attr::warn_id)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious - why do we need to have separate indices? would using just Node*
as a key here be sufficient? or are there some considerations with inlining that might change that? (I'm not TS expert, so @suo probably has a better idea)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Changed to using Node*
as key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New turn of events.
Using Node*
as key doesn't work for ProflingExecutor, which unrolls loops and thus creating many copies of same node, resulting in many calls to warnings. Switched back to using an unique ID attached to aten::warn
.
// TODO TODO this set should be graph specific, rather than global. | ||
bool need_warn = true; | ||
if (inst.X != -1) { | ||
auto inserted = warned_indices_.insert(inst.X); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably need a lock for it because afaiu the same Code
can run concurrently from several threads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
const auto msg = pop(stack).toStringRef(); | ||
if (need_warn) { | ||
TORCH_WARN(msg); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was the case with this code even earlier, but I find it suspicious that this branch didn't have drop(stack,1)
. Do we actually have unittests that verify it? It might be a bug because I think that additional argument (i.e. stacklevel) shouldn't depend on presence or absence of range information. Safer fix might be to record # of node.inputs in inst.N
and use it to drop from the stack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding an extra stack causes interpreter error. It kind of makes sense since lack of file info means stack_level
is not meaningful. Anyway, I will do some more investigation and address this issue in a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info).
instructions_.emplace_back(op, X, N); | ||
instructions_.emplace_back( | ||
op, | ||
safe_narrow_cast<int32_t, int64_t>(X), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, won't it throw if there are > int32 distinct warnings? it actually can happen because you keep increasing the same counter across all models, so it might overflow eventually. you probably want explicit static_cast in emitWarn i n this case as the truncation is ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this logic since index is no longer needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added static_cast
at time of emitting warn op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in a current form, minor comments are optional
// ensure each WARN instruction only executes once to mimic Python behavior. | ||
struct WarnedNodes { | ||
public: | ||
bool contains(Node* n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to be a bit paranoid (not that it matters in this case): there can be a race between two threads between calls to contains
and insert
that makes two logs to be logged. You can go by with a single method shoud_log_once
that both tries to insert and returns a bool whether it was successfully inserted.
Given it's a cheap operation, the simple (not read-write) mutex is fine too
const auto msg = pop(stack).toStringRef(); | ||
if (need_warn) { | ||
TORCH_WARN(msg); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@gmagogsfm merged this pull request in d150d3e. |
#46369) Summary: This diff restores previous behavior of silently allow overflowing when inserting instructions. The behavior was changed recently in #45382. But it started to break some existing use cases that haver overflow problems. Restoring original behavior but throw a warning to to unblock existing use cases where overflowing happens. Pull Request resolved: #46369 Reviewed By: kwanmacher, wanchaol, fbhuba Differential Revision: D24324345 Pulled By: gmagogsfm fbshipit-source-id: 1c0fac421d4de38f070e21059bbdc1b788575bdf
std::unordered_set<int32_t> warned_nodes_; | ||
}; | ||
|
||
WarnedNodes warned_nodes_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually not we want - multiple invocations of the same function should still warn only once (like python does). Sorry for missing it in the review. You probably need to move it to Code level or something like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another suggestion - instead of unordered_set, just create a vector<atomic> inside Code and use that to modify the index directly. I think this way you don't need locking. Or you can still have unordered_map, but pre-populate it with all possible indices beforehand so that you don't need the lock
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, I wonder if what python does is actually desired. Think about following use case:
@torch.jit.script
def issue_warning_wrong_dtype(dtype: str):
warnings.warn("Incorrect data type " + dtype)
It may be called in various callsites, and they are all useful. IMHO, it is slightly emit warning for each unique callsite. Number of warnings would still not be too spammy because there likely won't be that many callsites to the extend that they feel spammy to users. What do you think?
I am still investigating the issue reported internally, which I feel is a multi-threading issues similar to #46684. Will update here once finding out more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, ideally we should respect stacklevel
argument, like Python does. I.e. in you issue_warning_wrong_dtype
example, it'd be stacklevel=2
or something. However, that's much harder to implement and frankly I don't think it's worth the effort as warnings in TorchScript are kind of a niche use case in general.
Fixing the spammy warning is important though, because the typical case is inference, where model is called many times in a loop and each of them creates independent InterpreterState. So the current implementation sadly will log on each invocation spamming the logs. (that's pretty much the internal issue you're referring to).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this problem a lot but couldn't think of a clean way to implement the "warn once" behavior, which requires maintaining a global state that is above ScriptModule. This sort of goes against TorchScript's philosophy of keep states of ScriptModule local to itself.
Here is an alternative solution: Given what we talked about, it sounds like it is OK (or even preferred) to have zero warnings emitted during inference. If that's the case, we can expose a control knob that predictor can toggle to disable Warnings entirely inside TorchScript module. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can have a global knob :)
In general, I don't think that having this state across ScriptModule invocation is necessarily that bad. It's a very narrow case of state and it matches python's warnings
module behavior.
requires maintaining a global state that is above ScriptModule
I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.
Based on my rough understanding, that bit per Warning node would have to be in predictor, which runs inference calls in a for loop. In that case, implementation detail of TorchScript kind of leaked. Let me know what you think.
FileCheck() \ | ||
.check_count( | ||
str="UserWarning: I am warning you", | ||
count=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, so in this case it should be 1
, not 2
- that's what Python is doing
aten::warn
so that each has its unique idaten::warn
has been executed before and skip themFixes #45108