Skip to content

Commit

Permalink
fix: composite literal type check
Browse files Browse the repository at this point in the history
In the case of a pointer or alias composite literal expression, `compositeGenerator` changes the type to remove the pointer or alias. This causes a nested composite literal to have the wrong type.

Instead of changing the node type, the removal of the pointer or alias is moved to the runtime, allowing the node type to remain unchanged. This fixes potential issues in the type checking.
  • Loading branch information
nrwiersma committed Aug 14, 2020
1 parent 913680d commit da9e6a0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
22 changes: 22 additions & 0 deletions _test/struct55.go
@@ -0,0 +1,22 @@
package main

import (
"log"
"os"
)

type Logger struct {
m []*log.Logger
}

func (l *Logger) Infof(format string, args ...interface{}) {
l.m[0].Printf(format, args...)
}

func main() {
l := &Logger{m: []*log.Logger{log.New(os.Stdout, "", log.Lmsgprefix)}}
l.Infof("test %s", "test")
}

// Output:
// test test
25 changes: 6 additions & 19 deletions interp/cfg.go
Expand Up @@ -952,28 +952,16 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {
case compositeLitExpr:
wireChild(n)

underlying := func(t *itype) *itype {
for {
switch t.cat {
case ptrT, aliasT:
t = t.val
continue
default:
return t
}
}
}

child := n.child
if n.nleft > 0 {
child = child[1:]
}

switch n.typ.cat {
case arrayT:
err = check.arrayLitExpr(child, underlying(n.typ.val), n.typ.size)
err = check.arrayLitExpr(child, n.typ.val, n.typ.size)
case mapT:
err = check.mapLitExpr(child, n.typ.key, underlying(n.typ.val))
err = check.mapLitExpr(child, n.typ.key, n.typ.val)
case structT:
err = check.structLitExpr(child, n.typ)
case valueT:
Expand All @@ -993,7 +981,7 @@ func (interp *Interpreter) cfg(root *node, pkgID string) ([]*node, error) {

n.findex = sc.add(n.typ)
// TODO: Check that composite literal expr matches corresponding type
n.gen = compositeGenerator(n)
n.gen = compositeGenerator(n, n.typ)

case fallthroughtStmt:
if n.anc.kind != caseBody {
Expand Down Expand Up @@ -2348,11 +2336,10 @@ func gotoLabel(s *symbol) {
}
}

func compositeGenerator(n *node) (gen bltnGenerator) {
switch n.typ.cat {
func compositeGenerator(n *node, typ *itype) (gen bltnGenerator) {
switch typ.cat {
case aliasT, ptrT:
n.typ = n.typ.val
gen = compositeGenerator(n)
gen = compositeGenerator(n, n.typ.val)
case arrayT:
gen = arrayLit
case mapT:
Expand Down
26 changes: 18 additions & 8 deletions interp/run.go
Expand Up @@ -962,9 +962,11 @@ func call(n *node) {
}

// Init variadic argument vector
varIndex := variadic
if variadic >= 0 {
if method {
vararg = nf.data[numRet+variadic+1]
varIndex++
} else {
vararg = nf.data[numRet+variadic]
}
Expand Down Expand Up @@ -1000,7 +1002,7 @@ func call(n *node) {
} else {
d.Set(src)
}
case variadic >= 0 && i >= variadic:
case variadic >= 0 && i >= varIndex:
if v(f).Type() == vararg.Type() {
vararg.Set(v(f))
} else {
Expand Down Expand Up @@ -2057,6 +2059,10 @@ func destType(n *node) *itype {
func doCompositeLit(n *node, hasType bool) {
value := valueGenerator(n, n.findex)
next := getExec(n.tnext)
typ := n.typ
if typ.cat == ptrT || typ.cat == aliasT {
typ = typ.val
}
child := n.child
if hasType {
child = n.child[1:]
Expand All @@ -2065,7 +2071,7 @@ func doCompositeLit(n *node, hasType bool) {

values := make([]func(*frame) reflect.Value, len(child))
for i, c := range child {
convertLiteralValue(c, n.typ.field[i].typ.TypeOf())
convertLiteralValue(c, typ.field[i].typ.TypeOf())
if c.typ.cat == funcT {
values[i] = genFunctionWrapper(c)
} else {
Expand All @@ -2076,7 +2082,7 @@ func doCompositeLit(n *node, hasType bool) {
i := n.findex
l := n.level
n.exec = func(f *frame) bltn {
a := reflect.New(n.typ.TypeOf()).Elem()
a := reflect.New(typ.TypeOf()).Elem()
for i, v := range values {
a.Field(i).Set(v(f))
}
Expand All @@ -2099,25 +2105,29 @@ func compositeLitNotype(n *node) { doCompositeLit(n, false) }
func doCompositeSparse(n *node, hasType bool) {
value := valueGenerator(n, n.findex)
next := getExec(n.tnext)
typ := n.typ
if typ.cat == ptrT || typ.cat == aliasT {
typ = typ.val
}
child := n.child
if hasType {
child = n.child[1:]
}
destInterface := destType(n).cat == interfaceT

values := make(map[int]func(*frame) reflect.Value)
a, _ := n.typ.zero()
a, _ := typ.zero()
for _, c := range child {
c1 := c.child[1]
field := n.typ.fieldIndex(c.child[0].ident)
convertLiteralValue(c1, n.typ.field[field].typ.TypeOf())
field := typ.fieldIndex(c.child[0].ident)
convertLiteralValue(c1, typ.field[field].typ.TypeOf())
switch {
case c1.typ.cat == funcT:
values[field] = genFunctionWrapper(c1)
case isArray(c1.typ) && c1.typ.val != nil && c1.typ.val.cat == interfaceT:
values[field] = genValueInterfaceArray(c1)
case isRecursiveType(n.typ.field[field].typ, n.typ.field[field].typ.rtype):
values[field] = genValueRecursiveInterface(c1, n.typ.field[field].typ.rtype)
case isRecursiveType(typ.field[field].typ, typ.field[field].typ.rtype):
values[field] = genValueRecursiveInterface(c1, typ.field[field].typ.rtype)
default:
values[field] = genValue(c1)
}
Expand Down

0 comments on commit da9e6a0

Please sign in to comment.