Skip to content

Commit

Permalink
interp: resolve type for untyped shift expressions
Browse files Browse the repository at this point in the history
A non-constant shift expression can be untyped, requiring to apply a
type from inherited context. This change insures that such context is
propagated during CFG pre-order walk, to be used if necessary.
    
Fixes #927.
  • Loading branch information
mvertes committed Nov 2, 2020
1 parent c817823 commit 9880738
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
60 changes: 57 additions & 3 deletions interp/cfg.go
Expand Up @@ -66,6 +66,41 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
return false
}
switch n.kind {
case binaryExpr, unaryExpr, parenExpr:
if isBoolAction(n) {
break
}
// Gather assigned type if set, to give context for type propagation at post-order.
switch n.anc.kind {
case assignStmt, defineStmt:
a := n.anc
i := childPos(n) - a.nright
if len(a.child) > a.nright+a.nleft {
i--
}
dest := a.child[i]
if dest.typ != nil && !isInterface(dest.typ) {
// Interface type are not propagated, and will be resolved at post-order.
n.typ = dest.typ
}
case binaryExpr, unaryExpr, parenExpr:
n.typ = n.anc.typ
}

case defineStmt:
// Determine type of variables initialized at declaration, so it can be propagated.
if n.nleft+n.nright == len(n.child) {
// No type was specified on the left hand side, it will resolved at post-order.
break
}
n.typ, err = nodeType(interp, sc, n.child[n.nleft])
if err != nil {
break
}
for i := 0; i < n.nleft; i++ {
n.child[i].typ = n.typ
}

case blockStmt:
if n.anc != nil && n.anc.kind == rangeStmt {
// For range block: ensure that array or map type is propagated to iterators
Expand Down Expand Up @@ -447,7 +482,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
var atyp *itype
if n.nleft+n.nright < len(n.child) {
if atyp, err = nodeType(interp, sc, n.child[n.nleft]); err != nil {
return
break
}
}

Expand Down Expand Up @@ -644,7 +679,12 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
}

switch n.action {
case aRem, aShl, aShr:
case aRem:
n.typ = c0.typ
case aShl, aShr:
if c0.typ.untyped {
break
}
n.typ = c0.typ
case aEqual, aNotEqual:
n.typ = sc.getType("bool")
Expand Down Expand Up @@ -860,7 +900,12 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
n.gen = nop
n.findex = -1
n.typ = c0.typ
n.rval = c1.rval.Convert(c0.typ.rtype)
if c, ok := c1.rval.Interface().(constant.Value); ok {
i, _ := constant.Int64Val(constant.ToInt(c))
n.rval = reflect.ValueOf(i).Convert(c0.typ.rtype)
} else {
n.rval = c1.rval.Convert(c0.typ.rtype)
}
default:
n.gen = convert
n.typ = c0.typ
Expand Down Expand Up @@ -2474,3 +2519,12 @@ func isArithmeticAction(n *node) bool {
return false
}
}

func isBoolAction(n *node) bool {
switch n.action {
case aEqual, aGreater, aGreaterEqual, aLand, aLor, aLower, aLowerEqual, aNot, aNotEqual:
return true
default:
return false
}
}
10 changes: 10 additions & 0 deletions interp/interp_eval_test.go
Expand Up @@ -71,6 +71,16 @@ func TestEvalArithmetic(t *testing.T) {
})
}

func TestEvalShift(t *testing.T) {
i := interp.New(interp.Options{})
runTests(t, i, []testCase{
{src: "a, b, m := uint32(1), uint32(2), uint32(0); m = a + (1 << b)", res: "5"},
{src: "c := uint(1); d := uint(+(-(1 << c)))", res: "18446744073709551614"},
{src: "e, f := uint32(0), uint32(0); f = 1 << -(e * 2)", res: "1"},
{pre: func() { eval(t, i, "const k uint = 1 << 17") }, src: "int(k)", res: "131072"},
})
}

func TestEvalStar(t *testing.T) {
i := interp.New(interp.Options{})
runTests(t, i, []testCase{
Expand Down
5 changes: 5 additions & 0 deletions interp/run.go
Expand Up @@ -1766,6 +1766,11 @@ func neg(n *node) {
dest(f).SetInt(-value(f).Int())
return next
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n.exec = func(f *frame) bltn {
dest(f).SetUint(-value(f).Uint())
return next
}
case reflect.Float32, reflect.Float64:
n.exec = func(f *frame) bltn {
dest(f).SetFloat(-value(f).Float())
Expand Down
27 changes: 20 additions & 7 deletions interp/typecheck.go
Expand Up @@ -217,6 +217,7 @@ var binaryOpPredicates = opPredicates{
// binaryExpr type checks a binary expression.
func (check typecheck) binaryExpr(n *node) error {
c0, c1 := n.child[0], n.child[1]

a := n.action
if isAssignAction(a) {
a--
Expand All @@ -226,6 +227,21 @@ func (check typecheck) binaryExpr(n *node) error {
return check.shift(n)
}

switch n.action {
case aRem:
if zeroConst(c1) {
return n.cfgErrorf("invalid operation: division by zero")
}
case aQuo:
if zeroConst(c1) {
return n.cfgErrorf("invalid operation: division by zero")
}
if c0.rval.IsValid() && c1.rval.IsValid() {
// Avoid constant conversions below to ensure correct constant integer quotient.
return nil
}
}

_ = check.convertUntyped(c0, c1.typ)
_ = check.convertUntyped(c1, c0.typ)

Expand All @@ -241,16 +257,13 @@ func (check typecheck) binaryExpr(n *node) error {
if err := check.op(binaryOpPredicates, a, n, c0, t0); err != nil {
return err
}

switch n.action {
case aQuo, aRem:
if (c0.typ.untyped || isInt(t0)) && c1.typ.untyped && constant.Sign(c1.rval.Interface().(constant.Value)) == 0 {
return n.cfgErrorf("invalid operation: division by zero")
}
}
return nil
}

func zeroConst(n *node) bool {
return n.typ.untyped && constant.Sign(n.rval.Interface().(constant.Value)) == 0
}

func (check typecheck) index(n *node, max int) error {
if err := check.convertUntyped(n, &itype{cat: intT, name: "int"}); err != nil {
return err
Expand Down

0 comments on commit 9880738

Please sign in to comment.