Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,11 @@ impl Path {
/// Check if this path is potentially a trivial const arg, i.e., one that can _potentially_
/// be represented without an anon const in the HIR.
///
/// If `allow_mgca_arg` is true (as should be the case in most situations when
/// `#![feature(min_generic_const_args)]` is enabled), then this always returns true
/// because all paths are valid.
///
/// Otherwise, it returns true iff the path has exactly one segment, and it has no generic args
/// Returns true iff the path has exactly one segment, and it has no generic args
/// (i.e., it is _potentially_ a const parameter).
#[tracing::instrument(level = "debug", ret)]
pub fn is_potential_trivial_const_arg(&self, allow_mgca_arg: bool) -> bool {
allow_mgca_arg
|| self.segments.len() == 1 && self.segments.iter().all(|seg| seg.args.is_none())
pub fn is_potential_trivial_const_arg(&self) -> bool {
self.segments.len() == 1 && self.segments.iter().all(|seg| seg.args.is_none())
}
}

Expand Down Expand Up @@ -1372,6 +1367,15 @@ pub enum UnsafeSource {
UserProvided,
}

/// Track whether under `feature(min_generic_const_args)` this anon const
/// was explicitly disambiguated as an anon const or not through the use of
/// `const { ... }` syntax.
#[derive(Clone, PartialEq, Encodable, Decodable, Debug, Copy, Walkable)]
pub enum MgcaDisambiguation {
AnonConst,
Direct,
}

/// A constant (expression) that's not an item or associated item,
/// but needs its own `DefId` for type-checking, const-eval, etc.
/// These are usually found nested inside types (e.g., array lengths)
Expand All @@ -1381,6 +1385,7 @@ pub enum UnsafeSource {
pub struct AnonConst {
pub id: NodeId,
pub value: Box<Expr>,
pub mgca_disambiguation: MgcaDisambiguation,
}

/// An expression.
Expand All @@ -1399,26 +1404,20 @@ impl Expr {
///
/// This will unwrap at most one block level (curly braces). After that, if the expression
/// is a path, it mostly dispatches to [`Path::is_potential_trivial_const_arg`].
/// See there for more info about `allow_mgca_arg`.
///
/// The only additional thing to note is that when `allow_mgca_arg` is false, this function
/// will only allow paths with no qself, before dispatching to the `Path` function of
/// the same name.
/// This function will only allow paths with no qself, before dispatching to the `Path`
/// function of the same name.
///
/// Does not ensure that the path resolves to a const param/item, the caller should check this.
/// This also does not consider macros, so it's only correct after macro-expansion.
pub fn is_potential_trivial_const_arg(&self, allow_mgca_arg: bool) -> bool {
pub fn is_potential_trivial_const_arg(&self) -> bool {
let this = self.maybe_unwrap_block();
if allow_mgca_arg {
matches!(this.kind, ExprKind::Path(..))
if let ExprKind::Path(None, path) = &this.kind
&& path.is_potential_trivial_const_arg()
{
true
} else {
if let ExprKind::Path(None, path) = &this.kind
&& path.is_potential_trivial_const_arg(allow_mgca_arg)
{
true
} else {
false
}
false
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ macro_rules! common_visitor_and_walkers {
UnsafeBinderCastKind,
BinOpKind,
BlockCheckMode,
MgcaDisambiguation,
BorrowKind,
BoundAsyncness,
BoundConstness,
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,11 @@ impl<'hir> LoweringContext<'_, 'hir> {
arg
};

let anon_const = AnonConst { id: node_id, value: const_value };
let anon_const = AnonConst {
id: node_id,
value: const_value,
mgca_disambiguation: MgcaDisambiguation::AnonConst,
};
generic_args.push(AngleBracketedArg::Arg(GenericArg::Const(anon_const)));
} else {
real_args.push(arg);
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_ast_lowering/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
});
}

fn visit_const_arg_expr_field(&mut self, field: &'hir ConstArgExprField<'hir>) {
self.insert(field.span, field.hir_id, Node::ConstArgExprField(field));
self.with_parent(field.hir_id, |this| {
intravisit::walk_const_arg_expr_field(this, field);
})
}

fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) {
self.insert(stmt.span, stmt.hir_id, Node::Stmt(stmt));

Expand Down
105 changes: 95 additions & 10 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
.and_then(|partial_res| partial_res.full_res())
{
if !res.matches_ns(Namespace::TypeNS)
&& path.is_potential_trivial_const_arg(false)
&& path.is_potential_trivial_const_arg()
{
debug!(
"lower_generic_arg: Lowering type argument as const argument: {:?}",
Expand Down Expand Up @@ -2276,11 +2276,9 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
) -> &'hir hir::ConstArg<'hir> {
let tcx = self.tcx;

let ct_kind = if path
.is_potential_trivial_const_arg(tcx.features().min_generic_const_args())
&& (tcx.features().min_generic_const_args()
|| matches!(res, Res::Def(DefKind::ConstParam, _)))
{
let is_trivial_path = path.is_potential_trivial_const_arg()
&& matches!(res, Res::Def(DefKind::ConstParam, _));
let ct_kind = if is_trivial_path || tcx.features().min_generic_const_args() {
let qpath = self.lower_qpath(
ty_id,
&None,
Expand Down Expand Up @@ -2359,6 +2357,81 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
}

#[instrument(level = "debug", skip(self), ret)]
fn lower_expr_to_const_arg_direct(&mut self, expr: &Expr) -> hir::ConstArg<'hir> {
let overly_complex_const = |this: &mut Self| {
let e = this.dcx().struct_span_err(
expr.span,
"complex const arguments must be placed inside of a `const` block",
);

ConstArg { hir_id: this.next_id(), kind: hir::ConstArgKind::Error(expr.span, e.emit()) }
};

match &expr.kind {
ExprKind::Path(qself, path) => {
let qpath = self.lower_qpath(
expr.id,
qself,
path,
ParamMode::Explicit,
AllowReturnTypeNotation::No,
// FIXME(mgca): update for `fn foo() -> Bar<FOO<impl Trait>>` support
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
);

ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Path(qpath) }
}
ExprKind::Struct(se) => {
let path = self.lower_qpath(
expr.id,
&se.qself,
&se.path,
ParamMode::Explicit,
AllowReturnTypeNotation::No,
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
);

let fields = self.arena.alloc_from_iter(se.fields.iter().map(|f| {
let hir_id = self.lower_node_id(f.id);
self.lower_attrs(hir_id, &f.attrs, f.span, Target::ExprField);

let expr = if let ExprKind::ConstBlock(anon_const) = &f.expr.kind {
self.lower_anon_const_to_const_arg_direct(anon_const)
} else {
self.lower_expr_to_const_arg_direct(&f.expr)
};

&*self.arena.alloc(hir::ConstArgExprField {
hir_id,
field: self.lower_ident(f.ident),
expr: self.arena.alloc(expr),
span: self.lower_span(f.span),
})
}));

ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Struct(path, fields) }
}
ExprKind::Underscore => ConstArg {
hir_id: self.lower_node_id(expr.id),
kind: hir::ConstArgKind::Infer(expr.span, ()),
},
ExprKind::Block(block, _) => {
if let [stmt] = block.stmts.as_slice()
&& let StmtKind::Expr(expr) = &stmt.kind
&& matches!(expr.kind, ExprKind::Path(..) | ExprKind::Struct(..))
{
return self.lower_expr_to_const_arg_direct(expr);
}

overly_complex_const(self)
}
_ => overly_complex_const(self),
}
}

/// See [`hir::ConstArg`] for when to use this function vs
/// [`Self::lower_anon_const_to_anon_const`].
fn lower_anon_const_to_const_arg(&mut self, anon: &AnonConst) -> &'hir hir::ConstArg<'hir> {
Expand All @@ -2379,20 +2452,32 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
} else {
&anon.value
};

if tcx.features().min_generic_const_args() {
match anon.mgca_disambiguation {
MgcaDisambiguation::AnonConst => {
let lowered_anon = self.lower_anon_const_to_anon_const(anon);
return ConstArg {
hir_id: self.next_id(),
kind: hir::ConstArgKind::Anon(lowered_anon),
};
}
MgcaDisambiguation::Direct => return self.lower_expr_to_const_arg_direct(expr),
}
}

let maybe_res =
self.resolver.get_partial_res(expr.id).and_then(|partial_res| partial_res.full_res());
if let ExprKind::Path(qself, path) = &expr.kind
&& path.is_potential_trivial_const_arg(tcx.features().min_generic_const_args())
&& (tcx.features().min_generic_const_args()
|| matches!(maybe_res, Some(Res::Def(DefKind::ConstParam, _))))
&& path.is_potential_trivial_const_arg()
&& matches!(maybe_res, Some(Res::Def(DefKind::ConstParam, _)))
{
let qpath = self.lower_qpath(
expr.id,
qself,
path,
ParamMode::Explicit,
AllowReturnTypeNotation::No,
// FIXME(mgca): update for `fn foo() -> Bar<FOO<impl Trait>>` support
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
None,
);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ pub fn check_crate(krate: &ast::Crate, sess: &Session, features: &Features) {
gate_all!(fn_delegation, "functions delegation is not yet fully implemented");
gate_all!(postfix_match, "postfix match is experimental");
gate_all!(mut_ref, "mutable by-reference bindings are experimental");
gate_all!(min_generic_const_args, "unbraced const blocks as const args are experimental");
gate_all!(global_registration, "global registration is experimental");
gate_all!(return_type_notation, "return type notation is experimental");
gate_all!(pin_ergonomics, "pinned reference syntax is experimental");
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod llvm_enzyme {
use rustc_ast::{
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::{Ident, Span, Symbol, sym};
Expand Down Expand Up @@ -558,7 +558,11 @@ mod llvm_enzyme {
}
GenericParamKind::Const { .. } => {
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
let anon_const = AnonConst {
id: ast::DUMMY_NODE_ID,
value: expr,
mgca_disambiguation: MgcaDisambiguation::Direct,
};
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
}
GenericParamKind::Lifetime { .. } => None,
Expand Down Expand Up @@ -813,6 +817,7 @@ mod llvm_enzyme {
let anon_const = rustc_ast::AnonConst {
id: ast::DUMMY_NODE_ID,
value: ecx.expr_usize(span, 1 + x.width as usize),
mgca_disambiguation: MgcaDisambiguation::Direct,
};
TyKind::Array(ty.clone(), anon_const)
};
Expand All @@ -827,6 +832,7 @@ mod llvm_enzyme {
let anon_const = rustc_ast::AnonConst {
id: ast::DUMMY_NODE_ID,
value: ecx.expr_usize(span, x.width as usize),
mgca_disambiguation: MgcaDisambiguation::Direct,
};
let kind = TyKind::Array(ty.clone(), anon_const);
let ty =
Expand Down
18 changes: 15 additions & 3 deletions compiler/rustc_builtin_macros/src/pattern_type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rustc_ast::tokenstream::TokenStream;
use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token};
use rustc_ast::{AnonConst, DUMMY_NODE_ID, MgcaDisambiguation, Ty, TyPat, TyPatKind, ast, token};
use rustc_errors::PResult;
use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult};
use rustc_parse::exp;
Expand Down Expand Up @@ -60,8 +60,20 @@ fn ty_pat(kind: TyPatKind, span: Span) -> TyPat {
fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> TyPat {
let kind = match pat.kind {
ast::PatKind::Range(start, end, include_end) => TyPatKind::Range(
start.map(|value| Box::new(AnonConst { id: DUMMY_NODE_ID, value })),
end.map(|value| Box::new(AnonConst { id: DUMMY_NODE_ID, value })),
start.map(|value| {
Box::new(AnonConst {
id: DUMMY_NODE_ID,
value,
mgca_disambiguation: MgcaDisambiguation::Direct,
})
}),
end.map(|value| {
Box::new(AnonConst {
id: DUMMY_NODE_ID,
value,
mgca_disambiguation: MgcaDisambiguation::Direct,
})
}),
include_end,
),
ast::PatKind::Or(variants) => {
Expand Down
13 changes: 11 additions & 2 deletions compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use cranelift_codegen::ir::immediates::Offset32;
use rustc_abi::Endian;
use rustc_middle::ty::SimdAlign;
use rustc_middle::ty::{SimdAlign, ValTreeKindExt};

use super::*;
use crate::prelude::*;
Expand Down Expand Up @@ -143,7 +143,10 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(

let total_len = lane_count * 2;

let indexes = idx.iter().map(|idx| idx.unwrap_leaf().to_u32()).collect::<Vec<u32>>();
let indexes = idx
.iter()
.map(|idx| idx.to_value().valtree.unwrap_leaf().to_u32())
.collect::<Vec<u32>>();

for &idx in &indexes {
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
Expand Down Expand Up @@ -962,6 +965,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
let ptr_val = ptr.load_scalar(fx);

let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.to_value()
.valtree
.unwrap_leaf()
.to_simd_alignment();

Expand Down Expand Up @@ -1007,6 +1012,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
let ret_lane_layout = fx.layout_of(ret_lane_ty);

let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.to_value()
.valtree
.unwrap_leaf()
.to_simd_alignment();

Expand Down Expand Up @@ -1060,6 +1067,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
let ptr_val = ptr.load_scalar(fx);

let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.to_value()
.valtree
.unwrap_leaf()
.to_simd_alignment();

Expand Down
Loading
Loading