Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add named tuple's error message and workaround for RET failure #46347

Closed
wants to merge 4 commits into from

Conversation

jwpark1985
Copy link
Contributor

@jwpark1985 jwpark1985 commented Oct 14, 2020

Stack from ghstack:

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:

  1. ins.op == RET (for returing)
  2. type->kind() == TypeKind::TupleType (for pruning non-tuple types)
  3. type->cast().name() (for pruning Tuple type)
  • I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.

[Information of Tuple and NamedTuple types]

  1. Tuple
    type->str(): (int, int)
    type->repr_str(): Tuple[int, int]
    type->kind(): TypeKind::TupleType # different with other types
    type()->cast(): True
    type()->cast()>name(): False # different with NamedTuple

  2. NamedTuple
    type->str(): torch.myNamedTuple
    type->repr_str(): torch.myNamedTuple
    type->kind(): TypeKind::TupleType # different with other types
    type()->cast(): True
    type->cast().name() = True # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List, Dict and 2) accessing Module class's member functions)

Differential Revision: D24291962

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 14, 2020
jwpark1985 pushed a commit that referenced this pull request Oct 14, 2020
Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)

ghstack-source-id: 114324445
Pull Request resolved: #46347
@dr-ci
Copy link

dr-ci bot commented Oct 14, 2020

💊 CI failures summary and remediations

As of commit 8f12b8b (more details on the Dr. CI page):



1 failure confirmed as flaky and can be ignored:

  • docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c

ci.pytorch.org: 2 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 18 times.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Oct 14, 2020

💊 CI failures summary and remediations

As of commit 8f12b8b (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR---

1 failure not recognized by patterns:

Job Step Action
CircleCI docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c Check if image should be built 🔁 rerun
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 13 times.

…lure"

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)

[ghstack-poisoned]
jwpark1985 pushed a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46347

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)
ghstack-source-id: 114335601

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)
…lure"

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)

[ghstack-poisoned]
jwpark1985 pushed a commit that referenced this pull request Oct 14, 2020
Pull Request resolved: #46347

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)
ghstack-source-id: 114339694

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)
@codecov
Copy link

codecov bot commented Oct 15, 2020

Codecov Report

Merging #46347 into gh/jwpark1985/1/base will decrease coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@                   Coverage Diff                    @@
##           gh/jwpark1985/1/base   #46347      +/-   ##
========================================================
- Coverage                 68.33%   68.33%   -0.01%     
========================================================
  Files                       408      408              
  Lines                     53758    53758              
========================================================
- Hits                      36735    36734       -1     
- Misses                    17023    17024       +1     
Impacted Files Coverage Δ
torch/testing/_internal/expecttest.py 77.55% <0.00%> (-1.03%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 635aebd...8f12b8b. Read the comment docs.

…lure"

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)

[ghstack-poisoned]
jwpark1985 pushed a commit that referenced this pull request Oct 15, 2020
Pull Request resolved: #46347

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)
ghstack-source-id: 114361762

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 92921c8.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 92921c8.

@facebook-github-bot facebook-github-bot deleted the gh/jwpark1985/1/head branch October 19, 2020 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants