Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest impl Trait return type when incorrectly using a generic return type #89892

Merged
merged 1 commit into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use crate::infer::canonical::Canonical;
use crate::ty::fold::ValidateBoundVars;
use crate::ty::subst::{GenericArg, InternalSubsts, Subst, SubstsRef};
use crate::ty::InferTy::{self, *};
use crate::ty::{self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable};
use crate::ty::{
self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable, TypeVisitor,
};
use crate::ty::{DelaySpanBugEmitted, List, ParamEnv};
use polonius_engine::Atom;
use rustc_data_structures::captures::Captures;
Expand All @@ -24,7 +26,7 @@ use std::borrow::Cow;
use std::cmp::Ordering;
use std::fmt;
use std::marker::PhantomData;
use std::ops::{Deref, Range};
use std::ops::{ControlFlow, Deref, Range};
use ty::util::IntTypeExt;

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, TyEncodable, TyDecodable)]
Expand Down Expand Up @@ -2072,6 +2074,24 @@ impl<'tcx> Ty<'tcx> {
!matches!(self.kind(), Param(_) | Infer(_) | Error(_))
}

/// Checks whether a type recursively contains another type
///
/// Example: `Option<()>` contains `()`
pub fn contains(self, other: Ty<'tcx>) -> bool {
struct ContainsTyVisitor<'tcx>(Ty<'tcx>);

impl<'tcx> TypeVisitor<'tcx> for ContainsTyVisitor<'tcx> {
type BreakTy = ();

fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow<Self::BreakTy> {
if self.0 == t { ControlFlow::BREAK } else { t.super_visit_with(self) }
}
}

let cf = self.visit_with(&mut ContainsTyVisitor(other));
cf.is_break()
}

/// Returns the type and mutability of `*ty`.
///
/// The parameter `explicit` indicates if this is an *explicit* dereference.
Expand Down
116 changes: 115 additions & 1 deletion compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ use rustc_errors::{Applicability, DiagnosticBuilder};
use rustc_hir as hir;
use rustc_hir::def::{CtorOf, DefKind};
use rustc_hir::lang_items::LangItem;
use rustc_hir::{Expr, ExprKind, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind};
use rustc_hir::{
Expr, ExprKind, GenericBound, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind,
WherePredicate,
};
use rustc_infer::infer::{self, TyCtxtInferExt};

use rustc_middle::lint::in_external_macro;
use rustc_middle::ty::{self, Binder, Ty};
use rustc_span::symbol::{kw, sym};
Expand Down Expand Up @@ -559,13 +563,123 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let ty = self.tcx.erase_late_bound_regions(ty);
if self.can_coerce(expected, ty) {
err.span_label(sp, format!("expected `{}` because of return type", expected));
self.try_suggest_return_impl_trait(err, expected, ty, fn_id);
return true;
}
false
}
}
}

/// check whether the return type is a generic type with a trait bound
/// only suggest this if the generic param is not present in the arguments
/// if this is true, hint them towards changing the return type to `impl Trait`
/// ```
/// fn cant_name_it<T: Fn() -> u32>() -> T {
/// || 3
/// }
/// ```
fn try_suggest_return_impl_trait(
&self,
err: &mut DiagnosticBuilder<'_>,
expected: Ty<'tcx>,
found: Ty<'tcx>,
fn_id: hir::HirId,
) {
// Only apply the suggestion if:
// - the return type is a generic parameter
// - the generic param is not used as a fn param
// - the generic param has at least one bound
// - the generic param doesn't appear in any other bounds where it's not the Self type
// Suggest:
// - Changing the return type to be `impl <all bounds>`

debug!("try_suggest_return_impl_trait, expected = {:?}, found = {:?}", expected, found);

let ty::Param(expected_ty_as_param) = expected.kind() else { return };

let fn_node = self.tcx.hir().find(fn_id);

let Some(hir::Node::Item(hir::Item {
kind:
hir::ItemKind::Fn(
hir::FnSig { decl: hir::FnDecl { inputs: fn_parameters, output: fn_return, .. }, .. },
hir::Generics { params, where_clause, .. },
_body_id,
),
..
})) = fn_node else { return };

let Some(expected_generic_param) = params.get(expected_ty_as_param.index as usize) else { return };

// get all where BoundPredicates here, because they are used in to cases below
let where_predicates = where_clause
.predicates
.iter()
.filter_map(|p| match p {
WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
bounds,
bounded_ty,
..
}) => {
// FIXME: Maybe these calls to `ast_ty_to_ty` can be removed (and the ones below)
let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, bounded_ty);
Some((ty, bounds))
}
_ => None,
})
.map(|(ty, bounds)| match ty.kind() {
ty::Param(param_ty) if param_ty == expected_ty_as_param => Ok(Some(bounds)),
// check whether there is any predicate that contains our `T`, like `Option<T>: Send`
_ => match ty.contains(expected) {
true => Err(()),
false => Ok(None),
},
})
.collect::<Result<Vec<_>, _>>();

let Ok(where_predicates) = where_predicates else { return };

// now get all predicates in the same types as the where bounds, so we can chain them
let predicates_from_where =
where_predicates.iter().flatten().map(|bounds| bounds.iter()).flatten();

// extract all bounds from the source code using their spans
let all_matching_bounds_strs = expected_generic_param
.bounds
.iter()
.chain(predicates_from_where)
.filter_map(|bound| match bound {
GenericBound::Trait(_, _) => {
self.tcx.sess.source_map().span_to_snippet(bound.span()).ok()
}
_ => None,
})
.collect::<Vec<String>>();

if all_matching_bounds_strs.len() == 0 {
return;
}

let all_bounds_str = all_matching_bounds_strs.join(" + ");

let ty_param_used_in_fn_params = fn_parameters.iter().any(|param| {
let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, param);
Noratrieb marked this conversation as resolved.
Show resolved Hide resolved
matches!(ty.kind(), ty::Param(fn_param_ty_param) if expected_ty_as_param == fn_param_ty_param)
});

if ty_param_used_in_fn_params {
return;
}

err.span_suggestion(
fn_return.span(),
"consider using an impl return type",
format!("impl {}", all_bounds_str),
Applicability::MaybeIncorrect,
);
}

pub(in super::super) fn suggest_missing_break_or_return_expr(
&self,
err: &mut DiagnosticBuilder<'_>,
Expand Down
31 changes: 31 additions & 0 deletions src/test/ui/return/return-impl-trait-bad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
trait Trait {}
impl Trait for () {}

fn bad_echo<T>(_t: T) -> T {
"this should not suggest impl Trait" //~ ERROR mismatched types
}

fn bad_echo_2<T: Trait>(_t: T) -> T {
"this will not suggest it, because that would probably be wrong" //~ ERROR mismatched types
}

fn other_bounds_bad<T>() -> T
where
T: Send,
Option<T>: Send,
{
"don't suggest this, because Option<T> places additional constraints" //~ ERROR mismatched types
}

// FIXME: implement this check
trait GenericTrait<T> {}

fn used_in_trait<T>() -> T
where
T: Send,
(): GenericTrait<T>,
{
"don't suggest this, because the generic param is used in the bound." //~ ERROR mismatched types
}

fn main() {}
59 changes: 59 additions & 0 deletions src/test/ui/return/return-impl-trait-bad.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
error[E0308]: mismatched types
--> $DIR/return-impl-trait-bad.rs:5:5
|
LL | fn bad_echo<T>(_t: T) -> T {
| - - expected `T` because of return type
| |
| this type parameter
LL | "this should not suggest impl Trait"
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
|
= note: expected type parameter `T`
found reference `&'static str`

error[E0308]: mismatched types
--> $DIR/return-impl-trait-bad.rs:9:5
|
LL | fn bad_echo_2<T: Trait>(_t: T) -> T {
| - - expected `T` because of return type
| |
| this type parameter
LL | "this will not suggest it, because that would probably be wrong"
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
|
= note: expected type parameter `T`
found reference `&'static str`

error[E0308]: mismatched types
--> $DIR/return-impl-trait-bad.rs:17:5
|
LL | fn other_bounds_bad<T>() -> T
| - - expected `T` because of return type
| |
| this type parameter
...
LL | "don't suggest this, because Option<T> places additional constraints"
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
|
= note: expected type parameter `T`
found reference `&'static str`

error[E0308]: mismatched types
--> $DIR/return-impl-trait-bad.rs:28:5
|
LL | fn used_in_trait<T>() -> T
| - -
| | |
| | expected `T` because of return type
| | help: consider using an impl return type: `impl Send`
| this type parameter
...
LL | "don't suggest this, because the generic param is used in the bound."
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
|
= note: expected type parameter `T`
found reference `&'static str`

error: aborting due to 4 previous errors

For more information about this error, try `rustc --explain E0308`.
30 changes: 30 additions & 0 deletions src/test/ui/return/return-impl-trait.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// run-rustfix

trait Trait {}
impl Trait for () {}

// this works
fn foo() -> impl Trait {
()
}

fn bar<T: Trait + std::marker::Sync>() -> impl Trait + std::marker::Sync + Send
where
T: Send,
{
() //~ ERROR mismatched types
}

fn other_bounds<T>() -> impl Trait
where
T: Trait,
Vec<usize>: Clone,
{
() //~ ERROR mismatched types
}

fn main() {
foo();
bar::<()>();
other_bounds::<()>();
}
30 changes: 30 additions & 0 deletions src/test/ui/return/return-impl-trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// run-rustfix

trait Trait {}
impl Trait for () {}

// this works
fn foo() -> impl Trait {
()
}

fn bar<T: Trait + std::marker::Sync>() -> T
Noratrieb marked this conversation as resolved.
Show resolved Hide resolved
where
Noratrieb marked this conversation as resolved.
Show resolved Hide resolved
T: Send,
{
() //~ ERROR mismatched types
}

fn other_bounds<T>() -> T
where
T: Trait,
Vec<usize>: Clone,
{
() //~ ERROR mismatched types
}

fn main() {
foo();
bar::<()>();
other_bounds::<()>();
}
34 changes: 34 additions & 0 deletions src/test/ui/return/return-impl-trait.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
error[E0308]: mismatched types
--> $DIR/return-impl-trait.rs:15:5
|
LL | fn bar<T: Trait + std::marker::Sync>() -> T
| - -
| | |
| | expected `T` because of return type
| this type parameter help: consider using an impl return type: `impl Trait + std::marker::Sync + Send`
...
LL | ()
| ^^ expected type parameter `T`, found `()`
|
= note: expected type parameter `T`
found unit type `()`

error[E0308]: mismatched types
--> $DIR/return-impl-trait.rs:23:5
|
LL | fn other_bounds<T>() -> T
| - -
| | |
| | expected `T` because of return type
| | help: consider using an impl return type: `impl Trait`
| this type parameter
...
LL | ()
| ^^ expected type parameter `T`, found `()`
|
= note: expected type parameter `T`
found unit type `()`

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0308`.