Skip to content

Commit

Permalink
interp: fix type processing to support multiple recursive fields
Browse files Browse the repository at this point in the history
Fixes #1304
  • Loading branch information
mvertes committed Nov 8, 2021
1 parent a876bb3 commit cb81fe4
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 27 deletions.
16 changes: 16 additions & 0 deletions _test/issue-1304.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package main

type Node struct {
Name string
Alias *Node
Child []*Node
}

func main() {
n := &Node{Name: "parent"}
n.Child = append(n.Child, &Node{Name: "child"})
println(n.Name, n.Child[0].Name)
}

// Output:
// parent child
1 change: 1 addition & 0 deletions _test/struct46.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type A struct {
}

type D struct {
F *A
E *A
}

Expand Down
22 changes: 22 additions & 0 deletions _test/struct61.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

import "fmt"

type A struct {
B string
D
}

type D struct {
*A
E *A
}

func main() {
a := &A{B: "b"}
a.D = D{E: a}
fmt.Println(a.D.E.B)
}

// Output:
// b
11 changes: 11 additions & 0 deletions _test/struct62.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package main

func main() {
type A struct{ *A }
v := &A{}
v.A = v
println("v.A.A = v", v.A.A == v)
}

// Output:
// v.A.A = v true
4 changes: 2 additions & 2 deletions internal/unsafe2/unsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ type structType struct {
fields []structField
}

// SwapFieldType swaps the type of the struct field with the given type.
// SetFieldType sets the type of the struct field at the given index, to the given type.
//
// The struct type must have been created at runtime. This is very unsafe.
func SwapFieldType(s reflect.Type, idx int, t reflect.Type) {
func SetFieldType(s reflect.Type, idx int, t reflect.Type) {
if s.Kind() != reflect.Struct || idx >= s.NumField() {
return
}
Expand Down
2 changes: 1 addition & 1 deletion internal/unsafe2/unsafe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestSwapFieldType(t *testing.T) {
typ := reflect.StructOf(f)
ntyp := reflect.PtrTo(typ)

unsafe2.SwapFieldType(typ, 1, ntyp)
unsafe2.SetFieldType(typ, 1, ntyp)

if typ.Field(1).Type != ntyp {
t.Fatalf("unexpected field type: want %s; got %s", ntyp, typ.Field(1).Type)
Expand Down
11 changes: 7 additions & 4 deletions interp/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,11 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
sc.loop = n

case importSpec:
// already all done in gta
// Already all done in GTA.
return false

case typeSpec:
// processing already done in GTA pass for global types, only parses inlined types
// Processing already done in GTA pass for global types, only parses inlined types.
if sc.def == nil {
return false
}
Expand All @@ -426,8 +426,11 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
return false
}
if typ.incomplete {
err = n.cfgErrorf("invalid type declaration")
return false
// Type may still be incomplete in case of a local recursive struct declaration.
if typ, err = typ.finalize(); err != nil {
err = n.cfgErrorf("invalid type declaration")
return false
}
}

switch n.child[1].kind {
Expand Down
10 changes: 5 additions & 5 deletions interp/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -2981,13 +2981,13 @@ func _append(n *node) {
l := len(args)
values := make([]func(*frame) reflect.Value, l)
for i, arg := range args {
switch {
case isEmptyInterface(n.typ.val):
switch elem := n.typ.elem(); {
case isEmptyInterface(elem):
values[i] = genValue(arg)
case isInterfaceSrc(n.typ.val):
case isInterfaceSrc(elem):
values[i] = genValueInterface(arg)
case isInterfaceBin(n.typ.val):
values[i] = genInterfaceWrapper(arg, n.typ.val.rtype)
case isInterfaceBin(elem):
values[i] = genInterfaceWrapper(arg, elem.rtype)
case arg.typ.untyped:
values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem())
default:
Expand Down
130 changes: 116 additions & 14 deletions interp/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"

"github.com/traefik/yaegi/internal/unsafe2"
Expand Down Expand Up @@ -1528,22 +1529,33 @@ func (t *itype) getMethod(name string) *node {
// LookupMethod returns a pointer to method definition associated to type t
// and the list of indices to access the right struct field, in case of an embedded method.
func (t *itype) lookupMethod(name string) (*node, []int) {
return t.lookupMethod2(name, nil)
}

func (t *itype) lookupMethod2(name string, seen map[*itype]bool) (*node, []int) {
if seen == nil {
seen = map[*itype]bool{}
}
if seen[t] {
return nil, nil
}
seen[t] = true
if t.cat == ptrT {
return t.val.lookupMethod(name)
return t.val.lookupMethod2(name, seen)
}
var index []int
m := t.getMethod(name)
if m == nil {
for i, f := range t.field {
if f.embed {
if n, index2 := f.typ.lookupMethod(name); n != nil {
if n, index2 := f.typ.lookupMethod2(name, seen); n != nil {
index = append([]int{i}, index2...)
return n, index
}
}
}
if t.cat == aliasT || isInterfaceSrc(t) && t.val != nil {
return t.val.lookupMethod(name)
return t.val.lookupMethod2(name, seen)
}
}
return m, index
Expand All @@ -1562,12 +1574,23 @@ func (t *itype) methodDepth(name string) int {

// LookupBinMethod returns a method and a path to access a field in a struct object (the receiver).
func (t *itype) lookupBinMethod(name string) (m reflect.Method, index []int, isPtr, ok bool) {
return t.lookupBinMethod2(name, nil)
}

func (t *itype) lookupBinMethod2(name string, seen map[*itype]bool) (m reflect.Method, index []int, isPtr, ok bool) {
if seen == nil {
seen = map[*itype]bool{}
}
if seen[t] {
return
}
seen[t] = true
if t.cat == ptrT {
return t.val.lookupBinMethod(name)
return t.val.lookupBinMethod2(name, seen)
}
for i, f := range t.field {
if f.embed {
if m2, index2, isPtr2, ok2 := f.typ.lookupBinMethod(name); ok2 {
if m2, index2, isPtr2, ok2 := f.typ.lookupBinMethod2(name, seen); ok2 {
index = append([]int{i}, index2...)
return m2, index, isPtr2, ok2
}
Expand Down Expand Up @@ -1630,8 +1653,19 @@ type fieldRebuild struct {
}

type refTypeContext struct {
defined map[string]*itype
refs map[string][]fieldRebuild
defined map[string]*itype

// refs keeps track of all the places (in the same type recursion) where the
// type name (as key) is used as a field of another (or possibly the same) struct
// type. Each of these fields will then live as an unsafe2.dummy type until the
// whole recursion is fully resolved, and the type is fixed.
refs map[string][]fieldRebuild

// When we detect for the first time that we are in a recursive type (thanks to
// defined), we keep track of the first occurrence of the type where the recursion
// started, so we can restart the last step that fixes all the types from the same
// "top-level" point.
rect *itype
rebuilding bool
}

Expand All @@ -1640,12 +1674,57 @@ func (c *refTypeContext) Clone() *refTypeContext {
return &refTypeContext{defined: c.defined, refs: c.refs, rebuilding: c.rebuilding}
}

func (c *refTypeContext) isComplete() bool {
for _, t := range c.defined {
if t.rtype == nil {
return false
}
}
return true
}

func (t *itype) fixDummy(typ reflect.Type) reflect.Type {
if typ == unsafe2.DummyType {
return t.rtype
}
switch typ.Kind() {
case reflect.Array:
return reflect.ArrayOf(typ.Len(), t.fixDummy(typ.Elem()))
case reflect.Chan:
return reflect.ChanOf(typ.ChanDir(), t.fixDummy(typ.Elem()))
case reflect.Func:
in := make([]reflect.Type, typ.NumIn())
for i := range in {
in[i] = t.fixDummy(typ.In(i))
}
out := make([]reflect.Type, typ.NumOut())
for i := range out {
out[i] = t.fixDummy(typ.Out(i))
}
return reflect.FuncOf(in, out, typ.IsVariadic())
case reflect.Map:
return reflect.MapOf(t.fixDummy(typ.Key()), t.fixDummy(typ.Elem()))
case reflect.Ptr:
return reflect.PtrTo(t.fixDummy(typ.Elem()))
case reflect.Slice:
return reflect.SliceOf(t.fixDummy(typ.Elem()))
case reflect.Struct:
fields := make([]reflect.StructField, typ.NumField())
for i := range fields {
fields[i] = typ.Field(i)
fields[i].Type = t.fixDummy(fields[i].Type)
}
return reflect.StructOf(fields)
}
return typ
}

// RefType returns a reflect.Type representation from an interpreter type.
// In simple cases, reflect types are directly mapped from the interpreter
// counterpart.
// For recursive named struct or interfaces, as reflect does not permit to
// create a recursive named struct, a nil type is set temporarily for each recursive
// field. When done, the nil type fields are updated with the original reflect type
// create a recursive named struct, a dummy type is set temporarily for each recursive
// field. When done, the dummy type fields are updated with the original reflect type
// pointer using unsafe. We thus obtain a usable recursive type definition, except
// for string representation, as created reflect types are still unnamed.
func (t *itype) refType(ctx *refTypeContext) reflect.Type {
Expand All @@ -1667,14 +1746,23 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type {
return t.rtype
}
if dt := ctx.defined[name]; dt != nil {
// We get here when we are a struct field, and our type name has already been
// seen at least once in one of our englobing structs. i.e. there's at least one
// level of type recursion.
if dt.rtype != nil {
t.rtype = dt.rtype
return dt.rtype
}

// To indicate that a rebuild is needed on the nearest struct
// field, create an entry with a nil type.
// The recursion has not been fully resolved yet.
// To indicate that a rebuild is needed on the englobing struct,
// return a dummy field type and create an entry with an empty fieldRebuild.
flds := ctx.refs[name]
ctx.rect = dt

// We know we are used as a field by someone, but we don't know by who
// at this point in the code, so we just mark it as an empty fieldRebuild for now.
// We'll complete the fieldRebuild in the caller.
ctx.refs[name] = append(flds, fieldRebuild{})
return unsafe2.DummyType
}
Expand Down Expand Up @@ -1717,9 +1805,8 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type {
}
var fields []reflect.StructField
for i, f := range t.field {
fctx := ctx.Clone()
field := reflect.StructField{
Name: exportName(f.name), Type: f.typ.refType(fctx),
Name: exportName(f.name), Type: f.typ.refType(ctx),
Tag: reflect.StructTag(f.tag), Anonymous: f.embed,
}
fields = append(fields, field)
Expand All @@ -1733,12 +1820,27 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type {
}
}
t.rtype = reflect.StructOf(fields)
if ctx.isComplete() {
for _, s := range ctx.defined {
for i := 0; i < s.rtype.NumField(); i++ {
f := s.rtype.Field(i)
if strings.HasSuffix(f.Type.String(), "unsafe2.dummy") {
unsafe2.SetFieldType(s.rtype, i, ctx.rect.fixDummy(s.rtype.Field(i).Type))
}
}
}
}

// The rtype has now been built, we can go back and rebuild
// all the recursive types that relied on this type.
// However, as we are keyed by type name, if two or more (recursive) fields at
// the same depth level are of the same type, they "mask" each other, and only one
// of them is in ctx.refs, which means this pass below does not fully do the job.
// Which is why we have the pass above that is done one last time, for all fields,
// one the recursion has been fully resolved.
for _, f := range ctx.refs[name] {
ftyp := f.typ.field[f.idx].typ.refType(&refTypeContext{defined: ctx.defined, rebuilding: true})
unsafe2.SwapFieldType(f.typ.rtype, f.idx, ftyp)
unsafe2.SetFieldType(f.typ.rtype, f.idx, ftyp)
}
default:
if z, _ := t.zero(); z.IsValid() {
Expand Down
2 changes: 1 addition & 1 deletion interp/typecheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (check typecheck) assignment(n *node, typ *itype, context string) error {
return nil
}

if !n.typ.assignableTo(typ) {
if !n.typ.assignableTo(typ) && typ.str != "*unsafe2.dummy" {
if context == "" {
return n.cfgErrorf("cannot use type %s as type %s", n.typ.id(), typ.id())
}
Expand Down

0 comments on commit cb81fe4

Please sign in to comment.