Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SIMD type element count validation #80652

Merged
merged 5 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 0 additions & 117 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,25 +498,6 @@ impl CodegenCx<'b, 'tcx> {
let t_f32 = self.type_f32();
let t_f64 = self.type_f64();

macro_rules! vector_types {
($id_out:ident: $elem_ty:ident, $len:expr) => {
let $id_out = self.type_vector($elem_ty, $len);
};
($($id_out:ident: $elem_ty:ident, $len:expr;)*) => {
$(vector_types!($id_out: $elem_ty, $len);)*
}
}
vector_types! {
t_v2f32: t_f32, 2;
t_v4f32: t_f32, 4;
t_v8f32: t_f32, 8;
t_v16f32: t_f32, 16;

t_v2f64: t_f64, 2;
t_v4f64: t_f64, 4;
t_v8f64: t_f64, 8;
}

ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f32", fn(t_f32) -> t_i32);
ifn!("llvm.wasm.trunc.saturate.unsigned.i32.f64", fn(t_f64) -> t_i32);
ifn!("llvm.wasm.trunc.saturate.unsigned.i64.f32", fn(t_f32) -> t_i64);
Expand All @@ -540,149 +521,51 @@ impl CodegenCx<'b, 'tcx> {
ifn!("llvm.sideeffect", fn() -> void);

ifn!("llvm.powi.f32", fn(t_f32, t_i32) -> t_f32);
ifn!("llvm.powi.v2f32", fn(t_v2f32, t_i32) -> t_v2f32);
ifn!("llvm.powi.v4f32", fn(t_v4f32, t_i32) -> t_v4f32);
ifn!("llvm.powi.v8f32", fn(t_v8f32, t_i32) -> t_v8f32);
ifn!("llvm.powi.v16f32", fn(t_v16f32, t_i32) -> t_v16f32);
ifn!("llvm.powi.f64", fn(t_f64, t_i32) -> t_f64);
ifn!("llvm.powi.v2f64", fn(t_v2f64, t_i32) -> t_v2f64);
ifn!("llvm.powi.v4f64", fn(t_v4f64, t_i32) -> t_v4f64);
ifn!("llvm.powi.v8f64", fn(t_v8f64, t_i32) -> t_v8f64);

ifn!("llvm.pow.f32", fn(t_f32, t_f32) -> t_f32);
ifn!("llvm.pow.v2f32", fn(t_v2f32, t_v2f32) -> t_v2f32);
ifn!("llvm.pow.v4f32", fn(t_v4f32, t_v4f32) -> t_v4f32);
ifn!("llvm.pow.v8f32", fn(t_v8f32, t_v8f32) -> t_v8f32);
ifn!("llvm.pow.v16f32", fn(t_v16f32, t_v16f32) -> t_v16f32);
ifn!("llvm.pow.f64", fn(t_f64, t_f64) -> t_f64);
ifn!("llvm.pow.v2f64", fn(t_v2f64, t_v2f64) -> t_v2f64);
ifn!("llvm.pow.v4f64", fn(t_v4f64, t_v4f64) -> t_v4f64);
ifn!("llvm.pow.v8f64", fn(t_v8f64, t_v8f64) -> t_v8f64);

ifn!("llvm.sqrt.f32", fn(t_f32) -> t_f32);
ifn!("llvm.sqrt.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.sqrt.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.sqrt.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.sqrt.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.sqrt.f64", fn(t_f64) -> t_f64);
ifn!("llvm.sqrt.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.sqrt.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.sqrt.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.sin.f32", fn(t_f32) -> t_f32);
ifn!("llvm.sin.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.sin.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.sin.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.sin.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.sin.f64", fn(t_f64) -> t_f64);
ifn!("llvm.sin.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.sin.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.sin.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.cos.f32", fn(t_f32) -> t_f32);
ifn!("llvm.cos.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.cos.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.cos.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.cos.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.cos.f64", fn(t_f64) -> t_f64);
ifn!("llvm.cos.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.cos.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.cos.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.exp.f32", fn(t_f32) -> t_f32);
ifn!("llvm.exp.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.exp.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.exp.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.exp.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.exp.f64", fn(t_f64) -> t_f64);
ifn!("llvm.exp.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.exp.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.exp.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.exp2.f32", fn(t_f32) -> t_f32);
ifn!("llvm.exp2.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.exp2.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.exp2.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.exp2.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.exp2.f64", fn(t_f64) -> t_f64);
ifn!("llvm.exp2.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.exp2.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.exp2.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.log.f32", fn(t_f32) -> t_f32);
ifn!("llvm.log.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.log.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.log.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.log.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.log.f64", fn(t_f64) -> t_f64);
ifn!("llvm.log.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.log.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.log.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.log10.f32", fn(t_f32) -> t_f32);
ifn!("llvm.log10.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.log10.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.log10.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.log10.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.log10.f64", fn(t_f64) -> t_f64);
ifn!("llvm.log10.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.log10.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.log10.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.log2.f32", fn(t_f32) -> t_f32);
ifn!("llvm.log2.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.log2.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.log2.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.log2.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.log2.f64", fn(t_f64) -> t_f64);
ifn!("llvm.log2.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.log2.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.log2.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.fma.f32", fn(t_f32, t_f32, t_f32) -> t_f32);
ifn!("llvm.fma.v2f32", fn(t_v2f32, t_v2f32, t_v2f32) -> t_v2f32);
ifn!("llvm.fma.v4f32", fn(t_v4f32, t_v4f32, t_v4f32) -> t_v4f32);
ifn!("llvm.fma.v8f32", fn(t_v8f32, t_v8f32, t_v8f32) -> t_v8f32);
ifn!("llvm.fma.v16f32", fn(t_v16f32, t_v16f32, t_v16f32) -> t_v16f32);
ifn!("llvm.fma.f64", fn(t_f64, t_f64, t_f64) -> t_f64);
ifn!("llvm.fma.v2f64", fn(t_v2f64, t_v2f64, t_v2f64) -> t_v2f64);
ifn!("llvm.fma.v4f64", fn(t_v4f64, t_v4f64, t_v4f64) -> t_v4f64);
ifn!("llvm.fma.v8f64", fn(t_v8f64, t_v8f64, t_v8f64) -> t_v8f64);

ifn!("llvm.fabs.f32", fn(t_f32) -> t_f32);
ifn!("llvm.fabs.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.fabs.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.fabs.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.fabs.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.fabs.f64", fn(t_f64) -> t_f64);
ifn!("llvm.fabs.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.fabs.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.fabs.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.minnum.f32", fn(t_f32, t_f32) -> t_f32);
ifn!("llvm.minnum.f64", fn(t_f64, t_f64) -> t_f64);
ifn!("llvm.maxnum.f32", fn(t_f32, t_f32) -> t_f32);
ifn!("llvm.maxnum.f64", fn(t_f64, t_f64) -> t_f64);

ifn!("llvm.floor.f32", fn(t_f32) -> t_f32);
ifn!("llvm.floor.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.floor.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.floor.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.floor.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.floor.f64", fn(t_f64) -> t_f64);
ifn!("llvm.floor.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.floor.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.floor.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.ceil.f32", fn(t_f32) -> t_f32);
ifn!("llvm.ceil.v2f32", fn(t_v2f32) -> t_v2f32);
ifn!("llvm.ceil.v4f32", fn(t_v4f32) -> t_v4f32);
ifn!("llvm.ceil.v8f32", fn(t_v8f32) -> t_v8f32);
ifn!("llvm.ceil.v16f32", fn(t_v16f32) -> t_v16f32);
ifn!("llvm.ceil.f64", fn(t_f64) -> t_f64);
ifn!("llvm.ceil.v2f64", fn(t_v2f64) -> t_v2f64);
ifn!("llvm.ceil.v4f64", fn(t_v4f64) -> t_v4f64);
ifn!("llvm.ceil.v8f64", fn(t_v8f64) -> t_v8f64);

ifn!("llvm.trunc.f32", fn(t_f32) -> t_f32);
ifn!("llvm.trunc.f64", fn(t_f64) -> t_f64);
Expand Down
133 changes: 55 additions & 78 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ fn generic_simd_intrinsic(
}

fn simd_simple_float_intrinsic(
name: &str,
name: Symbol,
in_elem: &::rustc_middle::ty::TyS<'_>,
in_ty: &::rustc_middle::ty::TyS<'_>,
in_len: u64,
Expand All @@ -1036,93 +1036,70 @@ fn generic_simd_intrinsic(
}
}
}
let ety = match in_elem.kind() {
ty::Float(f) if f.bit_width() == 32 => {
if in_len < 2 || in_len > 16 {
return_error!(
"unsupported floating-point vector `{}` with length `{}` \
out-of-range [2, 16]",
in_ty,
in_len
);
}
"f32"
}
ty::Float(f) if f.bit_width() == 64 => {
if in_len < 2 || in_len > 8 {

let (elem_ty_str, elem_ty) = if let ty::Float(f) = in_elem.kind() {
let elem_ty = bx.cx.type_float_from_ty(*f);
match f.bit_width() {
32 => ("f32", elem_ty),
64 => ("f64", elem_ty),
_ => {
return_error!(
"unsupported floating-point vector `{}` with length `{}` \
out-of-range [2, 8]",
in_ty,
in_len
"unsupported element type `{}` of floating-point vector `{}`",
f.name_str(),
in_ty
);
}
"f64"
}
ty::Float(f) => {
return_error!(
"unsupported element type `{}` of floating-point vector `{}`",
f.name_str(),
in_ty
);
}
_ => {
return_error!("`{}` is not a floating-point type", in_ty);
}
} else {
return_error!("`{}` is not a floating-point type", in_ty);
};

let llvm_name = &format!("llvm.{0}.v{1}{2}", name, in_len, ety);
let intrinsic = bx.get_intrinsic(&llvm_name);
let c =
bx.call(intrinsic, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None);
let vec_ty = bx.type_vector(elem_ty, in_len);

let (intr_name, fn_ty) = match name {
sym::simd_fsqrt => ("sqrt", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fsin => ("sin", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fcos => ("cos", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fabs => ("fabs", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_floor => ("floor", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_ceil => ("ceil", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fexp => ("exp", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fexp2 => ("exp2", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_flog10 => ("log10", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_flog2 => ("log2", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_flog => ("log", bx.type_func(&[vec_ty], vec_ty)),
sym::simd_fpowi => ("powi", bx.type_func(&[vec_ty, bx.type_i32()], vec_ty)),
sym::simd_fpow => ("pow", bx.type_func(&[vec_ty, vec_ty], vec_ty)),
sym::simd_fma => ("fma", bx.type_func(&[vec_ty, vec_ty, vec_ty], vec_ty)),
_ => return_error!("unrecognized intrinsic `{}`", name),
};

let llvm_name = &format!("llvm.{0}.v{1}{2}", intr_name, in_len, elem_ty_str);
let f = bx.declare_cfn(&llvm_name, fn_ty);
llvm::SetUnnamedAddress(f, llvm::UnnamedAddr::No);
calebzulawski marked this conversation as resolved.
Show resolved Hide resolved
let c = bx.call(f, &args.iter().map(|arg| arg.immediate()).collect::<Vec<_>>(), None);
unsafe { llvm::LLVMRustSetHasUnsafeAlgebra(c) };
Ok(c)
}

match name {
sym::simd_fsqrt => {
return simd_simple_float_intrinsic("sqrt", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fsin => {
return simd_simple_float_intrinsic("sin", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fcos => {
return simd_simple_float_intrinsic("cos", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fabs => {
return simd_simple_float_intrinsic("fabs", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_floor => {
return simd_simple_float_intrinsic("floor", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_ceil => {
return simd_simple_float_intrinsic("ceil", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fexp => {
return simd_simple_float_intrinsic("exp", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fexp2 => {
return simd_simple_float_intrinsic("exp2", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_flog10 => {
return simd_simple_float_intrinsic("log10", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_flog2 => {
return simd_simple_float_intrinsic("log2", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_flog => {
return simd_simple_float_intrinsic("log", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fpowi => {
return simd_simple_float_intrinsic("powi", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fpow => {
return simd_simple_float_intrinsic("pow", in_elem, in_ty, in_len, bx, span, args);
}
sym::simd_fma => {
return simd_simple_float_intrinsic("fma", in_elem, in_ty, in_len, bx, span, args);
}
_ => { /* fallthrough */ }
if std::matches!(
name,
sym::simd_fsqrt
| sym::simd_fsin
| sym::simd_fcos
| sym::simd_fabs
| sym::simd_floor
| sym::simd_ceil
| sym::simd_fexp
| sym::simd_fexp2
| sym::simd_flog10
| sym::simd_flog2
| sym::simd_flog
| sym::simd_fpowi
| sym::simd_fpow
| sym::simd_fma
) {
return simd_simple_float_intrinsic(name, in_elem, in_ty, in_len, bx, span, args);
}

// FIXME: use:
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,17 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> {
};

// SIMD vectors of zero length are not supported.
// Additionally, lengths are capped at 2^16 as a fixed maximum backends must
// support.
//
// Can't be caught in typeck if the array length is generic.
if e_len == 0 {
tcx.sess.fatal(&format!("monomorphising SIMD type `{}` of zero length", ty));
} else if e_len > 65536 {
calebzulawski marked this conversation as resolved.
Show resolved Hide resolved
tcx.sess.fatal(&format!(
"monomorphising SIMD type `{}` of length greater than 65536",
ty,
));
}

// Compute the ABI of the element type:
Expand Down
22 changes: 22 additions & 0 deletions compiler/rustc_typeck/src/check/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,28 @@ pub fn check_simd(tcx: TyCtxt<'_>, sp: Span, def_id: LocalDefId) {
.emit();
return;
}

let len = if let ty::Array(_ty, c) = e.kind() {
c.try_eval_usize(tcx, tcx.param_env(def.did))
} else {
Some(fields.len() as u64)
};
if let Some(len) = len {
if len == 0 {
struct_span_err!(tcx.sess, sp, E0075, "SIMD vector cannot be empty").emit();
return;
} else if len > 65536 {
calebzulawski marked this conversation as resolved.
Show resolved Hide resolved
struct_span_err!(
tcx.sess,
sp,
E0075,
"SIMD vector cannot have more than 65536 elements"
)
.emit();
return;
}
}

match e.kind() {
ty::Param(_) => { /* struct<T>(T, T, T, T) is ok */ }
_ if e.is_machine() => { /* struct(u8, u8, u8, u8) is ok */ }
Expand Down
12 changes: 12 additions & 0 deletions src/test/ui/simd-type-generic-monomorphisation-empty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// build-fail

#![feature(repr_simd, platform_intrinsics)]

// error-pattern:monomorphising SIMD type `Simd<0_usize>` of zero length

#[repr(simd)]
struct Simd<const N: usize>([f32; N]);

fn main() {
let _ = Simd::<0>([]);
}
4 changes: 4 additions & 0 deletions src/test/ui/simd-type-generic-monomorphisation-empty.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
error: monomorphising SIMD type `Simd<0_usize>` of zero length

error: aborting due to previous error

Loading