Skip to content

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Mar 9, 2020

torch.nn.functional.interpolate was written as a builtin op when we scripted the standard library, because it has four possible overloads. As a result, whenever we make a change to interpolate, we need to make changes in two places, and it also makes it impossible to optimize the interpolate op. The builtin is tech debt.

I talked with @ailzhang, and the symbolic script changes are good to remove (i guess that makes a third place we needed to re-implement interpolate).

I'm trying to get rid of unneccessary builtin operators because we're standardizing mobile bytecode soon, so we should try to get this landed as soon as possible.

@eellison eellison requested a review from apaszke as a code owner March 9, 2020 23:12
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 9, 2020
@dr-ci
Copy link

dr-ci bot commented Mar 9, 2020

💊 CircleCI build failures summary and remediations

As of commit fd82d4e (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following build failures do not appear to be due to upstream breakages (reran 2 jobs to discount flakiness):

See CircleCI build pytorch_linux_xenial_py3_clang5_mobile_custom_build_static (1/2)

Step: "Set Up CI Environment After attach_workspace" (full log | pattern match details) <confirmed not flaky by 2 failures>

E: Failed to fetch https://download.docker.com/linux/ubuntu/dists/xenial/stable/binary-amd64/Packages.bz2 Hash Sum mismatch
                                                                  96% [39 Packages store 0 B]                             Get:53 http://archive.ubuntu.com/ubuntu xenial-updates/multiverse amd64 Packages [16.8 kB] 
                                                               96% [39 Packages store 0 B]                             Get:54 http://archive.ubuntu.com/ubuntu xenial-updates/multiverse Translation-en [8,468 B] 
96% [Waiting for headers] 96% [40 Translation-en store 0 B] [Waiting for headers]                                                         Get:55 http://archive.ubuntu.com/ubuntu xenial-backports/main Sources [4,848 B] 
                                                                    96% [40 Translation-en store 0 B]                                   Get:56 http://archive.ubuntu.com/ubuntu xenial-backports/universe Sources [7,120 B] 
                                                                    96% [40 Translation-en store 0 B]                                   Get:57 http://archive.ubuntu.com/ubuntu xenial-backports/main amd64 Packages [7,280 B] 
                                                                     96% [40 Translation-en store 0 B]                                   Get:58 http://archive.ubuntu.com/ubuntu xenial-backports/main Translation-en [4,456 B] 
                                                                           96% [40 Translation-en store 0 B]                                   Get:59 http://archive.ubuntu.com/ubuntu xenial-backports/universe amd64 Packages [8,064 B] 
                                                                     96% [40 Translation-en store 0 B]                                   Get:60 http://archive.ubuntu.com/ubuntu xenial-backports/universe Translation-en [4,328 B] 
100% [60 Translation-en store 0 B]                                4,813 kB/s 0s 100% [Working]                                                    4,813 kB/s 0s                                                                                 Fetched 28.9 MB in 6s (4,801 kB/s) 
Reading package lists... 99%  Reading package lists... Done  
E: Failed to fetch https://download.docker.com/linux/ubuntu/dists/xenial/stable/binary-amd64/Packages.bz2  Hash Sum mismatch 
E: Some index files failed to download. They have been ignored, or old ones used instead. 

See CircleCI build caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build (2/2)

Step: "Set Up CI Environment After attach_workspace" (full log | pattern match details) <confirmed not flaky by 2 failures>

E: Failed to fetch https://download.docker.com/linux/ubuntu/dists/xenial/stable/binary-amd64/Packages.bz2 Hash Sum mismatch
                                                                  96% [39 Packages store 0 B]                             Get:53 http://archive.ubuntu.com/ubuntu xenial-updates/multiverse amd64 Packages [16.8 kB] 
96% [Waiting for headers] 96% [40 Translation-en store 0 B] [Waiting for headers]                                                         Get:54 http://archive.ubuntu.com/ubuntu xenial-updates/multiverse Translation-en [8,468 B] 
                                                                           96% [40 Translation-en store 0 B]                                   Get:55 http://archive.ubuntu.com/ubuntu xenial-backports/main Sources [4,848 B] 
                                                                    96% [40 Translation-en store 0 B]                                   Get:56 http://archive.ubuntu.com/ubuntu xenial-backports/universe Sources [7,120 B] 
                                                                    96% [40 Translation-en store 0 B] [Waiting for headers]                                                         Get:57 http://archive.ubuntu.com/ubuntu xenial-backports/main amd64 Packages [7,280 B] 
                                                                     96% [40 Translation-en store 0 B]                                   Get:58 http://archive.ubuntu.com/ubuntu xenial-backports/main Translation-en [4,456 B] 
                                                                           96% [40 Translation-en store 0 B]                                   Get:59 http://archive.ubuntu.com/ubuntu xenial-backports/universe amd64 Packages [8,064 B] 
                                                                     96% [40 Translation-en store 0 B]                                   Get:60 http://archive.ubuntu.com/ubuntu xenial-backports/universe Translation-en [4,328 B] 
                                   100% [Working]                Fetched 28.9 MB in 5s (4,865 kB/s) 
Reading package lists... 99%  Reading package lists... Done  
E: Failed to fetch https://download.docker.com/linux/ubuntu/dists/xenial/stable/binary-amd64/Packages.bz2  Hash Sum mismatch 
E: Some index files failed to download. They have been ignored, or old ones used instead. 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 59 times.

test/test_jit.py Outdated
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (False, 'aten::__interpolate')),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the value that changed here and why

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per @ailzhang can be removed

if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if size is not None and scale_factor is not None:
raise ValueError('only one of size or scale_factor should be defined')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to this error message? We should try to preserve the original behavior as much as possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought they were saying the same thing, but i guess they're slightly different. readded

# type: (int, Tuple[Tensor, Optional[List[int]], Optional[float], Optional[bool]]) -> List[int]
pass

def _interp_output_size(dim, closed_over_args): # noqa: F811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is kind of long, can you add an overload for _check_size_scale_factor as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of overload decls, i would have to add 16 lines to separate out these 5 error checking lines, i don't think it's worth it, can do it if you think that's worth it.

test/test_jit.py Outdated
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (False, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale', (False, 'aten::__interpolate')),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

False is the default value and it no longer need to check nodes in the differentiated graph. You can safely remove the last tuple (False, 'aten::__interpolate') for all these lines.

@eellison eellison requested review from driazati and jerryzh168 March 10, 2020 22:31
// TODO: sort returns a tuple of Tensors, we have
// to extend the API to support that
// "sort",
"__interpolate",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the corresponding ops to this list? I think it might be better to just put this change in the same PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the difference between "single_input_call_funcs" and "single_input_aten_funcs"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I be putting the ops?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Mar 11, 2020 via email

@jerryzh168
Copy link
Contributor

jerryzh168 commented Mar 11, 2020 via email

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neginraoof
Copy link
Contributor

neginraoof commented Mar 11, 2020

Hi @eellison, I'm trying to look into changes breaking onnx tests.
Just testing an interpolate layer in scripting like:

@torch.jit.script_method
def forward(self, x):
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size)

Is now producing a very large ir graph.
This was previously shown in the graph as a single aten interpolate node.
Do you know if this change is expected? And does this affect the performance for model scripting?

@eellison
Copy link
Contributor Author

eellison commented Mar 11, 2020

Hi @eellison, I'm trying to look into changes breaking onnx tests.
Just testing an interpolate layer in scripting like:

@torch.jit.script_method
def forward(self, x):
return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size)

Is now producing a very large ir graph.
This was previously shown in the graph as a single aten interpolate node.
Do you know if this change is expected? And does this affect the performance for model scripting?

Yes this is expected. We previously hacked it in as a builtin node, and are now representing it as its python code. In the short term it may be marginally slower, but in the long term will be faster, as we do a better job of optimizing away the non-Tensor ops, and potentially do codegen for the aten ops it invokes. It will also be more maintanable as now tracing & scripting creates the same ops, and we do not need 4 different implementations of interpolate (register_prim_ops, functional.py, symbolic_script, onnx_export -> functional.py).

In your example mode will pretty much always be known at compile time. self.size as a known constant is also more realistic here, because with the profiling executor that will become a constant.

def forward(self, x):
    return torch.nn.functional.interpolate(x, mode=self.mode, size=self.size)

Copy link
Contributor

@driazati driazati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torchscript and nn changes look fine

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neginraoof
Copy link
Contributor

neginraoof commented Mar 11, 2020

@eellison
I see now. So then even setting the constants in this case just cuts the graph by a few lines.
I can see all the interpolate functional module scripted into an inflated graph of about 900 nodes large.
With all the If blocks and branches.
I don't know if it makes sense to export such a large model for an op though, and how much this could be optimized.
Do you have a measure of how this impacts the scripting perf?

@eellison
Copy link
Contributor Author

eellison commented Mar 11, 2020

@eellison
I see now. So then even setting the constants in this case just cuts the graph by a few lines.
I can see all the interpolate functional module scripted into an inflated graph of about 900 nodes large.
With all the If blocks and branches.
Do you have a measure of how this impacts the scripting perf?

It's not just a question of perf, but of maintainability. I've talked with other members of TorchScript team and vetted this change as it pertains to scripting.

@neginraoof
Copy link
Contributor

@eellison Thanks.
cc @houseroad
I'm mainly concerned about how this will impact the efficiency of the onnx model after this change.
The export for scripted interpolate op will be highly inefficient.
About the optimizations that you mentioned, do you know if these optimizations are going to be part of the torch ir graph? And will those be visible to onnx?

@eellison
Copy link
Contributor Author

@neginraoof it doesn't affect interrpolate tracing, which is vast majority of ONNX usage. I also said at the time interpolate script onnx export was implemented that we were going to remove "aten::__interpolate", and suggested that it be a requisite. If there is serious concerns about torch.nn.interpolate, someone can try to move the op to be a natively declared aten builtin, but the current duplication is not sustainable.

@neginraoof
Copy link
Contributor

@eellison Do you know that is the future optimization plan? Are we going to have codegen for aten ops? Or is there another component optimizing the torch ir graph? I'm trying to understand whether this could be used to improve the ONNX graph as well.

@eellison
Copy link
Contributor Author

@eellison Do you know that is the future optimization plan? Are we going to have codegen for aten ops? Or is there another component optimizing the torch ir graph? I'm trying to understand whether this could be used to improve the ONNX graph as well.

As it stands the plan is just to more aggressively optimize python idioms. I have a WIP pr that gets interpolate down to a single executed op. If someone thought it was worth the effort they could move "interpolate" to be a native aten function, but no one up to this point has thought it worth the cost/gain tradeoff.

@eellison
Copy link
Contributor Author

For anyone investigating breakage, I opened up disabled test here: #34658

@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in 514cba0.

.check("aten::max") \
.check("aten::min") \
.check("aten::mean") \
.check("aten::__interpolate") \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison we need the tests as well, actually I think this might have broken these tests, will sync with you tomorrow.

@houseroad houseroad mentioned this pull request Mar 13, 2020
facebook-github-bot pushed a commit that referenced this pull request Apr 15, 2020
…erpolate (#35744)

Summary:
Since aten;:__interpolate is removed in #34514, we need a pass replace interpolate function with aten::__interpolate for ONNX export.
Pull Request resolved: #35744

Reviewed By: hl475

Differential Revision: D20907041

Pulled By: houseroad

fbshipit-source-id: f2d2cdfec47389245c50f538267124eedf682adf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants