Skip to content

Commit

Permalink
v: improve comptime var checking with is operator and smartcasting (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
felipensp committed Jan 1, 2024
1 parent 18954af commit e5cf336
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 20 deletions.
7 changes: 7 additions & 0 deletions vlib/v/ast/scope.v
Expand Up @@ -129,6 +129,13 @@ pub fn (mut s Scope) update_ct_var_kind(name string, kind ComptimeVarKind) {
}
}

pub fn (mut s Scope) update_smartcasts(name string, typ Type) {
mut obj := unsafe { s.objects[name] }
if mut obj is Var {
obj.smartcasts = [typ]
}
}

// selector_expr: name.field_name
pub fn (mut s Scope) register_struct_field(name string, field ScopeStructField) {
if f := s.struct_fields[name] {
Expand Down
12 changes: 9 additions & 3 deletions vlib/v/checker/checker.v
Expand Up @@ -3794,7 +3794,7 @@ fn (mut c Checker) concat_expr(mut node ast.ConcatExpr) ast.Type {
}

// smartcast takes the expression with the current type which should be smartcasted to the target type in the given scope
fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.Type, mut scope ast.Scope) {
fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.Type, mut scope ast.Scope, is_comptime bool) {
sym := c.table.sym(cur_type)
to_type := if sym.kind == .interface_ && c.table.sym(to_type_).kind != .interface_ {
to_type_.ref()
Expand Down Expand Up @@ -3852,7 +3852,7 @@ fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.
orig_type = expr.obj.typ
}
is_inherited = expr.obj.is_inherited
ct_type_var = if expr.obj.ct_type_var == .field_var {
ct_type_var = if is_comptime && expr.obj.ct_type_var != .no_comptime {
.smartcast
} else {
.no_comptime
Expand All @@ -3861,9 +3861,15 @@ fn (mut c Checker) smartcast(mut expr ast.Expr, cur_type ast.Type, to_type_ ast.
// smartcast either if the value is immutable or if the mut argument is explicitly given
if (!is_mut || expr.is_mut) && !is_already_casted {
smartcasts << to_type
if var := scope.find_var(expr.name) {
if is_comptime && var.ct_type_var == .smartcast {
scope.update_smartcasts(expr.name, to_type)
return
}
}
scope.register(ast.Var{
name: expr.name
typ: if ct_type_var == .smartcast { to_type } else { cur_type }
typ: cur_type
pos: expr.pos
is_used: true
is_mut: expr.is_mut
Expand Down
2 changes: 1 addition & 1 deletion vlib/v/checker/comptime.v
Expand Up @@ -217,7 +217,7 @@ fn (mut c Checker) comptime_for(mut node ast.ComptimeFor) {
c.unwrap_generic(node.typ)
} else {
node.typ = c.expr(mut node.expr)
node.typ
c.unwrap_generic(node.typ)
}
sym := c.table.final_sym(typ)
if sym.kind == .placeholder || typ.has_flag(.generic) {
Expand Down
2 changes: 1 addition & 1 deletion vlib/v/checker/for.v
Expand Up @@ -285,7 +285,7 @@ fn (mut c Checker) for_stmt(mut node ast.ForStmt) {
if node.cond.right is ast.TypeNode && node.cond.left in [ast.Ident, ast.SelectorExpr] {
if c.table.type_kind(node.cond.left_type) in [.sum_type, .interface_] {
c.smartcast(mut node.cond.left, node.cond.left_type, node.cond.right_type, mut
node.scope)
node.scope, false)
}
}
}
Expand Down
20 changes: 13 additions & 7 deletions vlib/v/checker/if.v
Expand Up @@ -142,7 +142,6 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type {
}
if left is ast.SelectorExpr {
comptime_field_name = left.expr.str()
c.comptime.type_map[comptime_field_name] = got_type
is_comptime_type_is_expr = true
if comptime_field_name == c.comptime.comptime_for_field_var {
left_type := c.unwrap_generic(c.comptime.comptime_for_field_type)
Expand Down Expand Up @@ -177,10 +176,13 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type {
left_type := c.unwrap_generic(left.typ)
skip_state = c.check_compatible_types(left_type, right as ast.TypeNode)
} else if left is ast.Ident {
is_comptime_type_is_expr = true
mut checked_type := ast.void_type
is_comptime_type_is_expr = true
if var := left.scope.find_var(left.name) {
checked_type = c.unwrap_generic(var.typ)
if var.smartcasts.len > 0 {
checked_type = c.unwrap_generic(var.smartcasts.last())
}
}
skip_state = c.check_compatible_types(checked_type, right as ast.TypeNode)
}
Expand Down Expand Up @@ -344,7 +346,7 @@ fn (mut c Checker) if_expr(mut node ast.IfExpr) ast.Type {
if comptime_field_name.len > 0 {
if comptime_field_name == c.comptime.comptime_for_method_var {
c.comptime.type_map[comptime_field_name] = c.comptime.comptime_for_method_ret_type
} else {
} else if comptime_field_name == c.comptime.comptime_for_field_var {
c.comptime.type_map[comptime_field_name] = c.comptime.comptime_for_field_type
}
}
Expand Down Expand Up @@ -516,11 +518,14 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) {
c.smartcast_if_conds(mut node.right, mut scope)
} else if node.left is ast.Ident && node.op == .ne && node.right is ast.None {
c.smartcast(mut node.left, node.left_type, node.left_type.clear_flag(.option), mut
scope)
scope, false)
} else if node.op == .key_is {
if node.left_type == ast.Type(0) {
if node.left is ast.Ident && c.comptime.is_comptime_var(node.left) {
node.left_type = c.comptime.get_comptime_var_type(node.left)
} else {
node.left_type = c.expr(mut node.left)
}
mut is_comptime := false
right_expr := node.right
right_type := match right_expr {
ast.TypeNode {
Expand All @@ -531,6 +536,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) {
}
ast.Ident {
if right_expr.name == c.comptime.comptime_for_variant_var {
is_comptime = true
c.comptime.type_map['${c.comptime.comptime_for_variant_var}.typ']
} else {
c.error('invalid type `${right_expr}`', right_expr.pos)
Expand All @@ -544,7 +550,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) {
}
if right_type != ast.Type(0) {
right_sym := c.table.sym(right_type)
mut expr_type := c.expr(mut node.left)
mut expr_type := c.unwrap_generic(node.left_type)
left_sym := c.table.sym(expr_type)
if left_sym.kind == .aggregate {
expr_type = (left_sym.info as ast.Aggregate).sum_type
Expand Down Expand Up @@ -581,7 +587,7 @@ fn (mut c Checker) smartcast_if_conds(mut node ast.Expr, mut scope ast.Scope) {
}
if left_sym.kind in [.interface_, .sum_type] {
c.smartcast(mut node.left, node.left_type, right_type, mut
scope)
scope, is_comptime)
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion vlib/v/checker/infix.v
Expand Up @@ -675,7 +675,8 @@ fn (mut c Checker) infix_expr(mut node ast.InfixExpr) ast.Type {
if typ != ast.none_type_idx {
c.error('`${op}` can only be used to test for none in sql', node.pos)
}
} else if left_sym.kind !in [.interface_, .sum_type] {
} else if left_sym.kind !in [.interface_, .sum_type]
&& !c.comptime.is_comptime_var(node.left) {
c.error('`${op}` can only be used with interfaces and sum types',
node.pos) // can be used in sql too, but keep err simple
} else if mut left_sym.info is ast.SumType {
Expand Down
3 changes: 2 additions & 1 deletion vlib/v/checker/match.v
Expand Up @@ -476,7 +476,8 @@ fn (mut c Checker) match_exprs(mut node ast.MatchExpr, cond_type_sym ast.TypeSym
expr_type = expr_types[0].typ
}

c.smartcast(mut node.cond, node.cond_type, expr_type, mut branch.scope)
c.smartcast(mut node.cond, node.cond_type, expr_type, mut branch.scope,
false)
}
}
}
Expand Down
5 changes: 1 addition & 4 deletions vlib/v/comptime/comptimeinfo.v
Expand Up @@ -54,7 +54,7 @@ pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type {
node.obj.typ
}
.smartcast {
ct.type_map['${ct.comptime_for_variant_var}.typ'] or { ast.void_type }
ct.type_map['${ct.comptime_for_variant_var}.typ'] or { node.obj.typ }
}
.key_var, .value_var {
// key and value variables from normal for stmt
Expand All @@ -77,9 +77,6 @@ pub fn (mut ct ComptimeInfo) get_comptime_var_type(node ast.Expr) ast.Type {
ct.comptime_for_variant_var {
return ct.type_map['${ct.comptime_for_variant_var}.typ']
}
ct.comptime_for_enum_var {
return ct.type_map['${ct.comptime_for_enum_var}.typ']
}
else {
// field_var.typ from $for field
return ct.comptime_for_field_type
Expand Down
2 changes: 1 addition & 1 deletion vlib/v/gen/c/cgen.v
Expand Up @@ -4546,7 +4546,7 @@ fn (mut g Gen) ident(node ast.Ident) {
}
}
if node.obj.ct_type_var == .smartcast {
cur_variant_sym := g.table.sym(g.comptime.type_map['${g.comptime.comptime_for_variant_var}.typ'])
cur_variant_sym := g.table.sym(g.unwrap_generic(g.comptime.get_comptime_var_type(node)))
g.write('${dot}_${cur_variant_sym.cname}')
} else if !is_option_unwrap
&& obj_sym.kind in [.sum_type, .interface_] {
Expand Down
5 changes: 5 additions & 0 deletions vlib/v/gen/c/comptime.v
Expand Up @@ -896,7 +896,11 @@ fn (mut g Gen) comptime_for(node ast.ComptimeFor) {
if sym.info.vals.len > 0 {
g.writeln('\tEnumData ${node.val_var} = {0};')
}
g.push_new_comptime_info()
for val in sym.info.vals {
g.comptime.comptime_for_enum_var = node.val_var
g.comptime.type_map['${node.val_var}.typ'] = node.typ

g.writeln('/* enum vals ${i} */ {')
g.writeln('\t${node.val_var}.name = _SLIT("${val}");')
g.write('\t${node.val_var}.value = ')
Expand All @@ -918,6 +922,7 @@ fn (mut g Gen) comptime_for(node ast.ComptimeFor) {
g.writeln('}')
i++
}
g.pop_comptime_info()
}
}
} else if node.kind == .attributes {
Expand Down
6 changes: 5 additions & 1 deletion vlib/v/gen/c/infix.v
Expand Up @@ -673,7 +673,11 @@ fn (mut g Gen) infix_expr_in_optimization(left ast.Expr, right ast.ArrayInit) {

// infix_expr_is_op generates code for `is` and `!is`
fn (mut g Gen) infix_expr_is_op(node ast.InfixExpr) {
mut left_sym := g.table.sym(node.left_type)
mut left_sym := if g.comptime.is_comptime_var(node.left) {
g.table.sym(g.unwrap_generic(g.comptime.get_comptime_var_type(node.left)))
} else {
g.table.sym(node.left_type)
}
is_aggregate := left_sym.kind == .aggregate
if is_aggregate {
parent_left_type := (left_sym.info as ast.Aggregate).sum_type
Expand Down
19 changes: 19 additions & 0 deletions vlib/v/tests/comptime_var_is_check_test.v
@@ -0,0 +1,19 @@
type TestSum = int | string

fn gen[T, R](sum T) R {
$if T is $sumtype {
$for v in sum.variants {
if sum is v {
$if sum is R {
return sum
}
}
}
}
return R{}
}

fn test_main() {
assert dump(gen[TestSum, string](TestSum('foo'))) == 'foo'
assert dump(gen[TestSum, int](TestSum(123))) == 123
}

0 comments on commit e5cf336

Please sign in to comment.