Skip to content

Commit eeb4682

Browse files
Auto merge of #149114 - BoxyUwU:mgca_adt_exprs, r=<try>
MGCA: Support struct expressions without intermediary anon consts
2 parents d2f8873 + 734737a commit eeb4682

File tree

116 files changed

+1278
-435
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+1278
-435
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,11 @@ impl Path {
141141
/// Check if this path is potentially a trivial const arg, i.e., one that can _potentially_
142142
/// be represented without an anon const in the HIR.
143143
///
144-
/// If `allow_mgca_arg` is true (as should be the case in most situations when
145-
/// `#![feature(min_generic_const_args)]` is enabled), then this always returns true
146-
/// because all paths are valid.
147-
///
148-
/// Otherwise, it returns true iff the path has exactly one segment, and it has no generic args
144+
/// Returns true iff the path has exactly one segment, and it has no generic args
149145
/// (i.e., it is _potentially_ a const parameter).
150146
#[tracing::instrument(level = "debug", ret)]
151-
pub fn is_potential_trivial_const_arg(&self, allow_mgca_arg: bool) -> bool {
152-
allow_mgca_arg
153-
|| self.segments.len() == 1 && self.segments.iter().all(|seg| seg.args.is_none())
147+
pub fn is_potential_trivial_const_arg(&self) -> bool {
148+
self.segments.len() == 1 && self.segments.iter().all(|seg| seg.args.is_none())
154149
}
155150
}
156151

@@ -1372,6 +1367,15 @@ pub enum UnsafeSource {
13721367
UserProvided,
13731368
}
13741369

1370+
/// Track whether under `feature(min_generic_const_args)` this anon const
1371+
/// was explicitly disambiguated as an anon const or not through the use of
1372+
/// `const { ... }` syntax.
1373+
#[derive(Clone, PartialEq, Encodable, Decodable, Debug, Copy, Walkable)]
1374+
pub enum MgcaDisambiguation {
1375+
AnonConst,
1376+
Direct,
1377+
}
1378+
13751379
/// A constant (expression) that's not an item or associated item,
13761380
/// but needs its own `DefId` for type-checking, const-eval, etc.
13771381
/// These are usually found nested inside types (e.g., array lengths)
@@ -1381,6 +1385,7 @@ pub enum UnsafeSource {
13811385
pub struct AnonConst {
13821386
pub id: NodeId,
13831387
pub value: Box<Expr>,
1388+
pub mgca_disambiguation: MgcaDisambiguation,
13841389
}
13851390

13861391
/// An expression.
@@ -1399,26 +1404,20 @@ impl Expr {
13991404
///
14001405
/// This will unwrap at most one block level (curly braces). After that, if the expression
14011406
/// is a path, it mostly dispatches to [`Path::is_potential_trivial_const_arg`].
1402-
/// See there for more info about `allow_mgca_arg`.
14031407
///
1404-
/// The only additional thing to note is that when `allow_mgca_arg` is false, this function
1405-
/// will only allow paths with no qself, before dispatching to the `Path` function of
1406-
/// the same name.
1408+
/// This function will only allow paths with no qself, before dispatching to the `Path`
1409+
/// function of the same name.
14071410
///
14081411
/// Does not ensure that the path resolves to a const param/item, the caller should check this.
14091412
/// This also does not consider macros, so it's only correct after macro-expansion.
1410-
pub fn is_potential_trivial_const_arg(&self, allow_mgca_arg: bool) -> bool {
1413+
pub fn is_potential_trivial_const_arg(&self) -> bool {
14111414
let this = self.maybe_unwrap_block();
1412-
if allow_mgca_arg {
1413-
matches!(this.kind, ExprKind::Path(..))
1415+
if let ExprKind::Path(None, path) = &this.kind
1416+
&& path.is_potential_trivial_const_arg()
1417+
{
1418+
true
14141419
} else {
1415-
if let ExprKind::Path(None, path) = &this.kind
1416-
&& path.is_potential_trivial_const_arg(allow_mgca_arg)
1417-
{
1418-
true
1419-
} else {
1420-
false
1421-
}
1420+
false
14221421
}
14231422
}
14241423

compiler/rustc_ast/src/visit.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ macro_rules! common_visitor_and_walkers {
415415
UnsafeBinderCastKind,
416416
BinOpKind,
417417
BlockCheckMode,
418+
MgcaDisambiguation,
418419
BorrowKind,
419420
BoundAsyncness,
420421
BoundConstness,

compiler/rustc_ast_lowering/src/expr.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
484484
arg
485485
};
486486

487-
let anon_const = AnonConst { id: node_id, value: const_value };
487+
let anon_const = AnonConst {
488+
id: node_id,
489+
value: const_value,
490+
mgca_disambiguation: MgcaDisambiguation::AnonConst,
491+
};
488492
generic_args.push(AngleBracketedArg::Arg(GenericArg::Const(anon_const)));
489493
} else {
490494
real_args.push(arg);

compiler/rustc_ast_lowering/src/index.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
281281
});
282282
}
283283

284+
fn visit_const_arg_expr_field(&mut self, field: &'hir ConstArgExprField<'hir>) {
285+
self.insert(field.span, field.hir_id, Node::ConstArgExprField(field));
286+
self.with_parent(field.hir_id, |this| {
287+
intravisit::walk_const_arg_expr_field(this, field);
288+
})
289+
}
290+
284291
fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) {
285292
self.insert(stmt.span, stmt.hir_id, Node::Stmt(stmt));
286293

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
12061206
.and_then(|partial_res| partial_res.full_res())
12071207
{
12081208
if !res.matches_ns(Namespace::TypeNS)
1209-
&& path.is_potential_trivial_const_arg(false)
1209+
&& path.is_potential_trivial_const_arg()
12101210
{
12111211
debug!(
12121212
"lower_generic_arg: Lowering type argument as const argument: {:?}",
@@ -2276,11 +2276,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
22762276
) -> &'hir hir::ConstArg<'hir> {
22772277
let tcx = self.tcx;
22782278

2279-
let ct_kind = if path
2280-
.is_potential_trivial_const_arg(tcx.features().min_generic_const_args())
2281-
&& (tcx.features().min_generic_const_args()
2282-
|| matches!(res, Res::Def(DefKind::ConstParam, _)))
2283-
{
2279+
let is_trivial_path = path.is_potential_trivial_const_arg()
2280+
&& matches!(res, Res::Def(DefKind::ConstParam, _));
2281+
let ct_kind = if is_trivial_path || tcx.features().min_generic_const_args() {
22842282
let qpath = self.lower_qpath(
22852283
ty_id,
22862284
&None,
@@ -2359,6 +2357,81 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
23592357
}
23602358
}
23612359

2360+
#[instrument(level = "debug", skip(self), ret)]
2361+
fn lower_expr_to_const_arg_direct(&mut self, expr: &Expr) -> hir::ConstArg<'hir> {
2362+
let overly_complex_const = |this: &mut Self| {
2363+
let e = this.dcx().struct_span_err(
2364+
expr.span,
2365+
"complex const arguments must be placed inside of a `const` block",
2366+
);
2367+
2368+
ConstArg { hir_id: this.next_id(), kind: hir::ConstArgKind::Error(expr.span, e.emit()) }
2369+
};
2370+
2371+
match &expr.kind {
2372+
ExprKind::Path(qself, path) => {
2373+
let qpath = self.lower_qpath(
2374+
expr.id,
2375+
qself,
2376+
path,
2377+
ParamMode::Explicit,
2378+
AllowReturnTypeNotation::No,
2379+
// FIXME(mgca): update for `fn foo() -> Bar<FOO<impl Trait>>` support
2380+
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
2381+
None,
2382+
);
2383+
2384+
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Path(qpath) }
2385+
}
2386+
ExprKind::Struct(se) => {
2387+
let path = self.lower_qpath(
2388+
expr.id,
2389+
&se.qself,
2390+
&se.path,
2391+
ParamMode::Explicit,
2392+
AllowReturnTypeNotation::No,
2393+
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
2394+
None,
2395+
);
2396+
2397+
let fields = self.arena.alloc_from_iter(se.fields.iter().map(|f| {
2398+
let hir_id = self.lower_node_id(f.id);
2399+
self.lower_attrs(hir_id, &f.attrs, f.span, Target::ExprField);
2400+
2401+
let expr = if let ExprKind::ConstBlock(anon_const) = &f.expr.kind {
2402+
self.lower_anon_const_to_const_arg_direct(anon_const)
2403+
} else {
2404+
self.lower_expr_to_const_arg_direct(&f.expr)
2405+
};
2406+
2407+
&*self.arena.alloc(hir::ConstArgExprField {
2408+
hir_id,
2409+
field: self.lower_ident(f.ident),
2410+
expr: self.arena.alloc(expr),
2411+
span: self.lower_span(f.span),
2412+
})
2413+
}));
2414+
2415+
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Struct(path, fields) }
2416+
}
2417+
ExprKind::Underscore => ConstArg {
2418+
hir_id: self.lower_node_id(expr.id),
2419+
kind: hir::ConstArgKind::Infer(expr.span, ()),
2420+
},
2421+
ExprKind::Block(block, _) => {
2422+
if let [stmt] = block.stmts.as_slice()
2423+
&& let StmtKind::Expr(expr) = &stmt.kind
2424+
&& matches!(expr.kind, ExprKind::Path(..) | ExprKind::Struct(..))
2425+
{
2426+
return self.lower_expr_to_const_arg_direct(expr);
2427+
}
2428+
2429+
overly_complex_const(self)
2430+
}
2431+
_ => overly_complex_const(self),
2432+
}
2433+
}
2434+
23622435
/// See [`hir::ConstArg`] for when to use this function vs
23632436
/// [`Self::lower_anon_const_to_anon_const`].
23642437
fn lower_anon_const_to_const_arg(&mut self, anon: &AnonConst) -> &'hir hir::ConstArg<'hir> {
@@ -2379,20 +2452,32 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
23792452
} else {
23802453
&anon.value
23812454
};
2455+
2456+
if tcx.features().min_generic_const_args() {
2457+
match anon.mgca_disambiguation {
2458+
MgcaDisambiguation::AnonConst => {
2459+
let lowered_anon = self.lower_anon_const_to_anon_const(anon);
2460+
return ConstArg {
2461+
hir_id: self.next_id(),
2462+
kind: hir::ConstArgKind::Anon(lowered_anon),
2463+
};
2464+
}
2465+
MgcaDisambiguation::Direct => return self.lower_expr_to_const_arg_direct(expr),
2466+
}
2467+
}
2468+
23822469
let maybe_res =
23832470
self.resolver.get_partial_res(expr.id).and_then(|partial_res| partial_res.full_res());
23842471
if let ExprKind::Path(qself, path) = &expr.kind
2385-
&& path.is_potential_trivial_const_arg(tcx.features().min_generic_const_args())
2386-
&& (tcx.features().min_generic_const_args()
2387-
|| matches!(maybe_res, Some(Res::Def(DefKind::ConstParam, _))))
2472+
&& path.is_potential_trivial_const_arg()
2473+
&& matches!(maybe_res, Some(Res::Def(DefKind::ConstParam, _)))
23882474
{
23892475
let qpath = self.lower_qpath(
23902476
expr.id,
23912477
qself,
23922478
path,
23932479
ParamMode::Explicit,
23942480
AllowReturnTypeNotation::No,
2395-
// FIXME(mgca): update for `fn foo() -> Bar<FOO<impl Trait>>` support
23962481
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
23972482
None,
23982483
);

compiler/rustc_ast_passes/src/feature_gate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ pub fn check_crate(krate: &ast::Crate, sess: &Session, features: &Features) {
514514
gate_all!(fn_delegation, "functions delegation is not yet fully implemented");
515515
gate_all!(postfix_match, "postfix match is experimental");
516516
gate_all!(mut_ref, "mutable by-reference bindings are experimental");
517+
gate_all!(min_generic_const_args, "unbraced const blocks as const args are experimental");
517518
gate_all!(global_registration, "global registration is experimental");
518519
gate_all!(return_type_notation, "return type notation is experimental");
519520
gate_all!(pin_ergonomics, "pinned reference syntax is experimental");

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod llvm_enzyme {
1717
use rustc_ast::{
1818
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
1919
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
20-
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
20+
MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
2121
};
2222
use rustc_expand::base::{Annotatable, ExtCtxt};
2323
use rustc_span::{Ident, Span, Symbol, sym};
@@ -558,7 +558,11 @@ mod llvm_enzyme {
558558
}
559559
GenericParamKind::Const { .. } => {
560560
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
561-
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
561+
let anon_const = AnonConst {
562+
id: ast::DUMMY_NODE_ID,
563+
value: expr,
564+
mgca_disambiguation: MgcaDisambiguation::Direct,
565+
};
562566
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
563567
}
564568
GenericParamKind::Lifetime { .. } => None,
@@ -813,6 +817,7 @@ mod llvm_enzyme {
813817
let anon_const = rustc_ast::AnonConst {
814818
id: ast::DUMMY_NODE_ID,
815819
value: ecx.expr_usize(span, 1 + x.width as usize),
820+
mgca_disambiguation: MgcaDisambiguation::Direct,
816821
};
817822
TyKind::Array(ty.clone(), anon_const)
818823
};
@@ -827,6 +832,7 @@ mod llvm_enzyme {
827832
let anon_const = rustc_ast::AnonConst {
828833
id: ast::DUMMY_NODE_ID,
829834
value: ecx.expr_usize(span, x.width as usize),
835+
mgca_disambiguation: MgcaDisambiguation::Direct,
830836
};
831837
let kind = TyKind::Array(ty.clone(), anon_const);
832838
let ty =

compiler/rustc_builtin_macros/src/pattern_type.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rustc_ast::tokenstream::TokenStream;
2-
use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token};
2+
use rustc_ast::{AnonConst, DUMMY_NODE_ID, MgcaDisambiguation, Ty, TyPat, TyPatKind, ast, token};
33
use rustc_errors::PResult;
44
use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult};
55
use rustc_parse::exp;
@@ -60,8 +60,20 @@ fn ty_pat(kind: TyPatKind, span: Span) -> TyPat {
6060
fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> TyPat {
6161
let kind = match pat.kind {
6262
ast::PatKind::Range(start, end, include_end) => TyPatKind::Range(
63-
start.map(|value| Box::new(AnonConst { id: DUMMY_NODE_ID, value })),
64-
end.map(|value| Box::new(AnonConst { id: DUMMY_NODE_ID, value })),
63+
start.map(|value| {
64+
Box::new(AnonConst {
65+
id: DUMMY_NODE_ID,
66+
value,
67+
mgca_disambiguation: MgcaDisambiguation::Direct,
68+
})
69+
}),
70+
end.map(|value| {
71+
Box::new(AnonConst {
72+
id: DUMMY_NODE_ID,
73+
value,
74+
mgca_disambiguation: MgcaDisambiguation::Direct,
75+
})
76+
}),
6577
include_end,
6678
),
6779
ast::PatKind::Or(variants) => {

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use cranelift_codegen::ir::immediates::Offset32;
44
use rustc_abi::Endian;
5-
use rustc_middle::ty::SimdAlign;
5+
use rustc_middle::ty::{SimdAlign, ValTreeKindExt};
66

77
use super::*;
88
use crate::prelude::*;
@@ -143,7 +143,10 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
143143

144144
let total_len = lane_count * 2;
145145

146-
let indexes = idx.iter().map(|idx| idx.unwrap_leaf().to_u32()).collect::<Vec<u32>>();
146+
let indexes = idx
147+
.iter()
148+
.map(|idx| idx.to_value().valtree.unwrap_leaf().to_u32())
149+
.collect::<Vec<u32>>();
147150

148151
for &idx in &indexes {
149152
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
@@ -962,6 +965,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
962965
let ptr_val = ptr.load_scalar(fx);
963966

964967
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
968+
.to_value()
969+
.valtree
965970
.unwrap_leaf()
966971
.to_simd_alignment();
967972

@@ -1007,6 +1012,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10071012
let ret_lane_layout = fx.layout_of(ret_lane_ty);
10081013

10091014
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1015+
.to_value()
1016+
.valtree
10101017
.unwrap_leaf()
10111018
.to_simd_alignment();
10121019

@@ -1060,6 +1067,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10601067
let ptr_val = ptr.load_scalar(fx);
10611068

10621069
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1070+
.to_value()
1071+
.valtree
10631072
.unwrap_leaf()
10641073
.to_simd_alignment();
10651074

0 commit comments

Comments
 (0)