diff --git a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs index 7d5740b748be..36df4af31d5e 100644 --- a/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs +++ b/crates/ide-assists/src/handlers/wrap_unwrap_cfg_attr.rs @@ -2,7 +2,7 @@ use ide_db::source_change::SourceChangeBuilder; use itertools::Itertools; use syntax::{ NodeOrToken, SyntaxToken, T, TextRange, algo, - ast::{self, AstNode, make, syntax_factory::SyntaxFactory}, + ast::{self, AstNode, edit::AstNodeEdit, make, syntax_factory::SyntaxFactory}, }; use crate::{AssistContext, AssistId, Assists}; @@ -27,7 +27,7 @@ use crate::{AssistContext, AssistId, Assists}; enum WrapUnwrapOption { WrapDerive { derive: TextRange, attr: ast::Attr }, - WrapAttr(ast::Attr), + WrapAttr(Vec), } /// Attempts to get the derive attribute from a derive attribute list @@ -102,9 +102,9 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption { if ident.parent().and_then(ast::TokenTree::cast).is_none() || !attr.simple_name().map(|v| v.eq("derive")).unwrap_or_default() { - WrapUnwrapOption::WrapAttr(attr) + WrapUnwrapOption::WrapAttr(vec![attr]) } else { - attempt_attr().unwrap_or(WrapUnwrapOption::WrapAttr(attr)) + attempt_attr().unwrap_or_else(|| WrapUnwrapOption::WrapAttr(vec![attr])) } } pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { @@ -118,13 +118,27 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - Some(attempt_get_derive(attr, ident)) } - (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(attr)), + (Some(attr), _) => Some(WrapUnwrapOption::WrapAttr(vec![attr])), _ => None, } } else { let covering_element = ctx.covering_element(); match covering_element { - NodeOrToken::Node(node) => ast::Attr::cast(node).map(WrapUnwrapOption::WrapAttr), + NodeOrToken::Node(node) => { + if let Some(attr) = ast::Attr::cast(node.clone()) { + Some(WrapUnwrapOption::WrapAttr(vec![attr])) + } else { + let attrs = node + .children() + .filter(|it| it.text_range().intersect(ctx.selection_trimmed()).is_some()) + .map(ast::Attr::cast) + .collect::>>()?; + if attrs.is_empty() { + return None; + } + Some(WrapUnwrapOption::WrapAttr(attrs)) + } + } NodeOrToken::Token(ident) if ident.kind() == syntax::T![ident] => { let attr = ident.parent_ancestors().find_map(ast::Attr::cast)?; Some(attempt_get_derive(attr, ident)) @@ -133,10 +147,12 @@ pub(crate) fn wrap_unwrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>) - } }?; match option { - WrapUnwrapOption::WrapAttr(attr) if attr.simple_name().as_deref() == Some("cfg_attr") => { - unwrap_cfg_attr(acc, attr) - } - WrapUnwrapOption::WrapAttr(attr) => wrap_cfg_attr(acc, ctx, attr), + WrapUnwrapOption::WrapAttr(attrs) => match &attrs[..] { + [attr] if attr.simple_name().as_deref() == Some("cfg_attr") => { + unwrap_cfg_attr(acc, attrs.into_iter().next().unwrap()) + } + _ => wrap_cfg_attrs(acc, ctx, attrs), + }, WrapUnwrapOption::WrapDerive { derive, attr } => wrap_derive(acc, ctx, attr, derive), } } @@ -220,40 +236,51 @@ fn wrap_derive( ); Some(()) } -fn wrap_cfg_attr(acc: &mut Assists, ctx: &AssistContext<'_>, attr: ast::Attr) -> Option<()> { - let range = attr.syntax().text_range(); - let path = attr.path()?; +fn wrap_cfg_attrs(acc: &mut Assists, ctx: &AssistContext<'_>, attrs: Vec) -> Option<()> { + let (first_attr, last_attr) = (attrs.first()?, attrs.last()?); + let range = first_attr.syntax().text_range().cover(last_attr.syntax().text_range()); + let path_attrs = + attrs.iter().map(|attr| Some((attr.path()?, attr.clone()))).collect::>>()?; let handle_source_change = |edit: &mut SourceChangeBuilder| { let make = SyntaxFactory::with_mappings(); - let mut editor = edit.make_editor(attr.syntax()); - let mut raw_tokens = - vec![NodeOrToken::Token(make.token(T![,])), NodeOrToken::Token(make.whitespace(" "))]; - path.syntax().descendants_with_tokens().for_each(|it| { - if let NodeOrToken::Token(token) = it { - raw_tokens.push(NodeOrToken::Token(token)); - } - }); - if let Some(meta) = attr.meta() { - if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) { - raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); - raw_tokens.push(NodeOrToken::Token(eq)); - raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); + let mut editor = edit.make_editor(first_attr.syntax()); + let mut raw_tokens = vec![]; + for (path, attr) in path_attrs { + raw_tokens.extend([ + NodeOrToken::Token(make.token(T![,])), + NodeOrToken::Token(make.whitespace(" ")), + ]); + path.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + if let Some(meta) = attr.meta() { + if let (Some(eq), Some(expr)) = (meta.eq_token(), meta.expr()) { + raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); + raw_tokens.push(NodeOrToken::Token(eq)); + raw_tokens.push(NodeOrToken::Token(make.whitespace(" "))); - expr.syntax().descendants_with_tokens().for_each(|it| { - if let NodeOrToken::Token(token) = it { - raw_tokens.push(NodeOrToken::Token(token)); - } - }); - } else if let Some(tt) = meta.token_tree() { - raw_tokens.extend(tt.token_trees_and_tokens()); + expr.syntax().descendants_with_tokens().for_each(|it| { + if let NodeOrToken::Token(token) = it { + raw_tokens.push(NodeOrToken::Token(token)); + } + }); + } else if let Some(tt) = meta.token_tree() { + raw_tokens.extend(tt.token_trees_and_tokens()); + } } } let meta = make.meta_token_tree(make.ident_path("cfg_attr"), make.token_tree(T!['('], raw_tokens)); - let cfg_attr = - if attr.excl_token().is_some() { make.attr_inner(meta) } else { make.attr_outer(meta) }; + let cfg_attr = if first_attr.excl_token().is_some() { + make.attr_inner(meta) + } else { + make.attr_outer(meta) + }; - editor.replace(attr.syntax(), cfg_attr.syntax()); + let syntax_range = first_attr.syntax().clone().into()..=last_attr.syntax().clone().into(); + editor.replace_all(syntax_range, vec![cfg_attr.syntax().clone().into()]); if let Some(snippet_cap) = ctx.config.snippet_cap && let Some(first_meta) = @@ -332,7 +359,8 @@ fn unwrap_cfg_attr(acc: &mut Assists, attr: ast::Attr) -> Option<()> { return None; } let handle_source_change = |f: &mut SourceChangeBuilder| { - let inner_attrs = inner_attrs.iter().map(|it| it.to_string()).join("\n"); + let inner_attrs = + inner_attrs.iter().map(|it| it.to_string()).join(&format!("\n{}", attr.indent_level())); f.replace(range, inner_attrs); }; acc.add( @@ -414,6 +442,42 @@ mod tests { } "#, ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[other_attr] + $0#[foo] + #[bar]$0 + #[other_attr] + test: u32, + } + "#, + r#" + pub struct Test { + #[other_attr] + #[cfg_attr($0, foo, bar)] + #[other_attr] + test: u32, + } + "#, + ); + check_assist( + wrap_unwrap_cfg_attr, + r#" + pub struct Test { + #[cfg_attr(debug_assertions$0, foo, bar)] + test: u32, + } + "#, + r#" + pub struct Test { + #[foo] + #[bar] + test: u32, + } + "#, + ); } #[test] fn to_from_eq_attr() {