Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions crates/hir/src/code_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,12 @@ impl ModuleDef {
_ => return,
};

hir_ty::diagnostics::validate_module_item(db, id, sink)
let module = match self.module(db) {
Some(it) => it,
None => return,
};

hir_ty::diagnostics::validate_module_item(db, module.id.krate, id, sink)
}
}

Expand Down Expand Up @@ -780,8 +785,9 @@ impl Function {
}

pub fn diagnostics(self, db: &dyn HirDatabase, sink: &mut DiagnosticSink) {
let krate = self.module(db).id.krate;
hir_def::diagnostics::validate_body(db.upcast(), self.id.into(), sink);
hir_ty::diagnostics::validate_module_item(db, self.id.into(), sink);
hir_ty::diagnostics::validate_module_item(db, krate, self.id.into(), sink);
hir_ty::diagnostics::validate_body(db, self.id.into(), sink);
}

Expand Down
15 changes: 11 additions & 4 deletions crates/hir_ty/src/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod decl_check;

use std::{any::Any, fmt};

use base_db::CrateId;
use hir_def::{DefWithBodyId, ModuleDefId};
use hir_expand::diagnostics::{Diagnostic, DiagnosticCode, DiagnosticSink};
use hir_expand::{name::Name, HirFileId, InFile};
Expand All @@ -18,12 +19,13 @@ pub use crate::diagnostics::expr::{record_literal_missing_fields, record_pattern

pub fn validate_module_item(
db: &dyn HirDatabase,
krate: CrateId,
owner: ModuleDefId,
sink: &mut DiagnosticSink<'_>,
) {
let _p = profile::span("validate_module_item");
let mut validator = decl_check::DeclValidator::new(owner, sink);
validator.validate_item(db);
let mut validator = decl_check::DeclValidator::new(db, krate, sink);
validator.validate_item(owner);
}

pub fn validate_body(db: &dyn HirDatabase, owner: DefWithBodyId, sink: &mut DiagnosticSink<'_>) {
Expand Down Expand Up @@ -407,7 +409,7 @@ mod tests {
for (module_id, _) in crate_def_map.modules.iter() {
for decl in crate_def_map[module_id].scope.declarations() {
let mut sink = DiagnosticSinkBuilder::new().build(&mut cb);
validate_module_item(self, decl, &mut sink);
validate_module_item(self, krate, decl, &mut sink);

if let ModuleDefId::FunctionId(f) = decl {
fns.push(f)
Expand All @@ -419,7 +421,12 @@ mod tests {
for item in impl_data.items.iter() {
if let AssocItemId::FunctionId(f) = item {
let mut sink = DiagnosticSinkBuilder::new().build(&mut cb);
validate_module_item(self, ModuleDefId::FunctionId(*f), &mut sink);
validate_module_item(
self,
krate,
ModuleDefId::FunctionId(*f),
&mut sink,
);
fns.push(*f)
}
}
Expand Down
103 changes: 49 additions & 54 deletions crates/hir_ty/src/diagnostics/decl_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

mod case_conv;

use base_db::CrateId;
use hir_def::{
adt::VariantData,
expr::{Pat, PatId},
Expand Down Expand Up @@ -40,7 +41,8 @@ mod allow {
}

pub(super) struct DeclValidator<'a, 'b: 'a> {
owner: ModuleDefId,
db: &'a dyn HirDatabase,
krate: CrateId,
sink: &'a mut DiagnosticSink<'b>,
}

Expand All @@ -53,26 +55,27 @@ struct Replacement {

impl<'a, 'b> DeclValidator<'a, 'b> {
pub(super) fn new(
owner: ModuleDefId,
db: &'a dyn HirDatabase,
krate: CrateId,
sink: &'a mut DiagnosticSink<'b>,
) -> DeclValidator<'a, 'b> {
DeclValidator { owner, sink }
DeclValidator { db, krate, sink }
}

pub(super) fn validate_item(&mut self, db: &dyn HirDatabase) {
match self.owner {
ModuleDefId::FunctionId(func) => self.validate_func(db, func),
ModuleDefId::AdtId(adt) => self.validate_adt(db, adt),
ModuleDefId::ConstId(const_id) => self.validate_const(db, const_id),
ModuleDefId::StaticId(static_id) => self.validate_static(db, static_id),
pub(super) fn validate_item(&mut self, item: ModuleDefId) {
match item {
ModuleDefId::FunctionId(func) => self.validate_func(func),
ModuleDefId::AdtId(adt) => self.validate_adt(adt),
ModuleDefId::ConstId(const_id) => self.validate_const(const_id),
ModuleDefId::StaticId(static_id) => self.validate_static(static_id),
_ => return,
}
}

fn validate_adt(&mut self, db: &dyn HirDatabase, adt: AdtId) {
fn validate_adt(&mut self, adt: AdtId) {
match adt {
AdtId::StructId(struct_id) => self.validate_struct(db, struct_id),
AdtId::EnumId(enum_id) => self.validate_enum(db, enum_id),
AdtId::StructId(struct_id) => self.validate_struct(struct_id),
AdtId::EnumId(enum_id) => self.validate_enum(enum_id),
AdtId::UnionId(_) => {
// Unions aren't yet supported by this validator.
}
Expand All @@ -82,27 +85,27 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
/// Checks whether not following the convention is allowed for this item.
///
/// Currently this method doesn't check parent attributes.
fn allowed(&self, db: &dyn HirDatabase, id: AttrDefId, allow_name: &str) -> bool {
db.attrs(id).by_key("allow").tt_values().any(|tt| tt.to_string().contains(allow_name))
fn allowed(&self, id: AttrDefId, allow_name: &str) -> bool {
self.db.attrs(id).by_key("allow").tt_values().any(|tt| tt.to_string().contains(allow_name))
}

fn validate_func(&mut self, db: &dyn HirDatabase, func: FunctionId) {
let data = db.function_data(func);
fn validate_func(&mut self, func: FunctionId) {
let data = self.db.function_data(func);
if data.is_extern {
mark::hit!(extern_func_incorrect_case_ignored);
return;
}

let body = db.body(func.into());
let body = self.db.body(func.into());

// Recursively validate inner scope items, such as static variables and constants.
for (item_id, _) in body.item_scope.values() {
let mut validator = DeclValidator::new(item_id, self.sink);
validator.validate_item(db);
let mut validator = DeclValidator::new(self.db, self.krate, self.sink);
validator.validate_item(item_id);
}

// Check whether non-snake case identifiers are allowed for this function.
if self.allowed(db, func.into(), allow::NON_SNAKE_CASE) {
if self.allowed(func.into(), allow::NON_SNAKE_CASE) {
return;
}

Expand Down Expand Up @@ -169,19 +172,17 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
// If there is at least one element to spawn a warning on, go to the source map and generate a warning.
self.create_incorrect_case_diagnostic_for_func(
func,
db,
fn_name_replacement,
fn_param_replacements,
);
self.create_incorrect_case_diagnostic_for_variables(func, db, pats_replacements);
self.create_incorrect_case_diagnostic_for_variables(func, pats_replacements);
}

/// Given the information about incorrect names in the function declaration, looks up into the source code
/// for exact locations and adds diagnostics into the sink.
fn create_incorrect_case_diagnostic_for_func(
&mut self,
func: FunctionId,
db: &dyn HirDatabase,
fn_name_replacement: Option<Replacement>,
fn_param_replacements: Vec<Replacement>,
) {
Expand All @@ -190,8 +191,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
return;
}

let fn_loc = func.lookup(db.upcast());
let fn_src = fn_loc.source(db.upcast());
let fn_loc = func.lookup(self.db.upcast());
let fn_src = fn_loc.source(self.db.upcast());

// Diagnostic for function name.
if let Some(replacement) = fn_name_replacement {
Expand Down Expand Up @@ -282,20 +283,19 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
fn create_incorrect_case_diagnostic_for_variables(
&mut self,
func: FunctionId,
db: &dyn HirDatabase,
pats_replacements: Vec<(PatId, Replacement)>,
) {
// XXX: only look at source_map if we do have missing fields
if pats_replacements.is_empty() {
return;
}

let (_, source_map) = db.body_with_source_map(func.into());
let (_, source_map) = self.db.body_with_source_map(func.into());

for (id, replacement) in pats_replacements {
if let Ok(source_ptr) = source_map.pat_syntax(id) {
if let Some(expr) = source_ptr.value.as_ref().left() {
let root = source_ptr.file_syntax(db.upcast());
let root = source_ptr.file_syntax(self.db.upcast());
if let ast::Pat::IdentPat(ident_pat) = expr.to_node(&root) {
let parent = match ident_pat.syntax().parent() {
Some(parent) => parent,
Expand Down Expand Up @@ -333,12 +333,11 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
}
}

fn validate_struct(&mut self, db: &dyn HirDatabase, struct_id: StructId) {
let data = db.struct_data(struct_id);
fn validate_struct(&mut self, struct_id: StructId) {
let data = self.db.struct_data(struct_id);

let non_camel_case_allowed =
self.allowed(db, struct_id.into(), allow::NON_CAMEL_CASE_TYPES);
let non_snake_case_allowed = self.allowed(db, struct_id.into(), allow::NON_SNAKE_CASE);
let non_camel_case_allowed = self.allowed(struct_id.into(), allow::NON_CAMEL_CASE_TYPES);
let non_snake_case_allowed = self.allowed(struct_id.into(), allow::NON_SNAKE_CASE);

// Check the structure name.
let struct_name = data.name.to_string();
Expand Down Expand Up @@ -379,7 +378,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
// If there is at least one element to spawn a warning on, go to the source map and generate a warning.
self.create_incorrect_case_diagnostic_for_struct(
struct_id,
db,
struct_name_replacement,
struct_fields_replacements,
);
Expand All @@ -390,7 +388,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
fn create_incorrect_case_diagnostic_for_struct(
&mut self,
struct_id: StructId,
db: &dyn HirDatabase,
struct_name_replacement: Option<Replacement>,
struct_fields_replacements: Vec<Replacement>,
) {
Expand All @@ -399,8 +396,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
return;
}

let struct_loc = struct_id.lookup(db.upcast());
let struct_src = struct_loc.source(db.upcast());
let struct_loc = struct_id.lookup(self.db.upcast());
let struct_src = struct_loc.source(self.db.upcast());

if let Some(replacement) = struct_name_replacement {
let ast_ptr = match struct_src.value.name() {
Expand Down Expand Up @@ -473,11 +470,11 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
}
}

fn validate_enum(&mut self, db: &dyn HirDatabase, enum_id: EnumId) {
let data = db.enum_data(enum_id);
fn validate_enum(&mut self, enum_id: EnumId) {
let data = self.db.enum_data(enum_id);

// Check whether non-camel case names are allowed for this enum.
if self.allowed(db, enum_id.into(), allow::NON_CAMEL_CASE_TYPES) {
if self.allowed(enum_id.into(), allow::NON_CAMEL_CASE_TYPES) {
return;
}

Expand Down Expand Up @@ -512,7 +509,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
// If there is at least one element to spawn a warning on, go to the source map and generate a warning.
self.create_incorrect_case_diagnostic_for_enum(
enum_id,
db,
enum_name_replacement,
enum_fields_replacements,
)
Expand All @@ -523,7 +519,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
fn create_incorrect_case_diagnostic_for_enum(
&mut self,
enum_id: EnumId,
db: &dyn HirDatabase,
enum_name_replacement: Option<Replacement>,
enum_variants_replacements: Vec<Replacement>,
) {
Expand All @@ -532,8 +527,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
return;
}

let enum_loc = enum_id.lookup(db.upcast());
let enum_src = enum_loc.source(db.upcast());
let enum_loc = enum_id.lookup(self.db.upcast());
let enum_src = enum_loc.source(self.db.upcast());

if let Some(replacement) = enum_name_replacement {
let ast_ptr = match enum_src.value.name() {
Expand Down Expand Up @@ -608,10 +603,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
}
}

fn validate_const(&mut self, db: &dyn HirDatabase, const_id: ConstId) {
let data = db.const_data(const_id);
fn validate_const(&mut self, const_id: ConstId) {
let data = self.db.const_data(const_id);

if self.allowed(db, const_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
if self.allowed(const_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
return;
}

Expand All @@ -632,8 +627,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
return;
};

let const_loc = const_id.lookup(db.upcast());
let const_src = const_loc.source(db.upcast());
let const_loc = const_id.lookup(self.db.upcast());
let const_src = const_loc.source(self.db.upcast());

let ast_ptr = match const_src.value.name() {
Some(name) => name,
Expand All @@ -652,14 +647,14 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
self.sink.push(diagnostic);
}

fn validate_static(&mut self, db: &dyn HirDatabase, static_id: StaticId) {
let data = db.static_data(static_id);
fn validate_static(&mut self, static_id: StaticId) {
let data = self.db.static_data(static_id);
if data.is_extern {
mark::hit!(extern_static_incorrect_case_ignored);
return;
}

if self.allowed(db, static_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
if self.allowed(static_id.into(), allow::NON_UPPER_CASE_GLOBAL) {
return;
}

Expand All @@ -680,8 +675,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> {
return;
};

let static_loc = static_id.lookup(db.upcast());
let static_src = static_loc.source(db.upcast());
let static_loc = static_id.lookup(self.db.upcast());
let static_src = static_loc.source(self.db.upcast());

let ast_ptr = match static_src.value.name() {
Some(name) => name,
Expand Down