Skip to content

Commit

Permalink
fix: preserve where clause when builtin derive
Browse files Browse the repository at this point in the history
  • Loading branch information
Austaras committed Feb 8, 2024
1 parent e9d3565 commit dad0fdb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
Expand Up @@ -157,7 +157,7 @@ where
generic: Vec<T::InGenericArg>,
}
impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
impl <T: $crate::clone::Clone, > $crate::clone::Clone for Foo<T, > where <T as Trait>::InWc: Marker, T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {
Expand Down
22 changes: 18 additions & 4 deletions crates/hir-expand/src/builtin_derive_macro.rs
Expand Up @@ -194,6 +194,7 @@ struct BasicAdtInfo {
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
where_clause: Vec<tt::Subtree>,
associated_types: Vec<tt::Subtree>,
}

Expand All @@ -202,10 +203,11 @@ fn parse_adt(
adt: &ast::Adt,
call_site: Span,
) -> Result<BasicAdtInfo, ExpandError> {
let (name, generic_param_list, shape) = match adt {
let (name, generic_param_list, where_clause, shape) = match adt {
ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
it.where_clause(),
AdtShape::Struct(VariantShape::from(tm, it.field_list())?),
),
ast::Adt::Enum(it) => {
Expand All @@ -217,6 +219,7 @@ fn parse_adt(
(
it.name(),
it.generic_param_list(),
it.where_clause(),
AdtShape::Enum {
default_variant,
variants: it
Expand All @@ -233,7 +236,9 @@ fn parse_adt(
},
)
}
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
ast::Adt::Union(it) => {
(it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union)
}
};

let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
Expand Down Expand Up @@ -274,6 +279,14 @@ fn parse_adt(
})
.collect();

let where_clause = if let Some(w) = where_clause {
w.predicates()
.map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
.collect()
} else {
vec![]
};

// For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
// types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
// also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
Expand Down Expand Up @@ -301,7 +314,7 @@ fn parse_adt(
.map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site))
.collect();
let name_token = name_to_token(tm, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
}

fn name_to_token(
Expand Down Expand Up @@ -366,7 +379,8 @@ fn expand_simple_derive(
}
};
let trait_body = make_trait_body(&info);
let mut where_block = vec![];
let mut where_block: Vec<_> =
info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
let (params, args): (Vec<_>, Vec<_>) = info
.param_types
.into_iter()
Expand Down
31 changes: 31 additions & 0 deletions crates/hir-ty/src/tests/macros.rs
Expand Up @@ -1373,3 +1373,34 @@ pub fn attr_macro() {}
"#,
);
}

#[test]
fn clone_with_type_bound() {
check_types(
r#"
//- minicore: derive, clone, builtin_impls
#[derive(Clone)]
struct Float;
trait TensorKind: Clone {
/// The primitive type of the tensor.
type Primitive: Clone;
}
impl TensorKind for Float {
type Primitive = f64;
}
#[derive(Clone)]
struct Tensor<K = Float> where K: TensorKind
{
primitive: K::Primitive,
}
fn foo(t: Tensor) {
let x = t.clone();
//^ Tensor<Float>
}
"#,
);
}

0 comments on commit dad0fdb

Please sign in to comment.