Skip to content

Conversation

ansley
Copy link

@ansley ansley commented May 25, 2021

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 25, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 7a9b4f4 (more details on the Dr. CI page):


  • 4/4 failures introduced in this PR

🕵️ 4 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-py3.8-gcc9-coverage / test (distributed, 1, 1, linux.2xlarge) (1/4)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun)

2021-09-08T21:41:51.0754934Z test_udf_remote_...yUniqueId(created_on=0, local_id=0) to be created.
2021-09-08T21:41:10.3204821Z frame #15: <unknown function> + 0x486ea (0x7f52761476ea in /opt/conda/lib/python3.8/site-packages/torch/lib/libc10.so)
2021-09-08T21:41:10.3206427Z frame #16: <unknown function> + 0xc9039 (0x7f5276053039 in /opt/conda/lib/libstdc++.so.6)
2021-09-08T21:41:10.3208193Z frame #17: <unknown function> + 0x76db (0x7f5299b406db in /lib/x86_64-linux-gnu/libpthread.so.0)
2021-09-08T21:41:10.3209810Z frame #18: clone + 0x3f (0x7f529986971f in /lib/x86_64-linux-gnu/libc.so.6)
2021-09-08T21:41:10.3210514Z 
2021-09-08T21:41:10.7360311Z ok (3.724s)
2021-09-08T21:41:25.9887815Z   test_rpc_builtin_timeout (__main__.FaultyFaultyAgentRpcTest) ... ok (15.253s)
2021-09-08T21:41:35.3320247Z   test_rpc_script_timeout (__main__.FaultyFaultyAgentRpcTest) ... ok (9.343s)
2021-09-08T21:41:39.0574855Z   test_rref_to_here_timeout (__main__.FaultyFaultyAgentRpcTest) ... ok (3.725s)
2021-09-08T21:41:46.7904960Z   test_udf_remote_message_delay_timeout (__main__.FaultyFaultyAgentRpcTest) ... ok (7.733s)
2021-09-08T21:41:51.0754934Z   test_udf_remote_message_delay_timeout_to_self (__main__.FaultyFaultyAgentRpcTest) ... [E request_callback_no_python.cpp:559] Received error while processing request type 261: falseINTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/distributed/rpc/rref_context.cpp":385, please report a bug to PyTorch. Expected OwnerRRef with id GloballyUniqueId(created_on=0, local_id=0) to be created.
2021-09-08T21:41:51.0756775Z Exception raised from getOwnerRRef at /var/lib/jenkins/workspace/torch/csrc/distributed/rpc/rref_context.cpp:385 (most recent call first):
2021-09-08T21:41:51.0758672Z frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x59 (0x7fc6341e6f59 in /opt/conda/lib/python3.8/site-packages/torch/lib/libc10.so)
2021-09-08T21:41:51.0760308Z frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xa3 (0x7fc6341bdb34 in /opt/conda/lib/python3.8/site-packages/torch/lib/libc10.so)
2021-09-08T21:41:51.0762301Z frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x61 (0x7fc6341e4341 in /opt/conda/lib/python3.8/site-packages/torch/lib/libc10.so)
2021-09-08T21:41:51.0764097Z frame #3: torch::distributed::rpc::RRefContext::getOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, bool) + 0x628 (0x7fc63d702a08 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
2021-09-08T21:41:51.0766398Z frame #4: torch::distributed::rpc::RequestCallbackNoPython::assignOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, torch::distributed::rpc::GloballyUniqueId const&, c10::intrusive_ptr<c10::ivalue::Future, c10::detail::intrusive_target_default_null_type<c10::ivalue::Future> >) const + 0x8c (0x7fc63d6e925c in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
2021-09-08T21:41:51.0773488Z frame #5: torch::distributed::rpc::RequestCallbackImpl::processPythonRemoteCall(torch::distributed::rpc::RpcCommandBase&, std::vector<c10::Stream, std::allocator<c10::Stream> >) const + 0xf5 (0x7fc64df8aa95 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
2021-09-08T21:41:51.0776136Z frame #6: torch::distributed::rpc::RequestCallbackNoPython::processRpc(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::vector<c10::Stream, std::allocator<c10::Stream> >) const + 0x1f0 (0x7fc63d6efdf0 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
2021-09-08T21:41:51.0778846Z frame #7: torch::distributed::rpc::RequestCallbackImpl::processRpcWithErrors(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::vector<c10::Stream, std::allocator<c10::Stream> >) const + 0x60 (0x7fc64df8a360 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
2021-09-08T21:41:51.0780634Z frame #8: <unknown function> + 0x9290da0 (0x7fc63d6e4da0 in /opt/conda/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)

See GitHub Actions build linux-bionic-py3.8-gcc9-coverage / test (default, 1, 2, linux.2xlarge) (2/4)

Step: "Unknown" (full log | diagnosis details | 🔁 rerun)

2021-09-08T21:33:15.6428869Z CONTINUE_THROUGH_ERROR: false
2021-09-08T21:33:15.6424643Z   PR_LABELS: [
  "oncall: jit",
  "cla signed"
]
2021-09-08T21:33:15.6425707Z   DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-py3.8-gcc9:3fb4365799993abcfc83e51d42c137e89cb2459a
2021-09-08T21:33:15.6426907Z   JOB_BASE_NAME: linux-bionic-py3.8-gcc9-coverage-test
2021-09-08T21:33:15.6427494Z   TEST_CONFIG: default
2021-09-08T21:33:15.6427802Z   SHARD_NUMBER: 1
2021-09-08T21:33:15.6428108Z   NUM_TEST_SHARDS: 2
2021-09-08T21:33:15.6428460Z   PYTORCH_IGNORE_DISABLED_ISSUES: 
2021-09-08T21:33:15.6428869Z   CONTINUE_THROUGH_ERROR: false
2021-09-08T21:33:15.6429193Z   SHM_SIZE: 1g
2021-09-08T21:33:15.6429479Z   PR_NUMBER: 58911
2021-09-08T21:33:15.6429776Z ##[endgroup]
2021-09-08T21:33:30.4705948Z Processing ./dist/torch-1.10.0a0+git2e479f6-cp38-cp38-linux_x86_64.whl
2021-09-08T21:33:30.4969686Z Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch==1.10.0a0+git2e479f6) (3.10.0.0)
2021-09-08T21:33:30.7419854Z Installing collected packages: torch
2021-09-08T21:33:37.0293786Z Successfully installed torch-1.10.0a0+git2e479f6
2021-09-08T21:33:37.3302716Z ++++ dirname .jenkins/pytorch/common.sh
2021-09-08T21:33:37.3309578Z +++ cd .jenkins/pytorch
2021-09-08T21:33:37.3311314Z +++ pwd -P

See GitHub Actions build linux-xenial-py3.6-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (3/4)

Step: "Test PyTorch" (full log | diagnosis details | 🔁 rerun)

2021-09-08T21:24:52.7427895Z The PR is introduc...m to confirm whether this change is wanted or not.
2021-09-08T21:24:52.7413927Z processing existing schema:  alltoall_base(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor _1, Tensor _2, int[] _3, int[] _4) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7415355Z processing existing schema:  alltoall(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, Tensor[] _2) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7416747Z processing existing schema:  send(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7418100Z processing existing schema:  recv(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2, int _3) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7419558Z processing existing schema:  recv_anysource(__torch__.torch.classes.dist_c10d.ProcessGroup _0, Tensor[] _1, int _2) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7420951Z processing existing schema:  barrier(__torch__.torch.classes.dist_c10d.ProcessGroup _0) -> (__torch__.torch.classes.dist_c10d.Work _0)
2021-09-08T21:24:52.7422083Z processing existing schema:  __init__(__torch__.torch.classes.dist_c10d.frontend _0) -> (NoneType _0)
2021-09-08T21:24:52.7423579Z processing existing schema:  new_process_group_helper(__torch__.torch.classes.dist_c10d.frontend _0, int _1, int _2, int[] _3, str _4, __torch__.torch.classes.dist_c10d.Store _5, str? _6, int _7) -> (__torch__.torch.classes.dist_c10d.ProcessGroup _0)
2021-09-08T21:24:52.7425218Z processing existing schema:  get_process_group_by_name(__torch__.torch.classes.dist_c10d.frontend _0, str _1) -> (__torch__.torch.classes.dist_c10d.ProcessGroup _0)
2021-09-08T21:24:52.7426659Z processing existing schema:  get_name_of_process_group(__torch__.torch.classes.dist_c10d.frontend _0, __torch__.torch.classes.dist_c10d.ProcessGroup _1) -> (str _0)
2021-09-08T21:24:52.7427895Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2021-09-08T21:24:52.7428519Z 
2021-09-08T21:24:52.7428796Z Broken ops: [
2021-09-08T21:24:52.7429410Z 	aten::concat(Tensor[] tensors, int dim=0) -> (Tensor)
2021-09-08T21:24:52.7430171Z 	aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> (Tensor(a!))
2021-09-08T21:24:52.7430936Z 	aten::concat.names(Tensor[] tensors, str dim) -> (Tensor)
2021-09-08T21:24:52.7431749Z 	aten::concat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> (Tensor(a!))
2021-09-08T21:24:52.7432201Z ]
2021-09-08T21:24:52.7432462Z + cleanup
2021-09-08T21:24:52.7432875Z + retcode=1
2021-09-08T21:24:52.7433152Z + set +x

See GitHub Actions build win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (4/4)

Step: "Run test scripts" (full log | diagnosis details | 🔁 rerun)

2021-09-08T22:20:38.1022112Z ERROR [0.000s]: test_poisson_sample (__main__.TestDistributions)
2021-09-08T22:20:38.1013274Z   File "distributions/test_distributions.py", line 812, in _check_sampler_discrete
2021-09-08T22:20:38.1014337Z     chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
2021-09-08T22:20:38.1015451Z   File "c:\jenkins\miniconda3\lib\site-packages\scipy\stats\stats.py", line 6852, in chisquare
2021-09-08T22:20:38.1016429Z     return power_divergence(f_obs, f_exp=f_exp, ddof=ddof, axis=axis,
2021-09-08T22:20:38.1017528Z   File "c:\jenkins\miniconda3\lib\site-packages\scipy\stats\stats.py", line 6694, in power_divergence
2021-09-08T22:20:38.1018432Z     raise ValueError(msg)
2021-09-08T22:20:38.1019657Z ValueError: For each axis slice, the sum of the observed frequencies must agree with the sum of the expected frequencies to a relative tolerance of 1e-08, but the percent differences are:
2021-09-08T22:20:38.1020775Z 0.008265582255680495
2021-09-08T22:20:38.1021072Z 
2021-09-08T22:20:38.1021448Z ======================================================================
2021-09-08T22:20:38.1022112Z ERROR [0.000s]: test_poisson_sample (__main__.TestDistributions)
2021-09-08T22:20:38.1022909Z ----------------------------------------------------------------------
2021-09-08T22:20:38.1023559Z Traceback (most recent call last):
2021-09-08T22:20:38.1024442Z   File "distributions/test_distributions.py", line 1352, in test_poisson_sample
2021-09-08T22:20:38.1025317Z     self._check_sampler_discrete(Poisson(rate),
2021-09-08T22:20:38.1026287Z   File "distributions/test_distributions.py", line 812, in _check_sampler_discrete
2021-09-08T22:20:38.1027350Z     chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
2021-09-08T22:20:38.1028448Z   File "c:\jenkins\miniconda3\lib\site-packages\scipy\stats\stats.py", line 6852, in chisquare
2021-09-08T22:20:38.1029429Z     return power_divergence(f_obs, f_exp=f_exp, ddof=ddof, axis=axis,
2021-09-08T22:20:38.1030530Z   File "c:\jenkins\miniconda3\lib\site-packages\scipy\stats\stats.py", line 6694, in power_divergence
2021-09-08T22:20:38.1031465Z     raise ValueError(msg)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

ansley pushed a commit that referenced this pull request May 25, 2021
ghstack-source-id: 3114191
Pull Request resolved: #58911
ansley pushed a commit that referenced this pull request May 27, 2021
ghstack-source-id: ff171ae
Pull Request resolved: #58911
ansley pushed a commit that referenced this pull request Jun 7, 2021
ghstack-source-id: 390877c
Pull Request resolved: #58911
ansley pushed a commit that referenced this pull request Jun 26, 2021
ghstack-source-id: 8c9e04c
Pull Request resolved: #58911
ansley pushed a commit that referenced this pull request Jul 28, 2021
ghstack-source-id: 0539263
Pull Request resolved: #58911
@ansley ansley requested review from gmagogsfm and eellison July 28, 2021 02:04
@eellison
Copy link
Contributor

eellison commented Aug 3, 2021

deferring to @gmagogsfm unless there is a specific thing u want comment on

@@ -1127,15 +1127,6 @@ void AliasDb::makePointerTo(const Value* from, const Value* to) {
// immutable. `Any` is mutable but can point to an immutable type
// through refinement
if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't look at any of the other code in here, but we should still be making a pointer if there is any type overlap between the sets. please tag me for changes related to alias analysis. this should actually have been changed as part of the Union changes... maybe you could factor out the code from refinement into a "has type overlap" function or something similar ? it would be nice if we could reuse that logic to peephole optimize away prim::isinstance

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're still making a pointer if there's type overlap; the only thing I deleted was the check leading up to the TORCH_INTERNAL_ASSERT on line 1137. The reason we need to do this is because, if we don't, the two branches for the unchecked_cast will falsely throw an error.

Example traceback from the TestUnion.test_union_with_dictliteral test:

Traceback (most recent call last):
  File "/data/users/ansley/pytorch/test/jit/test_union.py", line 705, in test_union_with_dictliteral
    self.checkScript(fn, ())
  File "/home/ansley/local/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/testing/_internal/jit_utils.py", line 454, in checkScript
    self.checkScript(
  File "/home/ansley/local/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/testing/_internal/jit_utils.py", line 485, in checkScript
    script_outputs = scripted_fn(*recording_inputs)
  File "/home/ansley/local/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 139, in prof_func_call
    return prof_callable(func_call, *args, **kwargs)
  File "/home/ansley/local/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 136, in prof_callable
    return callable(*args, **kwargs)
RuntimeError: expected_kindINTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":1138, please report a bug to PyTorch. intDict(str, Tensor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(isMutableTypeInternal(from) != isMutableTypeInternal(to)) we're checking equality not overlap here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the wording choice, I did mean equality

Copy link
Contributor

@eellison eellison Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we want to be checking overlap, e.g. Optional[List[int]] should have a pointer to List[int]. If the sometimes/never/always subtyping logic from ir_emitter was factored out we could just check that it's not never, and then make a pointer

ansley pushed a commit that referenced this pull request Aug 11, 2021
ghstack-source-id: 348572a
Pull Request resolved: #58911
union_type_hint->containedTypes().end(),
std::back_inserter(list_types),
[&](TypePtr type_ptr) {
return type_ptr->isSubtypeOf(AnyListType::get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for type_ptr->kind == ListType::Kind

Node* result = graph->insertNode(graph->createList(elem_type, values));
if (annotated_union_type) {
Node* n = graph->insertNode(
graph->create(prim::unchecked_cast, {result->output()}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline:
This node is inserted because interpreter isn't happy about the type mismatch in assignment.

TODO: Find out why interpreter isn't happy about it
TODO: Find out performance implications (if any) about inserting this unchecked_cast node into graph.

c10::optional<TypePtr> unified = unifyTypeList(
types, nowhere, /*default_to_union=*/true, element_type_hint);
types, nowhere, /*default_to_union=*/true, known_elem_type);

if (!type_hint && *unified == AnyType::get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this should never trigger when default_to_union=True.

// We don't want to use `elem_type` as the final argument to
// `unifyTypeList` because there's a chance that `elem_type` is
// the Tensor default
auto known_elem_type =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto known_elem_type =
auto elem_type_hint =

@@ -3789,13 +3837,48 @@ struct to_ir {
// `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
TypePtr elem_type = TensorType::get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TypePtr elem_type = TensorType::get();
TypePtr inferred_elem_type = TensorType::get();

x["bar"] = torch.tensor(3)
return x

self.checkScript(fn, ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
Use FileCheck on generated IR to make sure right nodes/types are inserted into graph

ansley pushed a commit that referenced this pull request Aug 17, 2021
ghstack-source-id: 422805b
Pull Request resolved: #58911
ansley pushed a commit that referenced this pull request Sep 7, 2021
ghstack-source-id: ef937cc
Pull Request resolved: #58911
@ansley
Copy link
Author

ansley commented Sep 7, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

// If the type hint was not `List[T]`, throw an error
throw ErrorReport(ll) << "Expected a List type hint but instead got "
<< type_hint->repr_str();
// If necessary/possible, make `type_hint` a ListType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The added logic is complex enough to warrant a helper function I think. What do you think?

@ansley
Copy link
Author

ansley commented Sep 7, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

ansley pushed a commit that referenced this pull request Sep 8, 2021
ghstack-source-id: 6f6ee5d
Pull Request resolved: #58911
@ansley
Copy link
Author

ansley commented Sep 8, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

2 similar comments
@ansley
Copy link
Author

ansley commented Sep 8, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ansley
Copy link
Author

ansley commented Sep 8, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ansley
Copy link
Author

ansley commented Sep 8, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

ansley pushed a commit that referenced this pull request Sep 8, 2021
ghstack-source-id: f3f63d2
Pull Request resolved: #58911
@ansley
Copy link
Author

ansley commented Sep 8, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

4 similar comments
@ansley
Copy link
Author

ansley commented Sep 9, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ansley
Copy link
Author

ansley commented Sep 9, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ansley
Copy link
Author

ansley commented Sep 10, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ansley
Copy link
Author

ansley commented Sep 10, 2021

@ansley has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ansley merged this pull request in c60075d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants