Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fix pattern type mismatches for bindings, enable pattern type mismatch diagnostics again #14732

Merged
merged 1 commit into from May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/hir-ty/src/diagnostics/match_check.rs
Expand Up @@ -148,6 +148,7 @@ impl<'a> PatCtxt<'a> {

hir_def::hir::Pat::Bind { id, subpat, .. } => {
let bm = self.infer.pat_binding_modes[&pat];
ty = &self.infer[id];
let name = &self.body.bindings[id].name;
match (bm, ty.kind(Interner)) {
(BindingMode::Ref(_), TyKind::Ref(.., rty)) => ty = rty,
Expand Down
25 changes: 13 additions & 12 deletions crates/hir-ty/src/infer/pat.rs
Expand Up @@ -263,7 +263,7 @@ impl<'a> InferenceContext<'a> {
// Don't emit type mismatches again, the expression lowering already did that.
let ty = self.infer_lit_pat(expr, &expected);
self.write_pat_ty(pat, ty.clone());
return ty;
return self.pat_ty_after_adjustment(pat);
}
Pat::Box { inner } => match self.resolve_boxed_box() {
Some(box_adt) => {
Expand Down Expand Up @@ -298,8 +298,17 @@ impl<'a> InferenceContext<'a> {
.type_mismatches
.insert(pat.into(), TypeMismatch { expected, actual: ty.clone() });
}
self.write_pat_ty(pat, ty.clone());
ty
self.write_pat_ty(pat, ty);
self.pat_ty_after_adjustment(pat)
}

fn pat_ty_after_adjustment(&self, pat: PatId) -> Ty {
self.result
.pat_adjustments
.get(&pat)
.and_then(|x| x.first())
.unwrap_or(&self.result.type_of_pat[pat])
.clone()
}

fn infer_ref_pat(
Expand Down Expand Up @@ -345,7 +354,7 @@ impl<'a> InferenceContext<'a> {
}
BindingMode::Move => inner_ty.clone(),
};
self.write_pat_ty(pat, bound_ty.clone());
self.write_pat_ty(pat, inner_ty.clone());
self.write_binding_ty(binding, bound_ty);
return inner_ty;
}
Expand Down Expand Up @@ -422,14 +431,6 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
Pat::Lit(expr) => {
!matches!(body[*expr], Expr::Literal(Literal::String(..) | Literal::ByteString(..)))
}
Pat::Bind { id, subpat: Some(subpat), .. }
if matches!(
body.bindings[*id].mode,
BindingAnnotation::Mutable | BindingAnnotation::Unannotated
) =>
{
is_non_ref_pat(body, *subpat)
}
Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false,
}
}
Expand Down
19 changes: 13 additions & 6 deletions crates/hir-ty/src/tests.rs
Expand Up @@ -17,7 +17,7 @@ use expect_test::Expect;
use hir_def::{
body::{Body, BodySourceMap, SyntheticSyntax},
db::{DefDatabase, InternDatabase},
hir::{ExprId, PatId},
hir::{ExprId, Pat, PatId},
item_scope::ItemScope,
nameres::DefMap,
src::HasSource,
Expand Down Expand Up @@ -149,10 +149,13 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
});
let mut unexpected_type_mismatches = String::new();
for def in defs {
let (_body, body_source_map) = db.body_with_source_map(def);
let (body, body_source_map) = db.body_with_source_map(def);
let inference_result = db.infer(def);

for (pat, ty) in inference_result.type_of_pat.iter() {
for (pat, mut ty) in inference_result.type_of_pat.iter() {
if let Pat::Bind { id, .. } = body.pats[pat] {
ty = &inference_result.type_of_binding[id];
}
let node = match pat_node(&body_source_map, pat, &db) {
Some(value) => value,
None => continue,
Expand Down Expand Up @@ -284,11 +287,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
let mut buf = String::new();

let mut infer_def = |inference_result: Arc<InferenceResult>,
body: Arc<Body>,
body_source_map: Arc<BodySourceMap>| {
let mut types: Vec<(InFile<SyntaxNode>, &Ty)> = Vec::new();
let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new();

for (pat, ty) in inference_result.type_of_pat.iter() {
for (pat, mut ty) in inference_result.type_of_pat.iter() {
if let Pat::Bind { id, .. } = body.pats[pat] {
ty = &inference_result.type_of_binding[id];
}
let syntax_ptr = match body_source_map.pat_syntax(pat) {
Ok(sp) => {
let root = db.parse_or_expand(sp.file_id);
Expand Down Expand Up @@ -386,9 +393,9 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
}
});
for def in defs {
let (_body, source_map) = db.body_with_source_map(def);
let (body, source_map) = db.body_with_source_map(def);
let infer = db.infer(def);
infer_def(infer, source_map);
infer_def(infer, body, source_map);
}

buf.truncate(buf.trim_end().len());
Expand Down
50 changes: 50 additions & 0 deletions crates/hir-ty/src/tests/simple.rs
Expand Up @@ -2033,6 +2033,56 @@ fn test() {
);
}

#[test]
fn tuple_pattern_nested_match_ergonomics() {
check_no_mismatches(
r#"
fn f(x: (&i32, &i32)) -> i32 {
match x {
(3, 4) => 5,
_ => 12,
}
}
"#,
);
check_types(
r#"
fn f(x: (&&&&i32, &&&i32)) {
let f = match x {
t @ (3, 4) => t,
_ => loop {},
};
f;
//^ (&&&&i32, &&&i32)
}
"#,
);
check_types(
r#"
fn f() {
let x = &&&(&&&2, &&&&&3);
let (y, z) = x;
//^ &&&&i32
let t @ (y, z) = x;
t;
//^ &&&(&&&i32, &&&&&i32)
}
"#,
);
check_types(
r#"
fn f() {
let x = &&&(&&&2, &&&&&3);
let (y, z) = x;
//^ &&&&i32
let t @ (y, z) = x;
t;
//^ &&&(&&&i32, &&&&&i32)
}
"#,
);
}

#[test]
fn fn_pointer_return() {
check_infer(
Expand Down
3 changes: 0 additions & 3 deletions crates/hir/src/lib.rs
Expand Up @@ -1535,9 +1535,6 @@ impl DefWithBody {
for (pat_or_expr, mismatch) in infer.type_mismatches() {
let expr_or_pat = match pat_or_expr {
ExprOrPatId::ExprId(expr) => source_map.expr_syntax(expr).map(Either::Left),
// FIXME: Re-enable these once we have less false positives
ExprOrPatId::PatId(_pat) => continue,
#[allow(unreachable_patterns)]
ExprOrPatId::PatId(pat) => source_map.pat_syntax(pat).map(Either::Right),
};
let expr_or_pat = match expr_or_pat {
Expand Down
11 changes: 11 additions & 0 deletions crates/hir/src/semantics.rs
Expand Up @@ -350,6 +350,13 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
self.imp.type_of_pat(pat)
}

/// It also includes the changes that binding mode makes in the type. For example in
/// `let ref x @ Some(_) = None` the result of `type_of_pat` is `Option<T>` but the result
/// of this function is `&mut Option<T>`
pub fn type_of_binding_in_pat(&self, pat: &ast::IdentPat) -> Option<Type> {
Copy link
Member

@Veykril Veykril May 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting thing to keep in mind that effectively pattern and binding are not always equal in a sense

self.imp.type_of_binding_in_pat(pat)
}

pub fn type_of_self(&self, param: &ast::SelfParam) -> Option<Type> {
self.imp.type_of_self(param)
}
Expand Down Expand Up @@ -1138,6 +1145,10 @@ impl<'db> SemanticsImpl<'db> {
.map(|(ty, coerced)| TypeInfo { original: ty, adjusted: coerced })
}

fn type_of_binding_in_pat(&self, pat: &ast::IdentPat) -> Option<Type> {
self.analyze(pat.syntax())?.type_of_binding_in_pat(self.db, pat)
}

fn type_of_self(&self, param: &ast::SelfParam) -> Option<Type> {
self.analyze(param.syntax())?.type_of_self(self.db, param)
}
Expand Down
23 changes: 22 additions & 1 deletion crates/hir/src/source_analyzer.rs
Expand Up @@ -13,7 +13,7 @@ use hir_def::{
scope::{ExprScopes, ScopeId},
Body, BodySourceMap,
},
hir::{ExprId, Pat, PatId},
hir::{BindingId, ExprId, Pat, PatId},
lang_item::LangItem,
lower::LowerCtx,
macro_id_to_def_id,
Expand Down Expand Up @@ -133,6 +133,15 @@ impl SourceAnalyzer {
self.body_source_map()?.node_pat(src)
}

fn binding_id_of_pat(&self, pat: &ast::IdentPat) -> Option<BindingId> {
let pat_id = self.pat_id(&pat.clone().into())?;
if let Pat::Bind { id, .. } = self.body()?.pats[pat_id] {
Some(id)
} else {
None
}
}

fn expand_expr(
&self,
db: &dyn HirDatabase,
Expand Down Expand Up @@ -198,6 +207,18 @@ impl SourceAnalyzer {
Some((mk_ty(ty), coerced.map(mk_ty)))
}

pub(crate) fn type_of_binding_in_pat(
&self,
db: &dyn HirDatabase,
pat: &ast::IdentPat,
) -> Option<Type> {
let binding_id = self.binding_id_of_pat(pat)?;
let infer = self.infer.as_ref()?;
let ty = infer[binding_id].clone();
let mk_ty = |ty| Type::new_with_resolver(db, &self.resolver, ty);
Some(mk_ty(ty))
}

pub(crate) fn type_of_self(
&self,
db: &dyn HirDatabase,
Expand Down
Expand Up @@ -91,7 +91,7 @@ fn collect_data(ident_pat: IdentPat, ctx: &AssistContext<'_>) -> Option<TupleDat
return None;
}

let ty = ctx.sema.type_of_pat(&ident_pat.clone().into())?.adjusted();
let ty = ctx.sema.type_of_binding_in_pat(&ident_pat)?;
let ref_type = if ty.is_mutable_reference() {
Some(RefType::Mutable)
} else if ty.is_reference() {
Expand Down
28 changes: 14 additions & 14 deletions crates/ide-assists/src/handlers/merge_match_arms.rs
@@ -1,4 +1,4 @@
use hir::TypeInfo;
use hir::Type;
use std::{collections::HashMap, iter::successors};
use syntax::{
algo::neighbor,
Expand Down Expand Up @@ -95,15 +95,15 @@ fn contains_placeholder(a: &ast::MatchArm) -> bool {
}

fn are_same_types(
current_arm_types: &HashMap<String, Option<TypeInfo>>,
current_arm_types: &HashMap<String, Option<Type>>,
arm: &ast::MatchArm,
ctx: &AssistContext<'_>,
) -> bool {
let arm_types = get_arm_types(ctx, arm);
for (other_arm_type_name, other_arm_type) in arm_types {
match (current_arm_types.get(&other_arm_type_name), other_arm_type) {
(Some(Some(current_arm_type)), Some(other_arm_type))
if other_arm_type.original == current_arm_type.original => {}
if other_arm_type == *current_arm_type => {}
_ => return false,
}
}
Expand All @@ -114,44 +114,44 @@ fn are_same_types(
fn get_arm_types(
context: &AssistContext<'_>,
arm: &ast::MatchArm,
) -> HashMap<String, Option<TypeInfo>> {
let mut mapping: HashMap<String, Option<TypeInfo>> = HashMap::new();
) -> HashMap<String, Option<Type>> {
let mut mapping: HashMap<String, Option<Type>> = HashMap::new();

fn recurse(
map: &mut HashMap<String, Option<TypeInfo>>,
map: &mut HashMap<String, Option<Type>>,
ctx: &AssistContext<'_>,
pat: &Option<ast::Pat>,
) {
if let Some(local_pat) = pat {
match pat {
Some(ast::Pat::TupleStructPat(tuple)) => {
match local_pat {
ast::Pat::TupleStructPat(tuple) => {
for field in tuple.fields() {
recurse(map, ctx, &Some(field));
}
}
Some(ast::Pat::TuplePat(tuple)) => {
ast::Pat::TuplePat(tuple) => {
for field in tuple.fields() {
recurse(map, ctx, &Some(field));
}
}
Some(ast::Pat::RecordPat(record)) => {
ast::Pat::RecordPat(record) => {
if let Some(field_list) = record.record_pat_field_list() {
for field in field_list.fields() {
recurse(map, ctx, &field.pat());
}
}
}
Some(ast::Pat::ParenPat(parentheses)) => {
ast::Pat::ParenPat(parentheses) => {
recurse(map, ctx, &parentheses.pat());
}
Some(ast::Pat::SlicePat(slice)) => {
ast::Pat::SlicePat(slice) => {
for slice_pat in slice.pats() {
recurse(map, ctx, &Some(slice_pat));
}
}
Some(ast::Pat::IdentPat(ident_pat)) => {
ast::Pat::IdentPat(ident_pat) => {
if let Some(name) = ident_pat.name() {
let pat_type = ctx.sema.type_of_pat(local_pat);
let pat_type = ctx.sema.type_of_binding_in_pat(ident_pat);
map.insert(name.text().to_string(), pat_type);
}
}
Expand Down