From af7472ecbc40693680669c7667c81811a3eb949a Mon Sep 17 00:00:00 2001 From: "Celina G. Val" Date: Mon, 30 Oct 2023 13:07:02 -0700 Subject: [PATCH] Add a stable MIR visitor Add a few utility functions as well and extend most `mir` and `ty` ADTs to implement `PartialEq` and `Eq`. --- compiler/rustc_smir/src/rustc_smir/mod.rs | 6 + compiler/stable_mir/src/lib.rs | 7 +- compiler/stable_mir/src/mir.rs | 2 + compiler/stable_mir/src/mir/body.rs | 74 ++-- compiler/stable_mir/src/mir/mono.rs | 4 + compiler/stable_mir/src/mir/visit.rs | 414 +++++++++++++++++++ compiler/stable_mir/src/ty.rs | 92 ++--- tests/ui-fulldeps/stable-mir/smir_visitor.rs | 148 +++++++ 8 files changed, 662 insertions(+), 85 deletions(-) create mode 100644 compiler/stable_mir/src/mir/visit.rs create mode 100644 tests/ui-fulldeps/stable-mir/smir_visitor.rs diff --git a/compiler/rustc_smir/src/rustc_smir/mod.rs b/compiler/rustc_smir/src/rustc_smir/mod.rs index 5ab5a048ffafa..b619ce6e35f99 100644 --- a/compiler/rustc_smir/src/rustc_smir/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/mod.rs @@ -216,6 +216,12 @@ impl<'tcx> Context for TablesWrapper<'tcx> { tables.create_def_id(def_id) } + fn instance_mangled_name(&self, def: InstanceDef) -> String { + let tables = self.0.borrow_mut(); + let instance = tables.instances[def]; + tables.tcx.symbol_name(instance).name.to_string() + } + fn mono_instance(&self, item: stable_mir::CrateItem) -> stable_mir::mir::mono::Instance { let mut tables = self.0.borrow_mut(); let def_id = tables[item.0]; diff --git a/compiler/stable_mir/src/lib.rs b/compiler/stable_mir/src/lib.rs index 38915afaa0c84..f316671b278ea 100644 --- a/compiler/stable_mir/src/lib.rs +++ b/compiler/stable_mir/src/lib.rs @@ -103,8 +103,6 @@ pub type DefKind = Opaque; pub type Filename = Opaque; /// Holds information about an item in the crate. -/// For now, it only stores the item DefId. Use functions inside `rustc_internal` module to -/// use this item. #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct CrateItem(pub DefId); @@ -224,6 +222,9 @@ pub trait Context { /// Get the instance. fn instance_def_id(&self, instance: InstanceDef) -> DefId; + /// Get the instance mangled name. + fn instance_mangled_name(&self, instance: InstanceDef) -> String; + /// Convert a non-generic crate item into an instance. /// This function will panic if the item is generic. fn mono_instance(&self, item: CrateItem) -> Instance; @@ -259,7 +260,7 @@ pub fn with(f: impl FnOnce(&dyn Context) -> R) -> R { } /// A type that provides internal information but that can still be used for debug purpose. -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq)] pub struct Opaque(String); impl std::fmt::Display for Opaque { diff --git a/compiler/stable_mir/src/mir.rs b/compiler/stable_mir/src/mir.rs index 3138bb1ec832c..2e1714b49c184 100644 --- a/compiler/stable_mir/src/mir.rs +++ b/compiler/stable_mir/src/mir.rs @@ -1,4 +1,6 @@ mod body; pub mod mono; +pub mod visit; pub use body::*; +pub use visit::MirVisitor; diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index 9f69e61d6fe7b..804f5e3d9d0c4 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -1,6 +1,6 @@ -use crate::ty::{AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability, Region}; +use crate::ty::{AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability, Region, Ty}; use crate::Opaque; -use crate::{ty::Ty, Span}; +use crate::Span; /// The SMIR representation of a single function. #[derive(Clone, Debug)] @@ -12,10 +12,10 @@ pub struct Body { // The first local is the return value pointer, followed by `arg_count` // locals for the function arguments, followed by any user-declared // variables and temporaries. - locals: LocalDecls, + pub(super) locals: LocalDecls, // The number of arguments this function takes. - arg_count: usize, + pub(super) arg_count: usize, } impl Body { @@ -35,7 +35,7 @@ impl Body { /// Return local that holds this function's return value. pub fn ret_local(&self) -> &LocalDecl { - &self.locals[0] + &self.locals[RETURN_LOCAL] } /// Locals in `self` that correspond to this function's arguments. @@ -60,7 +60,7 @@ impl Body { type LocalDecls = Vec; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct LocalDecl { pub ty: Ty, pub span: Span, @@ -72,13 +72,13 @@ pub struct BasicBlock { pub terminator: Terminator, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Terminator { pub kind: TerminatorKind, pub span: Span, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum TerminatorKind { Goto { target: usize, @@ -122,7 +122,7 @@ pub enum TerminatorKind { }, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct InlineAsmOperand { pub in_value: Option, pub out_place: Option, @@ -131,7 +131,7 @@ pub struct InlineAsmOperand { pub raw_rpr: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum UnwindAction { Continue, Unreachable, @@ -139,7 +139,7 @@ pub enum UnwindAction { Cleanup(usize), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AssertMessage { BoundsCheck { len: Operand, index: Operand }, Overflow(BinOp, Operand, Operand), @@ -151,7 +151,7 @@ pub enum AssertMessage { MisalignedPointerDereference { required: Operand, found: Operand }, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum BinOp { Add, AddUnchecked, @@ -177,20 +177,20 @@ pub enum BinOp { Offset, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum UnOp { Not, Neg, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum CoroutineKind { Async(CoroutineSource), Coroutine, Gen(CoroutineSource), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum CoroutineSource { Block, Closure, @@ -204,7 +204,7 @@ pub(crate) type LocalDefId = Opaque; pub(crate) type Coverage = Opaque; /// The FakeReadCause describes the type of pattern why a FakeRead statement exists. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum FakeReadCause { ForMatchGuard, ForMatchedPlace(LocalDefId), @@ -214,7 +214,7 @@ pub enum FakeReadCause { } /// Describes what kind of retag is to be performed -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum RetagKind { FnEntry, TwoPhase, @@ -222,7 +222,7 @@ pub enum RetagKind { Default, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub enum Variance { Covariant, Invariant, @@ -230,26 +230,26 @@ pub enum Variance { Bivariant, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct CopyNonOverlapping { pub src: Operand, pub dst: Operand, pub count: Operand, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum NonDivergingIntrinsic { Assume(Operand), CopyNonOverlapping(CopyNonOverlapping), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Statement { pub kind: StatementKind, pub span: Span, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum StatementKind { Assign(Place, Rvalue), FakeRead(FakeReadCause, Place), @@ -266,7 +266,7 @@ pub enum StatementKind { Nop, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Rvalue { /// Creates a pointer with the indicated mutability to the place. /// @@ -378,7 +378,7 @@ pub enum Rvalue { Use(Operand), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AggregateKind { Array(Ty), Tuple, @@ -387,21 +387,21 @@ pub enum AggregateKind { Coroutine(CoroutineDef, GenericArgs, Movability), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum Operand { Copy(Place), Move(Place), Constant(Constant), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Place { pub local: Local, /// projection out of a place (access a field, deref a pointer, etc) pub projection: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct UserTypeProjection { pub base: UserTypeAnnotationIndex, pub projection: String, @@ -409,6 +409,8 @@ pub struct UserTypeProjection { pub type Local = usize; +pub const RETURN_LOCAL: Local = 0; + type FieldIdx = usize; /// The source-order index of a variant in a type. @@ -416,20 +418,20 @@ pub type VariantIdx = usize; type UserTypeAnnotationIndex = usize; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Constant { pub span: Span, pub user_ty: Option, pub literal: Const, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct SwitchTarget { pub value: u128, pub target: usize, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum BorrowKind { /// Data must be immutable and is aliasable. Shared, @@ -446,26 +448,26 @@ pub enum BorrowKind { }, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum MutBorrowKind { Default, TwoPhaseBorrow, ClosureCapture, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Mutability { Not, Mut, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Safety { Unsafe, Normal, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum PointerCoercion { /// Go from a fn-item type to a fn-pointer type. ReifyFnPointer, @@ -492,7 +494,7 @@ pub enum PointerCoercion { Unsize, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum CastKind { PointerExposeAddress, PointerFromExposedAddress, @@ -507,7 +509,7 @@ pub enum CastKind { Transmute, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum NullOp { /// Returns the size of a value of that type. SizeOf, diff --git a/compiler/stable_mir/src/mir/mono.rs b/compiler/stable_mir/src/mir/mono.rs index 997576fc7cbf0..8f533349848e2 100644 --- a/compiler/stable_mir/src/mir/mono.rs +++ b/compiler/stable_mir/src/mir/mono.rs @@ -42,6 +42,10 @@ impl Instance { with(|context| context.instance_ty(self.def)) } + pub fn mangled_name(&self) -> String { + with(|context| context.instance_mangled_name(self.def)) + } + /// Resolve an instance starting from a function definition and generic arguments. pub fn resolve(def: FnDef, args: &GenericArgs) -> Result { with(|context| { diff --git a/compiler/stable_mir/src/mir/visit.rs b/compiler/stable_mir/src/mir/visit.rs new file mode 100644 index 0000000000000..806dced71ff3e --- /dev/null +++ b/compiler/stable_mir/src/mir/visit.rs @@ -0,0 +1,414 @@ +//! # The Stable MIR Visitor +//! +//! ## Overview +//! +//! We currently only support an immutable visitor. +//! The structure of this visitor is similar to the ones internal to `rustc`, +//! and it follows the following conventions: +//! +//! For every mir item, the trait has a `visit_` and a `super_` method. +//! - `visit_`, by default, calls `super_` +//! - `super_`, by default, destructures the `` and calls `visit_` for +//! all sub-items that compose the original item. +//! +//! In order to implement a visitor, override the `visit_*` methods for the types you are +//! interested in analyzing, and invoke (within that method call) +//! `self.super_*` to continue to the traverse. +//! Avoid calling `super` methods in other circumstances. +//! +//! For the most part, we do not destructure things external to the +//! MIR, e.g., types, spans, etc, but simply visit them and stop. +//! This avoids duplication with other visitors like `TypeFoldable`. +//! +//! ## Updating +//! +//! The code is written in a very deliberate style intended to minimize +//! the chance of things being overlooked. +//! +//! Use pattern matching to reference fields and ensure that all +//! matches are exhaustive. +//! +//! For this to work, ALL MATCHES MUST BE EXHAUSTIVE IN FIELDS AND VARIANTS. +//! That means you never write `..` to skip over fields, nor do you write `_` +//! to skip over variants in a `match`. +//! +//! The only place that `_` is acceptable is to match a field (or +//! variant argument) that does not require visiting. + +use crate::mir::*; +use crate::ty::{Const, GenericArgs, Region, Ty}; +use crate::{Opaque, Span}; + +pub trait MirVisitor { + fn visit_body(&mut self, body: &Body) { + self.super_body(body) + } + + fn visit_basic_block(&mut self, bb: &BasicBlock) { + self.super_basic_block(bb) + } + + fn visit_ret_decl(&mut self, local: Local, decl: &LocalDecl) { + self.super_ret_decl(local, decl) + } + + fn visit_arg_decl(&mut self, local: Local, decl: &LocalDecl) { + self.super_arg_decl(local, decl) + } + + fn visit_local_decl(&mut self, local: Local, decl: &LocalDecl) { + self.super_local_decl(local, decl) + } + + fn visit_statement(&mut self, stmt: &Statement, location: Location) { + self.super_statement(stmt, location) + } + + fn visit_terminator(&mut self, term: &Terminator, location: Location) { + self.super_terminator(term, location) + } + + fn visit_span(&mut self, span: &Span) { + self.super_span(span) + } + + fn visit_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { + self.super_place(place, ptx, location) + } + + fn visit_local(&mut self, local: &Local, ptx: PlaceContext, location: Location) { + let _ = (local, ptx, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue, location: Location) { + self.super_rvalue(rvalue, location) + } + + fn visit_operand(&mut self, operand: &Operand, location: Location) { + self.super_operand(operand, location) + } + + fn visit_user_type_projection(&mut self, projection: &UserTypeProjection) { + self.super_user_type_projection(projection) + } + + fn visit_ty(&mut self, ty: &Ty, location: Location) { + let _ = location; + self.super_ty(ty) + } + + fn visit_constant(&mut self, constant: &Constant, location: Location) { + self.super_constant(constant, location) + } + + fn visit_const(&mut self, constant: &Const, location: Location) { + self.super_const(constant, location) + } + + fn visit_region(&mut self, region: &Region, location: Location) { + let _ = location; + self.super_region(region) + } + + fn visit_args(&mut self, args: &GenericArgs, location: Location) { + let _ = location; + self.super_args(args) + } + + fn visit_assert_msg(&mut self, msg: &AssertMessage, location: Location) { + self.super_assert_msg(msg, location) + } + + fn super_body(&mut self, body: &Body) { + let Body { blocks, locals: _, arg_count } = body; + + for bb in blocks { + self.visit_basic_block(bb); + } + + self.visit_ret_decl(RETURN_LOCAL, body.ret_local()); + + for (idx, arg) in body.arg_locals().iter().enumerate() { + self.visit_arg_decl(idx + 1, arg) + } + + let local_start = arg_count + 1; + for (idx, arg) in body.arg_locals().iter().enumerate() { + self.visit_local_decl(idx + local_start, arg) + } + } + + fn super_basic_block(&mut self, bb: &BasicBlock) { + let BasicBlock { statements, terminator } = bb; + for stmt in statements { + self.visit_statement(stmt, Location(stmt.span)); + } + self.visit_terminator(terminator, Location(terminator.span)); + } + + fn super_local_decl(&mut self, local: Local, decl: &LocalDecl) { + let _ = local; + let LocalDecl { ty, span } = decl; + self.visit_ty(ty, Location(*span)); + } + + fn super_ret_decl(&mut self, local: Local, decl: &LocalDecl) { + self.super_local_decl(local, decl) + } + + fn super_arg_decl(&mut self, local: Local, decl: &LocalDecl) { + self.super_local_decl(local, decl) + } + + fn super_statement(&mut self, stmt: &Statement, location: Location) { + let Statement { kind, span } = stmt; + self.visit_span(span); + match kind { + StatementKind::Assign(place, rvalue) => { + self.visit_place(place, PlaceContext::MUTATING, location); + self.visit_rvalue(rvalue, location); + } + StatementKind::FakeRead(_, place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location); + } + StatementKind::SetDiscriminant { place, .. } => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + StatementKind::Deinit(place) => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + StatementKind::StorageLive(local) => { + self.visit_local(local, PlaceContext::NON_USE, location); + } + StatementKind::StorageDead(local) => { + self.visit_local(local, PlaceContext::NON_USE, location); + } + StatementKind::Retag(_, place) => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + StatementKind::PlaceMention(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location); + } + StatementKind::AscribeUserType { place, projections, variance: _ } => { + self.visit_place(place, PlaceContext::NON_USE, location); + self.visit_user_type_projection(projections); + } + StatementKind::Coverage(coverage) => visit_opaque(coverage), + StatementKind::Intrinsic(intrisic) => match intrisic { + NonDivergingIntrinsic::Assume(operand) => { + self.visit_operand(operand, location); + } + NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + src, + dst, + count, + }) => { + self.visit_operand(src, location); + self.visit_operand(dst, location); + self.visit_operand(count, location); + } + }, + StatementKind::ConstEvalCounter => {} + StatementKind::Nop => {} + } + } + + fn super_terminator(&mut self, term: &Terminator, location: Location) { + let Terminator { kind, span } = term; + self.visit_span(&span); + match kind { + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Unreachable + | TerminatorKind::CoroutineDrop => {} + TerminatorKind::Assert { cond, expected: _, msg, target: _, unwind: _ } => { + self.visit_operand(cond, location); + self.visit_assert_msg(msg, location); + } + TerminatorKind::Drop { place, target: _, unwind: _ } => { + self.visit_place(place, PlaceContext::MUTATING, location); + } + TerminatorKind::Call { func, args, destination, target: _, unwind: _ } => { + self.visit_operand(func, location); + for arg in args { + self.visit_operand(arg, location); + } + self.visit_place(destination, PlaceContext::MUTATING, location); + } + TerminatorKind::InlineAsm { operands, .. } => { + for op in operands { + let InlineAsmOperand { in_value, out_place, raw_rpr: _ } = op; + if let Some(input) = in_value { + self.visit_operand(input, location); + } + if let Some(output) = out_place { + self.visit_place(output, PlaceContext::MUTATING, location); + } + } + } + TerminatorKind::Return => { + let local = RETURN_LOCAL; + self.visit_local(&local, PlaceContext::NON_MUTATING, location); + } + TerminatorKind::SwitchInt { discr, targets: _, otherwise: _ } => { + self.visit_operand(discr, location); + } + } + } + + fn super_span(&mut self, span: &Span) { + let _ = span; + } + + fn super_place(&mut self, place: &Place, ptx: PlaceContext, location: Location) { + let _ = location; + let _ = ptx; + visit_opaque(&Opaque(place.projection.clone())); + } + + fn super_rvalue(&mut self, rvalue: &Rvalue, location: Location) { + match rvalue { + Rvalue::AddressOf(mutability, place) => { + let pcx = PlaceContext { is_mut: *mutability == Mutability::Mut }; + self.visit_place(place, pcx, location); + } + Rvalue::Aggregate(_, operands) => { + for op in operands { + self.visit_operand(op, location); + } + } + Rvalue::BinaryOp(_, lhs, rhs) | Rvalue::CheckedBinaryOp(_, lhs, rhs) => { + self.visit_operand(lhs, location); + self.visit_operand(rhs, location); + } + Rvalue::Cast(_, op, ty) => { + self.visit_operand(op, location); + self.visit_ty(ty, location); + } + Rvalue::CopyForDeref(place) | Rvalue::Discriminant(place) | Rvalue::Len(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location); + } + Rvalue::Ref(region, kind, place) => { + self.visit_region(region, location); + let pcx = PlaceContext { is_mut: matches!(kind, BorrowKind::Mut { .. }) }; + self.visit_place(place, pcx, location); + } + Rvalue::Repeat(op, constant) => { + self.visit_operand(op, location); + self.visit_const(constant, location); + } + Rvalue::ShallowInitBox(op, ty) => { + self.visit_ty(ty, location); + self.visit_operand(op, location) + } + Rvalue::ThreadLocalRef(_) => {} + Rvalue::NullaryOp(_, ty) => { + self.visit_ty(ty, location); + } + Rvalue::UnaryOp(_, op) | Rvalue::Use(op) => { + self.visit_operand(op, location); + } + } + } + + fn super_operand(&mut self, operand: &Operand, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + self.visit_place(place, PlaceContext::NON_MUTATING, location) + } + Operand::Constant(constant) => { + self.visit_constant(constant, location); + } + } + } + + fn super_user_type_projection(&mut self, projection: &UserTypeProjection) { + // This is a no-op on mir::Visitor. + let _ = projection; + } + + fn super_ty(&mut self, ty: &Ty) { + let _ = ty; + } + + fn super_constant(&mut self, constant: &Constant, location: Location) { + let Constant { span, user_ty: _, literal } = constant; + self.visit_span(span); + self.visit_const(literal, location); + } + + fn super_const(&mut self, constant: &Const, location: Location) { + let Const { kind: _, ty, id: _ } = constant; + self.visit_ty(ty, location); + } + + fn super_region(&mut self, region: &Region) { + let _ = region; + } + + fn super_args(&mut self, args: &GenericArgs) { + let _ = args; + } + + fn super_assert_msg(&mut self, msg: &AssertMessage, location: Location) { + match msg { + AssertMessage::BoundsCheck { len, index } => { + self.visit_operand(len, location); + self.visit_operand(index, location); + } + AssertMessage::Overflow(_, left, right) => { + self.visit_operand(left, location); + self.visit_operand(right, location); + } + AssertMessage::OverflowNeg(op) + | AssertMessage::DivisionByZero(op) + | AssertMessage::RemainderByZero(op) => { + self.visit_operand(op, location); + } + AssertMessage::ResumedAfterReturn(_) | AssertMessage::ResumedAfterPanic(_) => { //nothing to visit + } + AssertMessage::MisalignedPointerDereference { required, found } => { + self.visit_operand(required, location); + self.visit_operand(found, location); + } + } + } +} + +/// This function is a no-op that gets used to ensure this visitor is kept up-to-date. +/// +/// The idea is that whenever we replace an Opaque type by a real type, the compiler will fail +/// when trying to invoke `visit_opaque`. +/// +/// If you are here because your compilation is broken, replace the failing call to `visit_opaque()` +/// by a `visit_` for your construct. +fn visit_opaque(_: &Opaque) {} + +/// The location of a statement / terminator in the code and the CFG. +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct Location(Span); + +impl Location { + pub fn span(&self) -> Span { + self.0 + } +} + +/// Information about a place's usage. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct PlaceContext { + /// Whether the access is mutable or not. Keep this private so we can increment the type in a + /// backward compatible manner. + is_mut: bool, +} + +impl PlaceContext { + const MUTATING: Self = PlaceContext { is_mut: true }; + const NON_MUTATING: Self = PlaceContext { is_mut: false }; + const NON_USE: Self = PlaceContext { is_mut: false }; + + pub fn is_mutating(&self) -> bool { + self.is_mut + } +} diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs index e7440cc439b16..5dfaa0fd89158 100644 --- a/compiler/stable_mir/src/ty.rs +++ b/compiler/stable_mir/src/ty.rs @@ -22,12 +22,12 @@ impl Ty { } /// Represents a constant in MIR or from the Type system. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Const { /// The constant kind. - kind: ConstantKind, + pub(crate) kind: ConstantKind, /// The constant type. - ty: Ty, + pub(crate) ty: Ty, /// Used for internal tracking of the internal constant. pub id: ConstId, } @@ -54,12 +54,12 @@ pub struct ConstId(pub usize); type Ident = Opaque; -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Region { pub kind: RegionKind, } -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum RegionKind { ReEarlyBound(EarlyBoundRegion), ReLateBound(DebruijnIndex, BoundRegion), @@ -70,7 +70,7 @@ pub enum RegionKind { pub(crate) type DebruijnIndex = u32; -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct EarlyBoundRegion { pub def_id: RegionDef, pub index: u32, @@ -79,7 +79,7 @@ pub struct EarlyBoundRegion { pub(crate) type BoundVar = u32; -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct BoundRegion { pub var: BoundVar, pub kind: BoundRegionKind, @@ -87,7 +87,7 @@ pub struct BoundRegion { pub(crate) type UniverseIndex = u32; -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Placeholder { pub universe: UniverseIndex, pub bound: T, @@ -127,7 +127,7 @@ pub struct LineInfo { pub end_col: usize, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum TyKind { RigidTy(RigidTy), Alias(AliasKind, AliasTy), @@ -135,7 +135,7 @@ pub enum TyKind { Bound(usize, BoundTy), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum RigidTy { Bool, Char, @@ -236,7 +236,7 @@ pub struct ImplDef(pub DefId); pub struct RegionDef(pub DefId); /// A list of generic arguments. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct GenericArgs(pub Vec); impl std::ops::Index for GenericArgs { @@ -255,7 +255,7 @@ impl std::ops::Index for GenericArgs { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum GenericArgKind { Lifetime(Region), Type(Ty), @@ -284,13 +284,13 @@ impl GenericArgKind { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum TermKind { Type(Ty), Const(Const), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AliasKind { Projection, Inherent, @@ -298,7 +298,7 @@ pub enum AliasKind { Weak, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct AliasTy { pub def_id: AliasDef, pub args: GenericArgs, @@ -306,7 +306,7 @@ pub struct AliasTy { pub type PolyFnSig = Binder; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct FnSig { pub inputs_and_output: Vec, pub c_variadic: bool, @@ -345,18 +345,18 @@ pub enum Abi { RiscvInterruptS, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Binder { pub value: T, pub bound_vars: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct EarlyBinder { pub value: T, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum BoundVariableKind { Ty(BoundTyKind), Region(BoundRegionKind), @@ -369,46 +369,46 @@ pub enum BoundTyKind { Param(ParamDef, String), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum BoundRegionKind { BrAnon, BrNamed(BrNamedDef, String), BrEnv, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum DynKind { Dyn, DynStar, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum ExistentialPredicate { Trait(ExistentialTraitRef), Projection(ExistentialProjection), AutoTrait(TraitDef), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ExistentialTraitRef { pub def_id: TraitDef, pub generic_args: GenericArgs, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ExistentialProjection { pub def_id: TraitDef, pub generic_args: GenericArgs, pub term: TermKind, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ParamTy { pub index: u32, pub name: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct BoundTy { pub var: usize, pub kind: BoundTyKind, @@ -424,14 +424,14 @@ pub type Promoted = u32; pub type InitMaskMaterialized = Vec; /// Stores the provenance information of pointers stored in memory. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ProvenanceMap { /// Provenance in this map applies from the given offset for an entire pointer-size worth of /// bytes. Two entries in this map are always at least a pointer size apart. pub ptrs: Vec<(Size, Prov)>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Allocation { pub bytes: Bytes, pub provenance: ProvenanceMap, @@ -439,7 +439,7 @@ pub struct Allocation { pub mutability: Mutability, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum ConstantKind { Allocated(Allocation), Unevaluated(UnevaluatedConst), @@ -449,13 +449,13 @@ pub enum ConstantKind { ZeroSized, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ParamConst { pub index: u32, pub name: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct UnevaluatedConst { pub def: ConstDef, pub args: GenericArgs, @@ -469,7 +469,7 @@ pub enum TraitSpecializationKind { AlwaysApplicable, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct TraitDecl { pub def_id: TraitDef, pub unsafety: Safety, @@ -500,13 +500,13 @@ impl TraitDecl { pub type ImplTrait = EarlyBinder; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct TraitRef { pub def_id: TraitDef, pub args: GenericArgs, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Generics { pub parent: Option, pub parent_count: usize, @@ -517,14 +517,14 @@ pub struct Generics { pub host_effect_index: Option, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum GenericParamDefKind { Lifetime, Type { has_default: bool, synthetic: bool }, Const { has_default: bool }, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct GenericParamDef { pub name: super::Symbol, pub def_id: GenericDef, @@ -538,7 +538,7 @@ pub struct GenericPredicates { pub predicates: Vec<(PredicateKind, Span)>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum PredicateKind { Clause(ClauseKind), ObjectSafe(TraitDef), @@ -550,7 +550,7 @@ pub enum PredicateKind { AliasRelate(TermKind, TermKind, AliasRelationDirection), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum ClauseKind { Trait(TraitPredicate), RegionOutlives(RegionOutlivesPredicate), @@ -561,50 +561,50 @@ pub enum ClauseKind { ConstEvaluatable(Const), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum ClosureKind { Fn, FnMut, FnOnce, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct SubtypePredicate { pub a: Ty, pub b: Ty, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct CoercePredicate { pub a: Ty, pub b: Ty, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum AliasRelationDirection { Equate, Subtype, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct TraitPredicate { pub trait_ref: TraitRef, pub polarity: ImplPolarity, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct OutlivesPredicate(pub A, pub B); pub type RegionOutlivesPredicate = OutlivesPredicate; pub type TypeOutlivesPredicate = OutlivesPredicate; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct ProjectionPredicate { pub projection_ty: AliasTy, pub term: TermKind, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum ImplPolarity { Positive, Negative, diff --git a/tests/ui-fulldeps/stable-mir/smir_visitor.rs b/tests/ui-fulldeps/stable-mir/smir_visitor.rs new file mode 100644 index 0000000000000..de5148bb5f420 --- /dev/null +++ b/tests/ui-fulldeps/stable-mir/smir_visitor.rs @@ -0,0 +1,148 @@ +// run-pass +//! Sanity check Stable MIR Visitor + +// ignore-stage1 +// ignore-cross-compile +// ignore-remote +// ignore-windows-gnu mingw has troubles with linking https://github.com/rust-lang/rust/pull/116837 +// edition: 2021 + +#![feature(rustc_private)] +#![feature(assert_matches)] +#![feature(control_flow_enum)] + +extern crate rustc_middle; +#[macro_use] +extern crate rustc_smir; +extern crate rustc_driver; +extern crate rustc_interface; +extern crate stable_mir; + +use std::collections::HashSet; +use rustc_middle::ty::TyCtxt; +use rustc_smir::rustc_internal; +use stable_mir::*; +use stable_mir::mir::MirVisitor; +use std::io::Write; +use std::ops::ControlFlow; + +const CRATE_NAME: &str = "input"; + +fn test_visitor(_tcx: TyCtxt<'_>) -> ControlFlow<()> { + let main_fn = stable_mir::entry_fn(); + let main_body = main_fn.unwrap().body(); + let main_visitor = TestVisitor::collect(&main_body); + assert!(main_visitor.ret_val.is_some()); + assert!(main_visitor.args.is_empty()); + assert!(main_visitor.tys.contains(&main_visitor.ret_val.unwrap().ty)); + assert!(!main_visitor.calls.is_empty()); + + let exit_fn = main_visitor.calls.last().unwrap(); + assert!(exit_fn.mangled_name().contains("exit_fn"), "Unexpected last function: {exit_fn:?}"); + + let exit_body = exit_fn.body(); + let exit_visitor = TestVisitor::collect(&exit_body); + assert!(exit_visitor.ret_val.is_some()); + assert_eq!(exit_visitor.args.len(), 1); + assert!(exit_visitor.tys.contains(&exit_visitor.ret_val.unwrap().ty)); + assert!(exit_visitor.tys.contains(&exit_visitor.args[0].ty)); + ControlFlow::Continue(()) +} + +struct TestVisitor<'a> { + pub body: &'a mir::Body, + pub tys: HashSet, + pub ret_val: Option, + pub args: Vec, + pub calls: Vec +} + +impl<'a> TestVisitor<'a> { + fn collect(body: &'a mir::Body) -> TestVisitor<'a> { + let mut visitor = TestVisitor { + body: &body, + tys: Default::default(), + ret_val: None, + args: vec![], + calls: vec![], + }; + visitor.visit_body(&body); + visitor + } +} + +impl<'a> mir::MirVisitor for TestVisitor<'a> { + fn visit_ty(&mut self, ty: &ty::Ty, _location: mir::visit::Location) { + self.tys.insert(*ty); + self.super_ty(ty) + } + + fn visit_ret_decl(&mut self, local: mir::Local, decl: &mir::LocalDecl) { + assert!(local == mir::RETURN_LOCAL); + assert!(self.ret_val.is_none()); + self.ret_val = Some(decl.clone()); + self.super_ret_decl(local, decl); + } + + fn visit_arg_decl(&mut self, local: mir::Local, decl: &mir::LocalDecl) { + self.args.push(decl.clone()); + assert_eq!(local, self.args.len()); + self.super_arg_decl(local, decl); + } + + fn visit_terminator(&mut self, term: &mir::Terminator, location: mir::visit::Location) { + if let mir::TerminatorKind::Call { func, .. } = &term.kind { + let ty::TyKind::RigidTy(ty) = func.ty(self.body.locals()).kind() else { unreachable! + () }; + let ty::RigidTy::FnDef(def, args) = ty else { unreachable!() }; + self.calls.push(mir::mono::Instance::resolve(def, &args).unwrap()); + } + self.super_terminator(term, location); + } +} + +/// This test will generate and analyze a dummy crate using the stable mir. +/// For that, it will first write the dummy crate into a file. +/// Then it will create a `StableMir` using custom arguments and then +/// it will run the compiler. +fn main() { + let path = "sim_visitor_input.rs"; + generate_input(&path).unwrap(); + let args = vec![ + "rustc".to_string(), + "-Cpanic=abort".to_string(), + "--crate-name".to_string(), + CRATE_NAME.to_string(), + path.to_string(), + ]; + run!(args, tcx, test_visitor(tcx)).unwrap(); +} + +fn generate_input(path: &str) -> std::io::Result<()> { + let mut file = std::fs::File::create(path)?; + write!( + file, + r#" + fn main() -> std::process::ExitCode {{ + let inputs = Inputs::new(); + let total = inputs.values.iter().sum(); + exit_fn(total) + }} + + fn exit_fn(code: u8) -> std::process::ExitCode {{ + std::process::ExitCode::from(code) + }} + + struct Inputs {{ + values: [u8; 3], + }} + + impl Inputs {{ + fn new() -> Inputs {{ + Inputs {{ values: [0, 1, 2] }} + }} + }} + "# + )?; + Ok(()) +}