Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 193 additions & 15 deletions crates/core_arch/src/x86_64/amx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0, 1, 2)]
#[target_feature(enable = "amx-fp8")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -271,7 +271,7 @@ pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0, 1, 2)]
#[target_feature(enable = "amx-fp8")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -290,7 +290,7 @@ pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0, 1, 2)]
#[target_feature(enable = "amx-fp8")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -309,7 +309,7 @@ pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0, 1, 2)]
#[target_feature(enable = "amx-fp8")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -329,7 +329,7 @@ pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-movrs")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tileloaddrs, DST = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -349,7 +349,7 @@ pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-movrs")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tileloaddrst1, DST = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -372,7 +372,7 @@ pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usiz
#[rustc_legacy_const_generics(0, 1, 2)]
#[target_feature(enable = "amx-tf32")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -389,7 +389,7 @@ pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowd2ps, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -404,7 +404,7 @@ pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -421,7 +421,7 @@ pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2phh, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -437,7 +437,7 @@ pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -454,7 +454,7 @@ pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2phl, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -470,7 +470,7 @@ pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -480,12 +480,78 @@ pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2bf16h, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16h<const TILE: i32>(row: u32) -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
tcvtrowps2bf16h(TILE as i8, row).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16hi<const TILE: i32, const ROW: i32>() -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2bf16l, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16l<const TILE: i32>(row: u32) -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
tcvtrowps2bf16l(TILE as i8, row).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, not(target_vendor = "apple")),
assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16li<const TILE: i32, const ROW: i32>() -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
}

/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tilemovrow, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand All @@ -499,7 +565,7 @@ pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
all(test, not(target_vendor = "apple")),
assert_instr(tilemovrow, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
Expand Down Expand Up @@ -567,6 +633,14 @@ unsafe extern "C" {
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phli"]
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16h"]
fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16hi"]
fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16l"]
fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16li"]
fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tilemovrow"]
fn tilemovrow(tile: i8, row: u32) -> i32x16;
#[link_name = "llvm.x86.tilemovrowi"]
Expand Down Expand Up @@ -1276,6 +1350,110 @@ mod tests {
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16h() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = _tile_cvtrowps2bf16h::<0>(i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
0
} else {
_mm_cvtness_sbh(i as _).to_bits()
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16hi() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
0
} else {
_mm_cvtness_sbh(i as _).to_bits()
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16l() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = _tile_cvtrowps2bf16l::<0>(i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
_mm_cvtness_sbh(i as _).to_bits()
} else {
0
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16li() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
_mm_cvtness_sbh(i as _).to_bits()
} else {
0
})
);
}
}
}

#[simd_test(enable = "amx-tf32")]
fn test_tile_mmultf32ps() {
unsafe {
Expand Down