Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ByMove coroutine-closure shim (for 2021 precise closure capturing behavior) #123518

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 56 additions & 4 deletions compiler/rustc_mir_transform/src/coroutine/by_move_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@
//! borrowing from the outer closure, and we simply peel off a `deref` projection
//! from them. This second body is stored alongside the first body, and optimized
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
//! we use this "by move" body instead.
//! we use this "by-move" body instead.
//!
//! ## How does this work?
//!
//! This pass essentially remaps the body of the (child) closure of the coroutine-closure
//! to take the set of upvars of the parent closure by value. This at least requires
//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure
//! captures something by value; however, it may also require renumbering field indices
//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine
//! to split one field capture into two.

use rustc_data_structures::unord::UnordMap;
use rustc_hir as hir;
Expand Down Expand Up @@ -117,8 +126,15 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {

let mut field_remapping = UnordMap::default();

// One parent capture may correspond to several child captures if we end up
// refining the set of captures via edition-2021 precise captures. We want to
// match up any number of child captures with one parent capture, so we keep
// peeking off this `Peekable` until the child doesn't match anymore.
let mut parent_captures =
tcx.closure_captures(parent_def_id).iter().copied().enumerate().peekable();
// Make sure we use every field at least once, b/c why are we capturing something
// if it's not used in the inner coroutine.
let mut field_used_at_least_once = false;

for (child_field_idx, child_capture) in tcx
.closure_captures(coroutine_def_id)
Expand All @@ -133,20 +149,36 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
bug!("we ran out of parent captures!")
};

assert!(
child_capture.place.projections.len() >= parent_capture.place.projections.len()
);
// A parent matches a child they share the same prefix of projections.
// The child may have more, if it is capturing sub-fields out of
// something that is captured by-move in the parent closure.
if !std::iter::zip(
&child_capture.place.projections,
&parent_capture.place.projections,
)
.all(|(child, parent)| child.kind == parent.kind)
{
// Make sure the field was used at least once.
assert!(
field_used_at_least_once,
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
);
field_used_at_least_once = false;
// Skip this field.
let _ = parent_captures.next().unwrap();
continue;
}

// Store this set of additional projections (fields and derefs).
// We need to re-apply them later.
let child_precise_captures =
&child_capture.place.projections[parent_capture.place.projections.len()..];

// If the parent captures by-move, and the child captures by-ref, then we
// need to peel an additional `deref` off of the body of the child.
let needs_deref = child_capture.is_by_ref() && !parent_capture.is_by_ref();
if needs_deref {
assert_ne!(
Expand All @@ -157,6 +189,8 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
);
}

// Finally, store the type of the parent's captured place. We need
// this when building the field projection in the MIR body later on.
let mut parent_capture_ty = parent_capture.place.ty();
parent_capture_ty = match parent_capture.info.capture_kind {
ty::UpvarCapture::ByValue => parent_capture_ty,
Expand All @@ -178,6 +212,7 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
),
);

field_used_at_least_once = true;
break;
}
}
Expand Down Expand Up @@ -226,21 +261,34 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
context: mir::visit::PlaceContext,
location: mir::Location,
) {
// Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a
// field projection. If this is in `field_remapping`, then it must not be an
// arg from calling the closure, but instead an upvar.
if place.local == ty::CAPTURE_STRUCT_LOCAL
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
place.projection.split_first()
&& let Some(&(remapped_idx, remapped_ty, needs_deref, additional_projections)) =
self.field_remapping.get(&idx)
{
// As noted before, if the parent closure captures a field by value, and
// the child captures a field by ref, then for the by-move body we're
// generating, we also are taking that field by value. Peel off a deref,
// since a layer of reffing has now become redundant.
let final_deref = if needs_deref {
let Some((mir::ProjectionElem::Deref, rest)) = projection.split_first() else {
bug!();
let [mir::ProjectionElem::Deref] = projection else {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this assumption was way too strong. Fixed in a follow-up.

bug!("There should only be a single deref for an upvar local initialization");
};
rest
&[]
} else {
projection
};

// The only thing that should be left is a deref, if the parent captured
// an upvar by-ref.
std::assert_matches::assert_matches!(final_deref, [] | [mir::ProjectionElem::Deref]);

// For all of the additional projections that come out of precise capturing,
// re-apply these projections.
let additional_projections =
additional_projections.iter().map(|elem| match elem.kind {
ProjectionKind::Deref => mir::ProjectionElem::Deref,
Expand All @@ -250,6 +298,10 @@ impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
_ => unreachable!("precise captures only through fields and derefs"),
});

// We start out with an adjusted field index (and ty), representing the
// upvar that we get from our parent closure. We apply any of the additional
// projections to make sure that to the rest of the body of the closure, the
// place looks the same, and then apply that final deref if necessary.
*place = mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems_from_iter(
Expand Down