Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap support for torch.tril, torch.triu #91403

Closed
2 tasks
Tracked by #1088
zou3519 opened this issue Dec 27, 2022 · 3 comments
Closed
2 tasks
Tracked by #1088

vmap support for torch.tril, torch.triu #91403

zou3519 opened this issue Dec 27, 2022 · 3 comments
Labels
good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@zou3519
Copy link
Contributor

zou3519 commented Dec 27, 2022

Tasks

Add vmap support for the following PyTorch operations. That is, each one needs a batching rule.

  • torch.tril
  • torch.triu

These already have batching rules implemented, but the batching rules are incorrect.

Expected behavior

Currently, when one does vmap over them using a 2D input, it succeeds:

import torch
from functorch import vmap
x = torch.randn(32, 3)
y = vmap(torch.triu)(x)

This should actually raise an exception. The mental model here is that vmap should be equivalent to performing a for-loop over the arguments. The following expression does raise an exception, so vmap should as well:

results = []
for xi in x:
  y = torch.triu(xi)
  results.append(y)

Read this!

See this note for more context https://github.com/pytorch/pytorch/blob/master/functorch/writing_batching_rules.md

If you're new to developing PyTorch and/or function transforms, we would recommend reading through https://github.com/pytorch/pytorch/wiki/Core-Frontend-Onboarding#unit-8-function-transforms-optional

For this issue, you'll likely need to implement something similar to the vmap support for torch.clone:

Testing

pytest test/functorch/test_vmap.py -v -k "test_op_has_batch_rule_tril" should pass (currently, it is an expected failure). Ditto for triu.

cc @Chillee @samdow @soumith

@zou3519 zou3519 added good first issue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: functorch Pertaining to torch.func or pytorch/functorch labels Dec 27, 2022
@benjamintli
Copy link

ill give this a shot

@sunmorgan
Copy link

ill take this in

@sunmorgan
Copy link

sunmorgan commented Jan 6, 2023

@zou3519 It seems like torch.tril and torch.triu uses the "VARIADIC_BDIMS_BATCH_RULE" definition from BatchRulesHelper.h shown below:

#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\

VMAP_SUPPORT(tril, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(tril)));

VMAP_SUPPORT(triu, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(triu)));

Then should the definition be removed? It doesn't look like the method is being used anywhere else

Also, when running tests on test/functorch/test_vmap.py, it outputs the module not found error "torch.testing._internal.autograd_function_db import autograd_function_db"

OS: Mac
Environment: Anaconda
Downloaded from source

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: functorch Pertaining to torch.func or pytorch/functorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants