From cb81fe41ab899ad591741f4b5b1e603e46d9a8b8 Mon Sep 17 00:00:00 2001 From: Marc Vertes Date: Mon, 8 Nov 2021 20:46:12 +0100 Subject: [PATCH] interp: fix type processing to support multiple recursive fields Fixes #1304 --- _test/issue-1304.go | 16 ++++ _test/struct46.go | 1 + _test/struct61.go | 22 ++++++ _test/struct62.go | 11 +++ internal/unsafe2/unsafe.go | 4 +- internal/unsafe2/unsafe_test.go | 2 +- interp/cfg.go | 11 ++- interp/run.go | 10 +-- interp/type.go | 130 ++++++++++++++++++++++++++++---- interp/typecheck.go | 2 +- 10 files changed, 182 insertions(+), 27 deletions(-) create mode 100644 _test/issue-1304.go create mode 100644 _test/struct61.go create mode 100644 _test/struct62.go diff --git a/_test/issue-1304.go b/_test/issue-1304.go new file mode 100644 index 000000000..a44bd8239 --- /dev/null +++ b/_test/issue-1304.go @@ -0,0 +1,16 @@ +package main + +type Node struct { + Name string + Alias *Node + Child []*Node +} + +func main() { + n := &Node{Name: "parent"} + n.Child = append(n.Child, &Node{Name: "child"}) + println(n.Name, n.Child[0].Name) +} + +// Output: +// parent child diff --git a/_test/struct46.go b/_test/struct46.go index a041d0bd4..62b5db4c8 100644 --- a/_test/struct46.go +++ b/_test/struct46.go @@ -8,6 +8,7 @@ type A struct { } type D struct { + F *A E *A } diff --git a/_test/struct61.go b/_test/struct61.go new file mode 100644 index 000000000..9d6071f07 --- /dev/null +++ b/_test/struct61.go @@ -0,0 +1,22 @@ +package main + +import "fmt" + +type A struct { + B string + D +} + +type D struct { + *A + E *A +} + +func main() { + a := &A{B: "b"} + a.D = D{E: a} + fmt.Println(a.D.E.B) +} + +// Output: +// b diff --git a/_test/struct62.go b/_test/struct62.go new file mode 100644 index 000000000..176b650ac --- /dev/null +++ b/_test/struct62.go @@ -0,0 +1,11 @@ +package main + +func main() { + type A struct{ *A } + v := &A{} + v.A = v + println("v.A.A = v", v.A.A == v) +} + +// Output: +// v.A.A = v true diff --git a/internal/unsafe2/unsafe.go b/internal/unsafe2/unsafe.go index 28da8b62d..47f96ad18 100644 --- a/internal/unsafe2/unsafe.go +++ b/internal/unsafe2/unsafe.go @@ -40,10 +40,10 @@ type structType struct { fields []structField } -// SwapFieldType swaps the type of the struct field with the given type. +// SetFieldType sets the type of the struct field at the given index, to the given type. // // The struct type must have been created at runtime. This is very unsafe. -func SwapFieldType(s reflect.Type, idx int, t reflect.Type) { +func SetFieldType(s reflect.Type, idx int, t reflect.Type) { if s.Kind() != reflect.Struct || idx >= s.NumField() { return } diff --git a/internal/unsafe2/unsafe_test.go b/internal/unsafe2/unsafe_test.go index 64e303b33..10fae470f 100644 --- a/internal/unsafe2/unsafe_test.go +++ b/internal/unsafe2/unsafe_test.go @@ -25,7 +25,7 @@ func TestSwapFieldType(t *testing.T) { typ := reflect.StructOf(f) ntyp := reflect.PtrTo(typ) - unsafe2.SwapFieldType(typ, 1, ntyp) + unsafe2.SetFieldType(typ, 1, ntyp) if typ.Field(1).Type != ntyp { t.Fatalf("unexpected field type: want %s; got %s", ntyp, typ.Field(1).Type) diff --git a/interp/cfg.go b/interp/cfg.go index 18ac0d34c..5398c3860 100644 --- a/interp/cfg.go +++ b/interp/cfg.go @@ -412,11 +412,11 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string sc.loop = n case importSpec: - // already all done in gta + // Already all done in GTA. return false case typeSpec: - // processing already done in GTA pass for global types, only parses inlined types + // Processing already done in GTA pass for global types, only parses inlined types. if sc.def == nil { return false } @@ -426,8 +426,11 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string return false } if typ.incomplete { - err = n.cfgErrorf("invalid type declaration") - return false + // Type may still be incomplete in case of a local recursive struct declaration. + if typ, err = typ.finalize(); err != nil { + err = n.cfgErrorf("invalid type declaration") + return false + } } switch n.child[1].kind { diff --git a/interp/run.go b/interp/run.go index e99843964..92eeb9916 100644 --- a/interp/run.go +++ b/interp/run.go @@ -2981,13 +2981,13 @@ func _append(n *node) { l := len(args) values := make([]func(*frame) reflect.Value, l) for i, arg := range args { - switch { - case isEmptyInterface(n.typ.val): + switch elem := n.typ.elem(); { + case isEmptyInterface(elem): values[i] = genValue(arg) - case isInterfaceSrc(n.typ.val): + case isInterfaceSrc(elem): values[i] = genValueInterface(arg) - case isInterfaceBin(n.typ.val): - values[i] = genInterfaceWrapper(arg, n.typ.val.rtype) + case isInterfaceBin(elem): + values[i] = genInterfaceWrapper(arg, elem.rtype) case arg.typ.untyped: values[i] = genValueAs(arg, n.child[1].typ.TypeOf().Elem()) default: diff --git a/interp/type.go b/interp/type.go index 0cc0593b5..7fd4a5f80 100644 --- a/interp/type.go +++ b/interp/type.go @@ -6,6 +6,7 @@ import ( "path/filepath" "reflect" "strconv" + "strings" "sync" "github.com/traefik/yaegi/internal/unsafe2" @@ -1528,22 +1529,33 @@ func (t *itype) getMethod(name string) *node { // LookupMethod returns a pointer to method definition associated to type t // and the list of indices to access the right struct field, in case of an embedded method. func (t *itype) lookupMethod(name string) (*node, []int) { + return t.lookupMethod2(name, nil) +} + +func (t *itype) lookupMethod2(name string, seen map[*itype]bool) (*node, []int) { + if seen == nil { + seen = map[*itype]bool{} + } + if seen[t] { + return nil, nil + } + seen[t] = true if t.cat == ptrT { - return t.val.lookupMethod(name) + return t.val.lookupMethod2(name, seen) } var index []int m := t.getMethod(name) if m == nil { for i, f := range t.field { if f.embed { - if n, index2 := f.typ.lookupMethod(name); n != nil { + if n, index2 := f.typ.lookupMethod2(name, seen); n != nil { index = append([]int{i}, index2...) return n, index } } } if t.cat == aliasT || isInterfaceSrc(t) && t.val != nil { - return t.val.lookupMethod(name) + return t.val.lookupMethod2(name, seen) } } return m, index @@ -1562,12 +1574,23 @@ func (t *itype) methodDepth(name string) int { // LookupBinMethod returns a method and a path to access a field in a struct object (the receiver). func (t *itype) lookupBinMethod(name string) (m reflect.Method, index []int, isPtr, ok bool) { + return t.lookupBinMethod2(name, nil) +} + +func (t *itype) lookupBinMethod2(name string, seen map[*itype]bool) (m reflect.Method, index []int, isPtr, ok bool) { + if seen == nil { + seen = map[*itype]bool{} + } + if seen[t] { + return + } + seen[t] = true if t.cat == ptrT { - return t.val.lookupBinMethod(name) + return t.val.lookupBinMethod2(name, seen) } for i, f := range t.field { if f.embed { - if m2, index2, isPtr2, ok2 := f.typ.lookupBinMethod(name); ok2 { + if m2, index2, isPtr2, ok2 := f.typ.lookupBinMethod2(name, seen); ok2 { index = append([]int{i}, index2...) return m2, index, isPtr2, ok2 } @@ -1630,8 +1653,19 @@ type fieldRebuild struct { } type refTypeContext struct { - defined map[string]*itype - refs map[string][]fieldRebuild + defined map[string]*itype + + // refs keeps track of all the places (in the same type recursion) where the + // type name (as key) is used as a field of another (or possibly the same) struct + // type. Each of these fields will then live as an unsafe2.dummy type until the + // whole recursion is fully resolved, and the type is fixed. + refs map[string][]fieldRebuild + + // When we detect for the first time that we are in a recursive type (thanks to + // defined), we keep track of the first occurrence of the type where the recursion + // started, so we can restart the last step that fixes all the types from the same + // "top-level" point. + rect *itype rebuilding bool } @@ -1640,12 +1674,57 @@ func (c *refTypeContext) Clone() *refTypeContext { return &refTypeContext{defined: c.defined, refs: c.refs, rebuilding: c.rebuilding} } +func (c *refTypeContext) isComplete() bool { + for _, t := range c.defined { + if t.rtype == nil { + return false + } + } + return true +} + +func (t *itype) fixDummy(typ reflect.Type) reflect.Type { + if typ == unsafe2.DummyType { + return t.rtype + } + switch typ.Kind() { + case reflect.Array: + return reflect.ArrayOf(typ.Len(), t.fixDummy(typ.Elem())) + case reflect.Chan: + return reflect.ChanOf(typ.ChanDir(), t.fixDummy(typ.Elem())) + case reflect.Func: + in := make([]reflect.Type, typ.NumIn()) + for i := range in { + in[i] = t.fixDummy(typ.In(i)) + } + out := make([]reflect.Type, typ.NumOut()) + for i := range out { + out[i] = t.fixDummy(typ.Out(i)) + } + return reflect.FuncOf(in, out, typ.IsVariadic()) + case reflect.Map: + return reflect.MapOf(t.fixDummy(typ.Key()), t.fixDummy(typ.Elem())) + case reflect.Ptr: + return reflect.PtrTo(t.fixDummy(typ.Elem())) + case reflect.Slice: + return reflect.SliceOf(t.fixDummy(typ.Elem())) + case reflect.Struct: + fields := make([]reflect.StructField, typ.NumField()) + for i := range fields { + fields[i] = typ.Field(i) + fields[i].Type = t.fixDummy(fields[i].Type) + } + return reflect.StructOf(fields) + } + return typ +} + // RefType returns a reflect.Type representation from an interpreter type. // In simple cases, reflect types are directly mapped from the interpreter // counterpart. // For recursive named struct or interfaces, as reflect does not permit to -// create a recursive named struct, a nil type is set temporarily for each recursive -// field. When done, the nil type fields are updated with the original reflect type +// create a recursive named struct, a dummy type is set temporarily for each recursive +// field. When done, the dummy type fields are updated with the original reflect type // pointer using unsafe. We thus obtain a usable recursive type definition, except // for string representation, as created reflect types are still unnamed. func (t *itype) refType(ctx *refTypeContext) reflect.Type { @@ -1667,14 +1746,23 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { return t.rtype } if dt := ctx.defined[name]; dt != nil { + // We get here when we are a struct field, and our type name has already been + // seen at least once in one of our englobing structs. i.e. there's at least one + // level of type recursion. if dt.rtype != nil { t.rtype = dt.rtype return dt.rtype } - // To indicate that a rebuild is needed on the nearest struct - // field, create an entry with a nil type. + // The recursion has not been fully resolved yet. + // To indicate that a rebuild is needed on the englobing struct, + // return a dummy field type and create an entry with an empty fieldRebuild. flds := ctx.refs[name] + ctx.rect = dt + + // We know we are used as a field by someone, but we don't know by who + // at this point in the code, so we just mark it as an empty fieldRebuild for now. + // We'll complete the fieldRebuild in the caller. ctx.refs[name] = append(flds, fieldRebuild{}) return unsafe2.DummyType } @@ -1717,9 +1805,8 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { } var fields []reflect.StructField for i, f := range t.field { - fctx := ctx.Clone() field := reflect.StructField{ - Name: exportName(f.name), Type: f.typ.refType(fctx), + Name: exportName(f.name), Type: f.typ.refType(ctx), Tag: reflect.StructTag(f.tag), Anonymous: f.embed, } fields = append(fields, field) @@ -1733,12 +1820,27 @@ func (t *itype) refType(ctx *refTypeContext) reflect.Type { } } t.rtype = reflect.StructOf(fields) + if ctx.isComplete() { + for _, s := range ctx.defined { + for i := 0; i < s.rtype.NumField(); i++ { + f := s.rtype.Field(i) + if strings.HasSuffix(f.Type.String(), "unsafe2.dummy") { + unsafe2.SetFieldType(s.rtype, i, ctx.rect.fixDummy(s.rtype.Field(i).Type)) + } + } + } + } // The rtype has now been built, we can go back and rebuild // all the recursive types that relied on this type. + // However, as we are keyed by type name, if two or more (recursive) fields at + // the same depth level are of the same type, they "mask" each other, and only one + // of them is in ctx.refs, which means this pass below does not fully do the job. + // Which is why we have the pass above that is done one last time, for all fields, + // one the recursion has been fully resolved. for _, f := range ctx.refs[name] { ftyp := f.typ.field[f.idx].typ.refType(&refTypeContext{defined: ctx.defined, rebuilding: true}) - unsafe2.SwapFieldType(f.typ.rtype, f.idx, ftyp) + unsafe2.SetFieldType(f.typ.rtype, f.idx, ftyp) } default: if z, _ := t.zero(); z.IsValid() { diff --git a/interp/typecheck.go b/interp/typecheck.go index 0620a6851..3d2705caa 100644 --- a/interp/typecheck.go +++ b/interp/typecheck.go @@ -53,7 +53,7 @@ func (check typecheck) assignment(n *node, typ *itype, context string) error { return nil } - if !n.typ.assignableTo(typ) { + if !n.typ.assignableTo(typ) && typ.str != "*unsafe2.dummy" { if context == "" { return n.cfgErrorf("cannot use type %s as type %s", n.typ.id(), typ.id()) }