Skip to content

Commit

Permalink
feat: add type assertion expression type checking
Browse files Browse the repository at this point in the history
This adds type checking to TypeAssertExpr. In order to allow for this, method types now have a receiver type in both reflect and native cases.
  • Loading branch information
nrwiersma committed Aug 20, 2020
1 parent 3faa47c commit 3640f2f
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 33 deletions.
14 changes: 14 additions & 0 deletions _test/method35.go
@@ -0,0 +1,14 @@
package main

import "strconv"

func main() {
var err error
_, err = strconv.Atoi("erwer")
if _, ok := err.(*strconv.NumError); ok {
println("here")
}
}

// Output:
// here
56 changes: 35 additions & 21 deletions interp/cfg.go
Expand Up @@ -334,6 +334,8 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
return false
}
recvTypeNode.typ = typ
n.child[2].typ.recv = typ
n.typ.recv = typ
index := sc.add(typ)
if len(fr.child) > 1 {
sc.sym[fr.child[0].ident] = &symbol{index: index, kind: varSym, typ: typ}
Expand Down Expand Up @@ -1334,11 +1336,15 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
// Search for field must then be performed on type T only (not *T)
switch method, ok := n.typ.rtype.MethodByName(n.child[1].ident); {
case ok:
hasRecvType := n.typ.rtype.Kind() != reflect.Interface
n.val = method.Index
n.gen = getIndexBinMethod
n.action = aGetMethod
n.recv = &receiver{node: n.child[0]}
n.typ = &itype{cat: valueT, rtype: method.Type, isBinMethod: true}
if hasRecvType {
n.typ.recv = n.typ
}
case n.typ.rtype.Kind() == reflect.Ptr:
if field, ok := n.typ.rtype.Elem().FieldByName(n.child[1].ident); ok {
n.typ = &itype{cat: valueT, rtype: field.Type}
Expand All @@ -1358,7 +1364,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
if m2, ok2 := pt.MethodByName(n.child[1].ident); ok2 {
n.val = m2.Index
n.gen = getIndexBinPtrMethod
n.typ = &itype{cat: valueT, rtype: m2.Type}
n.typ = &itype{cat: valueT, rtype: m2.Type, recv: &itype{cat: valueT, rtype: pt}}
n.recv = &receiver{node: n.child[0]}
n.action = aGetMethod
} else {
Expand All @@ -1372,14 +1378,14 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
// Handle pointer on object defined in runtime
if method, ok := n.typ.val.rtype.MethodByName(n.child[1].ident); ok {
n.val = method.Index
n.typ = &itype{cat: valueT, rtype: method.Type}
n.typ = &itype{cat: valueT, rtype: method.Type, recv: n.typ}
n.recv = &receiver{node: n.child[0]}
n.gen = getIndexBinMethod
n.action = aGetMethod
} else if method, ok := reflect.PtrTo(n.typ.val.rtype).MethodByName(n.child[1].ident); ok {
n.val = method.Index
n.gen = getIndexBinMethod
n.typ = &itype{cat: valueT, rtype: method.Type}
n.typ = &itype{cat: valueT, rtype: method.Type, recv: &itype{cat: valueT, rtype: reflect.PtrTo(n.typ.val.rtype)}}
n.recv = &receiver{node: n.child[0]}
n.action = aGetMethod
} else if field, ok := n.typ.val.rtype.FieldByName(n.child[1].ident); ok {
Expand Down Expand Up @@ -1445,7 +1451,7 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
}
n.recv = &receiver{node: n.child[0], index: lind}
n.val = append([]int{m.Index}, lind...)
n.typ = &itype{cat: valueT, rtype: m.Type}
n.typ = &itype{cat: valueT, rtype: m.Type, recv: n.child[0].typ}
} else if ti := n.typ.lookupField(n.child[1].ident); len(ti) > 0 {
// Handle struct field
n.val = ti
Expand Down Expand Up @@ -1658,25 +1664,33 @@ func (interp *Interpreter) cfg(root *node, importPath string) ([]*node, error) {
n.child[0].tnext = sbn.start

case typeAssertExpr:
if len(n.child) > 1 {
wireChild(n)
c1 := n.child[1]
if c1.typ == nil {
if c1.typ, err = nodeType(interp, sc, c1); err != nil {
return
}
if len(n.child) == 1 {
// The "o.(type)" is handled by typeSwitch.
n.gen = nop
break
}

wireChild(n)
c0, c1 := n.child[0], n.child[1]
if c1.typ == nil {
if c1.typ, err = nodeType(interp, sc, c1); err != nil {
return
}
if n.anc.action != aAssignX {
if n.child[0].typ.cat == valueT && isFunc(c1.typ) {
// Avoid special wrapping of interfaces and func types.
n.typ = &itype{cat: valueT, rtype: c1.typ.TypeOf()}
} else {
n.typ = c1.typ
}
n.findex = sc.add(n.typ)
}

err = check.typeAssertionExpr(c0, c1.typ)
if err != nil {
break
}

if n.anc.action != aAssignX {
if c0.typ.cat == valueT && isFunc(c1.typ) {
// Avoid special wrapping of interfaces and func types.
n.typ = &itype{cat: valueT, rtype: c1.typ.TypeOf()}
} else {
n.typ = c1.typ
}
} else {
n.gen = nop
n.findex = sc.add(n.typ)
}

case sliceExpr:
Expand Down
6 changes: 3 additions & 3 deletions interp/run.go
Expand Up @@ -152,7 +152,7 @@ func typeAssertStatus(n *node) {
value1(f).SetBool(ok)
return next
}
case c0.typ.cat == valueT:
case c0.typ.cat == valueT || c0.typ.cat == errorT:
n.exec = func(f *frame) bltn {
v := value(f)
ok := v.IsValid() && canAssertTypes(v.Elem().Type(), rtype)
Expand Down Expand Up @@ -205,7 +205,7 @@ func typeAssert(n *node) {
value0(f).Set(v)
return next
}
case c0.typ.cat == valueT:
case c0.typ.cat == valueT || c0.typ.cat == errorT:
n.exec = func(f *frame) bltn {
v := value(f).Elem()
typ := value0(f).Type()
Expand Down Expand Up @@ -272,7 +272,7 @@ func typeAssert2(n *node) {
}
return next
}
case n.child[0].typ.cat == valueT:
case n.child[0].typ.cat == valueT || n.child[0].typ.cat == errorT:
n.exec = func(f *frame) bltn {
v := value(f).Elem()
ok := v.IsValid() && canAssertTypes(v.Type(), rtype)
Expand Down
55 changes: 51 additions & 4 deletions interp/type.go
Expand Up @@ -108,6 +108,7 @@ type itype struct {
field []structField // Array of struct fields if structT or interfaceT
key *itype // Type of key element if MapT or nil
val *itype // Type of value element if chanT,chanSendT, chanRecvT, mapT, ptrT, aliasT, arrayT or variadicT
recv *itype // Receiver type for funcT or nil
arg []*itype // Argument types if funcT or nil
ret []*itype // Return types if funcT or nil
method []*node // Associated methods or nil
Expand Down Expand Up @@ -559,7 +560,7 @@ func nodeType(interp *Interpreter, sc *scope, n *node) (*itype, error) {
if m, _ := lt.lookupMethod(name); m != nil {
t, err = nodeType(interp, sc, m.child[2])
} else if bm, _, _, ok := lt.lookupBinMethod(name); ok {
t = &itype{cat: valueT, rtype: bm.Type, isBinMethod: true, scope: sc}
t = &itype{cat: valueT, rtype: bm.Type, recv: lt, isBinMethod: true, scope: sc}
} else if ti := lt.lookupField(name); len(ti) > 0 {
t = lt.fieldSeq(ti)
} else if bs, _, ok := lt.lookupBinField(name); ok {
Expand Down Expand Up @@ -761,11 +762,16 @@ func (t *itype) numIn() int {
case funcT:
return len(t.arg)
case valueT:
if t.rtype.Kind() == reflect.Func {
return t.rtype.NumIn()
if t.rtype.Kind() != reflect.Func {
return 0
}
in := t.rtype.NumIn()
if t.recv != nil {
in--
}
return in
}
return 1
return 0
}

func (t *itype) in(i int) *itype {
Expand All @@ -774,6 +780,9 @@ func (t *itype) in(i int) *itype {
return t.arg[i]
case valueT:
if t.rtype.Kind() == reflect.Func {
if t.recv != nil {
i++
}
if t.rtype.IsVariadic() && i == t.rtype.NumIn()-1 {
return &itype{cat: variadicT, val: &itype{cat: valueT, rtype: t.rtype.In(i).Elem()}}
}
Expand Down Expand Up @@ -995,6 +1004,14 @@ func (t *itype) methods() methodSet {
res[m.Name] = m.Type.String()
}
case ptrT:
if typ.val.cat == valueT {
// Ptr receiver methods need to be found with the ptr type.
typ.TypeOf() // Ensure the rtype exists.
for i := typ.rtype.NumMethod() - 1; i >= 0; i-- {
m := typ.rtype.Method(i)
res[m.Name] = m.Type.String()
}
}
for k, v := range getMethods(typ.val) {
res[k] = v
}
Expand Down Expand Up @@ -1244,6 +1261,36 @@ func (t *itype) lookupBinMethod(name string) (m reflect.Method, index []int, isP
return m, index, isPtr, ok
}

func lookupFieldOrMethod(t *itype, name string) *itype {
switch {
case t.cat == valueT || t.cat == ptrT && t.val.cat == valueT:
m, _, isPtr, ok := t.lookupBinMethod(name)
if !ok {
return nil
}
var recv *itype
if t.rtype.Kind() != reflect.Interface {
recv = t
if isPtr && t.cat != ptrT && t.rtype.Kind() != reflect.Ptr {
recv = &itype{cat: ptrT, val: t}
}
}
return &itype{cat: valueT, rtype: m.Type, recv: recv}
case t.cat == interfaceT:
seq := t.lookupField(name)
if seq == nil {
return nil
}
return t.fieldSeq(seq)
default:
n, _ := t.lookupMethod(name)
if n == nil {
return nil
}
return n.typ
}
}

func exportName(s string) string {
if canExport(s) {
return s
Expand Down
58 changes: 53 additions & 5 deletions interp/typecheck.go
Expand Up @@ -527,6 +527,59 @@ func (check typecheck) sliceExpr(n *node) error {
return nil
}

// typeAssertionExpr type checks a type assert expression.
func (check typecheck) typeAssertionExpr(n *node, typ *itype) error {
// TODO(nick): This type check is not complete and should be revisited once
// https://github.com/golang/go/issues/39717 lands. It is currently impractical to
// type check Named types as they cannot be asserted.

if n.typ.TypeOf().Kind() != reflect.Interface {
return n.cfgErrorf("invalid type assertion: non-interface type %s on left", n.typ.id())
}
ims := n.typ.methods()
if len(ims) == 0 {
// Empty interface must be a dynamic check.
return nil
}

if isInterface(typ) {
// Asserting to an interface is a dynamic check as we must look to the
// underlying struct.
return nil
}

for name := range ims {
im := lookupFieldOrMethod(n.typ, name)
tm := lookupFieldOrMethod(typ, name)
if im == nil {
// This should not be possible.
continue
}
if tm == nil {
return n.cfgErrorf("impossible type assertion: %s does not implement %s (missing %v method)", typ.id(), n.typ.id(), name)
}
if tm.recv != nil && tm.recv.TypeOf().Kind() == reflect.Ptr && typ.TypeOf().Kind() != reflect.Ptr {
return n.cfgErrorf("impossible type assertion: %s does not implement %s as %q method has a pointer receiver", typ.id(), n.typ.id(), name)
}

err := n.cfgErrorf("impossible type assertion: %s does not implement %s", typ.id(), n.typ.id())
if im.numIn() != tm.numIn() || im.numOut() != tm.numOut() {
return err
}
for i := 0; i < im.numIn(); i++ {
if !im.in(i).equals(tm.in(i)) {
return err
}
}
for i := 0; i < im.numOut(); i++ {
if !im.out(i).equals(tm.out(i)) {
return err
}
}
}
return nil
}

// conversion type checks the conversion of n to typ.
func (check typecheck) conversion(n *node, typ *itype) error {
var c constant.Value
Expand Down Expand Up @@ -601,11 +654,6 @@ func (check typecheck) arguments(n *node, child []*node, fun *node, ellipsis boo
}

var cnt int
if fun.kind == selectorExpr && fun.typ.cat == valueT && fun.recv != nil && !isInterface(fun.recv.node.typ) {
// If this is a bin call, and we have a receiver and the receiver is
// not an interface, then the first input is the receiver, so skip it.
cnt++
}
for i, arg := range child {
ellip := i == l-1 && ellipsis
if err := check.argument(arg, fun.typ, cnt, ellip); err != nil {
Expand Down

0 comments on commit 3640f2f

Please sign in to comment.