Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: only generate trait bound for associated types in field types #15000

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
);
}

#[test]
fn test_clone_expand_with_associated_types() {
check(
r#"
//- minicore: derive, clone
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);

#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
"#,
expect![[r#"
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);

#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}

impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {
qualified: qualified, shorthand: shorthand, generic: generic,
}
=>Foo {
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
}
,
}
}
}"#]],
);
}

#[test]
fn test_clone_expand_with_const_generics() {
check(
Expand Down
140 changes: 74 additions & 66 deletions crates/hir-expand/src/builtin_derive_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@ use ::tt::Ident;
use base_db::{CrateOrigin, LangCrateOrigin};
use itertools::izip;
use mbe::TokenMap;
use std::collections::HashSet;
use rustc_hash::FxHashSet;
use stdx::never;
use tracing::debug;

use crate::tt::{self, TokenId};
use syntax::{
ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName,
HasTypeBounds, PathType,
},
match_ast,
use crate::{
name::{AsName, Name},
tt::{self, TokenId},
};
use syntax::ast::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
};

use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
Expand Down Expand Up @@ -201,41 +200,54 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
debug!("no module item parsed");
ExpandError::Other("no item found".into())
})?;
let node = item.syntax();
let (name, params, shape) = match_ast! {
match node {
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)),
ast::Enum(it) => {
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
(
it.name(),
it.generic_param_list(),
AdtShape::Enum {
default_variant,
variants: it.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()?
}
)
},
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
_ => {
debug!("unexpected node is {:?}", node);
return Err(ExpandError::Other("expected struct, enum or union".into()))
},
let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
debug!("expected adt, found: {:?}", item);
ExpandError::Other("expected struct, enum or union".into())
})?;
let (name, generic_param_list, shape) = match &adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
),
ast::Adt::Enum(it) => {
let default_variant = it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
(
it.name(),
it.generic_param_list(),
AdtShape::Enum {
default_variant,
variants: it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.map(|x| {
Ok((
name_to_token(&token_map, x.name())?,
VariantShape::from(x.field_list(), &token_map)?,
))
})
.collect::<Result<_, ExpandError>>()?,
},
)
}
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
};
let mut param_type_set: HashSet<String> = HashSet::new();
let param_types = params

let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
let param_types = generic_param_list
.into_iter()
.flat_map(|param_list| param_list.type_or_const_params())
.map(|param| {
let name = {
let this = param.name();
match this {
Some(x) => {
param_type_set.insert(x.to_string());
param_type_set.insert(x.as_name());
mbe::syntax_node_to_token_tree(x.syntax()).0
}
None => tt::Subtree::empty(),
Expand All @@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
(name, ty, bounds)
})
.collect();
let is_associated_type = |p: &PathType| {
if let Some(p) = p.path() {
if let Some(parent) = p.qualifier() {
if let Some(x) = parent.segment() {
if let Some(x) = x.path_type() {
if let Some(x) = x.path() {
if let Some(pname) = x.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// <T as Trait>::Assoc
return true;
}
}
}
}
}
if let Some(pname) = parent.as_single_name_ref() {
if param_type_set.contains(&pname.to_string()) {
// T::Assoc
return true;
}
}
}
}
false

// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
// does not do that for some unknown reason.
//
// See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
// [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378

// It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
// `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
// we should not inspect `ast::PathType`s in parameter bounds and where clauses.
let field_list = match adt {
ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
};
let associated_types = node
.descendants()
.filter_map(PathType::cast)
.filter(is_associated_type)
let associated_types = field_list
.into_iter()
.flat_map(|it| it.descendants())
.filter_map(ast::PathType::cast)
.filter_map(|p| {
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
param_type_set.contains(&name).then_some(p)
})
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.collect::<Vec<_>>();
.collect();
let name_token = name_to_token(&token_map, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
}
Expand Down Expand Up @@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
/// }
/// ```
///
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and
/// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
/// therefore does not get bound by the derived trait.
fn expand_simple_derive(
tt: &tt::Subtree,
trait_path: tt::Subtree,
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) {
Ok(info) => info,
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
};
let trait_body = trait_body(&info);
let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
Expand Down
10 changes: 5 additions & 5 deletions crates/hir-ty/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
#[derive(Clone)]
struct AssocGeneric<T: Tr>(T::Assoc);

#[derive(Clone)]
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
// Currently rustc does not accept this.
// #[derive(Clone)]
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);

#[derive(Clone)]
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
Expand All @@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
let x = x.clone();
//^ &AssocGeneric<Copy>
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
let x = x.clone();
//^ &AssocGeneric2<Copy>
// let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
// let x = x.clone();
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
let x = x.clone();
//^ &AssocGeneric3<Copy>
Expand Down