Skip to content

Conversation

datagero
Copy link
Contributor

@datagero datagero commented Jul 10, 2024

Fixes #130301

Adjusted the call_str method to handle str conversion for UserDefinedObjectVariable.
Attempt in a clean branch for unrelated test errors.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link

pytorch-bot bot commented Jul 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130506

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 318af0f with merge base d039b14 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@datagero datagero marked this pull request as draft July 11, 2024 06:14
@datagero datagero marked this pull request as ready for review July 11, 2024 07:52
@datagero
Copy link
Contributor Author

datagero commented Jul 11, 2024

@oulgen @anijain2305 second attempt on this PR (previous #130320), it does seem that the changes affect the tests for some setups, which I cannot replicate locally.

PYTORCH_TEST_WITH_DYNAMO=1 python test/test_nn.py -k TestNN.test_ParameterList_meta

In other setups (e.g., screenshot passes for dynamo, 3, 3, linux.2xlarge) I was not able to find this specific test, not sure why.

Due to the proposed changes, there be trouble retireving 'value' and/or '_size' attributes from ParameterList UserDefinedObjectVariable. For '_size' not entirely sure why it would have difficulty since the attribute is defined in the class init function. Current commit is hoping to get some additional logs to understand better.

I'm trying to run some tests on the CI but cannot trigger the tests - is there a way I can achieve this?

image

@datagero
Copy link
Contributor Author

Hi @oulgen @anijain2305 - I think I need an approval to proceed with the CI tests?

@oulgen
Copy link
Contributor

oulgen commented Jul 12, 2024

I clicked on the run tests button, each time you publish a new version, you need someone to run tests for you

@bdhirsh bdhirsh requested a review from oulgen July 13, 2024 01:50
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 13, 2024
@datagero datagero requested review from a team, kulinseth and malfet as code owners July 16, 2024 10:48
@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: releng release notes category labels Jul 16, 2024
@datagero
Copy link
Contributor Author

@oulgen finally found the error, from test/test_nn.py], what I understand a str() function on the test ParameterList may not be not collecting/starting the _size parameter when we need it, by changing this to repr() the solution works.

@datagero
Copy link
Contributor Author

hi @oulgen are you able to help triggering the tests to confirm changes? 🙏

@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch from f3d15d2 to c07d066 Compare July 19, 2024 07:35
@datagero
Copy link
Contributor Author

@oulgen In test_saving_variable_to_disk defining tensors inside the saved_tensors_hooks context caused unintended side effects during assertions, specifically when calling _compare_regular_values_close. This was fixed by moving the tensor definitions (a and y) outside the saved_tensors_hooks context.

Test pass locally now:

image

Comment on lines 9608 to 9609
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldnt need to do this? @bdhirsh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. Indeed I do not fully understand the effects.
@bdhirsh This is what I was able to infer:
In test_saving_variable_to_disk defining tensors inside the saved_tensors_hooks context caused unintended side effects during assertions, specifically when calling _compare_regular_values_close. This was fixed by moving the tensor definitions (a and y) outside the saved_tensors_hooks context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh would appreciate any guide/comment on this one!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pulled your PR locally and this test fails for me with the change. taking a look

@datagero datagero requested a review from oulgen July 19, 2024 07:37
@ashwanirathee
Copy link
Contributor

Is there more work needed on this or close to done @datagero?

@datagero
Copy link
Contributor Author

@ashwani-rathee waiting on some feedback on a test/test_autograd.py modification; hopefully from @bdhirsh, unless there's someone else that can provide feedback? @oulgen

Comment on lines 1024 to 1025
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check if the arg.value has an overridden __str__ method. If yes, then either graph break right now, or inline the __str__ function.

Copy link
Contributor Author

@datagero datagero Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for the test/test_autograd.py discussed above, the call to str below is causing the issue for me. @bdhirsh were you able to have a look?

            def pack(x):
                name = os.path.join(tmp_dir, str(uuid.uuid4()))
                torch.save(x, name)
                return name

type(obj).__str__ is overriden to <function UUID.__str__ at 0x1026a8040>

if on the above we avoid str() and call uuid.uuid4().hex or uuid.uuid4().__str__(), then the test pass for me. Otherwise (even if changing torch/_dynamo/variables/builtin.py 1025 to be in-line value=arg.value.__str__()), it will throw assertion error:

AssertionError: Tensor-likes are not close!

Mismatched elements: 5 / 5 (100.0%)
Greatest absolute difference: 2.0 at index (0,) (up to 1e-05 allowed)
Greatest relative difference: inf at index (0,) (up to 1.3e-06 allowed)

Copy link
Contributor

@anijain2305 anijain2305 Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do something like

elif isinstance(arg, (variables.UserDefinedObjectVariable)) and type(arg.value).__str__ is object.__str__:

This will incrementally add support for user defined objects relying on the object str method. And for rest, it will graph break.

If you want to support custom __str__ method then you might need to do more work.

  1. Inline the __str__ method. Search for tx.inline_user_function_return to get an idea on how we inline.
  2. You will have to account for __repr__ functions when __str__ is absent. See the next comment.

Copy link
Contributor

@anijain2305 anijain2305 Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason its failing for you because str(x) does much more than just calling x.__str__()

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anijain2305 will have a look today, seems good pointers! The requirement of the ticket is to deal with custom __str__ methods so I'lll look at tx.inline_user_function_return and your notes.

From the ticket:

import torch

class C:
    x = 1

    def __str__(self):
        return "ok"

def foo(x):
    a = C()
    return x, str(a)

print(torch.compile(foo, fullgraph=True)(torch.ones(4)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorcery! That helped a lot

@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch from c07d066 to ed976a8 Compare July 27, 2024 04:58
@datagero
Copy link
Contributor Author

datagero commented Jul 27, 2024

  • Original Ticket: Handles custom objects with custom __str__ functions.
    -TestAutograd.test_saving_variable_to_disk: Requires inline str handling for custom objects due to possible graph breaks. Use tx.inline_user_function_return
    -TestNN.test_ParameterList_meta: Catches issues with ParameterList, which throws an AttributeError with call_str. We return empty (i.e., same behaviour as before this implementation) as this is possible graph break.

For reviewer (@anijain2305 , @oulgen , @bdhirsh ):

  • I'd appreciate if we can trigger the CI tests to check fixes
  • Happy to adjust for any feedback on style / or the comment descriptions for "graph break" need revision for technical accuracy

@datagero datagero requested a review from anijain2305 July 27, 2024 05:10
@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch from ed976a8 to e4273df Compare July 27, 2024 09:35
@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch from bee8e48 to 275f81c Compare July 27, 2024 11:43
return

# Inline the user function
user_func_variable = variables.UserFunctionVariable(bound_method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of this line. L1064 already does it.

try:
# Only supports certain function types
user_func_variable = variables.UserFunctionVariable(bound_method)
except AssertionError as e:
Copy link
Contributor

@anijain2305 anijain2305 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do an if condition instead of relying on Exception. We should probably just call VariableBuilder (cc @mlazos, I think we need to add a function in builder.py for identifying C functions and graph break on them). But I think this is not required in this PR.

Copy link
Contributor Author

@datagero datagero Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I did try/except in case UserFunctionVariable gets expanded to other function types. As per your note, I did not adjusted this on last push.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes.

@anijain2305 anijain2305 changed the title [Dynamo] Fix - adding Handler for UserDefinedObjectVariable [Dynamo] Fix - str handler for UserDefinedObjectVariable Jul 29, 2024
@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch 2 times, most recently from 6ba58bb to 5160181 Compare July 29, 2024 21:59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen done

@datagero datagero force-pushed the fix-dynamo-handler-str-userdefinedobjectvariable-2 branch from 5160181 to 318af0f Compare July 31, 2024 03:34
@anijain2305
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 31, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@anijain2305
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo open source release notes: releng release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch._dynamo.exc.Unsupported: builtin: str [<class 'torch._dynamo.variables.user_defined.UserDefinedObjectVariable'>] False

7 participants