Skip to content

Commit

Permalink
AST: Refactor type alias where clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
fmease committed Feb 29, 2024
1 parent 384d26f commit 2b80605
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 85 deletions.
32 changes: 22 additions & 10 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,10 @@ impl Default for Generics {
/// A where-clause in a definition.
#[derive(Clone, Encodable, Decodable, Debug)]
pub struct WhereClause {
/// `true` if we ate a `where` token: this can happen
/// if we parsed no predicates (e.g. `struct Foo where {}`).
/// This allows us to pretty-print accurately.
/// `true` if we ate a `where` token.
///
/// This can happen if we parsed no predicates, e.g., `struct Foo where {}`.
/// This allows us to pretty-print accurately and provide correct suggestion diagnostics.
pub has_where_token: bool,
pub predicates: ThinVec<WherePredicate>,
pub span: Span,
Expand Down Expand Up @@ -3007,18 +3008,29 @@ pub struct Trait {
///
/// If there is no where clause, then this is `false` with `DUMMY_SP`.
#[derive(Copy, Clone, Encodable, Decodable, Debug, Default)]
pub struct TyAliasWhereClause(pub bool, pub Span);
pub struct TyAliasWhereClause {
pub has_where_token: bool,
pub span: Span,
}

/// The span information for the two where clauses on a `TyAlias`.
#[derive(Copy, Clone, Encodable, Decodable, Debug, Default)]
pub struct TyAliasWhereClauses {
/// Before the equals sign.
pub before: TyAliasWhereClause,
/// After the equals sign.
pub after: TyAliasWhereClause,
/// The index in `TyAlias.generics.where_clause.predicates` that would split
/// into predicates from the where clause before the equals sign and the ones
/// from the where clause after the equals sign.
pub split: usize,
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub struct TyAlias {
pub defaultness: Defaultness,
pub generics: Generics,
/// The span information for the two where clauses (before equals, after equals)
pub where_clauses: (TyAliasWhereClause, TyAliasWhereClause),
/// The index in `generics.where_clause.predicates` that would split into
/// predicates from the where clause before the equals and the predicates
/// from the where clause after the equals
pub where_predicates_split: usize,
pub where_clauses: TyAliasWhereClauses,
pub bounds: GenericBounds,
pub ty: Option<P<Ty>>,
}
Expand Down
12 changes: 6 additions & 6 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,8 @@ pub fn noop_visit_item_kind<T: MutVisitor>(kind: &mut ItemKind, vis: &mut T) {
}) => {
visit_defaultness(defaultness, vis);
vis.visit_generics(generics);
vis.visit_span(&mut where_clauses.0.1);
vis.visit_span(&mut where_clauses.1.1);
vis.visit_span(&mut where_clauses.before.span);
vis.visit_span(&mut where_clauses.after.span);
visit_bounds(bounds, vis);
visit_opt(ty, |ty| vis.visit_ty(ty));
}
Expand Down Expand Up @@ -1163,8 +1163,8 @@ pub fn noop_flat_map_assoc_item<T: MutVisitor>(
}) => {
visit_defaultness(defaultness, visitor);
visitor.visit_generics(generics);
visitor.visit_span(&mut where_clauses.0.1);
visitor.visit_span(&mut where_clauses.1.1);
visitor.visit_span(&mut where_clauses.before.span);
visitor.visit_span(&mut where_clauses.after.span);
visit_bounds(bounds, visitor);
visit_opt(ty, |ty| visitor.visit_ty(ty));
}
Expand Down Expand Up @@ -1257,8 +1257,8 @@ pub fn noop_flat_map_foreign_item<T: MutVisitor>(
}) => {
visit_defaultness(defaultness, visitor);
visitor.visit_generics(generics);
visitor.visit_span(&mut where_clauses.0.1);
visitor.visit_span(&mut where_clauses.1.1);
visitor.visit_span(&mut where_clauses.before.span);
visitor.visit_span(&mut where_clauses.after.span);
visit_bounds(bounds, visitor);
visit_opt(ty, |ty| visitor.visit_ty(ty));
}
Expand Down
19 changes: 10 additions & 9 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,20 @@ pub(super) struct ItemLowerer<'a, 'hir> {
/// clause if it exists.
fn add_ty_alias_where_clause(
generics: &mut ast::Generics,
mut where_clauses: (TyAliasWhereClause, TyAliasWhereClause),
mut where_clauses: TyAliasWhereClauses,
prefer_first: bool,
) {
if !prefer_first {
where_clauses = (where_clauses.1, where_clauses.0);
}
if where_clauses.0.0 || !where_clauses.1.0 {
generics.where_clause.has_where_token = where_clauses.0.0;
generics.where_clause.span = where_clauses.0.1;
} else {
generics.where_clause.has_where_token = where_clauses.1.0;
generics.where_clause.span = where_clauses.1.1;
(where_clauses.before, where_clauses.after) = (where_clauses.after, where_clauses.before);
}
let where_clause =
if where_clauses.before.has_where_token || !where_clauses.after.has_where_token {
where_clauses.before
} else {
where_clauses.after
};
generics.where_clause.has_where_token = where_clause.has_where_token;
generics.where_clause.span = where_clause.span;
}

impl<'a, 'hir> ItemLowerer<'a, 'hir> {
Expand Down
27 changes: 13 additions & 14 deletions compiler/rustc_ast_passes/src/ast_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ impl<'a> AstValidator<'a> {
ty_alias: &TyAlias,
) -> Result<(), errors::WhereClauseBeforeTypeAlias> {
let before_predicates =
ty_alias.generics.where_clause.predicates.split_at(ty_alias.where_predicates_split).0;
ty_alias.generics.where_clause.predicates.split_at(ty_alias.where_clauses.split).0;

if ty_alias.ty.is_none() || before_predicates.is_empty() {
return Ok(());
}

let mut state = State::new();
if !ty_alias.where_clauses.1.0 {
if !ty_alias.where_clauses.after.has_where_token {
state.space();
state.word_space("where");
} else {
Expand All @@ -161,13 +161,13 @@ impl<'a> AstValidator<'a> {
state.print_where_predicate(p);
}

let span = ty_alias.where_clauses.0.1;
let span = ty_alias.where_clauses.before.span;
Err(errors::WhereClauseBeforeTypeAlias {
span,
sugg: errors::WhereClauseBeforeTypeAliasSugg {
left: span,
snippet: state.s.eof(),
right: ty_alias.where_clauses.1.1.shrink_to_hi(),
right: ty_alias.where_clauses.after.span.shrink_to_hi(),
},
})
}
Expand Down Expand Up @@ -457,8 +457,7 @@ impl<'a> AstValidator<'a> {
fn check_foreign_ty_genericless(
&self,
generics: &Generics,
before_where_clause: &TyAliasWhereClause,
after_where_clause: &TyAliasWhereClause,
where_clauses: &TyAliasWhereClauses,
) {
let cannot_have = |span, descr, remove_descr| {
self.dcx().emit_err(errors::ExternTypesCannotHave {
Expand All @@ -473,14 +472,14 @@ impl<'a> AstValidator<'a> {
cannot_have(generics.span, "generic parameters", "generic parameters");
}

let check_where_clause = |where_clause: &TyAliasWhereClause| {
if let TyAliasWhereClause(true, where_clause_span) = where_clause {
cannot_have(*where_clause_span, "`where` clauses", "`where` clause");
let check_where_clause = |where_clause: TyAliasWhereClause| {
if where_clause.has_where_token {
cannot_have(where_clause.span, "`where` clauses", "`where` clause");
}
};

check_where_clause(before_where_clause);
check_where_clause(after_where_clause);
check_where_clause(where_clauses.before);
check_where_clause(where_clauses.after);
}

fn check_foreign_kind_bodyless(&self, ident: Ident, kind: &str, body: Option<Span>) {
Expand Down Expand Up @@ -1122,9 +1121,9 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
if let Err(err) = self.check_type_alias_where_clause_location(ty_alias) {
self.dcx().emit_err(err);
}
} else if where_clauses.1.0 {
} else if where_clauses.after.has_where_token {
self.dcx().emit_err(errors::WhereClauseAfterTypeAlias {
span: where_clauses.1.1,
span: where_clauses.after.span,
help: self.session.is_nightly_build().then_some(()),
});
}
Expand Down Expand Up @@ -1154,7 +1153,7 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
self.check_defaultness(fi.span, *defaultness);
self.check_foreign_kind_bodyless(fi.ident, "type", ty.as_ref().map(|b| b.span));
self.check_type_no_bounds(bounds, "`extern` blocks");
self.check_foreign_ty_genericless(generics, &where_clauses.0, &where_clauses.1);
self.check_foreign_ty_genericless(generics, where_clauses);
self.check_foreign_item_ascii_only(fi.ident);
}
ForeignItemKind::Static(_, _, body) => {
Expand Down
15 changes: 4 additions & 11 deletions compiler/rustc_ast_pretty/src/pprust/state/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@ impl<'a> State<'a> {
defaultness,
generics,
where_clauses,
where_predicates_split,
bounds,
ty,
}) => {
self.print_associated_type(
ident,
generics,
*where_clauses,
*where_predicates_split,
bounds,
ty.as_deref(),
vis,
Expand Down Expand Up @@ -108,15 +106,14 @@ impl<'a> State<'a> {
&mut self,
ident: Ident,
generics: &ast::Generics,
where_clauses: (ast::TyAliasWhereClause, ast::TyAliasWhereClause),
where_predicates_split: usize,
where_clauses: ast::TyAliasWhereClauses,
bounds: &ast::GenericBounds,
ty: Option<&ast::Ty>,
vis: &ast::Visibility,
defaultness: ast::Defaultness,
) {
let (before_predicates, after_predicates) =
generics.where_clause.predicates.split_at(where_predicates_split);
generics.where_clause.predicates.split_at(where_clauses.split);
self.head("");
self.print_visibility(vis);
self.print_defaultness(defaultness);
Expand All @@ -127,13 +124,13 @@ impl<'a> State<'a> {
self.word_nbsp(":");
self.print_type_bounds(bounds);
}
self.print_where_clause_parts(where_clauses.0.0, before_predicates);
self.print_where_clause_parts(where_clauses.before.has_where_token, before_predicates);
if let Some(ty) = ty {
self.space();
self.word_space("=");
self.print_type(ty);
}
self.print_where_clause_parts(where_clauses.1.0, after_predicates);
self.print_where_clause_parts(where_clauses.after.has_where_token, after_predicates);
self.word(";");
self.end(); // end inner head-block
self.end(); // end outer head-block
Expand Down Expand Up @@ -249,15 +246,13 @@ impl<'a> State<'a> {
defaultness,
generics,
where_clauses,
where_predicates_split,
bounds,
ty,
}) => {
self.print_associated_type(
item.ident,
generics,
*where_clauses,
*where_predicates_split,
bounds,
ty.as_deref(),
&item.vis,
Expand Down Expand Up @@ -536,15 +531,13 @@ impl<'a> State<'a> {
defaultness,
generics,
where_clauses,
where_predicates_split,
bounds,
ty,
}) => {
self.print_associated_type(
ident,
generics,
*where_clauses,
*where_predicates_split,
bounds,
ty.as_deref(),
vis,
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,7 @@ impl<'a> TraitDef<'a> {
kind: ast::AssocItemKind::Type(Box::new(ast::TyAlias {
defaultness: ast::Defaultness::Final,
generics: Generics::default(),
where_clauses: (
ast::TyAliasWhereClause::default(),
ast::TyAliasWhereClause::default(),
),
where_predicates_split: 0,
where_clauses: ast::TyAliasWhereClauses::default(),
bounds: Vec::new(),
ty: Some(type_def.to_ty(cx, self.span, type_ident, generics)),
})),
Expand Down
17 changes: 11 additions & 6 deletions compiler/rustc_parse/src/parser/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,11 +971,17 @@ impl<'a> Parser<'a> {

let after_where_clause = self.parse_where_clause()?;

let where_clauses = (
TyAliasWhereClause(before_where_clause.has_where_token, before_where_clause.span),
TyAliasWhereClause(after_where_clause.has_where_token, after_where_clause.span),
);
let where_predicates_split = before_where_clause.predicates.len();
let where_clauses = TyAliasWhereClauses {
before: TyAliasWhereClause {
has_where_token: before_where_clause.has_where_token,
span: before_where_clause.span,
},
after: TyAliasWhereClause {
has_where_token: after_where_clause.has_where_token,
span: after_where_clause.span,
},
split: before_where_clause.predicates.len(),
};
let mut predicates = before_where_clause.predicates;
predicates.extend(after_where_clause.predicates);
let where_clause = WhereClause {
Expand All @@ -994,7 +1000,6 @@ impl<'a> Parser<'a> {
defaultness,
generics,
where_clauses,
where_predicates_split,
bounds,
ty,
})),
Expand Down
30 changes: 6 additions & 24 deletions src/tools/rustfmt/src/items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1651,8 +1651,7 @@ struct TyAliasRewriteInfo<'c, 'g>(
&'c RewriteContext<'c>,
Indent,
&'g ast::Generics,
(ast::TyAliasWhereClause, ast::TyAliasWhereClause),
usize,
ast::TyAliasWhereClauses,
symbol::Ident,
Span,
);
Expand All @@ -1672,23 +1671,14 @@ pub(crate) fn rewrite_type_alias<'a, 'b>(
ref bounds,
ref ty,
where_clauses,
where_predicates_split,
} = *ty_alias_kind;
let ty_opt = ty.as_ref();
let (ident, vis) = match visitor_kind {
Item(i) => (i.ident, &i.vis),
AssocTraitItem(i) | AssocImplItem(i) => (i.ident, &i.vis),
ForeignItem(i) => (i.ident, &i.vis),
};
let rw_info = &TyAliasRewriteInfo(
context,
indent,
generics,
where_clauses,
where_predicates_split,
ident,
span,
);
let rw_info = &TyAliasRewriteInfo(context, indent, generics, where_clauses, ident, span);
let op_ty = opaque_ty(ty);
// Type Aliases are formatted slightly differently depending on the context
// in which they appear, whether they are opaque, and whether they are associated.
Expand Down Expand Up @@ -1724,19 +1714,11 @@ fn rewrite_ty<R: Rewrite>(
vis: &ast::Visibility,
) -> Option<String> {
let mut result = String::with_capacity(128);
let TyAliasRewriteInfo(
context,
indent,
generics,
where_clauses,
where_predicates_split,
ident,
span,
) = *rw_info;
let TyAliasRewriteInfo(context, indent, generics, where_clauses, ident, span) = *rw_info;
let (before_where_predicates, after_where_predicates) = generics
.where_clause
.predicates
.split_at(where_predicates_split);
.split_at(where_clauses.split);
if !after_where_predicates.is_empty() {
return None;
}
Expand Down Expand Up @@ -1771,7 +1753,7 @@ fn rewrite_ty<R: Rewrite>(
let where_clause_str = rewrite_where_clause(
context,
before_where_predicates,
where_clauses.0.1,
where_clauses.before.span,
context.config.brace_style(),
Shape::legacy(where_budget, indent),
false,
Expand All @@ -1795,7 +1777,7 @@ fn rewrite_ty<R: Rewrite>(
let comment_span = context
.snippet_provider
.opt_span_before(span, "=")
.map(|op_lo| mk_sp(where_clauses.0.1.hi(), op_lo));
.map(|op_lo| mk_sp(where_clauses.before.span.hi(), op_lo));

let lhs = match comment_span {
Some(comment_span)
Expand Down

0 comments on commit 2b80605

Please sign in to comment.