Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep output type after calling SubgraphRewriter #65453

Closed
wants to merge 2 commits into from

Conversation

XiaobingSuper
Copy link
Collaborator

For jit SubgraphRewriter, it doesn't keep output type after overwriting the old graph, for example, in profiling mode, the old graph has the old operator's shapes, but after replacing the old operator with a newer operator by applying SubgraphRewriter, the tensor shape info was eliminated.

The activation is that I want to replace pytorch convolution with a customer's convolution, I first register aten::_convolution as a profiler node that can reorder the input and output's shapes, and then using graph rewrite to replace it as aten::conv2d, which tensors' shapes info are eliminated. I hope using input size do some pre-progress before replacing aten::conv2d with the customer's convolution.

Before rewrite:

graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/                      site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2                      2:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %x : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::_convolution(%x.1, %weight, %4,                       %3, %2, %3, %6, %2, %7, %6, %6, %5, %5), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.                      6/site-packages/torch/nn/modules/conv.py:443:0
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%x, %z, %7) # jit_test.py:                      24:0
  return (%16)

after rewrite by using aten::conv2d

graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:22:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %18 : Tensor = aten::conv2d(%x.1, %weight, %4, %3, %2, %3, %7)
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py:24:0
  return (%16)

expected result after replace aten::_convolution with aten::conv2d:

graph(%self.1 : __torch__.MyModule,
      %x.1 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu)):
  %7 : int = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6/                      site-packages/torch/nn/modules/conv.py:443:0
  %6 : bool = prim::Constant[value=0](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %5 : bool = prim::Constant[value=1](), scope: __module.conv # /home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.6                      /site-packages/torch/nn/modules/conv.py:443:0
  %4 : NoneType = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %conv : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv"](%self.1)
  %z : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::clone(%x.1, %4) # jit_test.py:2                      2:0
  %weight : Float(3, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=0, device=cpu) = prim::GetAttr[name="weight"](%conv)
  %18 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %weight, %4, %3,                       %2, %3, %7)
  %16 : Float(2, 3, 20, 20, strides=[1200, 400, 20, 1], requires_grad=0, device=cpu) = aten::add(%18, %z, %7) # jit_test.py                      :24:0
  return (%16)

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 22, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Sep 22, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit d53c4be (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@codecov
Copy link

codecov bot commented Sep 22, 2021

Codecov Report

Merging #65453 (80b9ebc) into master (64d3c73) will decrease coverage by 0.00%.
The diff coverage is n/a.

❗ Current head 80b9ebc differs from pull request most recent head d53c4be. Consider uploading reports for the commit d53c4be to get more accurate results

@@            Coverage Diff             @@
##           master   #65453      +/-   ##
==========================================
- Coverage   66.38%   66.37%   -0.01%     
==========================================
  Files         739      739              
  Lines       94295    94299       +4     
==========================================
- Hits        62594    62592       -2     
- Misses      31701    31707       +6     

@ZolotukhinM
Copy link

That change looks reasonable, thank you for making it! Do you mind adding a test for this scenario as well? You can put it here: https://github.com/pytorch/pytorch/blob/master/test/cpp/jit/test_subgraph_rewriter.cpp

@XiaobingSuper
Copy link
Collaborator Author

@ZolotukhinM, do you know how to check a graph node has shape info? I don't find a test case to check it.

@ZolotukhinM
Copy link

We can access type info from the Value* in JIT graph. This code can be used as an example:

auto const& it = v->type()->cast<TensorType>();
c10::ScalarType dtype = c10::ScalarType::Float;
if (!it) {
return c10::nullopt;
}
if (!it->isComplete()) {
return c10::nullopt;
}
if (it->scalarType()) {
// TODO: ideally we should be strict here and return nullopt if the dtype is
// absent in the JIT IR. We're assuming a default Float dtype for now, until
// dtype propagation is implemented.
dtype = *it->scalarType();
}
auto concrete_sizes = it->sizes().concrete_sizes();
if (!concrete_sizes) {
return c10::nullopt;
}

Alternatively, and it can be a better way for testing I think, we could simply print the IR after the rewrite and see if the new IR has the shape info. We could then use FileCheck statements to scan the output and search for the shape info in it (when a value has no shape info, it's printed as "%x : Tensor", when it does have a shape info, it's printed like "%x : Float(10, 20)").

@XiaobingSuper
Copy link
Collaborator Author

@ZolotukhinM, one test case is added.

Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks! Do you need help with merging the PR?

@XiaobingSuper
Copy link
Collaborator Author

@ZolotukhinM, yes, please help merge it, thanks!

@facebook-github-bot
Copy link
Contributor

@ZolotukhinM has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ZolotukhinM merged this pull request in 1682722.

@malfet malfet added this to the 1.10.1 milestone Nov 2, 2021
@seemethere
Copy link
Member

@malfet I see we marked this for inclusion in the 1.10.1 release, can we link to the issue that this resolves to verify it fixes a regression?

@malfet
Copy link
Contributor

malfet commented Dec 8, 2021

@seemethere this PR adds a test for the regression

@XiaobingSuper XiaobingSuper deleted the jit-rewrite branch December 20, 2021 04:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants