Skip to content

Commit ebf629d

Browse files
authored
ast,checker: improve type checking for sumtypes with generics (fix #25690) (#25699)
1 parent 8ef3db5 commit ebf629d

File tree

7 files changed

+208
-21
lines changed

7 files changed

+208
-21
lines changed

vlib/v/ast/table.v

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,9 @@ pub fn (mut t Table) register_sym(sym TypeSymbol) int {
937937
...sym
938938
}
939939
t.type_symbols[idx].idx = idx
940+
if t.type_symbols[idx].ngname == '' {
941+
t.type_symbols[idx].ngname = strip_generic_params(sym.name)
942+
}
940943
t.type_idxs[sym_name] = idx
941944
return idx
942945
}
@@ -960,6 +963,11 @@ pub fn (t &Table) known_type(name string) bool {
960963
return t.type_idxs[name] != 0 || t.parsing_type == name || name in ['i32', 'byte']
961964
}
962965

966+
@[inline]
967+
pub fn strip_generic_params(name string) string {
968+
return name.all_before('[')
969+
}
970+
963971
// start_parsing_type open the scope during the parsing of a type
964972
// where the type name must include the module prefix
965973
pub fn (mut t Table) start_parsing_type(type_name string) {
@@ -1178,6 +1186,7 @@ pub fn (mut t Table) find_or_register_chan(elem_type Type, is_mut bool) int {
11781186
kind: .chan
11791187
name: name
11801188
cname: cname
1189+
ngname: strip_generic_params(name)
11811190
info: Chan{
11821191
elem_type: elem_type
11831192
is_mut: is_mut
@@ -1199,6 +1208,7 @@ pub fn (mut t Table) find_or_register_map(key_type Type, value_type Type) int {
11991208
kind: .map
12001209
name: name
12011210
cname: cname
1211+
ngname: strip_generic_params(name)
12021212
info: Map{
12031213
key_type: key_type
12041214
value_type: value_type
@@ -1220,6 +1230,7 @@ pub fn (mut t Table) find_or_register_thread(return_type Type) int {
12201230
kind: .thread
12211231
name: name
12221232
cname: cname
1233+
ngname: strip_generic_params(name)
12231234
info: Thread{
12241235
return_type: return_type
12251236
}
@@ -1242,6 +1253,7 @@ pub fn (mut t Table) find_or_register_promise(return_type Type) int {
12421253
kind: .struct
12431254
name: name
12441255
cname: cname
1256+
ngname: strip_generic_params(name)
12451257
info: Struct{
12461258
concrete_types: [return_type, t.type_idxs['JS.Any']]
12471259
}
@@ -1265,6 +1277,7 @@ pub fn (mut t Table) find_or_register_array(elem_type Type) int {
12651277
kind: .array
12661278
name: name
12671279
cname: cname
1280+
ngname: strip_generic_params(name)
12681281
info: Array{
12691282
nr_dims: 1
12701283
elem_type: elem_type
@@ -1292,10 +1305,11 @@ pub fn (mut t Table) find_or_register_array_fixed(elem_type Type, size int, size
12921305
cname := prefix + t.array_fixed_cname(elem_type, size)
12931306
// register
12941307
array_fixed_type := TypeSymbol{
1295-
kind: .array_fixed
1296-
name: name
1297-
cname: cname
1298-
info: ArrayFixed{
1308+
kind: .array_fixed
1309+
name: name
1310+
cname: cname
1311+
ngname: strip_generic_params(name)
1312+
info: ArrayFixed{
12991313
elem_type: elem_type
13001314
size: size
13011315
size_expr: size_expr
@@ -1328,10 +1342,11 @@ pub fn (mut t Table) find_or_register_multi_return(mr_typs []Type) int {
13281342
return existing_idx
13291343
}
13301344
multireg_sym := TypeSymbol{
1331-
kind: .multi_return
1332-
name: name
1333-
cname: cname
1334-
info: MultiReturn{
1345+
kind: .multi_return
1346+
name: name
1347+
cname: cname
1348+
ngname: strip_generic_params(name)
1349+
info: MultiReturn{
13351350
types: mr_typs
13361351
}
13371352
}
@@ -1354,18 +1369,57 @@ pub fn (mut t Table) find_or_register_fn_type(f Fn, is_anon bool, has_decl bool)
13541369
return existing_idx
13551370
}
13561371
return t.register_sym(
1357-
kind: .function
1358-
name: name
1359-
cname: cname
1360-
mod: f.mod
1361-
info: FnType{
1372+
kind: .function
1373+
name: name
1374+
cname: cname
1375+
ngname: strip_generic_params(name)
1376+
mod: f.mod
1377+
info: FnType{
13621378
is_anon: anon
13631379
has_decl: has_decl
13641380
func: f
13651381
}
13661382
)
13671383
}
13681384

1385+
pub fn (mut t Table) find_or_register_generic_inst(parent_typ Type, concrete_types []Type) int {
1386+
parent_sym := t.sym(parent_typ)
1387+
if parent_sym.info !is Struct {
1388+
return 0
1389+
}
1390+
struct_info := parent_sym.info as Struct
1391+
if struct_info.generic_types.len == 0 || concrete_types.len != struct_info.generic_types.len {
1392+
return 0
1393+
}
1394+
mut inst_name := parent_sym.ngname + '['
1395+
mut inst_cname := parent_sym.cname + '_T_'
1396+
for i, ct in concrete_types {
1397+
ct_sym := t.sym(ct)
1398+
inst_name += ct_sym.name
1399+
inst_cname += ct_sym.cname
1400+
if i < concrete_types.len - 1 {
1401+
inst_name += ', '
1402+
inst_cname += '_T_'
1403+
}
1404+
}
1405+
inst_name += ']'
1406+
existing_idx := t.type_idxs[inst_name]
1407+
if existing_idx > 0 {
1408+
return existing_idx
1409+
}
1410+
return t.register_sym(
1411+
kind: .generic_inst
1412+
name: inst_name
1413+
cname: inst_cname
1414+
ngname: parent_sym.ngname
1415+
mod: parent_sym.mod
1416+
info: GenericInst{
1417+
parent_idx: parent_typ.idx()
1418+
concrete_types: concrete_types
1419+
}
1420+
)
1421+
}
1422+
13691423
pub fn (mut t Table) add_placeholder_type(name string, cname string, language Language) int {
13701424
mut modname := ''
13711425
if name.contains('.') {
@@ -1375,6 +1429,7 @@ pub fn (mut t Table) add_placeholder_type(name string, cname string, language La
13751429
kind: .placeholder
13761430
name: name
13771431
cname: util.no_dots(cname).replace_each(['&', ''])
1432+
ngname: strip_generic_params(name)
13781433
language: language
13791434
mod: modname
13801435
is_pub: true
@@ -1893,7 +1948,7 @@ pub fn (mut t Table) convert_generic_type(generic_type Type, generic_names []str
18931948
if sym.info.is_generic {
18941949
mut nrt := '${sym.name}['
18951950
mut rnrt := '${sym.rname}['
1896-
mut cnrt := '${sym.cname}['
1951+
mut cnrt := '${sym.cname}_T_'
18971952
mut t_generic_names := generic_names.clone()
18981953
mut t_to_types := to_types.clone()
18991954
if sym.generic_types.len > 0 && sym.generic_types.len == sym.info.generic_types.len
@@ -1930,15 +1985,14 @@ pub fn (mut t Table) convert_generic_type(generic_type Type, generic_names []str
19301985
if i != sym.info.generic_types.len - 1 {
19311986
nrt += ', '
19321987
rnrt += ', '
1933-
cnrt += ', '
1988+
cnrt += '_'
19341989
}
19351990
} else {
19361991
return none
19371992
}
19381993
}
19391994
nrt += ']'
19401995
rnrt += ']'
1941-
cnrt += ']'
19421996
mut idx := t.type_idxs[nrt]
19431997
if idx == 0 {
19441998
idx = t.type_idxs[rnrt]

vlib/v/ast/types.v

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ pub mut:
115115
name string // the internal & source name of the type, i.e. `[5]int`.
116116
cname string // the name with no dots for use in the generated C code
117117
rname string // the raw name
118+
ngname string // the name without generic parameters
118119
methods []Fn
119120
generic_types []Type
120121
mod string

vlib/v/checker/check_types.v

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,21 @@ fn (mut c Checker) check_basic(got ast.Type, expected ast.Type) bool {
460460
if c.table.sumtype_has_variant(expected, ast.mktyp(got), false) {
461461
return true
462462
}
463+
if exp_sym.kind == .placeholder && c.expected_type != ast.void_type {
464+
base_type := c.table.find_type(exp_sym.ngname)
465+
if base_type != 0 {
466+
base_sym := c.table.sym(base_type)
467+
if base_sym.kind == .sum_type && base_sym.info is ast.SumType {
468+
base_info := base_sym.info as ast.SumType
469+
for variant in base_info.variants {
470+
variant_sym := c.table.sym(variant)
471+
if variant_sym.ngname == got_sym.ngname {
472+
return true
473+
}
474+
}
475+
}
476+
}
477+
}
463478
// struct
464479
if exp_sym.kind == .struct && got_sym.kind == .struct {
465480
if c.table.type_to_str(expected) == c.table.type_to_str(got) {
@@ -950,6 +965,37 @@ fn (mut c Checker) infer_struct_generic_types(typ ast.Type, node ast.StructInit)
950965
}
951966
}
952967
}
968+
} else if field_sym.info is ast.SumType {
969+
for t in node.init_fields {
970+
if ft.name == t.name && t.typ != 0 {
971+
init_sym := c.table.sym(t.typ)
972+
for variant in field_sym.info.variants {
973+
variant_sym := c.table.sym(variant)
974+
if variant_sym.name == init_sym.name {
975+
if variant_sym.info is ast.Struct
976+
&& variant_sym.info.generic_types.len > 0 {
977+
if init_sym.info is ast.Struct
978+
&& init_sym.info.concrete_types.len > 0 {
979+
concrete_types << ast.mktyp(init_sym.info.concrete_types[0])
980+
continue gname
981+
}
982+
} else {
983+
for init_field in node.init_fields {
984+
if init_field.name != t.name && init_field.typ != 0 {
985+
field := sym.info.fields.filter(it.name == init_field.name)
986+
if field.len > 0 {
987+
if c.table.sym(field[0].typ).name == gt_name {
988+
concrete_types << ast.mktyp(init_field.typ)
989+
continue gname
990+
}
991+
}
992+
}
993+
}
994+
}
995+
}
996+
}
997+
}
998+
}
953999
}
9541000
}
9551001
c.error('could not infer generic type `${gt_name}` in generic struct `${sym.name}[${generic_names.join(', ')}]`',

vlib/v/checker/fn.v

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,6 +1569,33 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
15691569
node.args[i].typ = call_arg.expr.obj.typ
15701570
}
15711571
}
1572+
// sumtype coercion
1573+
param_type_sym := c.table.sym(param.typ)
1574+
if param_type_sym.kind == .placeholder {
1575+
base_type := c.table.find_type(param_type_sym.ngname)
1576+
if base_type != 0 {
1577+
base_sym := c.table.sym(base_type)
1578+
if base_sym.kind == .sum_type && base_sym.info is ast.SumType {
1579+
base_info := base_sym.info as ast.SumType
1580+
arg_typ_sym := c.table.sym(arg_typ)
1581+
for variant in base_info.variants {
1582+
variant_sym := c.table.sym(variant)
1583+
variant_base_name := variant_sym.ngname
1584+
if variant_base_name == arg_typ_sym.ngname {
1585+
node.args[i].expr = ast.CastExpr{
1586+
expr: call_arg.expr
1587+
typ: param.typ
1588+
typname: c.table.type_to_str(param.typ)
1589+
pos: call_arg.expr.pos()
1590+
}
1591+
node.args[i].typ = param.typ
1592+
arg_typ = param.typ
1593+
break
1594+
}
1595+
}
1596+
}
1597+
}
1598+
}
15721599
arg_typ_sym := c.table.sym(arg_typ)
15731600
if param.typ.has_flag(.generic) {
15741601
if arg_typ_sym.kind == .none && !param.typ.has_flag(.option) {

vlib/v/checker/struct.v

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,14 @@ fn (mut c Checker) struct_init(mut node ast.StructInit, is_field_zero_struct_ini
700700
mut exp_type := ast.no_type
701701
inited_fields << field_name
702702
exp_type = field_info.typ
703+
if c.inside_generic_struct_init && exp_type.has_flag(.generic) {
704+
generic_names := c.cur_struct_generic_types.map(c.table.sym(it).name)
705+
if unwrapped := c.table.convert_generic_type(exp_type, generic_names,
706+
c.cur_struct_concrete_types)
707+
{
708+
exp_type = unwrapped
709+
}
710+
}
703711
exp_type_sym := c.table.sym(exp_type)
704712
c.expected_type = exp_type
705713
got_type = c.expr(mut init_field.expr)
@@ -783,9 +791,35 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.',
783791
}
784792
} else if got_type != ast.void_type && got_type_sym.kind != .placeholder
785793
&& !exp_type.has_flag(.generic) {
786-
c.check_expected(c.unwrap_generic(got_type), c.unwrap_generic(exp_type)) or {
787-
c.error('cannot assign to field `${field_info.name}`: ${err.msg()}',
788-
init_field.pos)
794+
mut needs_sum_type_cast := false
795+
if exp_type_sym.kind == .placeholder {
796+
base_type := c.table.find_type(exp_type_sym.ngname)
797+
if base_type != 0 {
798+
base_sym := c.table.sym(base_type)
799+
if base_sym.kind == .sum_type && base_sym.info is ast.SumType {
800+
base_info := base_sym.info as ast.SumType
801+
for variant in base_info.variants {
802+
if c.table.sym(variant).ngname == got_type_sym.ngname {
803+
needs_sum_type_cast = true
804+
break
805+
}
806+
}
807+
}
808+
}
809+
}
810+
if needs_sum_type_cast {
811+
init_field.expr = ast.CastExpr{
812+
expr: init_field.expr
813+
typ: exp_type
814+
typname: c.table.type_to_str(exp_type)
815+
pos: init_field.expr.pos()
816+
}
817+
init_field.typ = exp_type
818+
} else {
819+
c.check_expected(c.unwrap_generic(got_type), c.unwrap_generic(exp_type)) or {
820+
c.error('cannot assign to field `${field_info.name}`: ${err.msg()}',
821+
init_field.pos)
822+
}
789823
}
790824
}
791825
if exp_type.has_flag(.shared_f) {
@@ -936,8 +970,11 @@ or use an explicit `unsafe{ a[..] }`, if you do not want a copy of the slice.',
936970
if struct_sym.info.concrete_types.len == 0 {
937971
concrete_types := c.infer_struct_generic_types(node.typ, node)
938972
if concrete_types.len > 0 {
939-
generic_names := struct_sym.info.generic_types.map(c.table.sym(it).name)
940-
node.typ = c.table.unwrap_generic_type(node.typ, generic_names, concrete_types)
973+
idx := c.table.find_or_register_generic_inst(node.typ, concrete_types)
974+
if idx > 0 {
975+
node.typ = ast.new_type(idx)
976+
c.table.generic_insts_to_concrete()
977+
}
941978
}
942979
} else if struct_sym.info.generic_types.len == struct_sym.info.concrete_types.len {
943980
parent_type := struct_sym.info.parent_type

vlib/v/parser/struct.v

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ fn (mut p Parser) struct_decl(is_anon bool) ast.StructDecl {
433433
language: language
434434
name: name
435435
cname: util.no_dots(name)
436+
ngname: ast.strip_generic_params(name)
436437
mod: p.mod
437438
info: ast.Struct{
438439
scoped_name: scoped_name
@@ -727,6 +728,7 @@ fn (mut p Parser) interface_decl() ast.InterfaceDecl {
727728
kind: .interface
728729
name: interface_name
729730
cname: util.no_dots(interface_name)
731+
ngname: ast.strip_generic_params(interface_name)
730732
mod: p.mod
731733
info: ast.Interface{
732734
types: []
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
struct Empty {}
2+
3+
struct Node[T] {
4+
value T
5+
next Chain[T]
6+
}
7+
8+
type Chain[T] = Empty | Node[T]
9+
10+
fn get[T](chain Chain[T]) T {
11+
return match chain {
12+
Empty { 0 }
13+
Node[T] { chain.value }
14+
}
15+
}
16+
17+
fn test_main() {
18+
chain := Node{0.2, Empty{}}
19+
assert get(chain) == 0.2
20+
}

0 commit comments

Comments
 (0)