Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 96 additions & 66 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_hir::{CoroutineDesugaring, CoroutineKind};
use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::util::Discr;
use rustc_middle::ty::{
Expand Down Expand Up @@ -110,6 +110,8 @@ impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
if *local == self.from {
*local = self.to;
} else if *local == self.to {
*local = self.from;
}
}

Expand Down Expand Up @@ -159,13 +161,15 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
}
}

#[tracing::instrument(level = "trace", skip(tcx))]
fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
place.local = new_base.local;

let mut new_projection = new_base.projection.to_vec();
new_projection.append(&mut place.projection.to_vec());

place.projection = tcx.mk_place_elems(&new_projection);
tracing::trace!(?place);
}

const SELF_ARG: Local = Local::from_u32(1);
Expand Down Expand Up @@ -204,8 +208,8 @@ struct TransformVisitor<'tcx> {
// The set of locals that have no `StorageLive`/`StorageDead` annotations.
always_live_locals: DenseBitSet<Local>,

// The original RETURN_PLACE local
old_ret_local: Local,
// New local we just create to hold the `CoroutineState` value.
new_ret_local: Local,

old_yield_ty: Ty<'tcx>,

Expand Down Expand Up @@ -270,6 +274,7 @@ impl<'tcx> TransformVisitor<'tcx> {
// `core::ops::CoroutineState` only has single element tuple variants,
// so we can just write to the downcasted first field and then set the
// discriminant to the appropriate variant.
#[tracing::instrument(level = "trace", skip(self, statements))]
fn make_state(
&self,
val: Operand<'tcx>,
Expand Down Expand Up @@ -341,13 +346,15 @@ impl<'tcx> TransformVisitor<'tcx> {
}
};

// Assign to `new_ret_local`, which will be replaced by `RETURN_PLACE` later.
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
));
}

// Create a Place referencing a coroutine struct field
#[tracing::instrument(level = "trace", skip(self), ret)]
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
let self_place = Place::from(SELF_ARG);
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
Expand All @@ -358,6 +365,7 @@ impl<'tcx> TransformVisitor<'tcx> {
}

// Create a statement which changes the discriminant
#[tracing::instrument(level = "trace", skip(self))]
fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
let self_place = Place::from(SELF_ARG);
Statement::new(
Expand All @@ -370,6 +378,7 @@ impl<'tcx> TransformVisitor<'tcx> {
}

// Create a statement which reads the discriminant into a temporary
#[tracing::instrument(level = "trace", skip(self, body))]
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
let temp_decl = LocalDecl::new(self.discr_ty, body.span);
let local_decls_len = body.local_decls.push(temp_decl);
Expand All @@ -382,55 +391,83 @@ impl<'tcx> TransformVisitor<'tcx> {
);
(assign, temp)
}

/// Swaps all references of `old_local` and `new_local`.
#[tracing::instrument(level = "trace", skip(self, body))]
fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
body.local_decls.swap(old_local, new_local);

let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
visitor.visit_body(body);
for suspension in &mut self.suspension_points {
let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
let location = Location { block: START_BLOCK, statement_index: 0 };
visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
}
}
}

impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
#[tracing::instrument(level = "trace", skip(self), ret)]
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
assert!(!self.remap.contains(*local));
}

fn visit_place(
&mut self,
place: &mut Place<'tcx>,
_context: PlaceContext,
_location: Location,
) {
#[tracing::instrument(level = "trace", skip(self), ret)]
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
// Replace an Local in the remap with a coroutine struct access
if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
}
}

fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
#[tracing::instrument(level = "trace", skip(self, stmt), ret)]
fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
// Remove StorageLive and StorageDead statements for remapped locals
for s in &mut data.statements {
if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = s.kind
&& self.remap.contains(l)
{
s.make_nop(true);
}
if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
&& self.remap.contains(l)
{
stmt.make_nop(true);
}
self.super_statement(stmt, location);
}

let ret_val = match data.terminator().kind {
#[tracing::instrument(level = "trace", skip(self, term), ret)]
fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
if let TerminatorKind::Return = term.kind {
// `visit_basic_block_data` introduces `Return` terminators which read `RETURN_PLACE`.
// But this `RETURN_PLACE` is already remapped, so we should not touch it again.
return;
}
self.super_terminator(term, location);
}

#[tracing::instrument(level = "trace", skip(self, data), ret)]
fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
match data.terminator().kind {
TerminatorKind::Return => {
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
}
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
Some((false, Some((resume, resume_arg)), value.clone(), drop))
let source_info = data.terminator().source_info;
// We must assign the value first in case it gets declared dead below
self.make_state(
Operand::Move(Place::return_place()),
source_info,
true,
&mut data.statements,
);
// Return state.
let state = VariantIdx::new(CoroutineArgs::RETURNED);
data.statements.push(self.set_discr(state, source_info));
data.terminator_mut().kind = TerminatorKind::Return;
}
_ => None,
};

if let Some((is_return, resume, v, drop)) = ret_val {
let source_info = data.terminator().source_info;
// We must assign the value first in case it gets declared dead below
self.make_state(v, source_info, is_return, &mut data.statements);
let state = if let Some((resume, mut resume_arg)) = resume {
// Yield
TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
let source_info = data.terminator().source_info;
// We must assign the value first in case it gets declared dead below
self.make_state(value.clone(), source_info, false, &mut data.statements);
// Yield state.
let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();

// The resume arg target location might itself be remapped if its base local is
Expand Down Expand Up @@ -461,13 +498,11 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
storage_liveness,
});

VariantIdx::new(state)
} else {
// Return
VariantIdx::new(CoroutineArgs::RETURNED) // state for returned
};
data.statements.push(self.set_discr(state, source_info));
data.terminator_mut().kind = TerminatorKind::Return;
let state = VariantIdx::new(state);
data.statements.push(self.set_discr(state, source_info));
data.terminator_mut().kind = TerminatorKind::Return;
}
_ => {}
}

self.super_basic_block_data(block, data);
Expand All @@ -483,6 +518,7 @@ fn make_aggregate_adt<'tcx>(
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
}

#[tracing::instrument(level = "trace", skip(tcx, body))]
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let coroutine_ty = body.local_decls.raw[1].ty;

Expand All @@ -495,6 +531,7 @@ fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Bo
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
}

#[tracing::instrument(level = "trace", skip(tcx, body))]
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let ref_coroutine_ty = body.local_decls.raw[1].ty;

Expand All @@ -511,27 +548,6 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
.visit_body(body);
}

/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
///
/// `local` will be changed to a new local decl with type `ty`.
///
/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
/// valid value to it before its first use.
fn replace_local<'tcx>(
local: Local,
ty: Ty<'tcx>,
body: &mut Body<'tcx>,
tcx: TyCtxt<'tcx>,
) -> Local {
let new_decl = LocalDecl::new(ty, body.span);
let new_local = body.local_decls.push(new_decl);
body.local_decls.swap(local, new_local);

RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);

new_local
}

/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
Expand All @@ -553,6 +569,7 @@ fn replace_local<'tcx>(
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
let context_mut_ref = Ty::new_task_context(tcx);

Expand Down Expand Up @@ -606,6 +623,7 @@ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local
}

#[cfg_attr(not(debug_assertions), allow(unused))]
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
Expand Down Expand Up @@ -670,6 +688,7 @@ struct LivenessInfo {
/// case none exist, the local is considered to be always live.
/// - a local has to be stored if it is either directly used after the
/// the suspend point, or if it is live and has been previously borrowed.
#[tracing::instrument(level = "trace", skip(tcx, body))]
fn locals_live_across_suspend_points<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
Expand Down Expand Up @@ -945,6 +964,7 @@ impl StorageConflictVisitor<'_, '_> {
}
}

#[tracing::instrument(level = "trace", skip(liveness, body))]
fn compute_layout<'tcx>(
liveness: LivenessInfo,
body: &Body<'tcx>,
Expand Down Expand Up @@ -1049,7 +1069,9 @@ fn compute_layout<'tcx>(
variant_source_info,
storage_conflicts,
};
debug!(?remap);
debug!(?layout);
debug!(?storage_liveness);

(remap, layout, storage_liveness)
}
Expand Down Expand Up @@ -1221,6 +1243,7 @@ fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
}
}

#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
fn create_coroutine_resume_function<'tcx>(
tcx: TyCtxt<'tcx>,
transform: TransformVisitor<'tcx>,
Expand Down Expand Up @@ -1299,7 +1322,7 @@ fn create_coroutine_resume_function<'tcx>(
}

/// An operation that can be performed on a coroutine.
#[derive(PartialEq, Copy, Clone)]
#[derive(PartialEq, Copy, Clone, Debug)]
enum Operation {
Resume,
Drop,
Expand All @@ -1314,6 +1337,7 @@ impl Operation {
}
}

#[tracing::instrument(level = "trace", skip(transform, body))]
fn create_cases<'tcx>(
body: &mut Body<'tcx>,
transform: &TransformVisitor<'tcx>,
Expand Down Expand Up @@ -1445,6 +1469,8 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// This only applies to coroutines
return;
};
tracing::trace!(def_id = ?body.source.def_id());

let old_ret_ty = body.return_ty();

assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
Expand Down Expand Up @@ -1491,10 +1517,6 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
}
};

// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);

// We need to insert clean drop for unresumed state and perform drop elaboration
// (finally in open_drop_for_tuple) before async drop expansion.
// Async drops, produced by this drop elaboration, will be expanded,
Expand Down Expand Up @@ -1541,6 +1563,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {

let can_return = can_return(tcx, body, body.typing_env(tcx));

// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
tracing::trace!(?new_ret_local);

// Run the transformation which converts Places from Local to coroutine struct
// accesses for locals in `remap`.
// It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
Expand All @@ -1553,13 +1580,16 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
storage_liveness,
always_live_locals,
suspension_points: Vec::new(),
old_ret_local,
discr_ty,
new_ret_local,
old_ret_ty,
old_yield_ty,
};
transform.visit_body(body);

// Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
transform.replace_local(RETURN_PLACE, new_ret_local, body);

// MIR parameters are not explicitly assigned-to when entering the MIR body.
// If we want to save their values inside the coroutine state, we need to do so explicitly.
let source_info = SourceInfo::outermost(body.span);
Expand Down
Loading
Loading