Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure each warnings.warn only executes once inside TorchScript. #45382

Closed
wants to merge 1 commit into from
Closed

Make sure each warnings.warn only executes once inside TorchScript. #45382

wants to merge 1 commit into from

Conversation

gmagogsfm
Copy link
Contributor

@gmagogsfm gmagogsfm commented Sep 26, 2020

  • Add a pass at end of runCleanupPasses to annotate aten::warn so that each has its unique id
  • Enhanced interpreter so that it tracks which aten::warn has been executed before and skip them
  • Improved insertInstruction so that it correctly checks for overflow

Fixes #45108

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 26, 2020
@dr-ci
Copy link

dr-ci bot commented Sep 26, 2020

💊 CI failures summary and remediations

As of commit c214d52 (more details on the Dr. CI page):


Commit c214d52 was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 65 times.

@gmagogsfm gmagogsfm marked this pull request as ready for review September 28, 2020 22:14
@gmagogsfm
Copy link
Contributor Author

There is a still a doctest failing, but I wanted to get some early feedbacks on the approach, thus sending out for review now.

namespace jit {

void AnnotateWarns(Block* b) {
static int64_t idx = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably need to make it atomic as compilation might happen concurrently from different threads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up removing the idx per your suggestion.

@@ -857,7 +870,11 @@ struct CodeImpl {

void emitWarn(Node* node) {
emitLoadInputs(node->inputs());
insertInstruction(WARN);
int64_t idx = -1;
if (node->hasAttribute(attr::warn_id)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious - why do we need to have separate indices? would using just Node* as a key here be sufficient? or are there some considerations with inlining that might change that? (I'm not TS expert, so @suo probably has a better idea)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed to using Node* as key.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New turn of events.

Using Node* as key doesn't work for ProflingExecutor, which unrolls loops and thus creating many copies of same node, resulting in many calls to warnings. Switched back to using an unique ID attached to aten::warn.

// TODO TODO this set should be graph specific, rather than global.
bool need_warn = true;
if (inst.X != -1) {
auto inserted = warned_indices_.insert(inst.X);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably need a lock for it because afaiu the same Code can run concurrently from several threads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

const auto msg = pop(stack).toStringRef();
if (need_warn) {
TORCH_WARN(msg);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was the case with this code even earlier, but I find it suspicious that this branch didn't have drop(stack,1). Do we actually have unittests that verify it? It might be a bug because I think that additional argument (i.e. stacklevel) shouldn't depend on presence or absence of range information. Safer fix might be to record # of node.inputs in inst.N and use it to drop from the stack

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding an extra stack causes interpreter error. It kind of makes sense since lack of file info means stack_level is not meaningful. Anyway, I will do some more investigation and address this issue in a later PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info).

instructions_.emplace_back(op, X, N);
instructions_.emplace_back(
op,
safe_narrow_cast<int32_t, int64_t>(X),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, won't it throw if there are > int32 distinct warnings? it actually can happen because you keep increasing the same counter across all models, so it might overflow eventually. you probably want explicit static_cast in emitWarn i n this case as the truncation is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this logic since index is no longer needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added static_cast at time of emitting warn op.

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in a current form, minor comments are optional

// ensure each WARN instruction only executes once to mimic Python behavior.
struct WarnedNodes {
public:
bool contains(Node* n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to be a bit paranoid (not that it matters in this case): there can be a race between two threads between calls to contains and insert that makes two logs to be logged. You can go by with a single method shoud_log_once that both tries to insert and returns a bool whether it was successfully inserted.

Given it's a cheap operation, the simple (not read-write) mutex is fine too

const auto msg = pop(stack).toStringRef();
if (need_warn) {
TORCH_WARN(msg);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm just asking whether we actually ever have a codepath exercising this branch (as usually we have the file info).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@gmagogsfm merged this pull request in d150d3e.

facebook-github-bot pushed a commit that referenced this pull request Oct 15, 2020
#46369)

Summary:
This diff restores previous behavior of silently allow overflowing when inserting instructions. The behavior was changed recently in #45382. But it started to break some existing use cases that haver overflow problems.

Restoring original behavior but throw a warning to to unblock existing use cases where overflowing happens.

Pull Request resolved: #46369

Reviewed By: kwanmacher, wanchaol, fbhuba

Differential Revision: D24324345

Pulled By: gmagogsfm

fbshipit-source-id: 1c0fac421d4de38f070e21059bbdc1b788575bdf
std::unordered_set<int32_t> warned_nodes_;
};

WarnedNodes warned_nodes_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not we want - multiple invocations of the same function should still warn only once (like python does). Sorry for missing it in the review. You probably need to move it to Code level or something like this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another suggestion - instead of unordered_set, just create a vector<atomic> inside Code and use that to modify the index directly. I think this way you don't need locking. Or you can still have unordered_map, but pre-populate it with all possible indices beforehand so that you don't need the lock

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I wonder if what python does is actually desired. Think about following use case:

@torch.jit.script
def issue_warning_wrong_dtype(dtype: str):
    warnings.warn("Incorrect data type " + dtype)

It may be called in various callsites, and they are all useful. IMHO, it is slightly emit warning for each unique callsite. Number of warnings would still not be too spammy because there likely won't be that many callsites to the extend that they feel spammy to users. What do you think?

I am still investigating the issue reported internally, which I feel is a multi-threading issues similar to #46684. Will update here once finding out more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, ideally we should respect stacklevel argument, like Python does. I.e. in you issue_warning_wrong_dtype example, it'd be stacklevel=2 or something. However, that's much harder to implement and frankly I don't think it's worth the effort as warnings in TorchScript are kind of a niche use case in general.

Fixing the spammy warning is important though, because the typical case is inference, where model is called many times in a loop and each of them creates independent InterpreterState. So the current implementation sadly will log on each invocation spamming the logs. (that's pretty much the internal issue you're referring to).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this problem a lot but couldn't think of a clean way to implement the "warn once" behavior, which requires maintaining a global state that is above ScriptModule. This sort of goes against TorchScript's philosophy of keep states of ScriptModule local to itself.

Here is an alternative solution: Given what we talked about, it sounds like it is OK (or even preferred) to have zero warnings emitted during inference. If that's the case, we can expose a control knob that predictor can toggle to disable Warnings entirely inside TorchScript module. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can have a global knob :)

In general, I don't think that having this state across ScriptModule invocation is necessarily that bad. It's a very narrow case of state and it matches python's warnings module behavior.

requires maintaining a global state that is above ScriptModule

I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we can just add a bit per original Warning IR Node somewhere. That'd be the closest match to what python does and also keep state "local" to ScriptModule that owns the original graph.

Based on my rough understanding, that bit per Warning node would have to be in predictor, which runs inference calls in a for loop. In that case, implementation detail of TorchScript kind of leaked. Let me know what you think.

FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so in this case it should be 1, not 2 - that's what Python is doing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

warnings.warn is too spammy in TorchScript
4 participants