Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add amx detection in cpuinfo #224

Open
mingfeima opened this issue Feb 20, 2024 · 7 comments
Open

[Feature Request] Add amx detection in cpuinfo #224

mingfeima opened this issue Feb 20, 2024 · 7 comments

Comments

@mingfeima
Copy link
Contributor

This proposal is to add amx detection in cpuinfo, amx refers to Intel® Advanced Matrix Extensions (Intel® AMX): https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html

something like:

cpuinfo_has_x86_amxbf16()
cpuinfo_has_x86_amxint8()

once this is settled, we can also switch the check from torch/aten to cpuinfo in convolution. Right now it is checked inside oneDNN via: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mkldnn/Utils.h#L99

@fbarchard
Copy link
Contributor

fbarchard commented Feb 29, 2024

Note that the following code, to check for linux kernel support, does not work in chromium sandbox?

#if (defined(__i386__) || defined(__x86_64__)) && defined(__linux__)
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18
/* SetTileDataUse() - Invoke syscall to set ARCH_SET_STATE_USE */
static bool SetTileDataUse(void) {
  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
    return false;
  }
  return true;
}
#endif

Is there another way to test for OS support?

@mingfeima
Copy link
Contributor Author

Note that the following code, to check for linux kernel support, does not work in chromium sandbox?

#if (defined(__i386__) || defined(__x86_64__)) && defined(__linux__)
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18
/* SetTileDataUse() - Invoke syscall to set ARCH_SET_STATE_USE */
static bool SetTileDataUse(void) {
  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
    return false;
  }
  return true;
}
#endif

Is there another way to test for OS support?

@malfet Does this repo has a CI to test the PRs?

@mingfeima
Copy link
Contributor Author

#231

@fbarchard
Copy link
Contributor

There are a couple more issues with amx detect, but I'm not sure they are in the scope of pytorch/cpuinfo

detect os support for amx on windows, linux etc.
detect/enable amx within chromium sandbox (can't do syscall)
enable amx. the syscall is currently required before using amx.

But I assume the reason amx is disabled by default is it has a high cost to thread switches, so it would be good to enable amx once we actually know we'll be using it
So it would be good to decouple AMX detection from enabling it.

unclear if it is intentional, but the amx intrinsics header is only on for 64 bit, not 32 bit x86.
But from what I can tell, the cpu and compiler (with assembly) can do amx in 32 bit OS's

I think these may be beyond the scope of cpuinfo and/or not entirely solvable? So this issue can be closed

@mingfeima
Copy link
Contributor Author

@fbarchard I think enabling AMX is not in the scope of pytorch/cpuinfo. The thing is we are using some detection method from onednn in pytorch which is not aligned with other functionalities (which are using cpuinfo). And we received requests to remove the onednn amx detection, and replace with cpuinfo amx detection.

As you have just mentioned, enabling AMX will require a syscall, e.g. syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) right now.

The detection and enabling should be decoupled as you said. Currently in pytorch, amx will be ONLY used inside onednn so you don't have to worry about the initialization. But we are trying to use amx intrinsics in some particular CPU kernels, one good example will be the int4packed_gemm which is used weight only quantization in LLM. And it will be something like:

if (cpuinfo_x86_has_amx_bf16()) {
   if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
      TORCH_CHECK(false, "Failed to enable AMX on CPU.");
       return;
   }
   int4packed_gemm_amx(...);
} else {
  // fallback to avx512f
  int4packed_gemm_avx512(...);
}

@bbharti
Copy link

bbharti commented Apr 11, 2024

@mingfeima I will suggest we add enabling AMX to pytorch/cpuinfo. If an app is not using OneDNN, it will be helpful to all those apps and user base. Tying it just to OneDNN is not right approach. May be I didnt understand your response.

@mingfeima
Copy link
Contributor Author

@mingfeima I will suggest we add enabling AMX to pytorch/cpuinfo. If an app is not using OneDNN, it will be helpful to all those apps and user base. Tying it just to OneDNN is not right approach. May be I didnt understand your response.

sure, that's just the our original plan:) we will replace all the platforms checks currently implementing through onednn to cpuinfo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants