Skip to content

Commit 1a11056

Browse files
authored
checker: fix validation of lambda params and returns in generic function calls (#22387)
1 parent f24d239 commit 1a11056

File tree

5 files changed

+25
-5
lines changed

5 files changed

+25
-5
lines changed

vlib/v/ast/table.v

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ pub mut:
3737
panic_handler FnPanicHandler = default_table_panic_handler
3838
panic_userdata voidptr = unsafe { nil } // can be used to pass arbitrary data to panic_handler;
3939
panic_npanics int
40-
cur_fn &FnDecl = unsafe { nil } // previously stored in Checker.cur_fn and Gen.cur_fn
40+
cur_fn &FnDecl = unsafe { nil } // previously stored in Checker.cur_fn and Gen.cur_fn
41+
cur_lambda &LambdaExpr = unsafe { nil } // current lambda node
4142
cur_concrete_types []Type // current concrete types, e.g. <int, string>
4243
gostmts int // how many `go` statements there were in the parsed files.
4344
// When table.gostmts > 0, __VTHREADS__ is defined, which can be checked with `$if threads {`

vlib/v/checker/checker.v

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,6 +2735,13 @@ fn (mut c Checker) unwrap_generic(typ ast.Type) ast.Type {
27352735
{
27362736
return t_typ
27372737
}
2738+
if c.inside_lambda && c.table.cur_lambda.call_ctx != unsafe { nil } {
2739+
if t_typ := c.table.resolve_generic_to_concrete(typ, c.table.cur_lambda.func.decl.generic_names,
2740+
c.table.cur_lambda.call_ctx.concrete_types)
2741+
{
2742+
return t_typ
2743+
}
2744+
}
27382745
}
27392746
}
27402747
return typ
@@ -2971,8 +2978,10 @@ pub fn (mut c Checker) expr(mut node ast.Expr) ast.Type {
29712978
}
29722979
ast.LambdaExpr {
29732980
c.inside_lambda = true
2981+
c.table.cur_lambda = unsafe { &node }
29742982
defer {
29752983
c.inside_lambda = false
2984+
c.table.cur_lambda = unsafe { nil }
29762985
}
29772986
return c.lambda_expr(mut node, c.expected_type)
29782987
}

vlib/v/checker/fn.v

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,10 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
12341234
for i, mut call_arg in node.args {
12351235
if call_arg.expr is ast.CallExpr {
12361236
node.args[i].typ = c.expr(mut call_arg.expr)
1237+
} else if mut call_arg.expr is ast.LambdaExpr {
1238+
if node.concrete_types.len > 0 {
1239+
call_arg.expr.call_ctx = unsafe { node }
1240+
}
12371241
}
12381242
}
12391243
c.check_expected_arg_count(mut node, func) or { return func.return_type }
@@ -1608,7 +1612,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
16081612
}
16091613
if mut call_arg.expr is ast.LambdaExpr {
16101614
// Calling fn is generic and lambda arg also is generic
1611-
if node.concrete_types.len > 0
1615+
if node.concrete_types.len > 0 && call_arg.expr.func != unsafe { nil }
16121616
&& call_arg.expr.func.decl.generic_names.len > 0 {
16131617
call_arg.expr.call_ctx = unsafe { node }
16141618
if c.table.register_fn_concrete_types(call_arg.expr.func.decl.fkey(),

vlib/v/checker/return.v

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,8 @@ fn (mut c Checker) return_stmt(mut node ast.Return) {
261261
} else {
262262
got_type_sym.name
263263
}
264-
// ignore generic casting expr on lambda in this phase
265-
if c.inside_lambda && exp_type.has_flag(.generic)
266-
&& node.exprs[expr_idxs[i]] is ast.CastExpr {
264+
// ignore generic lambda return in this phase
265+
if c.inside_lambda && exp_type.has_flag(.generic) {
267266
continue
268267
}
269268
c.error('cannot use `${got_type_name}` as ${c.error_type_name(exp_type)} in return argument',
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import arrays
2+
3+
fn test_main() {
4+
items := ['item1', 'item2', 'item3']
5+
list := arrays.map_indexed[string, string](items, |i, item| '${i}. ${item}')
6+
assert list == ['0. item1', '1. item2', '2. item3']
7+
}

0 commit comments

Comments
 (0)