Skip to content

Expose matrix multiplication operations with conjugate transposes of the inputs #51750

@ngimel

Description

@ngimel

🚀 Feature

BLAS libraries that pytorch uses have an option of doing implicit conjugate transpose of an argument ('h' option), but pytorch does not have bindings to those, and does not expose a way to call them. This option can be useful to speed up backward pass through matrix multiplication ops, because we could potentially avoid materializing conjugate.
It's not fully clear how to best expose this to user, or whether it should be exposed at all, as opposed to being called in the backward when necessary.
Pytorch sets t argument to blas calls depending on the strides of the input matrices, h cannot be set independently, it's possible to set it only if physical memory layout corresponds to transposed matrix, so UX here is not very clear. This issue is to discuss what exposure we want.

Related: we have dot and vdot functions, where vdot does an implicit conjugate of an argument.

Also related: #45063, where some comments discuss the possibility of adding conjugate views.

cc @ezyang @anjali411 @dylanbespalko @mruberry @jianyuh @nikitaved @pearu @heitorschueroff @walterddr @IvanYashchuk

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: complexRelated to complex number support in PyTorchmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions