Skip to content

torch.func.jvp does not support flash_attention calculation #165530

@pnotp

Description

@pnotp

🚀 The feature, motivation and pitch

When calling torch.func.jvp , the error is as follows:
[rank0]: NotImplementedError: Trying to use forward AD with _scaled_dot_product_flash_attention that does not support it because it has not been implemented yet.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions