From a8f56112eab74521dfc75ef811f1a9ea23bcf43c Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Wed, 6 Mar 2024 21:16:41 +0900 Subject: [PATCH] fix: Function argument type inference with associated type impl trait --- crates/hir-ty/src/lower.rs | 96 +++++++++++++++++++++++++++---- crates/hir-ty/src/tests/traits.rs | 47 +++++++++++++++ 2 files changed, 132 insertions(+), 11 deletions(-) diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index 75ac3b0d66b8..dac20f225971 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -995,12 +995,12 @@ impl<'a> TyLoweringContext<'a> { pub(crate) fn lower_type_bound( &'a self, - bound: &'a TypeBound, + bound: &'a Interned, self_ty: Ty, ignore_bindings: bool, ) -> impl Iterator + 'a { let mut bindings = None; - let trait_ref = match bound { + let trait_ref = match bound.as_ref() { TypeBound::Path(path, TraitBoundModifier::None) => { bindings = self.lower_trait_ref_from_path(path, Some(self_ty)); bindings @@ -1055,10 +1055,10 @@ impl<'a> TyLoweringContext<'a> { fn assoc_type_bindings_from_type_bound( &'a self, - bound: &'a TypeBound, + bound: &'a Interned, trait_ref: TraitRef, ) -> impl Iterator + 'a { - let last_segment = match bound { + let last_segment = match bound.as_ref() { TypeBound::Path(path, TraitBoundModifier::None) | TypeBound::ForLifetime(_, path) => { path.segments().last() } @@ -1121,7 +1121,63 @@ impl<'a> TyLoweringContext<'a> { ); } } else { - let ty = self.lower_ty(type_ref); + let ty = 'ty: { + if matches!( + self.impl_trait_mode, + ImplTraitLoweringState::Param(_) + | ImplTraitLoweringState::Variable(_) + ) { + // Find the generic index for the target of our `bound` + let target_param_idx = self + .resolver + .where_predicates_in_scope() + .find_map(|p| match p { + WherePredicate::TypeBound { + target: WherePredicateTypeTarget::TypeOrConstParam(idx), + bound: b, + } if b == bound => Some(idx), + _ => None, + }); + if let Some(target_param_idx) = target_param_idx { + let mut counter = 0; + for (idx, data) in self.generics().params.type_or_consts.iter() + { + // Count the number of `impl Trait` things that appear before + // the target of our `bound`. + // Our counter within `impl_trait_mode` should be that number + // to properly lower each types within `type_ref` + if data.type_param().is_some_and(|p| { + p.provenance == TypeParamProvenance::ArgumentImplTrait + }) { + counter += 1; + } + if idx == *target_param_idx { + break; + } + } + let mut ext = TyLoweringContext::new_maybe_unowned( + self.db, + self.resolver, + self.owner, + ) + .with_type_param_mode(self.type_param_mode); + match &self.impl_trait_mode { + ImplTraitLoweringState::Param(_) => { + ext.impl_trait_mode = + ImplTraitLoweringState::Param(Cell::new(counter)); + } + ImplTraitLoweringState::Variable(_) => { + ext.impl_trait_mode = ImplTraitLoweringState::Variable( + Cell::new(counter), + ); + } + _ => unreachable!(), + } + break 'ty ext.lower_ty(type_ref); + } + } + self.lower_ty(type_ref) + }; let alias_eq = AliasEq { alias: AliasTy::Projection(projection_ty.clone()), ty }; predicates.push(crate::wrap_empty_binders(WhereClause::AliasEq(alias_eq))); @@ -1403,8 +1459,14 @@ pub(crate) fn generic_predicates_for_param_query( assoc_name: Option, ) -> Arc<[Binders]> { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Variable); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Variable) + .with_type_param_mode(ParamLoweringMode::Variable) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Variable) + }; let generics = generics(db.upcast(), def); // we have to filter out all other predicates *first*, before attempting to lower them @@ -1490,8 +1552,14 @@ pub(crate) fn trait_environment_query( def: GenericDefId, ) -> Arc { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Placeholder); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Param) + .with_type_param_mode(ParamLoweringMode::Placeholder) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Placeholder) + }; let mut traits_in_scope = Vec::new(); let mut clauses = Vec::new(); for pred in resolver.where_predicates_in_scope() { @@ -1549,8 +1617,14 @@ pub(crate) fn generic_predicates_query( def: GenericDefId, ) -> Arc<[Binders]> { let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, def.into()) - .with_type_param_mode(ParamLoweringMode::Variable); + let ctx = if let GenericDefId::FunctionId(_) = def { + TyLoweringContext::new(db, &resolver, def.into()) + .with_impl_trait_mode(ImplTraitLoweringMode::Variable) + .with_type_param_mode(ParamLoweringMode::Variable) + } else { + TyLoweringContext::new(db, &resolver, def.into()) + .with_type_param_mode(ParamLoweringMode::Variable) + }; let generics = generics(db.upcast(), def); let mut predicates = resolver diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 39c5547b8d0e..b80cfe18e4cf 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -1231,6 +1231,53 @@ fn test(x: impl Trait, y: &impl Trait) { ); } +#[test] +fn argument_impl_trait_with_projection() { + check_infer( + r#" +trait X { + type Item; +} + +impl X for [T; 2] { + type Item = T; +} + +trait Y {} + +impl Y for T {} + +enum R { + A(T), + B(U), +} + +fn foo(x: impl X>) -> T { loop {} } + +fn bar() { + let a = foo([R::A(()), R::B(7)]); +} +"#, + expect![[r#" + 153..154 'x': impl X> + ?Sized + 190..201 '{ loop {} }': T + 192..199 'loop {}': ! + 197..199 '{}': () + 212..253 '{ ...)]); }': () + 222..223 'a': i32 + 226..229 'foo': fn foo([R<(), i32>; 2]) -> i32 + 226..250 'foo([R...B(7)])': i32 + 230..249 '[R::A(...:B(7)]': [R<(), i32>; 2] + 231..235 'R::A': extern "rust-call" A<(), i32>(()) -> R<(), i32> + 231..239 'R::A(())': R<(), i32> + 236..238 '()': () + 241..245 'R::B': extern "rust-call" B<(), i32>(i32) -> R<(), i32> + 241..248 'R::B(7)': R<(), i32> + 246..247 '7': i32 + "#]], + ); +} + #[test] fn simple_return_pos_impl_trait() { cov_mark::check!(lower_rpit);