Skip to content

Commit

Permalink
Auto merge of rust-lang#98112 - saethlin:mir-alignment-checks, r=oli-obk
Browse files Browse the repository at this point in the history
Insert alignment checks for pointer dereferences when debug assertions are enabled

Closes rust-lang#54915

- [x] Jake tells me this sounds like a place to use `MirPatch`, but I can't figure out how to insert a new basic block with a new terminator in the middle of an existing basic block, using `MirPatch`. (if nobody else backs up this point I'm checking this as "not actually a good idea" because the code looks pretty clean to me after rearranging it a bit)
- [x] Using `CastKind::PointerExposeAddress` is definitely wrong, we don't want to expose. Calling a function to get the pointer address seems quite excessive. ~I'll see if I can add a new `CastKind`.~ `CastKind::Transmute` to the rescue!
- [x] Implement a more helpful panic message like slice bounds checking.

r? `@oli-obk`
  • Loading branch information
bors committed Mar 31, 2023
2 parents ec7bb8d + 7507078 commit 22a7a19
Show file tree
Hide file tree
Showing 35 changed files with 372 additions and 21 deletions.
12 changes: 12 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,18 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
source_info.span,
);
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = codegen_operand(fx, required).load_scalar(fx);
let found = codegen_operand(fx, found).load_scalar(fx);
let location = fx.get_caller_location(source_info).load_scalar(fx);

codegen_panic_inner(
fx,
rustc_hir::LangItem::PanicBoundsCheck,
&[required, found, location],
source_info.span,
);
}
_ => {
let msg_str = msg.description();
codegen_panic(fx, msg_str, source_info);
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicBoundsCheck, vec![index, len, location])
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = self.codegen_operand(bx, required).immediate();
let found = self.codegen_operand(bx, found).immediate();
// It's `fn panic_bounds_check(index: usize, len: usize)`,
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicMisalignedPointerDereference, vec![required, found, location])
}
_ => {
let msg = bx.const_str(msg.description());
// It's `pub fn panic(expr: &str)`, with the wide reference being passed
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,12 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
ResumedAfterReturn(generator_kind) => ResumedAfterReturn(*generator_kind),
ResumedAfterPanic(generator_kind) => ResumedAfterPanic(*generator_kind),
MisalignedPointerDereference { ref required, ref found } => {
MisalignedPointerDereference {
required: eval_to_int(required)?,
found: eval_to_int(found)?,
}
}
};
Err(ConstEvalErrKind::AssertFailure(err).into())
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ language_item_table! {
PanicDisplay, sym::panic_display, panic_display, Target::Fn, GenericRequirement::None;
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);
PanicInfo, sym::panic_info, panic_info, Target::Struct, GenericRequirement::None;
PanicLocation, sym::panic_location, panic_location, Target::Struct, GenericRequirement::None;
PanicImpl, sym::panic_impl, panic_impl, Target::Fn, GenericRequirement::None;
Expand Down
20 changes: 18 additions & 2 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ impl<O> AssertKind<O> {

/// Getting a description does not require `O` to be printable, and does not
/// require allocation.
/// The caller is expected to handle `BoundsCheck` separately.
/// The caller is expected to handle `BoundsCheck` and `MisalignedPointerDereference` separately.
pub fn description(&self) -> &'static str {
use AssertKind::*;
match self {
Expand All @@ -1296,7 +1296,9 @@ impl<O> AssertKind<O> {
ResumedAfterReturn(GeneratorKind::Async(_)) => "`async fn` resumed after completion",
ResumedAfterPanic(GeneratorKind::Gen) => "generator resumed after panicking",
ResumedAfterPanic(GeneratorKind::Async(_)) => "`async fn` resumed after panicking",
BoundsCheck { .. } => bug!("Unexpected AssertKind"),
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
bug!("Unexpected AssertKind")
}
}
}

Expand Down Expand Up @@ -1353,6 +1355,13 @@ impl<O> AssertKind<O> {
Overflow(BinOp::Shl, _, r) => {
write!(f, "\"attempt to shift left by `{{}}`, which would overflow\", {:?}", r)
}
MisalignedPointerDereference { required, found } => {
write!(
f,
"\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\", {:?}, {:?}",
required, found
)
}
_ => write!(f, "\"{}\"", self.description()),
}
}
Expand Down Expand Up @@ -1397,6 +1406,13 @@ impl<O: fmt::Debug> fmt::Debug for AssertKind<O> {
Overflow(BinOp::Shl, _, r) => {
write!(f, "attempt to shift left by `{:#?}`, which would overflow", r)
}
MisalignedPointerDereference { required, found } => {
write!(
f,
"misaligned pointer dereference: address must be a multiple of {:?} but is {:?}",
required, found
)
}
_ => write!(f, "{}", self.description()),
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ pub enum AssertKind<O> {
RemainderByZero(O),
ResumedAfterReturn(GeneratorKind),
ResumedAfterPanic(GeneratorKind),
MisalignedPointerDereference { required: O, found: O },
}

#[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)]
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ macro_rules! make_mir_visitor {
ResumedAfterReturn(_) | ResumedAfterPanic(_) => {
// Nothing to visit
}
MisalignedPointerDereference { required, found } => {
self.visit_operand(required, location);
self.visit_operand(found, location);
}
}
}

Expand Down
227 changes: 227 additions & 0 deletions compiler/rustc_mir_transform/src/check_alignment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
use crate::MirPass;
use rustc_hir::def_id::DefId;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::mir::{
interpret::{ConstValue, Scalar},
visit::{PlaceContext, Visitor},
};
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
use rustc_session::Session;

pub struct CheckAlignment;

impl<'tcx> MirPass<'tcx> for CheckAlignment {
fn is_enabled(&self, sess: &Session) -> bool {
sess.opts.debug_assertions
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;

for block in (0..basic_blocks.len()).rev() {
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
let location = Location { block, statement_index };
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;

let mut finder = PointerFinder {
local_decls,
tcx,
pointers: Vec::new(),
def_id: body.source.def_id(),
};
for (pointer, pointee_ty) in finder.find_pointers(statement) {
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);

let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
pointer,
pointee_ty,
source_info,
new_block,
);
}
}
}
}
}

impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
self.pointers.clear();
self.visit_statement(statement, Location::START);
core::mem::take(&mut self.pointers)
}
}

struct PointerFinder<'tcx, 'a> {
local_decls: &'a mut LocalDecls<'tcx>,
tcx: TyCtxt<'tcx>,
def_id: DefId,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}

impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if let PlaceContext::NonUse(_) = context {
return;
}
if !place.is_indirect() {
return;
}

let pointer = Place::from(place.local);
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;

// We only want to check unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
return;
}

let Some(pointee) = pointer_ty.builtin_deref(true) else {
debug!("Indirect but no builtin deref: {:?}", pointer_ty);
return;
};
let mut pointee_ty = pointee.ty;
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
pointee_ty = pointee_ty.sequence_element_type(self.tcx);
}

if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
return;
}

if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
.contains(&pointee_ty)
{
debug!("Trivially aligned pointee type: {:?}", pointer_ty);
return;
}

self.pointers.push((pointer, pointee_ty))
}
}

fn split_block(
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
location: Location,
) -> BasicBlock {
let block_data = &mut basic_blocks[location.block];

// Drain every statement after this one and move the current terminator to a new basic block
let new_block = BasicBlockData {
statements: block_data.statements.split_off(location.statement_index),
terminator: block_data.terminator.take(),
is_cleanup: block_data.is_cleanup,
};

basic_blocks.push(new_block)
}

fn insert_alignment_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut LocalDecls<'tcx>,
block_data: &mut BasicBlockData<'tcx>,
pointer: Place<'tcx>,
pointee_ty: Ty<'tcx>,
source_info: SourceInfo,
new_block: BasicBlock,
) {
// Cast the pointer to a *const ()
let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not });
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
block_data
.statements
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });

// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data
.statements
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });

// Get the alignment of the pointee
let alignment =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((alignment, rvalue))),
});

// Subtract 1 from the alignment to get the alignment mask
let alignment_mask =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
let one = Operand::Constant(Box::new(Constant {
span: source_info.span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)),
tcx.types.usize,
),
}));
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_mask,
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))),
))),
});

// BitAnd the alignment mask with the pointer
let alignment_bits =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_bits,
Rvalue::BinaryOp(
BinOp::BitAnd,
Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))),
),
))),
});

// Check if the alignment bits are all zero
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
let zero = Operand::Constant(Box::new(Constant {
span: source_info.span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)),
tcx.types.usize,
),
}));
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
is_ok,
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
))),
});

// Set this block's terminator to our assert, continuing to new_block if we pass
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Copy(is_ok),
expected: true,
target: new_block,
msg: AssertKind::MisalignedPointerDereference {
required: Operand::Copy(alignment),
found: Operand::Copy(addr),
},
cleanup: None,
},
});
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ mod separate_const_switch;
mod shim;
mod ssa;
// This pass is public to allow external drivers to perform MIR cleanup
mod check_alignment;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
Expand Down Expand Up @@ -545,6 +546,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
tcx,
body,
&[
&check_alignment::CheckAlignment,
&reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode.
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
&unreachable_prop::UnreachablePropagation,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,7 @@ symbols! {
panic_implementation,
panic_info,
panic_location,
panic_misaligned_pointer_dereference,
panic_nounwind,
panic_runtime,
panic_str,
Expand Down
14 changes: 14 additions & 0 deletions library/core/src/panicking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ fn panic_bounds_check(index: usize, len: usize) -> ! {
panic!("index out of bounds: the len is {len} but the index is {index}")
}

#[cold]
#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))]
#[track_caller]
#[cfg_attr(not(bootstrap), lang = "panic_misaligned_pointer_dereference")] // needed by codegen for panic on misaligned pointer deref
fn panic_misaligned_pointer_dereference(required: usize, found: usize) -> ! {
if cfg!(feature = "panic_immediate_abort") {
super::intrinsics::abort()
}

panic!(
"misaligned pointer dereference: address must be a multiple of {required:#x} but is {found:#x}"
)
}

/// Panic because we cannot unwind out of a function.
///
/// This function is called directly by the codegen backend, and must not have
Expand Down
Loading

0 comments on commit 22a7a19

Please sign in to comment.