Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interp: improve handling of generic types #1489

Merged
merged 14 commits into from Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions _test/gen11.go
@@ -0,0 +1,33 @@
package main

import (
"encoding/json"
"fmt"
"net/netip"
)

type Slice[T any] struct {
x []T
}

type IPPrefixSlice struct {
x Slice[netip.Prefix]
}

func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }

// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}

func main() {
t := IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}

// Output:
// {{[]}}
// null <nil>
31 changes: 31 additions & 0 deletions _test/gen12.go
@@ -0,0 +1,31 @@
package main

import (
"fmt"
)

func MapOf[K comparable, V any](m map[K]V) Map[K, V] {
return Map[K, V]{m}
}

type Map[K comparable, V any] struct {
ж map[K]V
}

func (v MapView) Int() Map[string, int] { return MapOf(v.ж.Int) }

type VMap struct {
Int map[string]int
}

type MapView struct {
ж *VMap
}

func main() {
mv := MapView{&VMap{}}
fmt.Println(mv.ж)
}

// Output:
// &{map[]}
18 changes: 18 additions & 0 deletions _test/gen13.go
@@ -0,0 +1,18 @@
package main

type Map[K comparable, V any] struct {
ж map[K]V
}

func (m Map[K, V]) Has(k K) bool {
_, ok := m.ж[k]
return ok
}

func main() {
m := Map[string, float64]{}
println(m.Has("test"))
}

// Output:
// false
13 changes: 9 additions & 4 deletions _test/issue-1460.go
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"net/netip"
"reflect"
)

Expand All @@ -17,6 +18,10 @@ func unmarshalJSON[T any](b []byte, x *[]T) error {
return json.Unmarshal(b, x)
}

func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}

type StructView[T any] interface {
Valid() bool
AsStruct() T
Expand All @@ -31,10 +36,6 @@ type ViewCloner[T any, V StructView[T]] interface {
Clone() T
}

func SliceOfViews[T ViewCloner[T, V], V StructView[T]](x []T) SliceView[T, V] {
return SliceView[T, V]{x}
}

func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }

func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalJSON(b, &v.ж) }
Expand All @@ -51,6 +52,10 @@ func SliceOf[T any](x []T) Slice[T] {
return Slice[T]{x}
}

type IPPrefixSlice struct {
ж Slice[netip.Prefix]
}

type viewStruct struct {
Int int
Strings Slice[string]
Expand Down
23 changes: 23 additions & 0 deletions _test/issue-1488.go
@@ -0,0 +1,23 @@
package main

import "fmt"

type vector interface {
[]int | [3]int
}

func sum[V vector](v V) (out int) {
for i := 0; i < len(v); i++ {
out += v[i]
}
return
}

func main() {
va := [3]int{1, 2, 3}
vs := []int{1, 2, 3}
fmt.Println(sum[[3]int](va), sum[[]int](vs))
}

// Output:
// 6 6
14 changes: 14 additions & 0 deletions _test/p6.go
@@ -0,0 +1,14 @@
package main

import (
"fmt"

"github.com/traefik/yaegi/_test/p6"
)

func main() {
t := p6.IPPrefixSlice{}
fmt.Println(t)
b, e := t.MarshalJSON()
fmt.Println(string(b), e)
}
21 changes: 21 additions & 0 deletions _test/p6/p6.go
@@ -0,0 +1,21 @@
package p6

import (
"encoding/json"
"net/netip"
)

type Slice[T any] struct {
x []T
}

type IPPrefixSlice struct {
x Slice[netip.Prefix]
}

func (v Slice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.x) }

// MarshalJSON implements json.Marshaler.
func (v IPPrefixSlice) MarshalJSON() ([]byte, error) {
return v.x.MarshalJSON()
}
136 changes: 109 additions & 27 deletions interp/cfg.go
Expand Up @@ -322,8 +322,60 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
}
}
if n.typ == nil {
err = n.cfgErrorf("undefined type")
return false
// A nil type indicates either an error or a generic type.
// A child indexExpr or indexListExpr is used for type parameters,
// it indicates an instanciated generic.
if n.child[0].kind != indexExpr && n.child[0].kind != indexListExpr {
err = n.cfgErrorf("undefined type")
return false
}
t0, err1 := nodeType(interp, sc, n.child[0].child[0])
if err1 != nil {
return false
}
if t0.cat != genericT {
err = n.cfgErrorf("undefined type")
return false
}
// We have a composite literal of generic type, instantiate it.
lt := []*itype{}
for _, n1 := range n.child[0].child[1:] {
t1, err1 := nodeType(interp, sc, n1)
if err1 != nil {
return false
}
lt = append(lt, t1)
}
var g *node
g, _, err = genAST(sc, t0.node.anc, lt)
if err != nil {
return false
}
n.child[0] = g.lastChild()
n.typ, err = nodeType(interp, sc, n.child[0])
if err != nil {
return false
}
// Generate methods if any.
for _, nod := range t0.method {
gm, _, err2 := genAST(nod.scope, nod, lt)
if err2 != nil {
err = err2
return false
}
gm.typ, err = nodeType(interp, nod.scope, gm.child[2])
if err != nil {
return false
}
if _, err = interp.cfg(gm, sc, sc.pkgID, sc.pkgName); err != nil {
return false
}
if err = genRun(gm); err != nil {
return false
}
n.typ.addMethod(gm)
}
n.nleft = 1 // Indictate the type of composite literal.
}
}

Expand Down Expand Up @@ -439,6 +491,19 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if typ, err = nodeType(interp, sc, recvTypeNode); err != nil {
return false
}
if typ.cat == nilT {
// This may happen when instantiating generic methods.
s2, _, ok := sc.lookup(typ.id())
if !ok {
err = n.cfgErrorf("type not found: %s", typ.id())
break
}
typ = s2.typ
if typ.cat == nilT {
err = n.cfgErrorf("nil type: %s", typ.id())
break
}
}
recvTypeNode.typ = typ
n.child[2].typ.recv = typ
n.typ.recv = typ
Expand Down Expand Up @@ -871,16 +936,18 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
n.typ = t
return
}
g, err := genAST(sc, t.node.anc, []*node{c1})
g, found, err := genAST(sc, t.node.anc, []*itype{c1.typ})
if err != nil {
return
}
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
if !found {
if _, err = interp.cfg(g, t.node.anc.scope, importPath, pkgName); err != nil {
return
}
// Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
return
}
}
// Replace generic func node by instantiated one.
n.anc.child[childPos(n)] = g
Expand Down Expand Up @@ -1030,17 +1097,23 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
case c0.kind == indexListExpr:
// Instantiate a generic function then call it.
fun := c0.child[0].sym.node
g, err := genAST(sc, fun, c0.child[1:])
if err != nil {
return
lt := []*itype{}
for _, c := range c0.child[1:] {
lt = append(lt, c.typ)
}
_, err = interp.cfg(g, nil, importPath, pkgName)
g, found, err := genAST(sc, fun, lt)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
if !found {
_, err = interp.cfg(g, fun.scope, importPath, pkgName)
if err != nil {
return
}
err = genRun(g.child[3]) // Generate closures for function body.
if err != nil {
return
}
}
n.child[0] = g
c0 = n.child[0]
Expand Down Expand Up @@ -1212,23 +1285,26 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
if isGeneric(c0.typ) {
fun := c0.typ.node.anc
var g *node
var types []*node
var types []*itype
var found bool

// Infer type parameter from function call arguments.
if types, err = inferTypesFromCall(sc, fun, n.child[1:]); err != nil {
break
}
// Generate an instantiated AST from the generic function one.
if g, err = genAST(sc, fun, types); err != nil {
break
}
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, nil, importPath, pkgName); err != nil {
if g, found, err = genAST(sc, fun, types); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
if !found {
// Compile the generated function AST, so it becomes part of the scope.
if _, err = interp.cfg(g, fun.scope, importPath, pkgName); err != nil {
break
}
// AST compilation part 2: Generate closures for function body.
if err = genRun(g.child[3]); err != nil {
break
}
}
n.child[0] = g
c0 = n.child[0]
Expand Down Expand Up @@ -1487,6 +1563,10 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string

sym, level, found := sc.lookup(n.ident)
if !found {
if n.typ != nil {
// Node is a generic instance with an already populated type.
break
}
// retry with the filename, in case ident is a package name.
sym, level, found = sc.lookup(filepath.Join(n.ident, baseName))
if !found {
Expand Down Expand Up @@ -1916,7 +1996,7 @@ func (interp *Interpreter) cfg(root *node, sc *scope, importPath, pkgName string
err = n.cfgErrorf("undefined selector: %s", n.child[1].ident)
}
}
if err == nil && n.findex != -1 {
if err == nil && n.findex != -1 && n.typ.cat != genericT {
n.findex = sc.add(n.typ)
}

Expand Down Expand Up @@ -2375,11 +2455,13 @@ func (n *node) cfgErrorf(format string, a ...interface{}) *cfgError {

func genRun(nod *node) error {
var err error
seen := map[*node]bool{}

nod.Walk(func(n *node) bool {
if err != nil {
if err != nil || seen[n] {
return false
}
seen[n] = true
switch n.kind {
case funcType:
if len(n.anc.child) == 4 {
Expand Down