Skip to content

Commit

Permalink
feat: configure stdin, stdout and stderr per interpreter
Browse files Browse the repository at this point in the history
The goal is to provide greater control of input, output and error
streams of the interpreter. It is now possible to specify those
as options when creating a new interpreter. The provided values
are propagated to relevant stdlib symbols (i.e fmt.Print, etc).
Care is taken to not update the global variables os.Stdout, os.Stdin
and os.Stderr, as to not interfere with the host process.

The REPL function is now simplified. The deprecated version is removed.

The tests are updated to take advantage of the simplified access
to the interpreter output and errors.

Fixes #752.
  • Loading branch information
mvertes committed Aug 31, 2020
1 parent f4cc059 commit 341c69d
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 80 deletions.
3 changes: 2 additions & 1 deletion _test/eval0.go
Expand Up @@ -2,13 +2,14 @@ package main

import (
"log"
"os"

"github.com/containous/yaegi/interp"
)

func main() {
log.SetFlags(log.Lshortfile)
i := interp.New(interp.Options{})
i := interp.New(interp.Options{Stdout: os.Stdout})
if _, err := i.Eval(`func f() (int, int) { return 1, 2 }`); err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 0 additions & 3 deletions _test/inception.go
Expand Up @@ -20,6 +20,3 @@ func main() {
log.Fatal(err)
}
}

// Output:
// 42
14 changes: 7 additions & 7 deletions cmd/yaegi/run.go
Expand Up @@ -56,14 +56,14 @@ func run(arg []string) error {
}

if cmd != "" {
i.REPL(strings.NewReader(cmd), os.Stderr)
_, err = i.Eval(cmd)
}

if len(args) == 0 {
if interactive || cmd == "" {
i.REPL(os.Stdin, os.Stdout)
_, err = i.REPL()
}
return nil
return err
}

// Skip first os arg to set command line as expected by interpreted main
Expand All @@ -85,9 +85,9 @@ func run(arg []string) error {
}

if interactive {
i.REPL(os.Stdin, os.Stdout)
_, err = i.REPL()
}
return nil
return err
}

func isPackageName(path string) bool {
Expand Down Expand Up @@ -116,7 +116,7 @@ func runFile(i *interp.Interpreter, path string) error {
if s := string(b); strings.HasPrefix(s, "#!") {
// Allow executable go scripts, Have the same behavior as in interactive mode.
s = strings.Replace(s, "#!", "//", 1)
i.REPL(strings.NewReader(s), os.Stdout)
_, err = i.Eval(s)
} else {
// Files not starting with "#!" are supposed to be pure Go, directly Evaled.
_, err := i.EvalPath(path)
Expand All @@ -127,5 +127,5 @@ func runFile(i *interp.Interpreter, path string) error {
}
}
}
return nil
return err
}
6 changes: 4 additions & 2 deletions cmd/yaegi/yaegi_test.go
Expand Up @@ -2,11 +2,13 @@ package main

import (
"bytes"
"context"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -108,8 +110,8 @@ func TestYaegiCmdCancel(t *testing.T) {
continue
}

if outBuf.String() != "context canceled\n" {
t.Errorf("unexpected output: %q", &outBuf)
if strings.TrimSuffix(errBuf.String(), "\n") != context.Canceled.Error() {
t.Errorf("unexpected error: %q", &errBuf)
}
}
}
30 changes: 3 additions & 27 deletions example/pkg/pkg_test.go
Expand Up @@ -3,7 +3,6 @@ package pkg
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -104,39 +103,16 @@ func TestPackages(t *testing.T) {
t.Fatal(err)
}

// Init go interpreter
i := interp.New(interp.Options{GoPath: goPath})
var stdout, stderr bytes.Buffer
i := interp.New(interp.Options{GoPath: goPath, Stdout: &stdout, Stderr: &stderr})
i.Use(stdlib.Symbols) // Use binary standard library

var msg string
if test.evalFile != "" {
// TODO(mpl): this is brittle if we do concurrent tests and stuff, do better later.
stdout := os.Stdout
defer func() { os.Stdout = stdout }()
pr, pw, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdout = pw

if _, err := i.EvalPath(test.evalFile); err != nil {
fatalStderrf(t, "%v", err)
}

var buf bytes.Buffer
errC := make(chan error)
go func() {
_, err := io.Copy(&buf, pr)
errC <- err
}()

if err := pw.Close(); err != nil {
fatalStderrf(t, "%v", err)
}
if err := <-errC; err != nil {
fatalStderrf(t, "%v", err)
}
msg = buf.String()
msg = stdout.String()
} else {
// Load pkg from sources
topImport := "github.com/foo/pkg"
Expand Down
117 changes: 100 additions & 17 deletions interp/interp.go
Expand Up @@ -4,12 +4,14 @@ import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"go/build"
"go/scanner"
"go/token"
"io"
"io/ioutil"
"log"
"os"
"os/signal"
"reflect"
Expand Down Expand Up @@ -115,6 +117,9 @@ type opt struct {
noRun bool // compile, but do not run
fastChan bool // disable cancellable chan operations
context build.Context // build context: GOPATH, build constraints
stdin io.Reader // standard input
stdout io.Writer // standard output
stderr io.Writer // standard error
}

// Interpreter contains global resources and state.
Expand Down Expand Up @@ -207,10 +212,16 @@ func (n *node) Walk(in func(n *node) bool, out func(n *node)) {

// Options are the interpreter options.
type Options struct {
// GoPath sets GOPATH for the interpreter
// GoPath sets GOPATH for the interpreter.
GoPath string
// BuildTags sets build constraints for the interpreter

// BuildTags sets build constraints for the interpreter.
BuildTags []string

// Standard input, output and error streams.
// They default to os.Stding, os.Stdout and os.Stderr respectively.
Stdin io.Reader
Stdout, Stderr io.Writer
}

// New returns a new interpreter.
Expand All @@ -228,6 +239,18 @@ func New(options Options) *Interpreter {
hooks: &hooks{},
}

if i.opt.stdin = options.Stdin; i.opt.stdin == nil {
i.opt.stdin = os.Stdin
}

if i.opt.stdout = options.Stdout; i.opt.stdout == nil {
i.opt.stdout = os.Stdout
}

if i.opt.stderr = options.Stderr; i.opt.stderr == nil {
i.opt.stderr = os.Stderr
}

i.opt.context.GOPATH = options.GoPath
if len(options.BuildTags) > 0 {
i.opt.context.BuildTags = options.BuildTags
Expand Down Expand Up @@ -546,6 +569,68 @@ func (interp *Interpreter) Use(values Exports) {
interp.binPkg[k][s] = sym
}
}

// Checks if input values correspond to stdlib packages by looking for one
// well knwonw stdlib package path.
if _, ok := values["fmt"]; ok {
fixStdio(interp)
}
}

// fixStdio redefines interpreter stdlib symbols to use the standard input,
// output and errror assigned to the interpreter. The changes are limited to
// the interpreter only. Global values os.Stdin, os.Stdout and os.Stderr are
// not changed. Note that it is possible to escape the virtualized stdio by
// read/write directly to file descriptors 0, 1, 2.
func fixStdio(interp *Interpreter) {
p := interp.binPkg["fmt"]
if p == nil {
return
}

stdin, stdout, stderr := interp.stdin, interp.stdout, interp.stderr

p["Print"] = reflect.ValueOf(func(a ...interface{}) (n int, err error) { return fmt.Fprint(stdout, a...) })
p["Printf"] = reflect.ValueOf(func(f string, a ...interface{}) (n int, err error) { return fmt.Fprintf(stdout, f, a...) })
p["Println"] = reflect.ValueOf(func(a ...interface{}) (n int, err error) { return fmt.Fprintln(stdout, a...) })

p["Scan"] = reflect.ValueOf(func(a ...interface{}) (n int, err error) { return fmt.Fscan(stdin, a...) })
p["Scanf"] = reflect.ValueOf(func(f string, a ...interface{}) (n int, err error) { return fmt.Fscanf(stdin, f, a...) })
p["Scanln"] = reflect.ValueOf(func(a ...interface{}) (n int, err error) { return fmt.Fscanln(stdin, a...) })

if p = interp.binPkg["flag"]; p != nil {
c := flag.NewFlagSet(os.Args[0], flag.PanicOnError)
c.SetOutput(stderr)
p["CommandLine"] = reflect.ValueOf(&c).Elem()
}

if p = interp.binPkg["log"]; p != nil {
l := log.New(stderr, "", log.LstdFlags)
// Restrict Fatal symbols to panic instead of exit.
p["Fatal"] = reflect.ValueOf(l.Panic)
p["Fatalf"] = reflect.ValueOf(l.Panicf)
p["Fatalln"] = reflect.ValueOf(l.Panicln)

p["Flags"] = reflect.ValueOf(l.Flags)
p["Output"] = reflect.ValueOf(l.Output)
p["Panic"] = reflect.ValueOf(l.Panic)
p["Panicf"] = reflect.ValueOf(l.Panicf)
p["Panicln"] = reflect.ValueOf(l.Panicln)
p["Prefix"] = reflect.ValueOf(l.Prefix)
p["Print"] = reflect.ValueOf(l.Print)
p["Printf"] = reflect.ValueOf(l.Printf)
p["Println"] = reflect.ValueOf(l.Println)
p["SetFlags"] = reflect.ValueOf(l.SetFlags)
p["SetOutput"] = reflect.ValueOf(l.SetOutput)
p["SetPrefix"] = reflect.ValueOf(l.SetPrefix)
p["Writer"] = reflect.ValueOf(l.Writer)
}

if p = interp.binPkg["os"]; p != nil {
p["Stdin"] = reflect.ValueOf(&stdin).Elem()
p["Stdout"] = reflect.ValueOf(&stdout).Elem()
p["Stderr"] = reflect.ValueOf(&stderr).Elem()
}
}

// ignoreScannerError returns true if the error from Go scanner can be safely ignored
Expand All @@ -565,8 +650,10 @@ func ignoreScannerError(e *scanner.Error, s string) bool {
}

// REPL performs a Read-Eval-Print-Loop on input reader.
// Results are printed on output writer.
func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
// Results are printed to the output writer of the Interpreter, provided as option
// at creation time. Errors are printed to the similarly defined errors writer.
// The last interpreter result value and error are returned.
func (interp *Interpreter) REPL() (reflect.Value, error) {
// Preimport used bin packages, to avoid having to import these packages manually
// in REPL mode. These packages are already loaded anyway.
sc := interp.universe
Expand All @@ -580,6 +667,7 @@ func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
sc.sym[name] = &symbol{kind: pkgSym, typ: &itype{cat: binPkgT, path: k, scope: sc}}
}

in, out, errs := interp.stdin, interp.stdout, interp.stderr
ctx, cancel := context.WithCancel(context.Background())
end := make(chan struct{}) // channel to terminate the REPL
sig := make(chan os.Signal, 1) // channel to trap interrupt signal (Ctrl-C)
Expand All @@ -599,7 +687,9 @@ func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
for s.Scan() {
lines <- s.Text()
}
// TODO(mpl): log s.Err() if not nil?
if e := s.Err(); e != nil {
fmt.Fprintln(errs, e)
}
}()

go func() {
Expand All @@ -620,7 +710,7 @@ func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
select {
case <-end:
cancel()
return
return v, err
case line = <-lines:
src += line + "\n"
}
Expand All @@ -632,12 +722,12 @@ func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
if len(e) > 0 && ignoreScannerError(e[0], line) {
continue
}
fmt.Fprintln(out, strings.TrimPrefix(e[0].Error(), DefaultSourceName+":"))
fmt.Fprintln(errs, strings.TrimPrefix(e[0].Error(), DefaultSourceName+":"))
case Panic:
fmt.Fprintln(out, e.Value)
fmt.Fprintln(out, string(e.Stack))
fmt.Fprintln(errs, e.Value)
fmt.Fprintln(errs, string(e.Stack))
default:
fmt.Fprintln(out, err)
fmt.Fprintln(errs, err)
}
}
if errors.Is(err, context.Canceled) {
Expand All @@ -648,13 +738,6 @@ func (interp *Interpreter) REPL(in io.Reader, out io.Writer) {
}
}

// Repl performs a Read-Eval-Print-Loop on input file descriptor.
// Results are printed on output.
// Deprecated: use REPL instead.
func (interp *Interpreter) Repl(in, out *os.File) {
interp.REPL(in, out)
}

// getPrompt returns a function which prints a prompt only if input is a terminal.
func getPrompt(in io.Reader, out io.Writer) func(reflect.Value) {
s, ok := in.(interface{ Stat() (os.FileInfo, error) })
Expand Down
20 changes: 18 additions & 2 deletions interp/interp_eval_test.go
Expand Up @@ -166,6 +166,20 @@ func TestEvalImport(t *testing.T) {
})
}

func TestEvalStdout(t *testing.T) {
var out, err bytes.Buffer
i := interp.New(interp.Options{Stdout: &out, Stderr: &err})
i.Use(stdlib.Symbols)
_, e := i.Eval(`import "fmt"; func main() { fmt.Println("hello") }`)
if e != nil {
t.Fatal(e)
}
wanted := "hello\n"
if res := out.String(); res != wanted {
t.Fatalf("got %v, want %v", res, wanted)
}
}

func TestEvalNil(t *testing.T) {
i := interp.New(interp.Options{})
i.Use(stdlib.Symbols)
Expand Down Expand Up @@ -991,10 +1005,12 @@ func TestEvalScanner(t *testing.T) {
}

runREPL := func(t *testing.T, test testCase) {
i := interp.New(interp.Options{})
var stdout bytes.Buffer
safeStdout := &safeBuffer{buf: &stdout}
var stderr bytes.Buffer
safeStderr := &safeBuffer{buf: &stderr}
pin, pout := io.Pipe()
i := interp.New(interp.Options{Stdin: pin, Stdout: safeStdout, Stderr: safeStderr})
defer func() {
// Closing the pipe also takes care of making i.REPL terminate,
// hence freeing its goroutine.
Expand All @@ -1003,7 +1019,7 @@ func TestEvalScanner(t *testing.T) {
}()

go func() {
i.REPL(pin, safeStderr)
_, _ = i.REPL()
}()
for k, v := range test.src {
if _, err := pout.Write([]byte(v + "\n")); err != nil {
Expand Down

0 comments on commit 341c69d

Please sign in to comment.