Skip to content

Commit

Permalink
filecache: add trailing crc to serialized executable code (#2091)
Browse files Browse the repository at this point in the history
  • Loading branch information
evacchi committed Feb 25, 2024
1 parent 52212ad commit 1b2fd85
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 45 deletions.
13 changes: 13 additions & 0 deletions internal/engine/compiler/engine_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"runtime"

Expand All @@ -14,6 +15,8 @@ import (
"github.com/tetratelabs/wazero/internal/wasm"
)

var crc = crc32.MakeTable(crc32.Castagnoli)

func (e *engine) deleteCompiledModule(module *wasm.Module) {
e.mux.Lock()
defer e.mux.Unlock()
Expand Down Expand Up @@ -130,6 +133,9 @@ func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader
buf.Write(u64.LeBytes(uint64(cm.executable.Len())))
// Append the native code.
buf.Write(cm.executable.Bytes())
// Append checksum.
checksum := crc32.Checksum(cm.executable.Bytes(), crc)
buf.Write(u32.LeBytes(checksum))
return bytes.NewReader(buf.Bytes())
}

Expand Down Expand Up @@ -209,6 +215,13 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser, modul
return
}

expected := crc32.Checksum(cm.executable.Bytes(), crc)
if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil {
return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err)
} else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum {
return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum)
}

if runtime.GOARCH == "arm64" {
// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
if err = platform.MprotectRX(cm.executable.Bytes()); err != nil {
Expand Down
99 changes: 74 additions & 25 deletions internal/engine/compiler/engine_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/binary"
"errors"
"hash/crc32"
"io"
"math"
"testing"
Expand All @@ -20,6 +21,11 @@ import (

var testVersion = ""

func crcf(b []byte) []byte {
c := crc32.Checksum(b, crc)
return u32.LeBytes(c)
}

func concat(ins ...[]byte) (ret []byte) {
for _, in := range ins {
ret = append(ret, in...)
Expand Down Expand Up @@ -49,12 +55,13 @@ func TestSerializeCompiledModule(t *testing.T) {
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
[]byte{0}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
[]byte{0}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
crcf([]byte{1, 2, 3, 4, 5}), // crc of code.
),
},
{
Expand All @@ -71,12 +78,13 @@ func TestSerializeCompiledModule(t *testing.T) {
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
[]byte{1}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
[]byte{1}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
crcf([]byte{1, 2, 3, 4, 5}), // crc of code.
),
},
{
Expand All @@ -103,8 +111,9 @@ func TestSerializeCompiledModule(t *testing.T) {
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(5), // offset.
// Executable.
u64.LeBytes(8), // length of code.
[]byte{1, 2, 3, 4, 5, 1, 2, 3}, // code.
u64.LeBytes(8), // length of code.
[]byte{1, 2, 3, 4, 5, 1, 2, 3}, // code.
crcf([]byte{1, 2, 3, 4, 5, 1, 2, 3}), // crc of code.
),
},
}
Expand Down Expand Up @@ -151,7 +160,25 @@ func TestDeserializeCompiledModule(t *testing.T) {
expStaleCache: true,
},
{
name: "one function",
name: "invalid crc",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
[]byte{0}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
// Executable.
u64.LeBytes(5), // size.
[]byte{1, 2, 3, 4, 5}, // machine code.
crcf([]byte{1, 2, 3, 4}), // crc of code.
),
expStaleCache: false,
expErr: "compilationcache: checksum mismatch (expected 1397854123, got 691047668)",
},
{
name: "missing crc",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
Expand All @@ -164,6 +191,24 @@ func TestDeserializeCompiledModule(t *testing.T) {
u64.LeBytes(5), // size.
[]byte{1, 2, 3, 4, 5}, // machine code.
),
expStaleCache: false,
expErr: "compilationcache: could not read checksum: EOF",
},
{
name: "one function",
in: concat(
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
[]byte{0}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
// Executable.
u64.LeBytes(5), // size.
[]byte{1, 2, 3, 4, 5}, // machine code.
crcf([]byte{1, 2, 3, 4, 5}), // crc of code.
),
expCompiledModule: &compiledModule{
compiledCode: &compiledCode{
executable: makeCodeSegment(1, 2, 3, 4, 5),
Expand All @@ -181,12 +226,13 @@ func TestDeserializeCompiledModule(t *testing.T) {
[]byte(wazeroMagic),
[]byte{byte(len(testVersion))},
[]byte(testVersion),
[]byte{1}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
[]byte{1}, // ensure termination.
u32.LeBytes(1), // number of functions.
u64.LeBytes(12345), // stack pointer ceil.
u64.LeBytes(0), // offset.
u64.LeBytes(5), // length of code.
[]byte{1, 2, 3, 4, 5}, // code.
crcf([]byte{1, 2, 3, 4, 5}), // crc of code.
),
expCompiledModule: &compiledModule{
compiledCode: &compiledCode{
Expand All @@ -213,8 +259,9 @@ func TestDeserializeCompiledModule(t *testing.T) {
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(7), // offset.
// Executable.
u64.LeBytes(10), // size.
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code.
u64.LeBytes(10), // size.
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code.
crcf([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), // crc of code.
),
importedFunctionCount: 1,
expCompiledModule: &compiledModule{
Expand Down Expand Up @@ -322,8 +369,9 @@ func TestEngine_getCompiledModuleFromCache(t *testing.T) {
u64.LeBytes(0xffffffff), // stack pointer ceil.
u64.LeBytes(5), // offset.
// executables.
u64.LeBytes(10), // length of code.
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code.
u64.LeBytes(10), // length of code.
[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code.
crcf([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), // code.
)

tests := []struct {
Expand Down Expand Up @@ -475,6 +523,7 @@ func TestEngine_addCompiledModuleToCache(t *testing.T) {
u64.LeBytes(0), // offset.
u64.LeBytes(3), // size of executable.
[]byte{1, 2, 3},
crcf([]byte{1, 2, 3}), // code.
), actual)
require.NoError(t, content.Close())
})
Expand Down
13 changes: 13 additions & 0 deletions internal/engine/wazevo/engine_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"runtime"
"unsafe"
Expand All @@ -21,6 +22,8 @@ import (
"github.com/tetratelabs/wazero/internal/wasm"
)

var crc = crc32.MakeTable(crc32.Castagnoli)

// fileCacheKey returns a key for the file cache.
// In order to avoid collisions with the existing compiler, we do not use m.ID directly,
// but instead we rehash it with magic.
Expand Down Expand Up @@ -145,6 +148,9 @@ func serializeCompiledModule(wazeroVersion string, cm *compiledModule) io.Reader
buf.Write(u64.LeBytes(uint64(len(cm.executable))))
// Append the native code.
buf.Write(cm.executable)
// Append checksum.
checksum := crc32.Checksum(cm.executable, crc)
buf.Write(u32.LeBytes(checksum))
if sm := cm.sourceMap; len(sm.executableOffsets) > 0 {
buf.WriteByte(1) // indicates that source map is present.
l := len(sm.wasmBinaryOffsets)
Expand Down Expand Up @@ -226,6 +232,13 @@ func deserializeCompiledModule(wazeroVersion string, reader io.ReadCloser) (cm *
return nil, false, err
}

expected := crc32.Checksum(executable, crc)
if _, err = io.ReadFull(reader, eightBytes[:4]); err != nil {
return nil, false, fmt.Errorf("compilationcache: could not read checksum: %v", err)
} else if checksum := binary.LittleEndian.Uint32(eightBytes[:4]); expected != checksum {
return nil, false, fmt.Errorf("compilationcache: checksum mismatch (expected %d, got %d)", expected, checksum)
}

if runtime.GOARCH == "arm64" {
// On arm64, we cannot give all of rwx at the same time, so we change it to exec.
if err = platform.MprotectRX(executable); err != nil {
Expand Down
Loading

0 comments on commit 1b2fd85

Please sign in to comment.