From 2b7cf74aaacb0ccfcc3d7645d6f4ca7c3abed785 Mon Sep 17 00:00:00 2001 From: A4-Tacks Date: Fri, 29 Aug 2025 14:41:22 +0800 Subject: [PATCH] Add ReturnExpr support for add_return_type Example --- ```rust fn foo() { return 45$0; todo!() } ``` -> ```rust fn foo() -> i32 { return 45; todo!() } ``` --- .../src/handlers/add_return_type.rs | 122 +++++++++++++++--- 1 file changed, 101 insertions(+), 21 deletions(-) diff --git a/crates/ide-assists/src/handlers/add_return_type.rs b/crates/ide-assists/src/handlers/add_return_type.rs index a7104ce068da..996f52922e3e 100644 --- a/crates/ide-assists/src/handlers/add_return_type.rs +++ b/crates/ide-assists/src/handlers/add_return_type.rs @@ -1,3 +1,4 @@ +use either::Either; use hir::HirDisplay; use syntax::{AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize, ast, match_ast}; @@ -16,7 +17,8 @@ use crate::{AssistContext, AssistId, Assists}; // fn foo() -> i32 { 42i32 } // ``` pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?; + let (fn_type, tail_expr, builder_edit_pos) = + extract_ret(ctx.find_node_at_offset()?, Either::Right(ctx.selection_trimmed()))?; let module = ctx.sema.scope(tail_expr.syntax())?.module(); let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original(); if ty.is_unit() { @@ -132,9 +134,16 @@ fn peel_blocks(mut expr: ast::Expr) -> ast::Expr { expr } -fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> { - let (fn_type, tail_expr, return_type_range, action) = - if let Some(closure) = ctx.find_node_at_offset::() { +fn extract_ret( + node: Either>, + ret: Either, +) -> Option<(FnType, ast::Expr, InsertOrReplace)> { + let (fn_type, tail_expr, return_type_range, action) = match node { + Either::Left(ret_expr) => { + let node = ret_expr.syntax().ancestors().skip(1).find_map(AstNode::cast)?; + return extract_ret(node, Either::Left(ret_expr.expr()?)); + } + Either::Right(Either::Left(closure)) => { let rpipe = closure.param_list()?.syntax().last_token()?; let rpipe_pos = rpipe.text_range().end(); @@ -143,38 +152,41 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR let body = closure.body()?; let body_start = body.syntax().first_token()?.text_range().start(); let (tail_expr, wrap_expr) = match body { - ast::Expr::BlockExpr(block) => (block.tail_expr()?, false), - body => (body, true), + ast::Expr::BlockExpr(block) => (block.tail_expr(), false), + body => (Some(body), true), }; let ret_range = TextRange::new(rpipe_pos, body_start); (FnType::Closure { wrap_expr }, tail_expr, ret_range, action) - } else { - let func = ctx.find_node_at_offset::()?; - + } + Either::Right(Either::Right(func)) => { let rparen = func.param_list()?.r_paren_token()?; let rparen_pos = rparen.text_range().end(); let action = ret_ty_to_action(func.ret_type(), rparen)?; let body = func.body()?; let stmt_list = body.stmt_list()?; - let tail_expr = stmt_list.tail_expr()?; + let tail_expr = stmt_list.tail_expr(); let ret_range_end = stmt_list.l_curly_token()?.text_range().start(); let ret_range = TextRange::new(rparen_pos, ret_range_end); (FnType::Function, tail_expr, ret_range, action) - }; - let range = ctx.selection_trimmed(); - if return_type_range.contains_range(range) { - cov_mark::hit!(cursor_in_ret_position); - cov_mark::hit!(cursor_in_ret_position_closure); - } else if tail_expr.syntax().text_range().contains_range(range) { - cov_mark::hit!(cursor_on_tail); - cov_mark::hit!(cursor_on_tail_closure); - } else { - return None; + } + }; + if let Either::Right(&range) = ret.as_ref() { + if return_type_range.contains_range(range) { + cov_mark::hit!(cursor_in_ret_position); + cov_mark::hit!(cursor_in_ret_position_closure); + } else if let Some(tail_expr) = &tail_expr + && tail_expr.syntax().text_range().contains_range(range) + { + cov_mark::hit!(cursor_on_tail); + cov_mark::hit!(cursor_on_tail_closure); + } else { + return None; + } } - Some((fn_type, tail_expr, action)) + Some((fn_type, ret.left().or(tail_expr)?, action)) } #[cfg(test)] @@ -266,6 +278,41 @@ mod tests { ); } + #[test] + fn infer_return_type_return_expr() { + check_assist( + add_return_type, + r#"fn foo() { + return 45$0 +}"#, + r#"fn foo() -> i32 { + return 45 +}"#, + ); + + check_assist( + add_return_type, + r#"fn foo() { + return 45$0; +}"#, + r#"fn foo() -> i32 { + return 45; +}"#, + ); + + check_assist( + add_return_type, + r#"fn foo() { + return 45$0; + todo!() +}"#, + r#"fn foo() -> i32 { + return 45; + todo!() +}"#, + ); + } + #[test] fn infer_return_type_nested() { check_assist( @@ -369,6 +416,39 @@ mod tests { ); } + #[test] + fn infer_return_type_closure_return_expr() { + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| { return x$0 }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { return x }; +}"#, + ); + + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| { return x$0; }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { return x; }; +}"#, + ); + + check_assist( + add_return_type, + r#"fn foo() { + |x: i32| { return x$0; todo!() }; +}"#, + r#"fn foo() { + |x: i32| -> i32 { return x; todo!() }; +}"#, + ); + } + #[test] fn infer_return_type_closure_no_whitespace() { check_assist(