Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 101 additions & 21 deletions crates/ide-assists/src/handlers/add_return_type.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use either::Either;
use hir::HirDisplay;
use syntax::{AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize, ast, match_ast};

Expand All @@ -16,7 +17,8 @@ use crate::{AssistContext, AssistId, Assists};
// fn foo() -> i32 { 42i32 }
// ```
pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
let (fn_type, tail_expr, builder_edit_pos) =
extract_ret(ctx.find_node_at_offset()?, Either::Right(ctx.selection_trimmed()))?;
let module = ctx.sema.scope(tail_expr.syntax())?.module();
let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original();
if ty.is_unit() {
Expand Down Expand Up @@ -132,9 +134,16 @@ fn peel_blocks(mut expr: ast::Expr) -> ast::Expr {
expr
}

fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
let (fn_type, tail_expr, return_type_range, action) =
if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
fn extract_ret(
node: Either<ast::ReturnExpr, Either<ast::ClosureExpr, ast::Fn>>,
ret: Either<ast::Expr, TextRange>,
) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
let (fn_type, tail_expr, return_type_range, action) = match node {
Either::Left(ret_expr) => {
let node = ret_expr.syntax().ancestors().skip(1).find_map(AstNode::cast)?;
return extract_ret(node, Either::Left(ret_expr.expr()?));
}
Either::Right(Either::Left(closure)) => {
let rpipe = closure.param_list()?.syntax().last_token()?;
let rpipe_pos = rpipe.text_range().end();

Expand All @@ -143,38 +152,41 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR
let body = closure.body()?;
let body_start = body.syntax().first_token()?.text_range().start();
let (tail_expr, wrap_expr) = match body {
ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
body => (body, true),
ast::Expr::BlockExpr(block) => (block.tail_expr(), false),
body => (Some(body), true),
};

let ret_range = TextRange::new(rpipe_pos, body_start);
(FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
} else {
let func = ctx.find_node_at_offset::<ast::Fn>()?;

}
Either::Right(Either::Right(func)) => {
let rparen = func.param_list()?.r_paren_token()?;
let rparen_pos = rparen.text_range().end();
let action = ret_ty_to_action(func.ret_type(), rparen)?;

let body = func.body()?;
let stmt_list = body.stmt_list()?;
let tail_expr = stmt_list.tail_expr()?;
let tail_expr = stmt_list.tail_expr();

let ret_range_end = stmt_list.l_curly_token()?.text_range().start();
let ret_range = TextRange::new(rparen_pos, ret_range_end);
(FnType::Function, tail_expr, ret_range, action)
};
let range = ctx.selection_trimmed();
if return_type_range.contains_range(range) {
cov_mark::hit!(cursor_in_ret_position);
cov_mark::hit!(cursor_in_ret_position_closure);
} else if tail_expr.syntax().text_range().contains_range(range) {
cov_mark::hit!(cursor_on_tail);
cov_mark::hit!(cursor_on_tail_closure);
} else {
return None;
}
};
if let Either::Right(&range) = ret.as_ref() {
if return_type_range.contains_range(range) {
cov_mark::hit!(cursor_in_ret_position);
cov_mark::hit!(cursor_in_ret_position_closure);
} else if let Some(tail_expr) = &tail_expr
&& tail_expr.syntax().text_range().contains_range(range)
{
cov_mark::hit!(cursor_on_tail);
cov_mark::hit!(cursor_on_tail_closure);
} else {
return None;
}
}
Some((fn_type, tail_expr, action))
Some((fn_type, ret.left().or(tail_expr)?, action))
}

#[cfg(test)]
Expand Down Expand Up @@ -266,6 +278,41 @@ mod tests {
);
}

#[test]
fn infer_return_type_return_expr() {
check_assist(
add_return_type,
r#"fn foo() {
return 45$0
}"#,
r#"fn foo() -> i32 {
return 45
}"#,
);

check_assist(
add_return_type,
r#"fn foo() {
return 45$0;
}"#,
r#"fn foo() -> i32 {
return 45;
}"#,
);

check_assist(
add_return_type,
r#"fn foo() {
return 45$0;
todo!()
}"#,
r#"fn foo() -> i32 {
return 45;
todo!()
}"#,
);
}

#[test]
fn infer_return_type_nested() {
check_assist(
Expand Down Expand Up @@ -369,6 +416,39 @@ mod tests {
);
}

#[test]
fn infer_return_type_closure_return_expr() {
check_assist(
add_return_type,
r#"fn foo() {
|x: i32| { return x$0 };
}"#,
r#"fn foo() {
|x: i32| -> i32 { return x };
}"#,
);

check_assist(
add_return_type,
r#"fn foo() {
|x: i32| { return x$0; };
}"#,
r#"fn foo() {
|x: i32| -> i32 { return x; };
}"#,
);

check_assist(
add_return_type,
r#"fn foo() {
|x: i32| { return x$0; todo!() };
}"#,
r#"fn foo() {
|x: i32| -> i32 { return x; todo!() };
}"#,
);
}

#[test]
fn infer_return_type_closure_no_whitespace() {
check_assist(
Expand Down
Loading