Skip to content

Commit

Permalink
feat(plugins): Use wazero instead of wasmtime (#3042)
Browse files Browse the repository at this point in the history
* feat(plugins): Use wazero instead of wasmtime

* Remove wasm build tags

* Update internal/ext/wasm/wasm.go

Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com>

* Update internal/ext/wasm/wasm.go

Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com>

* Fix build

* Suggestions for PR #3042 (#3082)

* Only compile wasm once per process

* Remove unused

* Store runtime in flightgroup as well

* Handle error from instantiate

---------

Co-authored-by: Anuraag Agrawal <anuraaga@gmail.com>
  • Loading branch information
kyleconroy and anuraaga committed Jan 2, 2024
1 parent 4f7fca7 commit 4188d23
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 184 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.21

require (
github.com/antlr4-go/antlr/v4 v4.13.0
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0
github.com/cubicdaiya/gonp v1.0.4
github.com/davecgh/go-spew v1.1.1
github.com/fatih/structtag v1.2.0
Expand All @@ -20,6 +19,7 @@ require (
github.com/riza-io/grpc-go v0.2.0
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/tetratelabs/wazero v1.5.0
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99
github.com/xeipuuv/gojsonschema v1.2.0
golang.org/x/sync v0.5.0
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0
github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI=
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM=
github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
Expand Down Expand Up @@ -183,6 +181,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0=
github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk=
github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE=
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU=
github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6/go.mod h1:IQNVyA4d1hWIe23mlMMuqXjyWMdndgSlNx6FqBkwPsM=
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99 h1:HFee1ByN4FrqNVd53Mo28ccGO+g5gxqUV/gdvKMe4b8=
github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99/go.mod h1:f2JMhFocVxY3VKMd9ykUxMnX4EVew9WOgjnfaNBB6C8=
github.com/wasilibs/wazerox v0.0.0-20231208014050-e6b725634531 h1:zVJ4SZgaEE9sEH2L9k1+eAvCNa/WAAnT9UiMa3/tQrI=
Expand Down
1 change: 0 additions & 1 deletion internal/endtoend/case_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type Exec struct {
Contexts []string `json:"contexts"`
Process string `json:"process"`
OS []string `json:"os"`
WASM bool `json:"wasm"`
Env map[string]string `json:"env"`
}

Expand Down
5 changes: 0 additions & 5 deletions internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (

"github.com/sqlc-dev/sqlc/internal/cmd"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/ext/wasm"
"github.com/sqlc-dev/sqlc/internal/opts"
)

Expand Down Expand Up @@ -177,10 +176,6 @@ func TestReplay(t *testing.T) {
}
}

if args.WASM && !wasm.Enabled() {
t.Skipf("wasm support not enabled")
}

if len(args.OS) > 0 {
if !slices.Contains(args.OS, runtime.GOOS) {
t.Skipf("unsupported os: %s", runtime.GOOS)
Expand Down

This file was deleted.

This file was deleted.

23 changes: 0 additions & 23 deletions internal/ext/wasm/nowasm.go

This file was deleted.

207 changes: 61 additions & 146 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))

// The above build constraint is based of the cgo directives in this file:
// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
package wasm

import (
"bytes"
"context"
"crypto/sha256"
"errors"
Expand All @@ -15,10 +12,11 @@ import (
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
"golang.org/x/sync/singleflight"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -31,31 +29,13 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func Enabled() bool {
return true
}

// This version must be updated whenever the wasmtime-go dependency is updated
const wasmtimeVersion = `v14.0.0`
var flight singleflight.Group

func cacheDir() (string, error) {
cache := os.Getenv("SQLCCACHE")
if cache != "" {
return cache, nil
}
cacheHome := os.Getenv("XDG_CACHE_HOME")
if cacheHome == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
cacheHome = filepath.Join(home, ".cache")
}
return filepath.Join(cacheHome, "sqlc"), nil
type runtimeAndCode struct {
rt wazero.Runtime
code wazero.CompiledModule
}

var flight singleflight.Group

// Verify the provided sha256 is valid.
func (r *Runner) getChecksum(ctx context.Context) (string, error) {
if r.SHA256 != "" {
Expand All @@ -70,67 +50,26 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) {
return sum, nil
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) {
expected, err := r.getChecksum(ctx)
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine, expected)
})
if err != nil {
return nil, err
}
data, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
cacheDir, err := cache.PluginsDir()
if err != nil {
return nil, err
}

pluginDir := filepath.Join(cacheDir, expectedSha)
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
modPath := filepath.Join(pluginDir, modName)
_, staterr := os.Stat(modPath)
if staterr == nil {
data, err := os.ReadFile(modPath)
if err != nil {
return nil, err
}
return data, nil
}

wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadAndCompileWASM(ctx, cacheDir, expected)
})
if err != nil {
return nil, err
}

moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
module, err := wasmtime.NewModule(engine, wmod)
moduRegion.End()
if err != nil {
return nil, fmt.Errorf("define wasi: %w", err)
}

err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
data, ok := value.(*runtimeAndCode)
if !ok {
return nil, fmt.Errorf("returned value was not a compiled module")
}

return out, nil
return data, nil
}

func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
Expand Down Expand Up @@ -174,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error)
return wmod, actual, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) {
pluginDir := filepath.Join(cache, expected)
pluginPath := filepath.Join(pluginDir, "plugin.wasm")
_, staterr := os.Stat(pluginPath)
Expand Down Expand Up @@ -203,7 +142,26 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([
}
}

return wmod, nil
wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero"))
if err != nil {
return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err)
}

config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache)
rt := wazero.NewRuntimeWithConfig(ctx, config)

if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil {
return nil, fmt.Errorf("wasi_snapshot_preview1 instantiate: %w", err)
}

// Compile the Wasm binary once so that we can skip the entire compilation
// time during instantiation.
code, err := rt.CompileModule(ctx, wmod)
if err != nil {
return nil, fmt.Errorf("compile module: %w", err)
}

return &runtimeAndCode{rt: rt, code: code}, nil
}

// removePGCatalog removes the pg_catalog schema from the request. There is a
Expand Down Expand Up @@ -245,75 +203,34 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any,
return fmt.Errorf("failed to encode codegen request: %w", err)
}

engine := wasmtime.NewEngine()
module, err := r.loadModule(ctx, engine)
if err != nil {
return fmt.Errorf("loadModule: %w", err)
}

linker := wasmtime.NewLinker(engine)
if err := linker.DefineWasi(); err != nil {
return err
}

dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
runtimeAndCode, err := r.loadAndCompile(ctx)
if err != nil {
return fmt.Errorf("temp dir: %w", err)
}

defer os.RemoveAll(dir)
stdinPath := filepath.Join(dir, "stdin")
stderrPath := filepath.Join(dir, "stderr")
stdoutPath := filepath.Join(dir, "stdout")

if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
return fmt.Errorf("write file: %w", err)
return fmt.Errorf("loadBytes: %w", err)
}

// Configure WASI imports to write stdout into a file.
wasiConfig := wasmtime.NewWasiConfig()
wasiConfig.SetArgv([]string{"plugin.wasm", method})
wasiConfig.SetStdinFile(stdinPath)
wasiConfig.SetStdoutFile(stdoutPath)
wasiConfig.SetStderrFile(stderrPath)
var stderr, stdout bytes.Buffer

keys := []string{"SQLC_VERSION"}
vals := []string{info.Version}
conf := wazero.NewModuleConfig().
WithName("").
WithArgs("plugin.wasm", method).
WithStdin(bytes.NewReader(stdinBlob)).
WithStdout(&stdout).
WithStderr(&stderr).
WithEnv("SQLC_VERSION", info.Version)
for _, key := range r.Env {
keys = append(keys, key)
vals = append(vals, os.Getenv(key))
}
wasiConfig.SetEnv(keys, vals)

store := wasmtime.NewStore(engine)
store.SetWasi(wasiConfig)

linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
err = linker.DefineModule(store, "", module)
linkRegion.End()
if err != nil {
return fmt.Errorf("define wasi: %w", err)
conf = conf.WithEnv(key, os.Getenv(key))
}

// Run the function
fn, err := linker.GetDefault(store, "")
if err != nil {
return fmt.Errorf("wasi: get default: %w", err)
result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf)
if result != nil {
defer result.Close(ctx)
}

callRegion := trace.StartRegion(ctx, "call _start")
_, err = fn.Call(store)
callRegion.End()

if cerr := checkError(err, stderrPath); cerr != nil {
if cerr := checkError(err, stderr); cerr != nil {
return cerr
}

// Print WASM stdout
stdoutBlob, err := os.ReadFile(stdoutPath)
if err != nil {
return fmt.Errorf("read file: %w", err)
}
stdoutBlob := stdout.Bytes()

resp, ok := reply.(protoreflect.ProtoMessage)
if !ok {
Expand All @@ -331,23 +248,21 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st
return nil, status.Error(codes.Unimplemented, "")
}

func checkError(err error, stderrPath string) error {
func checkError(err error, stderr bytes.Buffer) error {
if err == nil {
return err
}

var wtError *wasmtime.Error
if errors.As(err, &wtError) {
if code, ok := wtError.ExitStatus(); ok {
if code == 0 {
return nil
}
if exitErr, ok := err.(*sys.ExitError); ok {
if exitErr.ExitCode() == 0 {
return nil
}
}

// Print WASM stdout
stderrBlob, rferr := os.ReadFile(stderrPath)
if rferr == nil && len(stderrBlob) > 0 {
return errors.New(string(stderrBlob))
stderrBlob := stderr.String()
if len(stderrBlob) > 0 {
return errors.New(stderrBlob)
}
return fmt.Errorf("call: %w", err)
}

0 comments on commit 4188d23

Please sign in to comment.