Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/hir_ty/src/chalk_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! about the code that Chalk needs.
use std::sync::Arc;

use cov_mark::hit;
use log::debug;

use chalk_ir::{cast::Cast, fold::shift::Shift, CanonicalVarKinds};
Expand Down Expand Up @@ -106,7 +107,9 @@ impl<'a> chalk_solve::RustIrDatabase<Interner> for ChalkContext<'a> {
};

fn local_impls(db: &dyn HirDatabase, module: ModuleId) -> Option<Arc<TraitImpls>> {
db.trait_impls_in_block(module.containing_block()?)
let block = module.containing_block()?;
hit!(block_local_impls);
db.trait_impls_in_block(block)
}

// Note: Since we're using impls_for_trait, only impls where the trait
Expand Down
10 changes: 7 additions & 3 deletions crates/hir_ty/src/test_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,20 @@ impl FileLoader for TestDB {
}

impl TestDB {
pub(crate) fn module_for_file(&self, file_id: FileId) -> ModuleId {
pub(crate) fn module_for_file_opt(&self, file_id: FileId) -> Option<ModuleId> {
for &krate in self.relevant_crates(file_id).iter() {
let crate_def_map = self.crate_def_map(krate);
for (local_id, data) in crate_def_map.modules() {
if data.origin.file_id() == Some(file_id) {
return crate_def_map.module_id(local_id);
return Some(crate_def_map.module_id(local_id));
}
}
}
panic!("Can't find module for file")
None
}

pub(crate) fn module_for_file(&self, file_id: FileId) -> ModuleId {
self.module_for_file_opt(file_id).unwrap()
}

pub(crate) fn extract_annotations(&self) -> FxHashMap<FileId, Vec<(TextRange, String)>> {
Expand Down
235 changes: 142 additions & 93 deletions crates/hir_ty/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@ mod incremental;

use std::{collections::HashMap, env, sync::Arc};

use base_db::{fixture::WithFixture, FileRange, SourceDatabase, SourceDatabaseExt};
use base_db::{fixture::WithFixture, FileRange, SourceDatabaseExt};
use expect_test::Expect;
use hir_def::{
body::{Body, BodySourceMap, SyntheticSyntax},
child_by_source::ChildBySource,
db::DefDatabase,
expr::{ExprId, PatId},
item_scope::ItemScope,
keys,
nameres::DefMap,
src::HasSource,
AssocItemId, DefWithBodyId, LocalModuleId, Lookup, ModuleDefId,
AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId,
};
use hir_expand::{db::AstDatabase, InFile};
use once_cell::race::OnceBool;
use stdx::format_to;
use syntax::{
algo,
ast::{self, AstNode, NameOwner},
SyntaxNode,
};
Expand Down Expand Up @@ -59,51 +57,55 @@ fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
}

fn check_types(ra_fixture: &str) {
check_types_impl(ra_fixture, false)
check_impl(ra_fixture, false, true, false)
}

fn check_types_source_code(ra_fixture: &str) {
check_types_impl(ra_fixture, true)
}

fn check_types_impl(ra_fixture: &str, display_source: bool) {
let _tracing = setup_tracing();
let db = TestDB::with_files(ra_fixture);
let mut checked_one = false;
for (file_id, annotations) in db.extract_annotations() {
for (range, expected) in annotations {
let ty = type_at_range(&db, FileRange { file_id, range });
let actual = if display_source {
let module = db.module_for_file(file_id);
ty.display_source_code(&db, module).unwrap()
} else {
ty.display_test(&db).to_string()
};
assert_eq!(expected, actual);
checked_one = true;
}
}

assert!(checked_one, "no `//^` annotations found");
check_impl(ra_fixture, false, true, true)
}

fn check_no_mismatches(ra_fixture: &str) {
check_mismatches_impl(ra_fixture, true)
check_impl(ra_fixture, true, false, false)
}

#[allow(unused)]
fn check_mismatches(ra_fixture: &str) {
check_mismatches_impl(ra_fixture, false)
fn check(ra_fixture: &str) {
check_impl(ra_fixture, false, false, false)
}

fn check_mismatches_impl(ra_fixture: &str, allow_none: bool) {
fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_source: bool) {
let _tracing = setup_tracing();
let (db, file_id) = TestDB::with_single_file(ra_fixture);
let module = db.module_for_file(file_id);
let def_map = module.def_map(&db);
let (db, files) = TestDB::with_many_files(ra_fixture);

let mut had_annotations = false;
let mut mismatches = HashMap::new();
let mut types = HashMap::new();
for (file_id, annotations) in db.extract_annotations() {
for (range, expected) in annotations {
let file_range = FileRange { file_id, range };
if only_types {
types.insert(file_range, expected);
} else if expected.starts_with("type: ") {
types.insert(file_range, expected.trim_start_matches("type: ").to_string());
} else if expected.starts_with("expected") {
mismatches.insert(file_range, expected);
} else {
panic!("unexpected annotation: {}", expected);
}
had_annotations = true;
}
}
assert!(had_annotations || allow_none, "no `//^` annotations found");

let mut defs: Vec<DefWithBodyId> = Vec::new();
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
for file_id in files {
let module = db.module_for_file_opt(file_id);
let module = match module {
Some(m) => m,
None => continue,
};
let def_map = module.def_map(&db);
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
}
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
let loc = it.lookup(&db);
Expand All @@ -118,37 +120,59 @@ fn check_mismatches_impl(ra_fixture: &str, allow_none: bool) {
loc.source(&db).value.syntax().text_range().start()
}
});
let mut mismatches = HashMap::new();
let mut push_mismatch = |src_ptr: InFile<SyntaxNode>, mismatch: TypeMismatch| {
let range = src_ptr.value.text_range();
if src_ptr.file_id.call_node(&db).is_some() {
panic!("type mismatch in macro expansion");
}
let file_range = FileRange { file_id: src_ptr.file_id.original_file(&db), range };
let actual = format!(
"expected {}, got {}",
mismatch.expected.display_test(&db),
mismatch.actual.display_test(&db)
);
mismatches.insert(file_range, actual);
};
let mut unexpected_type_mismatches = String::new();
for def in defs {
let (_body, body_source_map) = db.body_with_source_map(def);
let inference_result = db.infer(def);

for (pat, ty) in inference_result.type_of_pat.iter() {
let node = match pat_node(&body_source_map, pat, &db) {
Some(value) => value,
None => continue,
};
let range = node.as_ref().original_file_range(&db);
if let Some(expected) = types.remove(&range) {
let actual = if display_source {
ty.display_source_code(&db, def.module(&db)).unwrap()
} else {
ty.display_test(&db).to_string()
};
assert_eq!(actual, expected);
}
}

for (expr, ty) in inference_result.type_of_expr.iter() {
let node = match expr_node(&body_source_map, expr, &db) {
Some(value) => value,
None => continue,
};
let range = node.as_ref().original_file_range(&db);
if let Some(expected) = types.remove(&range) {
let actual = if display_source {
ty.display_source_code(&db, def.module(&db)).unwrap()
} else {
ty.display_test(&db).to_string()
};
assert_eq!(actual, expected);
}
}

for (pat, mismatch) in inference_result.pat_type_mismatches() {
let syntax_ptr = match body_source_map.pat_syntax(pat) {
Ok(sp) => {
let root = db.parse_or_expand(sp.file_id).unwrap();
sp.map(|ptr| {
ptr.either(
|it| it.to_node(&root).syntax().clone(),
|it| it.to_node(&root).syntax().clone(),
)
})
}
Err(SyntheticSyntax) => continue,
let node = match pat_node(&body_source_map, pat, &db) {
Some(value) => value,
None => continue,
};
push_mismatch(syntax_ptr, mismatch.clone());
let range = node.as_ref().original_file_range(&db);
let actual = format!(
"expected {}, got {}",
mismatch.expected.display_test(&db),
mismatch.actual.display_test(&db)
);
if let Some(annotation) = mismatches.remove(&range) {
assert_eq!(actual, annotation);
} else {
format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual);
}
}
for (expr, mismatch) in inference_result.expr_type_mismatches() {
let node = match body_source_map.expr_syntax(expr) {
Expand All @@ -158,45 +182,70 @@ fn check_mismatches_impl(ra_fixture: &str, allow_none: bool) {
}
Err(SyntheticSyntax) => continue,
};
push_mismatch(node, mismatch.clone());
}
}
let mut checked_one = false;
for (file_id, annotations) in db.extract_annotations() {
for (range, expected) in annotations {
let file_range = FileRange { file_id, range };
if let Some(mismatch) = mismatches.remove(&file_range) {
assert_eq!(mismatch, expected);
let range = node.as_ref().original_file_range(&db);
let actual = format!(
"expected {}, got {}",
mismatch.expected.display_test(&db),
mismatch.actual.display_test(&db)
);
if let Some(annotation) = mismatches.remove(&range) {
assert_eq!(actual, annotation);
} else {
assert!(false, "Expected mismatch not encountered: {}\n", expected);
format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual);
}
checked_one = true;
}
}

let mut buf = String::new();
for (range, mismatch) in mismatches {
format_to!(buf, "{:?}: {}\n", range.range, mismatch,);
if !unexpected_type_mismatches.is_empty() {
format_to!(buf, "Unexpected type mismatches:\n{}", unexpected_type_mismatches);
}
if !mismatches.is_empty() {
format_to!(buf, "Unchecked mismatch annotations:\n");
for m in mismatches {
format_to!(buf, "{:?}: {}\n", m.0.range, m.1);
}
}
assert!(buf.is_empty(), "Unexpected type mismatches:\n{}", buf);
if !types.is_empty() {
format_to!(buf, "Unchecked type annotations:\n");
for t in types {
format_to!(buf, "{:?}: type {}\n", t.0.range, t.1);
}
}
assert!(buf.is_empty(), "{}", buf);
}

assert!(checked_one || allow_none, "no `//^` annotations found");
fn expr_node(
body_source_map: &BodySourceMap,
expr: ExprId,
db: &TestDB,
) -> Option<InFile<SyntaxNode>> {
Some(match body_source_map.expr_syntax(expr) {
Ok(sp) => {
let root = db.parse_or_expand(sp.file_id).unwrap();
sp.map(|ptr| ptr.to_node(&root).syntax().clone())
}
Err(SyntheticSyntax) => return None,
})
}

fn type_at_range(db: &TestDB, pos: FileRange) -> Ty {
let file = db.parse(pos.file_id).ok().unwrap();
let expr = algo::find_node_at_range::<ast::Expr>(file.syntax(), pos.range).unwrap();
let fn_def = expr.syntax().ancestors().find_map(ast::Fn::cast).unwrap();
let module = db.module_for_file(pos.file_id);
let func = *module.child_by_source(db)[keys::FUNCTION]
.get(&InFile::new(pos.file_id.into(), fn_def))
.unwrap();

let (_body, source_map) = db.body_with_source_map(func.into());
if let Some(expr_id) = source_map.node_expr(InFile::new(pos.file_id.into(), &expr)) {
let infer = db.infer(func.into());
return infer[expr_id].clone();
}
panic!("Can't find expression")
fn pat_node(
body_source_map: &BodySourceMap,
pat: PatId,
db: &TestDB,
) -> Option<InFile<SyntaxNode>> {
Some(match body_source_map.pat_syntax(pat) {
Ok(sp) => {
let root = db.parse_or_expand(sp.file_id).unwrap();
sp.map(|ptr| {
ptr.either(
|it| it.to_node(&root).syntax().clone(),
|it| it.to_node(&root).syntax().clone(),
)
})
}
Err(SyntheticSyntax) => return None,
})
}

fn infer(ra_fixture: &str) -> String {
Expand Down
Loading