Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 3D+ input support for fp8 rowwise GEMM #2845

Closed
wants to merge 1 commit into from

Conversation

jianyuh
Copy link
Member

@jianyuh jianyuh commented Jul 15, 2024

Summary:
The input activation can be 3D+, with first dimension as batch dimension.

    1. If the input tensor is {M, K}, the output tensor is {M, N}.
    1. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.

This Diff adds the support, to match nn.Linear function / matmul supports.

Differential Revision: D59671644

Copy link

netlify bot commented Jul 15, 2024

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 7f8ea83
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/6696271974f6cd0008e92c91
😎 Deploy Preview https://deploy-preview-2845--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59671644

jianyuh added a commit to jianyuh/FBGEMM that referenced this pull request Jul 15, 2024
Summary:
Pull Request resolved: pytorch#2845

The input activation can be 3D+, with first dimension as batch dimension.

- 1. If the input tensor is {M, K}, the output tensor is {M, N}.
- 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.

This Diff adds the support, to match nn.Linear function / matmul supports.

Differential Revision: D59671644
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59671644

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59671644

jianyuh added a commit to jianyuh/FBGEMM that referenced this pull request Jul 15, 2024
Summary:
Pull Request resolved: pytorch#2845

The input activation can be 3D+, with first dimension as batch dimension.

- 1. If the input tensor is {M, K}, the output tensor is {M, N}.
- 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.

This Diff adds the support, to match nn.Linear function / matmul supports.

Differential Revision: D59671644
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59671644

jianyuh added a commit to jianyuh/FBGEMM that referenced this pull request Jul 16, 2024
Summary:
Pull Request resolved: pytorch#2845

The input activation can be 3D+, with first dimension as batch dimension.

- 1. If the input tensor is {M, K}, the output tensor is {M, N}.
- 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.

This Diff adds the support, to match nn.Linear function / matmul supports.

Reviewed By: sijiac, jiawenliu64, jwfromm

Differential Revision: D59671644
Summary:
Pull Request resolved: pytorch#2845

The input activation can be 3D+, with first dimension as batch dimension.

- 1. If the input tensor is {M, K}, the output tensor is {M, N}.
- 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.

This Diff adds the support, to match nn.Linear function / matmul supports.

Additionally, this Diff adds a FP8 architecture check:

```

functools.lru_cache
def check_fp8_arch() -> None:
    arch_major = torch.cuda.get_device_properties(torch.cuda.current_device()).major
    arch_minor = torch.cuda.get_device_properties(torch.cuda.current_device()).minor
    if torch.version.cuda and arch_major < 9:
        raise Exception("FP8 can only work on Nvidia H100+ GPUs with sm90+ support!")
    if torch.version.hip and (arch_major < 9 or arch_minor < 4):
        raise Exception(
            "FP8 can only work on Nvidia MI300x+ GPUs with gfx942+ support!"
        )
```

Reviewed By: sijiac, jiawenliu64, jwfromm

Differential Revision: D59671644
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59671644

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 903e928.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants