Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype using const generic for simd_shuffle IDX array #115933

Merged
merged 2 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 49 additions & 1 deletion compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn report_simd_type_validation_error(
pub(super) fn codegen_simd_intrinsic_call<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
intrinsic: Symbol,
_args: GenericArgsRef<'tcx>,
generic_args: GenericArgsRef<'tcx>,
args: &[mir::Operand<'tcx>],
ret: CPlace<'tcx>,
target: BasicBlock,
Expand Down Expand Up @@ -117,6 +117,54 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
});
}

// simd_shuffle_generic<T, U, const I: &[u32]>(x: T, y: T) -> U
sym::simd_shuffle_generic => {
let [x, y] = args else {
bug!("wrong number of args for intrinsic {intrinsic}");
};
let x = codegen_operand(fx, x);
let y = codegen_operand(fx, y);

if !x.layout().ty.is_simd() {
report_simd_type_validation_error(fx, intrinsic, span, x.layout().ty);
return;
}

let idx = generic_args[2]
.expect_const()
.eval(fx.tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();

assert_eq!(x.layout(), y.layout());
let layout = x.layout();

let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);

assert_eq!(lane_ty, ret_lane_ty);
assert_eq!(idx.len() as u64, ret_lane_count);

let total_len = lane_count * 2;

let indexes =
idx.iter().map(|idx| idx.unwrap_leaf().try_to_u16().unwrap()).collect::<Vec<u16>>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have used u32. Will fix in https://github.com/bjorn3/rustc_codegen_cranelift


for &idx in &indexes {
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
}

for (out_idx, in_idx) in indexes.into_iter().enumerate() {
let in_lane = if u64::from(in_idx) < lane_count {
x.value_lane(fx, in_idx.into())
} else {
y.value_lane(fx, u64::from(in_idx) - lane_count)
};
let out_lane = ret.place_lane(fx, u64::try_from(out_idx).unwrap());
out_lane.write_cvalue(fx, in_lane);
}
}

// simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U
sym::simd_shuffle => {
let (x, y, idx) = match args {
Expand Down
57 changes: 55 additions & 2 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::{bug, span_bug};
use rustc_span::{sym, symbol::kw, Span, Symbol};
use rustc_target::abi::{self, Align, HasDataLayout, Primitive};
Expand Down Expand Up @@ -376,7 +376,9 @@ impl<'ll, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}

_ if name.as_str().starts_with("simd_") => {
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
match generic_simd_intrinsic(
self, name, callee_ty, fn_args, args, ret_ty, llret_ty, span,
) {
Ok(llval) => llval,
Err(()) => return,
}
Expand Down Expand Up @@ -911,6 +913,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,
callee_ty: Ty<'tcx>,
fn_args: GenericArgsRef<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
ret_ty: Ty<'tcx>,
llret_ty: &'ll Type,
Expand Down Expand Up @@ -1030,6 +1033,56 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
));
}

if name == sym::simd_shuffle_generic {
let idx = fn_args[2]
.expect_const()
.eval(tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();
let n = idx.len() as u64;

require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(ret_len, ret_elem) would be more consistent with the in/out variables, I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's copy paste from the simd_shuffle impl, but yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 8 occurences now (there were 7) of this same code. We should probably deduplicate something here and reconsider the names, but not in this PR ^^

require!(
out_len == n,
InvalidMonomorphization::ReturnLength { span, name, in_len: n, ret_ty, out_len }
);
require!(
in_elem == out_ty,
InvalidMonomorphization::ReturnElement { span, name, in_elem, in_ty, ret_ty, out_ty }
);

let total_len = in_len * 2;

let indices: Option<Vec<_>> = idx
.iter()
.enumerate()
.map(|(arg_idx, val)| {
let idx = val.unwrap_leaf().try_to_i32().unwrap();
if idx >= i32::try_from(total_len).unwrap() {
bx.sess().emit_err(InvalidMonomorphization::ShuffleIndexOutOfBounds {
span,
name,
arg_idx: arg_idx as u64,
total_len: total_len.into(),
});
None
} else {
Some(bx.const_i32(idx))
}
})
.collect();
let Some(indices) = indices else {
return Ok(bx.const_null(llret_ty));
};

return Ok(bx.shuffle_vector(
args[0].immediate(),
args[1].immediate(),
bx.const_vector(&indices),
));
}

if name == sym::simd_shuffle {
// Make sure this is actually an array, since typeck only checks the length-suffixed
// version of this intrinsic.
Expand Down
44 changes: 23 additions & 21 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fn equate_intrinsic_type<'tcx>(
it: &hir::ForeignItem<'_>,
n_tps: usize,
n_lts: usize,
n_cts: usize,
sig: ty::PolyFnSig<'tcx>,
) {
let (own_counts, span) = match &it.kind {
Expand Down Expand Up @@ -51,7 +52,7 @@ fn equate_intrinsic_type<'tcx>(

if gen_count_ok(own_counts.lifetimes, n_lts, "lifetime")
&& gen_count_ok(own_counts.types, n_tps, "type")
&& gen_count_ok(own_counts.consts, 0, "const")
&& gen_count_ok(own_counts.consts, n_cts, "const")
{
let fty = Ty::new_fn_ptr(tcx, sig);
let it_def_id = it.owner_id.def_id;
Expand Down Expand Up @@ -492,7 +493,7 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {
};
let sig = tcx.mk_fn_sig(inputs, output, false, unsafety, Abi::RustIntrinsic);
let sig = ty::Binder::bind_with_vars(sig, bound_vars);
equate_intrinsic_type(tcx, it, n_tps, n_lts, sig)
equate_intrinsic_type(tcx, it, n_tps, n_lts, 0, sig)
}

/// Type-check `extern "platform-intrinsic" { ... }` functions.
Expand All @@ -504,9 +505,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let name = it.ident.name;

let (n_tps, inputs, output) = match name {
let (n_tps, n_cts, inputs, output) = match name {
sym::simd_eq | sym::simd_ne | sym::simd_lt | sym::simd_le | sym::simd_gt | sym::simd_ge => {
(2, vec![param(0), param(0)], param(1))
(2, 0, vec![param(0), param(0)], param(1))
}
sym::simd_add
| sym::simd_sub
Expand All @@ -522,8 +523,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_fmax
| sym::simd_fpow
| sym::simd_saturating_add
| sym::simd_saturating_sub => (1, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, vec![param(0), param(1)], param(0)),
| sym::simd_saturating_sub => (1, 0, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::simd_neg
| sym::simd_bswap
| sym::simd_bitreverse
Expand All @@ -541,25 +542,25 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_ceil
| sym::simd_floor
| sym::simd_round
| sym::simd_trunc => (1, vec![param(0)], param(0)),
sym::simd_fpowi => (1, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, vec![param(0), tcx.types.u32], param(1)),
| sym::simd_trunc => (1, 0, vec![param(0)], param(0)),
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
sym::simd_cast
| sym::simd_as
| sym::simd_cast_ptr
| sym::simd_expose_addr
| sym::simd_from_exposed_addr => (2, vec![param(0)], param(1)),
sym::simd_bitmask => (2, vec![param(0)], param(1)),
| sym::simd_from_exposed_addr => (2, 0, vec![param(0)], param(1)),
sym::simd_bitmask => (2, 0, vec![param(0)], param(1)),
sym::simd_select | sym::simd_select_bitmask => {
(2, vec![param(0), param(1), param(1)], param(1))
(2, 0, vec![param(0), param(1), param(1)], param(1))
}
sym::simd_reduce_all | sym::simd_reduce_any => (1, vec![param(0)], tcx.types.bool),
sym::simd_reduce_all | sym::simd_reduce_any => (1, 0, vec![param(0)], tcx.types.bool),
sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => {
(2, vec![param(0), param(1)], param(1))
(2, 0, vec![param(0), param(1)], param(1))
}
sym::simd_reduce_add_unordered
| sym::simd_reduce_mul_unordered
Expand All @@ -569,8 +570,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_reduce_min
| sym::simd_reduce_max
| sym::simd_reduce_min_nanless
| sym::simd_reduce_max_nanless => (2, vec![param(0)], param(1)),
sym::simd_shuffle => (3, vec![param(0), param(0), param(1)], param(2)),
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
_ => {
let msg = format!("unrecognized platform-specific intrinsic function: `{name}`");
tcx.sess.struct_span_err(it.span, msg).emit();
Expand All @@ -580,5 +582,5 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let sig = tcx.mk_fn_sig(inputs, output, false, hir::Unsafety::Unsafe, Abi::PlatformIntrinsic);
let sig = ty::Binder::dummy(sig);
equate_intrinsic_type(tcx, it, n_tps, 0, sig)
equate_intrinsic_type(tcx, it, n_tps, 0, n_cts, sig)
}
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ symbols! {
simd_shl,
simd_shr,
simd_shuffle,
simd_shuffle_generic,
simd_sub,
simd_trunc,
simd_xor,
Expand Down
5 changes: 3 additions & 2 deletions src/tools/miri/src/shims/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}

// The rest jumps to `ret` immediately.
this.emulate_intrinsic_by_name(intrinsic_name, args, dest)?;
this.emulate_intrinsic_by_name(intrinsic_name, instance.args, args, dest)?;

trace!("{:?}", this.dump_place(dest));
this.go_to_block(ret);
Expand All @@ -71,6 +71,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_intrinsic_by_name(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand All @@ -80,7 +81,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
return this.emulate_atomic_intrinsic(name, args, dest);
}
if let Some(name) = intrinsic_name.strip_prefix("simd_") {
return this.emulate_simd_intrinsic(name, args, dest);
return this.emulate_simd_intrinsic(name, generic_args, args, dest);
}

match intrinsic_name {
Expand Down
33 changes: 33 additions & 0 deletions src/tools/miri/src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_simd_intrinsic(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand Down Expand Up @@ -490,6 +491,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.write_immediate(val, &dest)?;
}
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

let index = generic_args[2].expect_const().eval(*this.tcx, this.param_env(), Some(this.tcx.span)).unwrap().unwrap_branch();
let index_len = index.len();

assert_eq!(left_len, right_len);
assert_eq!(index_len as u64, dest_len);

for i in 0..dest_len {
let src_index: u64 = index[i as usize].unwrap_leaf()
.try_to_u32().unwrap()
.into();
let dest = this.project_index(&dest, i)?;

let val = if src_index < left_len {
this.read_immediate(&this.project_index(&left, src_index)?)?
} else if src_index < left_len.checked_add(right_len).unwrap() {
let right_idx = src_index.checked_sub(left_len).unwrap();
this.read_immediate(&this.project_index(&right, right_idx)?)?
} else {
span_bug!(
this.cur_span(),
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
);
};
this.write_immediate(*val, &dest)?;
}
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
Expand Down
28 changes: 27 additions & 1 deletion tests/ui/simd/intrinsic/generic-elements.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// build-fail

#![feature(repr_simd, platform_intrinsics, rustc_attrs)]
#![feature(repr_simd, platform_intrinsics, rustc_attrs, adt_const_params)]
#![allow(incomplete_features)]

#[repr(simd)]
#[derive(Copy, Clone)]
Expand Down Expand Up @@ -35,6 +36,7 @@ extern "platform-intrinsic" {
fn simd_extract<T, E>(x: T, idx: u32) -> E;

fn simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U;
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
}

fn main() {
Expand Down Expand Up @@ -71,5 +73,29 @@ fn main() {
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
simd_shuffle::<_, _, i32x2>(x, x, IDX8);
//~^ ERROR expected return type of length 8, found `i32x2` with length 2

const I2: &[u32] = &[0; 2];
simd_shuffle_generic::<i32, i32, I2>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should really say "scalar" in this case, but... probably not a real concern.

const I4: &[u32] = &[0; 4];
simd_shuffle_generic::<i32, i32, I4>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
const I8: &[u32] = &[0; 8];
simd_shuffle_generic::<i32, i32, I8>(0, 0);
//~^ ERROR expected SIMD input type, found non-SIMD `i32`

simd_shuffle_generic::<_, f32x2, I2>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x2` with element type `f32`
simd_shuffle_generic::<_, f32x4, I4>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x4` with element type `f32`
simd_shuffle_generic::<_, f32x8, I8>(x, x);
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x8` with element type `f32`

simd_shuffle_generic::<_, i32x8, I2>(x, x);
//~^ ERROR expected return type of length 2, found `i32x8` with length 8
simd_shuffle_generic::<_, i32x8, I4>(x, x);
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
simd_shuffle_generic::<_, i32x2, I8>(x, x);
//~^ ERROR expected return type of length 8, found `i32x2` with length 2
Comment on lines +98 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong, it shouldn't error? Given:

  • lhs TxN vector
  • rhs TxN vector
  • i1xM mask vector

the shufflevector instruction produces a vector of type T but size M, and M > N is legitimate (especially in the case of "unifying two vectors where M == N*2").

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I8 has length 8, the return type has length 2, so M here would be both 2 and 8 -- and hence it should error, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see. apologies, yes, that's correct. Hmm that's... annoying, but anything more fancy is probably too painful to implement at this level of the type system.

}
}