Skip to content

Commit

Permalink
Auto merge of #15810 - rmehri01:assist_panics_in_macros, r=Veykril
Browse files Browse the repository at this point in the history
fix: assists panic when trying to edit usage inside macro

When we try to make a syntax node mutable inside a macro to edit it, it seems like the edits aren't properly reflected and will cause a panic when trying to make another syntax node mutable.

This PR changes `bool_to_enum` and `promote_local_to_const` to use the original syntax range instead to edit the original file instead of the macro file. I'm not sure how to do it for `inline_call` with the example I mentioned in the issue, so I've left it out for now.

Fixes #15807
  • Loading branch information
bors committed Jan 2, 2024
2 parents a8d935e + b105e9b commit e461efb
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 70 deletions.
227 changes: 166 additions & 61 deletions crates/ide-assists/src/handlers/bool_to_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName,
},
ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
};
use text_edit::TextRange;

use crate::assist_context::{AssistContext, Assists};
use crate::{
assist_context::{AssistContext, Assists},
utils,
};

// Assist: bool_to_enum
//
Expand Down Expand Up @@ -73,7 +76,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option

let usages = definition.usages(&ctx.sema).all();
add_enum_def(edit, ctx, &usages, target_node, &target_module);
replace_usages(edit, ctx, &usages, definition, &target_module);
replace_usages(edit, ctx, usages, definition, &target_module);
},
)
}
Expand Down Expand Up @@ -169,8 +172,8 @@ fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {

/// Converts an expression of type `bool` to one of the new enum type.
fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
let true_expr = make::expr_path(make::path_from_text("Bool::True")).clone_for_update();
let false_expr = make::expr_path(make::path_from_text("Bool::False")).clone_for_update();
let true_expr = make::expr_path(make::path_from_text("Bool::True"));
let false_expr = make::expr_path(make::path_from_text("Bool::False"));

if let ast::Expr::Literal(literal) = &expr {
match literal.kind() {
Expand All @@ -184,66 +187,62 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
make::tail_only_block_expr(true_expr),
Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))),
)
.clone_for_update()
}
}

/// Replaces all usages of the target identifier, both when read and written to.
fn replace_usages(
edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>,
usages: &UsageSearchResult,
usages: UsageSearchResult,
target_definition: Definition,
target_module: &hir::Module,
) {
for (file_id, references) in usages.iter() {
edit.edit_file(*file_id);
for (file_id, references) in usages {
edit.edit_file(file_id);

let refs_with_imports =
augment_references_with_imports(edit, ctx, references, target_module);
let refs_with_imports = augment_references_with_imports(ctx, references, target_module);

refs_with_imports.into_iter().rev().for_each(
|FileReferenceWithImport { range, old_name, new_name, import_data }| {
|FileReferenceWithImport { range, name, import_data }| {
// replace the usages in patterns and expressions
if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast)
{
if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
cov_mark::hit!(replaces_record_pat_shorthand);

let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
if let Some(def) = definition {
replace_usages(
edit,
ctx,
&def.usages(&ctx.sema).all(),
def.usages(&ctx.sema).all(),
target_definition,
target_module,
)
}
} else if let Some(initializer) = find_assignment_usage(&new_name) {
} else if let Some(initializer) = find_assignment_usage(&name) {
cov_mark::hit!(replaces_assignment);

replace_bool_expr(edit, initializer);
} else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) {
} else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
cov_mark::hit!(replaces_negation);

edit.replace(
prefix_expr.syntax().text_range(),
format!("{} == Bool::False", inner_expr),
);
} else if let Some((record_field, initializer)) = old_name
} else if let Some((record_field, initializer)) = name
.as_name_ref()
.and_then(ast::RecordExprField::for_field_name)
.and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
.and_then(|(got_field, _, _)| {
find_record_expr_usage(&new_name, got_field, target_definition)
find_record_expr_usage(&name, got_field, target_definition)
})
{
cov_mark::hit!(replaces_record_expr);

let record_field = edit.make_mut(record_field);
let enum_expr = bool_expr_to_enum_expr(initializer);
record_field.replace_expr(enum_expr);
} else if let Some(pat) = find_record_pat_field_usage(&old_name) {
utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
} else if let Some(pat) = find_record_pat_field_usage(&name) {
match pat {
ast::Pat::IdentPat(ident_pat) => {
cov_mark::hit!(replaces_record_pat);
Expand All @@ -253,7 +252,7 @@ fn replace_usages(
replace_usages(
edit,
ctx,
&def.usages(&ctx.sema).all(),
def.usages(&ctx.sema).all(),
target_definition,
target_module,
)
Expand All @@ -270,40 +269,44 @@ fn replace_usages(
}
_ => (),
}
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
{
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
edit.replace(ty_annotation.syntax().text_range(), "Bool");
replace_bool_expr(edit, initializer);
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) {
} else if let Some(receiver) = find_method_call_expr_usage(&name) {
edit.replace(
receiver.syntax().text_range(),
format!("({} == Bool::True)", receiver),
);
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
} else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name
.as_name_ref()
.and_then(ast::RecordExprField::for_field_name)
.and_then(|record_field| {
record_field.expr().map(|expr| (record_field, expr))
})
if let Some((record_field, expr)) =
name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
|record_field| record_field.expr().map(|expr| (record_field, expr)),
)
{
record_field.replace_expr(
utils::replace_record_field_expr(
ctx,
edit,
record_field,
make::expr_bin_op(
expr,
ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
make::expr_path(make::path_from_text("Bool::True")),
)
.clone_for_update(),
),
);
} else {
edit.replace(range, format!("{} == Bool::True", new_name.text()));
edit.replace(range, format!("{} == Bool::True", name.text()));
}
}

// add imports across modules where needed
if let Some((import_scope, path)) = import_data {
insert_use(&import_scope, path, &ctx.config.insert_use);
let scope = match import_scope.clone() {
ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
};
insert_use(&scope, path, &ctx.config.insert_use);
}
},
)
Expand All @@ -312,37 +315,31 @@ fn replace_usages(

struct FileReferenceWithImport {
range: TextRange,
old_name: ast::NameLike,
new_name: ast::NameLike,
name: ast::NameLike,
import_data: Option<(ImportScope, ast::Path)>,
}

fn augment_references_with_imports(
edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>,
references: &[FileReference],
references: Vec<FileReference>,
target_module: &hir::Module,
) -> Vec<FileReferenceWithImport> {
let mut visited_modules = FxHashSet::default();

references
.iter()
.into_iter()
.filter_map(|FileReference { range, name, .. }| {
let name = name.clone().into_name_like()?;
ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module()))
ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
})
.map(|(range, name, ref_module)| {
let old_name = name.clone();
let new_name = edit.make_mut(name.clone());

// if the referenced module is not the same as the target one and has not been seen before, add an import
let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
&& !visited_modules.contains(&ref_module)
{
visited_modules.insert(ref_module);

let import_scope =
ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema);
let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
let path = ref_module
.find_use_path_prefixed(
ctx.sema.db,
Expand All @@ -360,7 +357,7 @@ fn augment_references_with_imports(
None
};

FileReferenceWithImport { range, old_name, new_name, import_data }
FileReferenceWithImport { range, name, import_data }
})
.collect()
}
Expand Down Expand Up @@ -405,13 +402,12 @@ fn find_record_expr_usage(
let record_field = ast::RecordExprField::for_field_name(name_ref)?;
let initializer = record_field.expr()?;

if let Definition::Field(expected_field) = target_definition {
if got_field != expected_field {
return None;
match target_definition {
Definition::Field(expected_field) if got_field == expected_field => {
Some((record_field, initializer))
}
_ => None,
}

Some((record_field, initializer))
}

fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
Expand Down Expand Up @@ -466,12 +462,9 @@ fn add_enum_def(
let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent);

ted::insert_all(
ted::Position::before(&edit.make_syntax_mut(insert_before)),
vec![
enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
],
edit.insert(
insert_before.text_range().start(),
format!("{}\n\n{indent}", enum_def.syntax().text()),
);
}

Expand Down Expand Up @@ -800,6 +793,78 @@ fn main() {
)
}

#[test]
fn local_var_init_struct_usage() {
check_assist(
bool_to_enum,
r#"
struct Foo {
foo: bool,
}
fn main() {
let $0foo = true;
let s = Foo { foo };
}
"#,
r#"
struct Foo {
foo: bool,
}
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() {
let foo = Bool::True;
let s = Foo { foo: foo == Bool::True };
}
"#,
)
}

#[test]
fn local_var_init_struct_usage_in_macro() {
check_assist(
bool_to_enum,
r#"
struct Struct {
boolean: bool,
}
macro_rules! identity {
($body:expr) => {
$body
}
}
fn new() -> Struct {
let $0boolean = true;
identity![Struct { boolean }]
}
"#,
r#"
struct Struct {
boolean: bool,
}
macro_rules! identity {
($body:expr) => {
$body
}
}
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn new() -> Struct {
let boolean = Bool::True;
identity![Struct { boolean: boolean == Bool::True }]
}
"#,
)
}

#[test]
fn field_struct_basic() {
cov_mark::check!(replaces_record_expr);
Expand Down Expand Up @@ -1321,6 +1386,46 @@ fn main() {
)
}

#[test]
fn field_in_macro() {
check_assist(
bool_to_enum,
r#"
struct Struct {
$0boolean: bool,
}
fn boolean(x: Struct) {
let Struct { boolean } = x;
}
macro_rules! identity { ($body:expr) => { $body } }
fn new() -> Struct {
identity!(Struct { boolean: true })
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Struct {
boolean: Bool,
}
fn boolean(x: Struct) {
let Struct { boolean } = x;
}
macro_rules! identity { ($body:expr) => { $body } }
fn new() -> Struct {
identity!(Struct { boolean: Bool::True })
}
"#,
)
}

#[test]
fn field_non_bool() {
cov_mark::check!(not_applicable_non_bool_field);
Expand Down

0 comments on commit e461efb

Please sign in to comment.