Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 110 additions & 38 deletions crates/assists/src/handlers/extract_struct_from_enum_variant.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::iter;

use either::Either;
use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
use ide_db::{defs::Definition, search::Reference, RootDatabase};
use rustc_hash::{FxHashMap, FxHashSet};
Expand Down Expand Up @@ -31,40 +34,32 @@ pub(crate) fn extract_struct_from_enum_variant(
ctx: &AssistContext,
) -> Option<()> {
let variant = ctx.find_node_at_offset::<ast::Variant>()?;
let field_list = match variant.kind() {
ast::StructKind::Tuple(field_list) => field_list,
_ => return None,
};

// skip 1-tuple variants
if field_list.fields().count() == 1 {
return None;
}
let field_list = extract_field_list_if_applicable(&variant)?;

let variant_name = variant.name()?;
let variant_hir = ctx.sema.to_def(&variant)?;
if existing_struct_def(ctx.db(), &variant_name, &variant_hir) {
if existing_definition(ctx.db(), &variant_name, &variant_hir) {
return None;
}

let enum_ast = variant.parent_enum();
let visibility = enum_ast.visibility();
let enum_hir = ctx.sema.to_def(&enum_ast)?;
let variant_hir_name = variant_hir.name(ctx.db());
let enum_module_def = ModuleDef::from(enum_hir);
let current_module = enum_hir.module(ctx.db());
let target = variant.syntax().text_range();
acc.add(
AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
"Extract struct from enum variant",
target,
|builder| {
let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir));
let res = definition.usages(&ctx.sema).all();
let variant_hir_name = variant_hir.name(ctx.db());
let enum_module_def = ModuleDef::from(enum_hir);
let usages =
Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)).usages(&ctx.sema).all();

let mut visited_modules_set = FxHashSet::default();
let current_module = enum_hir.module(ctx.db());
visited_modules_set.insert(current_module);
let mut rewriters = FxHashMap::default();
for reference in res {
for reference in usages {
let rewriter = rewriters
.entry(reference.file_range.file_id)
.or_insert_with(SyntaxRewriter::default);
Expand All @@ -86,26 +81,49 @@ pub(crate) fn extract_struct_from_enum_variant(
builder.rewrite(rewriter);
}
builder.edit_file(ctx.frange.file_id);
update_variant(&mut rewriter, &variant_name, &field_list);
update_variant(&mut rewriter, &variant);
extract_struct_def(
&mut rewriter,
&enum_ast,
variant_name.clone(),
&field_list,
&variant.parent_enum().syntax().clone().into(),
visibility,
enum_ast.visibility(),
);
builder.rewrite(rewriter);
},
)
}

fn existing_struct_def(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
fn extract_field_list_if_applicable(
variant: &ast::Variant,
) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
match variant.kind() {
ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
Some(Either::Left(field_list))
}
ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
Some(Either::Right(field_list))
}
_ => None,
}
}

fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
variant
.parent_enum(db)
.module(db)
.scope(db, None)
.into_iter()
.filter(|(_, def)| match def {
// only check type-namespace
hir::ScopeDef::ModuleDef(def) => matches!(def,
ModuleDef::Module(_) | ModuleDef::Adt(_) |
ModuleDef::EnumVariant(_) | ModuleDef::Trait(_) |
ModuleDef::TypeAlias(_) | ModuleDef::BuiltinType(_)
),
_ => false,
})
.any(|(name, _)| name == variant_name.as_name())
}

Expand Down Expand Up @@ -133,19 +151,29 @@ fn extract_struct_def(
rewriter: &mut SyntaxRewriter,
enum_: &ast::Enum,
variant_name: ast::Name,
variant_list: &ast::TupleFieldList,
field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
start_offset: &SyntaxElement,
visibility: Option<ast::Visibility>,
) -> Option<()> {
let variant_list = make::tuple_field_list(
variant_list
.fields()
.flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))),
);
let pub_vis = Some(make::visibility_pub());
let field_list = match field_list {
Either::Left(field_list) => {
make::record_field_list(field_list.fields().flat_map(|field| {
Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
}))
.into()
}
Either::Right(field_list) => make::tuple_field_list(
field_list
.fields()
.flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
)
.into(),
};

rewriter.insert_before(
start_offset,
make::struct_(visibility, variant_name, None, variant_list.into()).syntax(),
make::struct_(visibility, variant_name, None, field_list).syntax(),
);
rewriter.insert_before(start_offset, &make::tokens::blank_line());

Expand All @@ -156,15 +184,14 @@ fn extract_struct_def(
Some(())
}

fn update_variant(
rewriter: &mut SyntaxRewriter,
variant_name: &ast::Name,
field_list: &ast::TupleFieldList,
) -> Option<()> {
let (l, r): (SyntaxElement, SyntaxElement) =
(field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into());
let replacement = vec![l, variant_name.syntax().clone().into(), r];
rewriter.replace_with_many(field_list.syntax(), replacement);
fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
let name = variant.name()?;
let tuple_field = make::tuple_field(None, make::ty(name.text()));
let replacement = make::variant(
name,
Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
);
rewriter.replace(variant.syntax(), replacement.syntax());
Some(())
}

Expand Down Expand Up @@ -211,12 +238,47 @@ mod tests {
use super::*;

#[test]
fn test_extract_struct_several_fields() {
fn test_extract_struct_several_fields_tuple() {
check_assist(
extract_struct_from_enum_variant,
"enum A { <|>One(u32, u32) }",
r#"struct One(pub u32, pub u32);

enum A { One(One) }"#,
);
}

#[test]
fn test_extract_struct_several_fields_named() {
check_assist(
extract_struct_from_enum_variant,
"enum A { <|>One { foo: u32, bar: u32 } }",
r#"struct One{ pub foo: u32, pub bar: u32 }

enum A { One(One) }"#,
);
}

#[test]
fn test_extract_struct_one_field_named() {
check_assist(
extract_struct_from_enum_variant,
"enum A { <|>One { foo: u32 } }",
r#"struct One{ pub foo: u32 }

enum A { One(One) }"#,
);
}

#[test]
fn test_extract_enum_variant_name_value_namespace() {
check_assist(
extract_struct_from_enum_variant,
r#"const One: () = ();
enum A { <|>One(u32, u32) }"#,
r#"const One: () = ();
struct One(pub u32, pub u32);

enum A { One(One) }"#,
);
}
Expand Down Expand Up @@ -298,12 +360,22 @@ fn another_fn() {
fn test_extract_enum_not_applicable_if_struct_exists() {
check_not_applicable(
r#"struct One;
enum A { <|>One(u8) }"#,
enum A { <|>One(u8, u32) }"#,
);
}

#[test]
fn test_extract_not_applicable_one_field() {
check_not_applicable(r"enum A { <|>One(u32) }");
}

#[test]
fn test_extract_not_applicable_no_field_tuple() {
check_not_applicable(r"enum A { <|>None() }");
}

#[test]
fn test_extract_not_applicable_no_field_named() {
check_not_applicable(r"enum A { <|>None {} }");
}
}
3 changes: 2 additions & 1 deletion crates/ide/src/diagnostics/fixes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ fn missing_record_expr_field_fix(
return None;
}
let new_field = make::record_field(
record_expr_field.field_name()?,
None,
make::name(record_expr_field.field_name()?.text()),
make::ty(&new_field_type.display_source_code(sema.db, module.into()).ok()?),
);

Expand Down
27 changes: 25 additions & 2 deletions crates/syntax/src/ast/make.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,16 @@ pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::Re
}
}

pub fn record_field(name: ast::NameRef, ty: ast::Type) -> ast::RecordField {
ast_from_text(&format!("struct S {{ {}: {}, }}", name, ty))
pub fn record_field(
visibility: Option<ast::Visibility>,
name: ast::Name,
ty: ast::Type,
) -> ast::RecordField {
let visibility = match visibility {
None => String::new(),
Some(it) => format!("{} ", it),
};
ast_from_text(&format!("struct S {{ {}{}: {}, }}", visibility, name, ty))
}

pub fn block_expr(
Expand Down Expand Up @@ -360,6 +368,13 @@ pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> as
ast_from_text(&format!("struct f({});", fields))
}

pub fn record_field_list(
fields: impl IntoIterator<Item = ast::RecordField>,
) -> ast::RecordFieldList {
let fields = fields.into_iter().join(", ");
ast_from_text(&format!("struct f {{ {} }}", fields))
}

pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField {
let visibility = match visibility {
None => String::new(),
Expand All @@ -368,6 +383,14 @@ pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::T
ast_from_text(&format!("struct f({}{});", visibility, ty))
}

pub fn variant(name: ast::Name, field_list: Option<ast::FieldList>) -> ast::Variant {
let field_list = match field_list {
None => String::new(),
Some(it) => format!("{}", it),
};
ast_from_text(&format!("enum f {{ {}{} }}", name, field_list))
}

pub fn fn_(
visibility: Option<ast::Visibility>,
fn_name: ast::Name,
Expand Down