Skip to content

Commit

Permalink
interp: improve the behaviour of interface{} function parameters
Browse files Browse the repository at this point in the history
We finally address a long standing limitation of the interpreter:
the capacity to generate the correct interface wrapper for an
anonymous interface{} function parameter of a binary function.

It allows for example fmt.Printf to invoke the String method
of an object defined within the interpreter, or json.Marshal
to invoke a textMarshaler method if it exists and if there is
no Marshaler method already defined for the passed interpreter
object.

To achieve that, we add a new mapType part of the "Used" symbols
to describe what not empty interfaces are expected and in which
priority order. This information can not be guessed and is found
in the related package documentation, then captured in stdlib/maptypes.go.

Then, at compile time and/or during execution, a lookup on mapTypes
is performed to allow the correct wrapper to be generated.

This change adds a new MapType type to the stdlib package.

Fixes #435.
  • Loading branch information
mvertes committed Jun 14, 2022
1 parent eaeb445 commit 236a0ef
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 19 deletions.
25 changes: 25 additions & 0 deletions _test/issue-435.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package main

import (
"fmt"
"strconv"
)

type Foo int

func (f Foo) String() string {
return "foo-" + strconv.Itoa(int(f))
}

func print1(arg interface{}) {
fmt.Println(arg)
}

func main() {
var arg Foo = 3
var f = print1
f(arg)
}

// Output:
// foo-3
36 changes: 31 additions & 5 deletions interp/interp.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,12 @@ type Interpreter struct {

name string // name of the input source file (or main)

opt // user settable options
cancelChan bool // enables cancellable chan operations
fset *token.FileSet // fileset to locate node in source code
binPkg Exports // binary packages used in interpreter, indexed by path
rdir map[string]bool // for src import cycle detection
opt // user settable options
cancelChan bool // enables cancellable chan operations
fset *token.FileSet // fileset to locate node in source code
binPkg Exports // binary packages used in interpreter, indexed by path
rdir map[string]bool // for src import cycle detection
mapTypes map[reflect.Value][]reflect.Type // special interfaces mapping for wrappers

mutex sync.RWMutex
frame *frame // program data storage during execution
Expand Down Expand Up @@ -333,6 +334,7 @@ func New(options Options) *Interpreter {
universe: initUniverse(),
scopes: map[string]*scope{},
binPkg: Exports{"": map[string]reflect.Value{"_error": reflect.ValueOf((*_error)(nil))}},
mapTypes: map[reflect.Value][]reflect.Type{},
srcPkg: imports{},
pkgNames: map[string]string{},
rdir: map[string]bool{},
Expand Down Expand Up @@ -675,6 +677,14 @@ func (interp *Interpreter) Use(values Exports) error {
importPath := path.Dir(k)
packageName := path.Base(k)

if k == "." && v["MapTypes"].IsValid() {
// Use mapping for special interface wrappers.
for kk, vv := range v["MapTypes"].Interface().(map[reflect.Value][]reflect.Type) {
interp.mapTypes[kk] = vv
}
continue
}

if importPath == "." {
return fmt.Errorf("export path %[1]q is missing a package name; did you mean '%[1]s/%[1]s'?", k)
}
Expand Down Expand Up @@ -726,6 +736,14 @@ func fixStdlib(interp *Interpreter) {
p["Scanf"] = reflect.ValueOf(func(f string, a ...interface{}) (n int, err error) { return fmt.Fscanf(stdin, f, a...) })
p["Scanln"] = reflect.ValueOf(func(a ...interface{}) (n int, err error) { return fmt.Fscanln(stdin, a...) })

// Update mapTypes to virtualized symbols as well.
interp.mapTypes[p["Print"]] = interp.mapTypes[reflect.ValueOf(fmt.Print)]
interp.mapTypes[p["Printf"]] = interp.mapTypes[reflect.ValueOf(fmt.Printf)]
interp.mapTypes[p["Println"]] = interp.mapTypes[reflect.ValueOf(fmt.Println)]
interp.mapTypes[p["Scan"]] = interp.mapTypes[reflect.ValueOf(fmt.Scan)]
interp.mapTypes[p["Scanf"]] = interp.mapTypes[reflect.ValueOf(fmt.Scanf)]
interp.mapTypes[p["Scanln"]] = interp.mapTypes[reflect.ValueOf(fmt.Scanln)]

if p = interp.binPkg["flag"]; p != nil {
c := flag.NewFlagSet(os.Args[0], flag.PanicOnError)
c.SetOutput(stderr)
Expand All @@ -752,6 +770,14 @@ func fixStdlib(interp *Interpreter) {
p["SetOutput"] = reflect.ValueOf(l.SetOutput)
p["SetPrefix"] = reflect.ValueOf(l.SetPrefix)
p["Writer"] = reflect.ValueOf(l.Writer)

// Update mapTypes to virtualized symbols as well.
interp.mapTypes[p["Print"]] = interp.mapTypes[reflect.ValueOf(log.Print)]
interp.mapTypes[p["Printf"]] = interp.mapTypes[reflect.ValueOf(log.Printf)]
interp.mapTypes[p["Println"]] = interp.mapTypes[reflect.ValueOf(log.Println)]
interp.mapTypes[p["Panic"]] = interp.mapTypes[reflect.ValueOf(log.Panic)]
interp.mapTypes[p["Panicf"]] = interp.mapTypes[reflect.ValueOf(log.Panicf)]
interp.mapTypes[p["Panicln"]] = interp.mapTypes[reflect.ValueOf(log.Panicln)]
}

if p = interp.binPkg["os"]; p != nil {
Expand Down
46 changes: 33 additions & 13 deletions interp/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -1479,20 +1479,27 @@ func callBin(n *node) {
}
}

// getMapType returns a reflect type suitable for interface wrapper for functions
// with some special processing in case of interface{} argument, i.e. fmt.Printf.
var getMapType func(*itype) reflect.Type
if lr, ok := n.interp.mapTypes[c0.rval]; ok {
getMapType = func(typ *itype) reflect.Type {
for _, rt := range lr {
if typ.implements(&itype{cat: valueT, rtype: rt}) {
return rt
}
}
return nil
}
}

// Determine if we should use `Call` or `CallSlice` on the function Value.
callFn := func(v reflect.Value, in []reflect.Value) []reflect.Value { return v.Call(in) }
if n.action == aCallSlice {
callFn = func(v reflect.Value, in []reflect.Value) []reflect.Value { return v.CallSlice(in) }
}

for i, c := range child {
var defType reflect.Type
if variadic >= 0 && i+rcvrOffset >= variadic {
defType = funcType.In(variadic)
} else {
defType = funcType.In(rcvrOffset + i)
}

switch {
case isBinCall(c, c.scope):
// Handle nested function calls: pass returned values as arguments
Expand Down Expand Up @@ -1527,6 +1534,19 @@ func callBin(n *node) {
break
}

// defType is the target type for a potential interface wrapper.
var defType reflect.Type
if variadic >= 0 && i+rcvrOffset >= variadic {
defType = funcType.In(variadic)
} else {
defType = funcType.In(rcvrOffset + i)
}
if getMapType != nil {
if rt := getMapType(c.typ); rt != nil {
defType = rt
}
}

switch {
case isFuncSrc(c.typ):
values = append(values, genFunctionWrapper(c))
Expand Down Expand Up @@ -1565,7 +1585,7 @@ func callBin(n *node) {
val := make([]reflect.Value, l+1)
val[0] = value(f)
for i, v := range values {
val[i+1] = v(f)
val[i+1] = getBinValue(getMapType, v, f)
}
f.deferred = append([][]reflect.Value{val}, f.deferred...)
return tnext
Expand All @@ -1575,7 +1595,7 @@ func callBin(n *node) {
n.exec = func(f *frame) bltn {
in := make([]reflect.Value, l)
for i, v := range values {
in[i] = v(f)
in[i] = getBinValue(getMapType, v, f)
}
go callFn(value(f), in)
return tnext
Expand All @@ -1587,7 +1607,7 @@ func callBin(n *node) {
n.exec = func(f *frame) bltn {
in := make([]reflect.Value, l)
for i, v := range values {
in[i] = v(f)
in[i] = getBinValue(getMapType, v, f)
}
res := callFn(value(f), in)
b := res[0].Bool()
Expand Down Expand Up @@ -1619,7 +1639,7 @@ func callBin(n *node) {
n.exec = func(f *frame) bltn {
in := make([]reflect.Value, l)
for i, v := range values {
in[i] = v(f)
in[i] = getBinValue(getMapType, v, f)
}
out := callFn(value(f), in)
for i, v := range rvalues {
Expand All @@ -1636,7 +1656,7 @@ func callBin(n *node) {
n.exec = func(f *frame) bltn {
in := make([]reflect.Value, l)
for i, v := range values {
in[i] = v(f)
in[i] = getBinValue(getMapType, v, f)
}
out := callFn(value(f), in)
for i, v := range out {
Expand All @@ -1652,7 +1672,7 @@ func callBin(n *node) {
n.exec = func(f *frame) bltn {
in := make([]reflect.Value, l)
for i, v := range values {
in[i] = v(f)
in[i] = getBinValue(getMapType, v, f)
}
out := callFn(value(f), in)
for i := 0; i < len(out); i++ {
Expand Down
15 changes: 15 additions & 0 deletions interp/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,21 @@ func genValueOutput(n *node, t reflect.Type) func(*frame) reflect.Value {
return value
}

func getBinValue(getMapType func(*itype) reflect.Type, value func(*frame) reflect.Value, f *frame) reflect.Value {
v := value(f)
if getMapType == nil {
return v
}
val, ok := v.Interface().(valueInterface)
if !ok || val.node == nil {
return v
}
if rt := getMapType(val.node.typ); rt != nil {
return genInterfaceWrapper(val.node, rt)(f)
}
return v
}

func valueInterfaceValue(v reflect.Value) reflect.Value {
for {
vv, ok := v.Interface().(valueInterface)
Expand Down
58 changes: 58 additions & 0 deletions stdlib/maptypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package stdlib

import (
"encoding"
"encoding/json"
"encoding/xml"
"fmt"
"log"
"reflect"
)

func init() {
mt := []reflect.Type{
reflect.TypeOf((*fmt.Formatter)(nil)).Elem(),
reflect.TypeOf((*fmt.Stringer)(nil)).Elem(),
}

MapTypes[reflect.ValueOf(fmt.Errorf)] = mt
MapTypes[reflect.ValueOf(fmt.Fprint)] = mt
MapTypes[reflect.ValueOf(fmt.Fprintf)] = mt
MapTypes[reflect.ValueOf(fmt.Fprintln)] = mt
MapTypes[reflect.ValueOf(fmt.Print)] = mt
MapTypes[reflect.ValueOf(fmt.Printf)] = mt
MapTypes[reflect.ValueOf(fmt.Println)] = mt
MapTypes[reflect.ValueOf(fmt.Sprint)] = mt
MapTypes[reflect.ValueOf(fmt.Sprintf)] = mt
MapTypes[reflect.ValueOf(fmt.Sprintln)] = mt

MapTypes[reflect.ValueOf(log.Fatal)] = mt
MapTypes[reflect.ValueOf(log.Fatalf)] = mt
MapTypes[reflect.ValueOf(log.Fatalln)] = mt
MapTypes[reflect.ValueOf(log.Panic)] = mt
MapTypes[reflect.ValueOf(log.Panicf)] = mt
MapTypes[reflect.ValueOf(log.Panicln)] = mt

mt = []reflect.Type{reflect.TypeOf((*fmt.Scanner)(nil)).Elem()}

MapTypes[reflect.ValueOf(fmt.Scan)] = mt
MapTypes[reflect.ValueOf(fmt.Scanf)] = mt
MapTypes[reflect.ValueOf(fmt.Scanln)] = mt

MapTypes[reflect.ValueOf(json.Marshal)] = []reflect.Type{
reflect.TypeOf((*json.Marshaler)(nil)).Elem(),
reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem(),
}
MapTypes[reflect.ValueOf(json.Unmarshal)] = []reflect.Type{
reflect.TypeOf((*json.Unmarshaler)(nil)).Elem(),
reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem(),
}
MapTypes[reflect.ValueOf(xml.Marshal)] = []reflect.Type{
reflect.TypeOf((*xml.Marshaler)(nil)).Elem(),
reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem(),
}
MapTypes[reflect.ValueOf(xml.Unmarshal)] = []reflect.Type{
reflect.TypeOf((*xml.Unmarshaler)(nil)).Elem(),
reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem(),
}
}
9 changes: 8 additions & 1 deletion stdlib/stdlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ import "reflect"
// Symbols variable stores the map of stdlib symbols per package.
var Symbols = map[string]map[string]reflect.Value{}

// MapTypes variable contains a map of functions which have an interface{} as parameter but
// do something special if the parameter implements a given interface.
var MapTypes = map[reflect.Value][]reflect.Type{}

func init() {
Symbols["github.com/traefik/yaegi/stdlib"] = map[string]reflect.Value{
Symbols["github.com/traefik/yaegi/stdlib/stdlib"] = map[string]reflect.Value{
"Symbols": reflect.ValueOf(Symbols),
}
Symbols["."] = map[string]reflect.Value{
"MapTypes": reflect.ValueOf(MapTypes),
}
}

// Provide access to go standard library (http://golang.org/pkg/)
Expand Down

0 comments on commit 236a0ef

Please sign in to comment.