Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change enum->int casts to not go through MIR casts. #96862

Merged
merged 4 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,29 +635,6 @@ fn codegen_stmt<'tcx>(
let (ptr, _extra) = operand.load_scalar_pair(fx);
lval.write_cvalue(fx, CValue::by_val(ptr, dest_layout))
}
} else if let ty::Adt(adt_def, _substs) = from_ty.kind() {
// enum -> discriminant value
assert!(adt_def.is_enum());
match to_ty.kind() {
ty::Uint(_) | ty::Int(_) => {}
_ => unreachable!("cast adt {} -> {}", from_ty, to_ty),
}
let to_clif_ty = fx.clif_type(to_ty).unwrap();

let discriminant = crate::discriminant::codegen_get_discriminant(
fx,
operand,
fx.layout_of(operand.layout().ty.discriminant_ty(fx.tcx)),
)
.load_scalar(fx);

let res = crate::cast::clif_intcast(
fx,
discriminant,
to_clif_ty,
to_ty.is_signed(),
);
lval.write_cvalue(fx, CValue::by_val(res, dest_layout));
} else {
let to_clif_ty = fx.clif_type(to_ty).unwrap();
let from = operand.load_scalar(fx);
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::ffi::CStr;
use std::iter;
use std::ops::Deref;
use std::ptr;
use tracing::debug;
use tracing::{debug, instrument};

// All Builders must have an llfn associated with them
#[must_use]
Expand Down Expand Up @@ -464,15 +464,15 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
}
}

#[instrument(level = "trace", skip(self))]
fn load_operand(&mut self, place: PlaceRef<'tcx, &'ll Value>) -> OperandRef<'tcx, &'ll Value> {
debug!("PlaceRef::load: {:?}", place);

assert_eq!(place.llextra.is_some(), place.layout.is_unsized());

if place.layout.is_zst() {
return OperandRef::new_zst(self, place.layout);
}

#[instrument(level = "trace", skip(bx))]
fn scalar_load_metadata<'a, 'll, 'tcx>(
bx: &mut Builder<'a, 'll, 'tcx>,
load: &'ll Value,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ extern "C" {
B: &Builder<'a>,
Val: &'a Value,
DestTy: &'a Type,
IsSized: bool,
IsSigned: bool,
) -> &'a Value;

// Comparisons
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}

/// Obtain the actual discriminant of a value.
#[instrument(level = "trace", skip(bx))]
pub fn codegen_get_discr<Bx: BuilderMethods<'a, 'tcx, Value = V>>(
self,
bx: &mut Bx,
Expand Down Expand Up @@ -420,12 +421,12 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "trace", skip(self, bx))]
pub fn codegen_place(
&mut self,
bx: &mut Bx,
place_ref: mir::PlaceRef<'tcx>,
) -> PlaceRef<'tcx, Bx::Value> {
debug!("codegen_place(place_ref={:?})", place_ref);
let cx = self.cx;
let tcx = self.cx.tcx();

Expand Down
81 changes: 9 additions & 72 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ use rustc_middle::ty::cast::{CastTy, IntTy};
use rustc_middle::ty::layout::{HasTyCtxt, LayoutOf};
use rustc_middle::ty::{self, adjustment::PointerCast, Instance, Ty, TyCtxt};
use rustc_span::source_map::{Span, DUMMY_SP};
use rustc_target::abi::{Abi, Int, Variants};

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "trace", skip(self, bx))]
pub fn codegen_rvalue(
&mut self,
mut bx: Bx,
dest: PlaceRef<'tcx, Bx::Value>,
rvalue: &mir::Rvalue<'tcx>,
) -> Bx {
debug!("codegen_rvalue(dest.llval={:?}, rvalue={:?})", dest.llval, rvalue);

match *rvalue {
mir::Rvalue::Use(ref operand) => {
let cg_operand = self.codegen_operand(&mut bx, operand);
Expand Down Expand Up @@ -284,74 +282,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
CastTy::from_ty(operand.layout.ty).expect("bad input type for cast");
let r_t_out = CastTy::from_ty(cast.ty).expect("bad output type for cast");
let ll_t_in = bx.cx().immediate_backend_type(operand.layout);
match operand.layout.variants {
Variants::Single { index } => {
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
if let Some(discr) =
operand.layout.ty.discriminant_for_variant(bx.tcx(), index)
{
let discr_layout = bx.cx().layout_of(discr.ty);
let discr_t = bx.cx().immediate_backend_type(discr_layout);
let discr_val = bx.cx().const_uint_big(discr_t, discr.val);
let discr_val =
bx.intcast(discr_val, ll_t_out, discr.ty.is_signed());

return (
bx,
OperandRef {
val: OperandValue::Immediate(discr_val),
layout: cast,
},
);
}
}
Variants::Multiple { .. } => {}
}
let llval = operand.immediate();

let mut signed = false;
if let Abi::Scalar(scalar) = operand.layout.abi {
if let Int(_, s) = scalar.primitive() {
// We use `i1` for bytes that are always `0` or `1`,
// e.g., `#[repr(i8)] enum E { A, B }`, but we can't
// let LLVM interpret the `i1` as signed, because
// then `i1 1` (i.e., E::B) is effectively `i8 -1`.
signed = !scalar.is_bool() && s;

if !scalar.is_always_valid(bx.cx())
&& scalar.valid_range(bx.cx()).end
>= scalar.valid_range(bx.cx()).start
{
// We want `table[e as usize ± k]` to not
// have bound checks, and this is the most
// convenient place to put the `assume`s.
if scalar.valid_range(bx.cx()).start > 0 {
let enum_value_lower_bound = bx.cx().const_uint_big(
ll_t_in,
scalar.valid_range(bx.cx()).start,
);
let cmp_start = bx.icmp(
IntPredicate::IntUGE,
llval,
enum_value_lower_bound,
);
bx.assume(cmp_start);
}

let enum_value_upper_bound = bx
.cx()
.const_uint_big(ll_t_in, scalar.valid_range(bx.cx()).end);
let cmp_end = bx.icmp(
IntPredicate::IntULE,
llval,
enum_value_upper_bound,
);
bx.assume(cmp_end);
}
}
}

let newval = match (r_t_in, r_t_out) {
(CastTy::Int(_), CastTy::Int(_)) => bx.intcast(llval, ll_t_out, signed),
(CastTy::Int(i), CastTy::Int(_)) => {
bx.intcast(llval, ll_t_out, i.is_signed())
}
(CastTy::Float, CastTy::Float) => {
let srcsz = bx.cx().float_width(ll_t_in);
let dstsz = bx.cx().float_width(ll_t_out);
Expand All @@ -363,8 +299,8 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
llval
}
}
(CastTy::Int(_), CastTy::Float) => {
if signed {
(CastTy::Int(i), CastTy::Float) => {
if i.is_signed() {
bx.sitofp(llval, ll_t_out)
} else {
bx.uitofp(llval, ll_t_out)
Expand All @@ -373,8 +309,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
(CastTy::Ptr(_) | CastTy::FnPtr, CastTy::Ptr(_)) => {
bx.pointercast(llval, ll_t_out)
}
(CastTy::Int(_), CastTy::Ptr(_)) => {
let usize_llval = bx.intcast(llval, bx.cx().type_isize(), signed);
(CastTy::Int(i), CastTy::Ptr(_)) => {
let usize_llval =
bx.intcast(llval, bx.cx().type_isize(), i.is_signed());
bx.inttoptr(usize_llval, ll_t_out)
}
(CastTy::Float, CastTy::Int(IntTy::I)) => {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ use crate::traits::BuilderMethods;
use crate::traits::*;

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_statement(&mut self, mut bx: Bx, statement: &mir::Statement<'tcx>) -> Bx {
debug!("codegen_statement(statement={:?})", statement);

self.set_debug_loc(&mut bx, statement.source_info);
match statement.kind {
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
Expand Down
25 changes: 2 additions & 23 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rustc_middle::mir::CastKind;
use rustc_middle::ty::adjustment::PointerCast;
use rustc_middle::ty::layout::{IntegerExt, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, FloatTy, Ty, TypeAndMut};
use rustc_target::abi::{Integer, Variants};
use rustc_target::abi::Integer;
use rustc_type_ir::sty::TyKind::*;

use super::{
Expand Down Expand Up @@ -127,12 +127,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Float(FloatTy::F64) => {
return Ok(self.cast_from_float(src.to_scalar()?.to_f64()?, cast_ty).into());
}
// The rest is integer/pointer-"like", including fn ptr casts and casts from enums that
// are represented as integers.
// The rest is integer/pointer-"like", including fn ptr casts
_ => assert!(
src.layout.ty.is_bool()
|| src.layout.ty.is_char()
|| src.layout.ty.is_enum()
|| src.layout.ty.is_integral()
|| src.layout.ty.is_any_ptr(),
"Unexpected cast from type {:?}",
Expand All @@ -142,25 +140,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

// # First handle non-scalar source values.

// Handle cast from a ZST enum (0 or 1 variants).
match src.layout.variants {
Variants::Single { index } => {
if src.layout.abi.is_uninhabited() {
// This is dead code, because an uninhabited enum is UB to
// instantiate.
throw_ub!(Unreachable);
}
if let Some(discr) = src.layout.ty.discriminant_for_variant(*self.tcx, index) {
assert!(src.layout.is_zst());
let discr_layout = self.layout_of(discr.ty)?;

let scalar = Scalar::from_uint(discr.val, discr_layout.layout.size());
return Ok(self.cast_from_int_like(scalar, discr_layout, cast_ty)?.into());
}
}
Variants::Multiple { .. } => {}
}

// Handle casting any ptr to raw ptr (might be a fat ptr).
if src.layout.ty.is_any_ptr() && cast_ty.is_unsafe_ptr() {
let dest_layout = self.layout_of(cast_ty)?;
Expand Down
32 changes: 28 additions & 4 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
use rustc_middle::mir::{
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, Local, Location, MirPass,
MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope, Statement,
StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, Local, Location,
MirPass, MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, Rvalue, SourceScope,
Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::{self, InstanceDef, ParamEnv, Ty, TyCtxt, TypeFoldable};
Expand Down Expand Up @@ -361,6 +361,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
);
}
}
Rvalue::Ref(..) => {}
Rvalue::Len(p) => {
let pty = p.ty(&self.body.local_decls, self.tcx).ty;
check_kinds!(
Expand Down Expand Up @@ -503,7 +504,30 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
let a = operand.ty(&self.body.local_decls, self.tcx);
check_kinds!(a, "Cannot shallow init type {:?}", ty::RawPtr(..));
}
_ => {}
Rvalue::Cast(kind, operand, target_type) => {
match kind {
CastKind::Misc => {
let op_ty = operand.ty(self.body, self.tcx);
if op_ty.is_enum() {
self.fail(
location,
format!(
"enum -> int casts should go through `Rvalue::Discriminant`: {operand:?}:{op_ty} as {target_type}",
),
);
}
}
// Nothing to check here
CastKind::PointerFromExposedAddress
| CastKind::PointerExposeAddress
| CastKind::Pointer(_) => {}
}
}
Rvalue::Repeat(_, _)
| Rvalue::ThreadLocalRef(_)
| Rvalue::AddressOf(_, _)
| Rvalue::NullaryOp(_, _)
| Rvalue::Discriminant(_) => {}
}
self.super_rvalue(rvalue, location);
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_middle/src/ty/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ pub enum IntTy {
Char,
}

impl IntTy {
pub fn is_signed(self) -> bool {
matches!(self, Self::I)
}
}

// Valid types for the result of a non-coercion cast
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CastTy<'tcx> {
Expand Down
29 changes: 25 additions & 4 deletions compiler/rustc_mir_build/src/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! See docs in `build/expr/mod.rs`.

use rustc_index::vec::Idx;
use rustc_middle::ty::util::IntTypeExt;

use crate::build::expr::as_place::PlaceBase;
use crate::build::expr::category::{Category, RvalueFunc};
Expand Down Expand Up @@ -190,7 +191,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
ExprKind::Cast { source } => {
let source = &this.thir[source];
let from_ty = CastTy::from_ty(source.ty);

// Casting an enum to an integer is equivalent to computing the discriminant and casting the
// discriminant. Previously every backend had to repeat the logic for this operation. Now we
// create all the steps directly in MIR with operations all backends need to support anyway.
let (source, ty) = if let ty::Adt(adt_def, ..) = source.ty.kind() && adt_def.is_enum() {
let discr_ty = adt_def.repr().discr_type().to_ty(this.tcx);
let place = unpack!(block = this.as_place(block, source));
let discr = this.temp(discr_ty, source.span);
this.cfg.push_assign(
block,
source_info,
discr,
Rvalue::Discriminant(place),
);

(Operand::Move(discr), discr_ty)
} else {
let ty = source.ty;
let source = unpack!(
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
);
(source, ty)
};
let from_ty = CastTy::from_ty(ty);
let cast_ty = CastTy::from_ty(expr.ty);
let cast_kind = match (from_ty, cast_ty) {
(Some(CastTy::Ptr(_) | CastTy::FnPtr), Some(CastTy::Int(_))) => {
Expand All @@ -201,9 +225,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
(_, _) => CastKind::Misc,
};
let source = unpack!(
block = this.as_operand(block, scope, source, None, NeedsTemporary::No)
);
block.and(Rvalue::Cast(cast_kind, source, expr.ty))
}
ExprKind::Pointer { cast, source } => {
Expand Down
Loading