Skip to content

Commit cc14272

Browse files
authored
cgen, checker: fix generic variable resolution on generic func return assignment (#21712)
1 parent 53d7a55 commit cc14272

File tree

10 files changed

+222
-46
lines changed

10 files changed

+222
-46
lines changed

vlib/v/ast/ast.v

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ pub mut:
774774
receiver_type Type // User / T, if receiver is generic, then cgen requires receiver_type to be T
775775
receiver_concrete_type Type // if receiver_type is T, then receiver_concrete_type is concrete type, otherwise it is the same as receiver_type
776776
return_type Type
777+
return_type_generic Type // the original generic return type from fn def
777778
fn_var_type Type // the fn type, when `is_fn_a_const` or `is_fn_var` is true
778779
const_name string // the fully qualified name of the const, i.e. `main.c`, given `const c = abc`, and callexpr: `c()`
779780
should_be_skipped bool // true for calls to `[if someflag?]` functions, when there is no `-d someflag`
@@ -825,6 +826,7 @@ pub enum ComptimeVarKind {
825826
value_var // map value from `for k,v in t.$(field.name)`
826827
field_var // comptime field var `a := t.$(field.name)`
827828
generic_param // generic fn parameter
829+
generic_var // generic var
828830
smartcast // smart cast when used in `is v` (when `v` is from $for .variants)
829831
}
830832

vlib/v/checker/assign.v

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,16 @@ fn (mut c Checker) assign_stmt(mut node ast.AssignStmt) {
397397
&& right.expr is ast.ComptimeSelector {
398398
left.obj.ct_type_var = .field_var
399399
left.obj.typ = c.comptime.comptime_for_field_type
400+
} else if mut right is ast.CallExpr {
401+
if left.obj.ct_type_var == .no_comptime
402+
&& c.table.cur_fn != unsafe { nil }
403+
&& c.table.cur_fn.generic_names.len != 0
404+
&& !right.comptime_ret_val
405+
&& right.return_type_generic.has_flag(.generic)
406+
&& c.is_generic_expr(right) {
407+
// mark variable as generic var because its type changes according to fn return generic resolution type
408+
left.obj.ct_type_var = .generic_var
409+
}
400410
}
401411
}
402412
ast.GlobalField {

vlib/v/checker/fn.v

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,9 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
15091509
} else {
15101510
node.return_type = func.return_type
15111511
}
1512+
if func.return_type.has_flag(.generic) {
1513+
node.return_type_generic = func.return_type
1514+
}
15121515
if node.concrete_types.len > 0 && func.return_type != 0 && c.table.cur_fn != unsafe { nil }
15131516
&& c.table.cur_fn.generic_names.len == 0 {
15141517
if typ := c.table.resolve_generic_to_concrete(func.return_type, func.generic_names,
@@ -1582,6 +1585,27 @@ fn (mut c Checker) register_trace_call(node ast.CallExpr, func ast.Fn) {
15821585
}
15831586
}
15841587

1588+
// is_generic_expr checks if the expr relies on fn generic argument
1589+
fn (mut c Checker) is_generic_expr(node ast.Expr) bool {
1590+
return match node {
1591+
ast.Ident {
1592+
c.comptime.is_generic_param_var(node)
1593+
}
1594+
ast.IndexExpr {
1595+
c.comptime.is_generic_param_var(node.left)
1596+
}
1597+
ast.CallExpr {
1598+
node.args.any(c.comptime.is_generic_param_var(it.expr))
1599+
}
1600+
ast.SelectorExpr {
1601+
c.comptime.is_generic_param_var(node.expr)
1602+
}
1603+
else {
1604+
false
1605+
}
1606+
}
1607+
}
1608+
15851609
fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concrete_types []ast.Type) map[int]ast.Type {
15861610
mut comptime_args := map[int]ast.Type{}
15871611
has_dynamic_vars := (c.table.cur_fn != unsafe { nil } && c.table.cur_fn.generic_names.len > 0)
@@ -1602,7 +1626,7 @@ fn (mut c Checker) resolve_comptime_args(func ast.Fn, node_ ast.CallExpr, concre
16021626
param_typ := param.typ
16031627
if call_arg.expr is ast.Ident {
16041628
if call_arg.expr.obj is ast.Var {
1605-
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
1629+
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
16061630
mut ctyp := c.comptime.get_comptime_var_type(call_arg.expr)
16071631
if ctyp != ast.void_type {
16081632
arg_sym := c.table.sym(ctyp)
@@ -2159,6 +2183,9 @@ fn (mut c Checker) method_call(mut node ast.CallExpr) ast.Type {
21592183
node.is_noreturn = method.is_noreturn
21602184
node.is_ctor_new = method.is_ctor_new
21612185
node.return_type = method.return_type
2186+
if method.return_type.has_flag(.generic) {
2187+
node.return_type_generic = method.return_type
2188+
}
21622189
if !method.is_pub && method.mod != c.mod {
21632190
// If a private method is called outside of the module
21642191
// its receiver type is defined in, show an error.

vlib/v/comptime/comptimeinfo.v

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,33 +44,49 @@ pub fn (mut ct ComptimeInfo) get_ct_type_var(node ast.Expr) ast.ComptimeVarKind
4444
}
4545
}
4646

47+
@[inline]
48+
pub fn (mut ct ComptimeInfo) is_generic_param_var(node ast.Expr) bool {
49+
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
50+
&& (node.obj as ast.Var).ct_type_var == .generic_param
51+
}
52+
4753
// get_comptime_var_type retrieves the actual type from a comptime related ast node
4854
@[inline]
4955
pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type {
50-
if node is ast.Ident && node.obj is ast.Var {
51-
return match (node.obj as ast.Var).ct_type_var {
52-
.generic_param {
53-
// generic parameter from current function
54-
node.obj.typ
55-
}
56-
.smartcast {
57-
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
58-
return if (node.obj as ast.Var).is_unwrapped {
59-
ctyp.clear_flag(.option)
60-
} else {
61-
ctyp
56+
if node is ast.Ident {
57+
if node.obj is ast.Var {
58+
return match node.obj.ct_type_var {
59+
.generic_param {
60+
// generic parameter from current function
61+
node.obj.typ
62+
}
63+
.generic_var {
64+
// generic var used on fn call assignment
65+
if node.obj.smartcasts.len > 0 {
66+
node.obj.smartcasts.last()
67+
} else {
68+
ct.type_map['g.${node.name}.${node.obj.pos.pos}'] or { node.obj.typ }
69+
}
70+
}
71+
.smartcast {
72+
ctyp := ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
73+
return if (node.obj as ast.Var).is_unwrapped {
74+
ctyp.clear_flag(.option)
75+
} else {
76+
ctyp
77+
}
78+
}
79+
.key_var, .value_var {
80+
// key and value variables from normal for stmt
81+
ct.type_map[node.name] or { ast.void_type }
82+
}
83+
.field_var {
84+
// field var from $for loop
85+
ct.comptime_for_field_type
86+
}
87+
else {
88+
ast.void_type
6289
}
63-
}
64-
.key_var, .value_var {
65-
// key and value variables from normal for stmt
66-
ct.type_map[node.name] or { ast.void_type }
67-
}
68-
.field_var {
69-
// field var from $for loop
70-
ct.comptime_for_field_type
71-
}
72-
else {
73-
ast.void_type
7490
}
7591
}
7692
} else if node is ast.ComptimeSelector {

vlib/v/gen/c/assign.v

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
284284
}
285285
g.assign_ct_type = var_type
286286
} else if val is ast.IndexExpr {
287-
if val.left is ast.Ident && g.is_generic_param_var(val.left) {
287+
if val.left is ast.Ident && g.comptime.is_generic_param_var(val.left) {
288288
ctyp := g.unwrap_generic(g.get_gn_var_type(val.left))
289289
if ctyp != ast.void_type {
290290
var_type = ctyp
@@ -293,6 +293,17 @@ fn (mut g Gen) assign_stmt(node_ ast.AssignStmt) {
293293
g.assign_ct_type = var_type
294294
}
295295
}
296+
} else if left.obj.ct_type_var == .generic_var && val is ast.CallExpr {
297+
if val.return_type_generic != 0 && val.return_type_generic.has_flag(.generic) {
298+
fn_ret_type := g.resolve_fn_return_type(val)
299+
if fn_ret_type != ast.void_type {
300+
var_type = fn_ret_type
301+
val_type = var_type
302+
left.obj.typ = var_type
303+
g.comptime.type_map['g.${left.name}.${left.obj.pos.pos}'] = var_type
304+
// eprintln('>> ${func.name} > resolve ${left.name}.${left.obj.pos.pos}.generic to ${g.table.type_to_str(var_type)}')
305+
}
306+
}
296307
}
297308
is_auto_heap = left.obj.is_auto_heap
298309
}

vlib/v/gen/c/cgen.v

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4721,12 +4721,6 @@ fn (mut g Gen) select_expr(node ast.SelectExpr) {
47214721
}
47224722
}
47234723

4724-
@[inline]
4725-
pub fn (mut g Gen) is_generic_param_var(node ast.Expr) bool {
4726-
return node is ast.Ident && node.info is ast.IdentVar && node.obj is ast.Var
4727-
&& (node.obj as ast.Var).ct_type_var == .generic_param
4728-
}
4729-
47304724
fn (mut g Gen) get_const_name(node ast.Ident) string {
47314725
if g.pref.translated && !g.is_builtin_mod
47324726
&& !util.module_is_builtin(node.name.all_before_last('.')) {

vlib/v/gen/c/fn.v

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,10 +1105,6 @@ fn (mut g Gen) gen_to_str_method_call(node ast.CallExpr) bool {
11051105
rec_type = g.comptime.get_comptime_var_type(left_node)
11061106
g.gen_expr_to_string(left_node, rec_type)
11071107
return true
1108-
} else if g.comptime.type_map.len > 0 {
1109-
rec_type = left_node.obj.typ
1110-
g.gen_expr_to_string(left_node, rec_type)
1111-
return true
11121108
} else if left_node.obj.smartcasts.len > 0 {
11131109
rec_type = g.unwrap_generic(left_node.obj.smartcasts.last())
11141110
cast_sym := g.table.sym(rec_type)
@@ -1154,6 +1150,79 @@ fn (mut g Gen) get_gn_var_type(var ast.Ident) ast.Type {
11541150
return ast.void_type
11551151
}
11561152

1153+
// resolve_fn_return_type resolves the generic return type of fn
1154+
fn (mut g Gen) resolve_fn_return_type(node ast.CallExpr) ast.Type {
1155+
if node.is_method {
1156+
if func := g.table.find_method(g.table.sym(node.left_type), node.name) {
1157+
if func.generic_names.len > 0 {
1158+
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
1159+
mut rec_len := 0
1160+
if node.left_type.has_flag(.generic) {
1161+
rec_sym := g.table.final_sym(g.unwrap_generic(node.left_type))
1162+
match rec_sym.info {
1163+
ast.Struct, ast.Interface, ast.SumType {
1164+
rec_len += rec_sym.info.generic_types.len
1165+
}
1166+
else {}
1167+
}
1168+
}
1169+
1170+
mut call_ := unsafe { node }
1171+
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
1172+
if concrete_types.len > 0 {
1173+
for k, v in comptime_args {
1174+
if (rec_len + k) < concrete_types.len {
1175+
if !node.concrete_types[k].has_flag(.generic) {
1176+
concrete_types[rec_len + k] = g.unwrap_generic(v)
1177+
}
1178+
}
1179+
}
1180+
}
1181+
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
1182+
func.generic_names, concrete_types)
1183+
{
1184+
if !gen_type.has_flag(.generic) {
1185+
return if node.or_block.kind == .absent {
1186+
gen_type
1187+
} else {
1188+
gen_type.clear_option_and_result()
1189+
}
1190+
}
1191+
}
1192+
}
1193+
}
1194+
} else {
1195+
if func := g.table.find_fn(node.name) {
1196+
if func.generic_names.len > 0 {
1197+
mut concrete_types := node.concrete_types.map(g.unwrap_generic(it))
1198+
mut call_ := unsafe { node }
1199+
comptime_args := g.resolve_comptime_args(func, mut call_, concrete_types)
1200+
if concrete_types.len > 0 {
1201+
for k, v in comptime_args {
1202+
if k < concrete_types.len {
1203+
if !node.concrete_types[k].has_flag(.generic) {
1204+
concrete_types[k] = g.unwrap_generic(v)
1205+
}
1206+
}
1207+
}
1208+
}
1209+
if gen_type := g.table.resolve_generic_to_concrete(node.return_type_generic,
1210+
func.generic_names, concrete_types)
1211+
{
1212+
if !gen_type.has_flag(.generic) {
1213+
return if node.or_block.kind == .absent {
1214+
gen_type
1215+
} else {
1216+
gen_type.clear_option_and_result()
1217+
}
1218+
}
1219+
}
1220+
}
1221+
}
1222+
}
1223+
return ast.void_type
1224+
}
1225+
11571226
fn (g Gen) get_generic_array_element_type(array ast.Array) ast.Type {
11581227
mut cparam_elem_info := array as ast.Array
11591228
mut cparam_elem_sym := g.table.sym(cparam_elem_info.elem_type)
@@ -1194,7 +1263,7 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
11941263
if mut call_arg.expr is ast.Ident {
11951264
if mut call_arg.expr.obj is ast.Var {
11961265
node_.args[i].typ = call_arg.expr.obj.typ
1197-
if call_arg.expr.obj.ct_type_var !in [.generic_param, .no_comptime] {
1266+
if call_arg.expr.obj.ct_type_var !in [.generic_var, .generic_param, .no_comptime] {
11981267
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr)
11991268
if ctyp != ast.void_type {
12001269
arg_sym := g.table.sym(ctyp)
@@ -1293,11 +1362,13 @@ fn (mut g Gen) resolve_comptime_args(func ast.Fn, mut node_ ast.CallExpr, concre
12931362
comptime_args[k] = comptime_args[k].set_nr_muls(0)
12941363
}
12951364
} else if mut call_arg.expr.right is ast.Ident {
1296-
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
1297-
if ctyp != ast.void_type {
1298-
comptime_args[k] = ctyp
1299-
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
1300-
comptime_args[k] = comptime_args[k].set_nr_muls(0)
1365+
if g.comptime.get_ct_type_var(call_arg.expr.right) != .generic_var {
1366+
mut ctyp := g.comptime.get_comptime_var_type(call_arg.expr.right)
1367+
if ctyp != ast.void_type {
1368+
comptime_args[k] = ctyp
1369+
if param_typ.nr_muls() > 0 && comptime_args[k].nr_muls() > 0 {
1370+
comptime_args[k] = comptime_args[k].set_nr_muls(0)
1371+
}
13011372
}
13021373
}
13031374
}

vlib/v/gen/c/for.v

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,13 @@ fn (mut g Gen) for_in_stmt(node_ ast.ForInStmt) {
450450
g.writeln('${t_expr});')
451451
g.writeln('\tif (${t_var}.state != 0) break;')
452452
val := if node.val_var in ['', '_'] { g.new_tmp_var() } else { node.val_var }
453-
val_styp := g.typ(node.val_type)
453+
val_styp := g.typ(ret_typ.clear_option_and_result())
454454
if node.val_is_mut {
455-
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
455+
if ret_typ.has_flag(.option) {
456+
g.writeln('\t${val_styp}* ${val} = (${val_styp}*)${t_var}.data;')
457+
} else {
458+
g.writeln('\t${val_styp} ${val} = (${val_styp})${t_var}.data;')
459+
}
456460
} else {
457461
g.writeln('\t${val_styp} ${val} = *(${val_styp}*)${t_var}.data;')
458462
}

vlib/v/gen/c/str_intp.v

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,7 @@ fn (mut g Gen) str_val(node ast.StringInterLiteral, i int, fmts []u8) {
192192
mut exp_typ := typ
193193
if expr is ast.Ident {
194194
if expr.obj is ast.Var {
195-
if g.comptime.type_map.len > 0 || g.comptime.comptime_for_method.len > 0 {
196-
exp_typ = expr.obj.typ
197-
} else if expr.obj.smartcasts.len > 0 {
195+
if expr.obj.smartcasts.len > 0 {
198196
exp_typ = g.unwrap_generic(expr.obj.smartcasts.last())
199197
cast_sym := g.table.sym(exp_typ)
200198
if cast_sym.info is ast.Aggregate {

vlib/v/tests/generic_return_test.v

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
fn mkey[K, V](m map[K]V) K {
2+
return K{}
3+
}
4+
5+
fn mvalue[K, V](m map[K]V) V {
6+
return V{}
7+
}
8+
9+
fn aelem[E](a []E) E {
10+
return E{}
11+
}
12+
13+
fn g[T](x T) {
14+
$if T is $map {
15+
dk := mkey(x)
16+
dv := mvalue(x)
17+
eprintln('default k: `${dk}` | typeof dk: ${typeof(dk).name}')
18+
eprintln('default v: `${dv}` | typeof dv: ${typeof(dv).name}')
19+
for k, v in x {
20+
eprintln('> k: ${k} | v: ${v}')
21+
}
22+
}
23+
$if T is $array {
24+
de := aelem(x)
25+
eprintln('default e: `${de}` | typeof de: ${typeof(de).name}')
26+
for idx, e in x {
27+
eprintln('> idx: ${idx} | e: ${e}')
28+
}
29+
}
30+
}
31+
32+
fn test_main() {
33+
g({
34+
'abc': 123
35+
'def': 456
36+
})
37+
g([1, 2, 3])
38+
g({
39+
123: 'ggg'
40+
456: 'hhh'
41+
})
42+
g(['xyz', 'zzz'])
43+
}

0 commit comments

Comments
 (0)