Skip to content

Commit 4cec11b

Browse files
committed
Add alignment parameter to simd_masked_{load,store}
1 parent 99ca0ae commit 4cec11b

File tree

17 files changed

+176
-80
lines changed

17 files changed

+176
-80
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,10 +1829,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18291829
}
18301830

18311831
if name == sym::simd_masked_load {
1832-
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
1832+
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>, alignment: u32) -> <N x T>
18331833
// * N: number of elements in the input vectors
18341834
// * T: type of the element to load
18351835
// * M: any integer width is supported, will be truncated to i1
1836+
// * `alignment`: must be a power of two constant
18361837
// Loads contiguous elements from memory behind `pointer`, but only for
18371838
// those lanes whose `mask` bit is enabled.
18381839
// The memory addresses corresponding to the “off” lanes are not accessed.
@@ -1844,10 +1845,18 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18441845
// The second argument must be a pointer matching the element type
18451846
let pointer_ty = args[1].layout.ty;
18461847

1847-
// The last argument is a passthrough vector providing values for disabled lanes
1848+
// The third argument is a passthrough vector providing values for disabled lanes
18481849
let values_ty = args[2].layout.ty;
18491850
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
18501851

1852+
// The fourth argument is the alignment, must be a power of two integer constant
1853+
let alignment = bx
1854+
.const_to_opt_u128(args[3].immediate(), false)
1855+
.expect("typeck should have ensure that this is a const");
1856+
if !alignment.is_power_of_two() {
1857+
return_error!(InvalidMonomorphization::AlignmentNotPowerOfTwo { span, name });
1858+
}
1859+
18511860
require_simd!(ret_ty, SimdReturn);
18521861

18531862
// Of the same length:
@@ -1893,7 +1902,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18931902
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
18941903

18951904
// Alignment of T, must be a constant integer value:
1896-
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
1905+
let alignment = bx.const_i32(alignment as i32);
18971906

18981907
let llvm_pointer = bx.type_ptr();
18991908

@@ -1908,10 +1917,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19081917
}
19091918

19101919
if name == sym::simd_masked_store {
1911-
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
1920+
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>, alignment: u32) -> ()
19121921
// * N: number of elements in the input vectors
19131922
// * T: type of the element to load
19141923
// * M: any integer width is supported, will be truncated to i1
1924+
// * `alignment`: must be a power of two constant
19151925
// Stores contiguous elements to memory behind `pointer`, but only for
19161926
// those lanes whose `mask` bit is enabled.
19171927
// The memory addresses corresponding to the “off” lanes are not accessed.
@@ -1923,10 +1933,18 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19231933
// The second argument must be a pointer matching the element type
19241934
let pointer_ty = args[1].layout.ty;
19251935

1926-
// The last argument specifies the values to store to memory
1936+
// The third argument specifies the values to store to memory
19271937
let values_ty = args[2].layout.ty;
19281938
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
19291939

1940+
// The fourth argument is the alignment, must be a power of two integer constant
1941+
let alignment = bx
1942+
.const_to_opt_u128(args[3].immediate(), false)
1943+
.expect("typeck should have ensure that this is a const");
1944+
if !alignment.is_power_of_two() {
1945+
return_error!(InvalidMonomorphization::AlignmentNotPowerOfTwo { span, name });
1946+
}
1947+
19301948
// Of the same length:
19311949
require!(
19321950
values_len == mask_len,
@@ -1965,8 +1983,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19651983

19661984
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
19671985

1968-
// Alignment of T, must be a constant integer value:
1969-
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
1986+
let alignment = bx.const_i32(alignment as i32);
19701987

19711988
let llvm_pointer = bx.type_ptr();
19721989

compiler/rustc_codegen_ssa/messages.ftl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ codegen_ssa_invalid_monomorphization_unsupported_symbol = invalid monomorphizati
169169
170170
codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size = invalid monomorphization of `{$name}` intrinsic: unsupported {$symbol} from `{$in_ty}` with element `{$in_elem}` of size `{$size}` to `{$ret_ty}`
171171
172+
codegen_ssa_invalid_monomorphization_non_power_of_two_alignment = invalid monomorphization of `{$name}` intrinsic: `alignment` is not a power of two
173+
172174
codegen_ssa_invalid_windows_subsystem = invalid windows subsystem `{$subsystem}`, only `windows` and `console` are allowed
173175
174176
codegen_ssa_ld64_unimplemented_modifier = `as-needed` modifier not implemented yet for ld64

compiler/rustc_codegen_ssa/src/errors.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,13 @@ pub enum InvalidMonomorphization<'tcx> {
11001100
expected_element: Ty<'tcx>,
11011101
vector_type: Ty<'tcx>,
11021102
},
1103+
1104+
#[diag(codegen_ssa_invalid_monomorphization_non_power_of_two_alignment, code = E0511)]
1105+
AlignmentNotPowerOfTwo {
1106+
#[primary_span]
1107+
span: Span,
1108+
name: Symbol,
1109+
},
11031110
}
11041111

11051112
pub enum ExpectedPointerMutability {

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,12 @@ pub(crate) fn check_intrinsic_type(
695695
(1, 0, vec![param(0), param(0), param(0)], param(0))
696696
}
697697
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
698-
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
699-
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
698+
sym::simd_masked_load => {
699+
(3, 0, vec![param(0), param(1), param(2), tcx.types.u32], param(2))
700+
}
701+
sym::simd_masked_store => {
702+
(3, 0, vec![param(0), param(1), param(2), tcx.types.u32], tcx.types.unit)
703+
}
700704
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
701705
sym::simd_insert | sym::simd_insert_dyn => {
702706
(2, 0, vec![param(0), tcx.types.u32, param(1)], param(0))

library/core/src/intrinsics/simd.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,14 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
392392
/// `val`.
393393
///
394394
/// # Safety
395-
/// Unmasked values in `T` must be readable as if by `<ptr>::read` (e.g. aligned to the element
396-
/// type).
395+
/// `alignment` must be a **const** power of two, and specifies the alignment of `ptr`.
396+
///
397+
/// Unmasked values in `T` must be readable as if by `<ptr>::read_unaligned` (aligned to `alignment`).
397398
///
398399
/// `mask` must only contain `0` or `!0` values.
399400
#[rustc_intrinsic]
400401
#[rustc_nounwind]
401-
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
402+
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T, alignment: u32) -> T;
402403

403404
/// Writes to a vector of pointers.
404405
///
@@ -414,13 +415,14 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
414415
/// Otherwise if the corresponding value in `mask` is `0`, do nothing.
415416
///
416417
/// # Safety
417-
/// Unmasked values in `T` must be writeable as if by `<ptr>::write` (e.g. aligned to the element
418-
/// type).
418+
/// `alignment` must be a **const** power of two, and specifies the alignment of `ptr`.
419+
///
420+
/// Unmasked values in `T` must be writeable as if by `<ptr>::write_unaligned` (aligned to `alignment`).
419421
///
420422
/// `mask` must only contain `0` or `!0` values.
421423
#[rustc_intrinsic]
422424
#[rustc_nounwind]
423-
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T);
425+
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T, alignment: u32);
424426

425427
/// Adds two simd vectors elementwise, with saturation.
426428
///

library/portable-simd/crates/core_simd/src/vector.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,14 @@ where
474474
or: Self,
475475
) -> Self {
476476
// SAFETY: The safety of reading elements through `ptr` is ensured by the caller.
477-
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
477+
unsafe {
478+
core::intrinsics::simd::simd_masked_load(
479+
enable.to_int(),
480+
ptr,
481+
or,
482+
const { core::mem::align_of::<T>() as u32 },
483+
)
484+
}
478485
}
479486

480487
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
@@ -723,7 +730,14 @@ where
723730
#[inline]
724731
pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
725732
// SAFETY: The safety of writing elements through `ptr` is ensured by the caller.
726-
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
733+
unsafe {
734+
core::intrinsics::simd::simd_masked_store(
735+
enable.to_int(),
736+
ptr,
737+
self,
738+
const { core::mem::align_of::<T>() as u32 },
739+
)
740+
}
727741
}
728742

729743
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.

src/tools/miri/src/intrinsics/simd.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
733733
}
734734
}
735735
"masked_load" => {
736-
let [mask, ptr, default] = check_intrinsic_arg_count(args)?;
736+
let [mask, ptr, default, _alignment] = check_intrinsic_arg_count(args)?;
737737
let (mask, mask_len) = this.project_to_simd(mask)?;
738738
let ptr = this.read_pointer(ptr)?;
739739
let (default, default_len) = this.project_to_simd(default)?;
@@ -759,7 +759,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
759759
}
760760
}
761761
"masked_store" => {
762-
let [mask, ptr, vals] = check_intrinsic_arg_count(args)?;
762+
let [mask, ptr, vals, _alignment] = check_intrinsic_arg_count(args)?;
763763
let (mask, mask_len) = this.project_to_simd(mask)?;
764764
let ptr = this.read_pointer(ptr)?;
765765
let (vals, vals_len) = this.project_to_simd(vals)?;

src/tools/miri/tests/pass/intrinsics/portable-simd.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -684,21 +684,21 @@ fn simd_masked_loadstore() {
684684
let buf = [3i32; 3];
685685
let default = i32x4::splat(0);
686686
let mask = i32x4::from_array([!0, !0, !0, 0]);
687-
let vals = unsafe { intrinsics::simd_masked_load(mask, buf.as_ptr(), default) };
687+
let vals = unsafe { intrinsics::simd_masked_load(mask, buf.as_ptr(), default, 4) };
688688
assert_eq!(vals, i32x4::from_array([3, 3, 3, 0]));
689689
// Also read in a way that the *first* element is OOB.
690690
let mask2 = i32x4::from_array([0, !0, !0, !0]);
691691
let vals =
692-
unsafe { intrinsics::simd_masked_load(mask2, buf.as_ptr().wrapping_sub(1), default) };
692+
unsafe { intrinsics::simd_masked_load(mask2, buf.as_ptr().wrapping_sub(1), default, 4) };
693693
assert_eq!(vals, i32x4::from_array([0, 3, 3, 3]));
694694

695695
// The buffer is deliberarely too short, so writing the last element would be UB.
696696
let mut buf = [42i32; 3];
697697
let vals = i32x4::from_array([1, 2, 3, 4]);
698-
unsafe { intrinsics::simd_masked_store(mask, buf.as_mut_ptr(), vals) };
698+
unsafe { intrinsics::simd_masked_store(mask, buf.as_mut_ptr(), vals, 4) };
699699
assert_eq!(buf, [1, 2, 3]);
700700
// Also write in a way that the *first* element is OOB.
701-
unsafe { intrinsics::simd_masked_store(mask2, buf.as_mut_ptr().wrapping_sub(1), vals) };
701+
unsafe { intrinsics::simd_masked_store(mask2, buf.as_mut_ptr().wrapping_sub(1), vals, 4) };
702702
assert_eq!(buf, [2, 3, 4]);
703703
}
704704

tests/assembly-llvm/simd-intrinsic-mask-load.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct f64x4([f64; 4]);
3535
pub struct m64x4([i64; 4]);
3636

3737
#[rustc_intrinsic]
38-
unsafe fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
38+
unsafe fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T, alignment: u32) -> T;
3939

4040
// CHECK-LABEL: load_i8x16
4141
#[no_mangle]
@@ -56,7 +56,7 @@ pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
5656
// x86-avx512-NOT: vpsllw
5757
// x86-avx512: vpmovb2m k1, xmm0
5858
// x86-avx512-NEXT: vmovdqu8 xmm0 {k1} {z}, xmmword ptr [rdi]
59-
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
59+
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 1)
6060
}
6161

6262
// CHECK-LABEL: load_f32x8
@@ -68,7 +68,12 @@ pub unsafe extern "C" fn load_f32x8(mask: m32x8, pointer: *const f32) -> f32x8 {
6868
// x86-avx512-NOT: vpslld
6969
// x86-avx512: vpmovd2m k1, ymm0
7070
// x86-avx512-NEXT: vmovups ymm0 {k1} {z}, ymmword ptr [rdi]
71-
simd_masked_load(mask, pointer, f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]))
71+
simd_masked_load(
72+
mask,
73+
pointer,
74+
f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]),
75+
4,
76+
)
7277
}
7378

7479
// CHECK-LABEL: load_f64x4
@@ -79,5 +84,5 @@ pub unsafe extern "C" fn load_f64x4(mask: m64x4, pointer: *const f64) -> f64x4 {
7984
//
8085
// x86-avx512-NOT: vpsllq
8186
// x86-avx512: vpmovq2m k1, ymm0
82-
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]))
87+
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]), 8)
8388
}

tests/assembly-llvm/simd-intrinsic-mask-store.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub struct f64x4([f64; 4]);
3535
pub struct m64x4([i64; 4]);
3636

3737
#[rustc_intrinsic]
38-
unsafe fn simd_masked_store<M, P, T>(mask: M, pointer: P, values: T);
38+
unsafe fn simd_masked_store<M, P, T>(mask: M, pointer: P, values: T, alignment: u32);
3939

4040
// CHECK-LABEL: store_i8x16
4141
#[no_mangle]
@@ -54,7 +54,7 @@ pub unsafe extern "C" fn store_i8x16(mask: m8x16, pointer: *mut i8, value: i8x16
5454
// x86-avx512-NOT: vpsllw
5555
// x86-avx512: vpmovb2m k1, xmm0
5656
// x86-avx512-NEXT: vmovdqu8 xmmword ptr [rdi] {k1}, xmm1
57-
simd_masked_store(mask, pointer, value)
57+
simd_masked_store(mask, pointer, value, 1)
5858
}
5959

6060
// CHECK-LABEL: store_f32x8
@@ -66,7 +66,7 @@ pub unsafe extern "C" fn store_f32x8(mask: m32x8, pointer: *mut f32, value: f32x
6666
// x86-avx512-NOT: vpslld
6767
// x86-avx512: vpmovd2m k1, ymm0
6868
// x86-avx512-NEXT: vmovups ymmword ptr [rdi] {k1}, ymm1
69-
simd_masked_store(mask, pointer, value)
69+
simd_masked_store(mask, pointer, value, 4)
7070
}
7171

7272
// CHECK-LABEL: store_f64x4
@@ -78,5 +78,5 @@ pub unsafe extern "C" fn store_f64x4(mask: m64x4, pointer: *mut f64, value: f64x
7878
// x86-avx512-NOT: vpsllq
7979
// x86-avx512: vpmovq2m k1, ymm0
8080
// x86-avx512-NEXT: vmovupd ymmword ptr [rdi] {k1}, ymm1
81-
simd_masked_store(mask, pointer, value)
81+
simd_masked_store(mask, pointer, value, 8)
8282
}

0 commit comments

Comments
 (0)