Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1603

Merged
merged 5 commits into from
Jan 22, 2023
Merged

[FX] Changes done internally at Facebook #1603

merged 5 commits into from
Jan 22, 2023

Conversation

frank-wei
Copy link
Contributor

2345bec2d694df51ddfa583b9a51fcec0b0ff656 Wei Wei wwei6@meta.com [fx2trt] add pass elapsed time in log a2b867b1dfa9bbc4c93cfefc44535d8e6db699b2 Shirong Wu shirong@meta.com Add missing kwarg 383f43da4a3ece0457637a945c762b9eb5ac9e73 Janet Yang qxy11@meta.com [fx][acc_tracer] Fix "list indices must be integers or slices, not Node" issue during lowering 7dcb956d4297787614b9c2f31cbe7c7e9f4720db Andrew Or andrewor@meta.com [Quant][fx][bc-breaking] Add simpler BackendConfig pattern format 140b3f0c9474ba7d677dc121f739395cf2d87db3 Huamin Li huaminli@meta.com Update customized_fuse_pass in lower_setting.py 5cef3146e886f315d881af64fd4117054acebe8b Shirong Wu shirong@meta.com Add TRT aten converter 6db5870f3c9b2e8c8c83854f7276979a027f78fd Shirong Wu shirong@meta.com Enable skipped test 5f53e0776add2bd4ea810a2aa8f16bd426e38895 Shirong Wu shirong@meta.com Enable explicit batch dim 8c250cafde9bc7bd18b59c91bb038673f7de4a10 Wei Wei wwei6@meta.com [aten2trt] add PT2.0 tracer

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

2345bec2d694df51ddfa583b9a51fcec0b0ff656 Wei Wei <wwei6@meta.com> [fx2trt] add pass elapsed time in log
a2b867b1dfa9bbc4c93cfefc44535d8e6db699b2 Shirong Wu <shirong@meta.com> Add missing kwarg
383f43da4a3ece0457637a945c762b9eb5ac9e73 Janet Yang <qxy11@meta.com> [fx][acc_tracer] Fix "list indices must be integers or slices, not Node" issue during lowering
7dcb956d4297787614b9c2f31cbe7c7e9f4720db Andrew Or <andrewor@meta.com> [Quant][fx][bc-breaking] Add simpler BackendConfig pattern format
140b3f0c9474ba7d677dc121f739395cf2d87db3 Huamin Li <huaminli@meta.com> Update customized_fuse_pass in lower_setting.py
5cef3146e886f315d881af64fd4117054acebe8b Shirong Wu <shirong@meta.com> Add TRT aten converter
6db5870f3c9b2e8c8c83854f7276979a027f78fd Shirong Wu <shirong@meta.com> Enable skipped test
5f53e0776add2bd4ea810a2aa8f16bd426e38895 Shirong Wu <shirong@meta.com> Enable explicit batch dim
8c250cafde9bc7bd18b59c91bb038673f7de4a10 Wei Wei <wwei6@meta.com> [aten2trt] add PT2.0 tracer
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@narendasan
Copy link
Collaborator

Any idea what AttributeError: '_OpNamespace' '_caffe2' object has no attribute 'RoIAlign' is in reference to?

@narendasan
Copy link
Collaborator

Also don't consider this gating for this PR getting merged but @apbose can you make sure to take a look at these so you know how the FX converter library is changing?

@narendasan narendasan requested review from apbose and removed request for narendasan January 22, 2023 00:46
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@yinghai yinghai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the ROI custom op error seems legit.

@frank-wei
Copy link
Contributor Author

frank-wei commented Jan 22, 2023

This operator seems not existed in external version. I need to comment it out.

>>> torch.__version__
'2.0.0.dev20230121+cu116'
>>> torch.ops._caffe2.RoIAlign
Traceback (most recent call last):
  File "/data/users/wwei6/miniconda3/envs/ait2/lib/python3.8/site-packages/torch/_ops.py", line 562, in __getattr__
    op, overload_names = torch._C._jit_get_operation(qualified_op_name)
RuntimeError: No such operator _caffe2::RoIAlign

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py	2023-01-22 19:34:32.383167 +0000
+++ py/torch_tensorrt/fx/test/converters/aten_op/test_cat_aten.py	2023-01-22 19:34:50.014348 +0000
@@ -7,11 +7,11 @@

class TestCatConverter(DispatchTestCase):
    @parameterized.expand(
        [
            ("pos", 1),
-            #("neg", -2), #Dynamo tracer issue
+            # ("neg", -2), #Dynamo tracer issue
        ]
    )
    def test_cat(self, _, dim):
        class Cat(nn.Module):
            def forward(self, x, y, z):
@@ -25,11 +25,11 @@
        )

    @parameterized.expand(
        [
            ("pos", 1),
-            #("neg", -2),  #Dynamo tracer issue
+            # ("neg", -2),  #Dynamo tracer issue
        ]
    )
    def test_cat_dynamic_shape(self, _, dim):
        class Cat(nn.Module):
            def forward(self, x, y):

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@yinghai yinghai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right

@frank-wei frank-wei merged commit 2e21ce6 into main Jan 22, 2023
@frank-wei frank-wei deleted the fb-sync-wwei6 branch January 22, 2023 22:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants