diff --git a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java index 9532e23..3745ca2 100644 --- a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java +++ b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java @@ -962,4 +962,44 @@ int proxyGetMetric(int metricId, int returnValuePtr) { return e.result().getValue(); } } + + @WasmExport + int proxyGetSharedData( + int keyDataPtr, int keySize, int returnValueData, int returnValueSize, int returnCas) { + try { + // Get key from memory + String key = string(readMemory(keyDataPtr, keySize)); + + // Get shared data value using handler + Handler.SharedData value = handler.getSharedData(key); + if (value == null) { + return WasmResult.NOT_FOUND.getValue(); + } + + copyIntoInstance(value.data, returnValueData, returnValueSize); + putUint32(returnCas, value.cas); + return WasmResult.OK.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } + + @WasmExport + int proxySetSharedData(int keyDataPtr, int keySize, int valueDataPtr, int valueSize, int cas) { + try { + // Get key from memory + String key = string(readMemory(keyDataPtr, keySize)); + + // Get value from memory + byte[] value = readMemory(valueDataPtr, valueSize); + + // Set shared data value using handler + WasmResult result = handler.setSharedData(key, value, cas); + return result.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java index 7216a8e..b85aa95 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java @@ -296,4 +296,14 @@ public WasmResult incrementMetric(int metricId, long value) { public long getMetric(int metricId) throws WasmException { return next().getMetric(metricId); } + + @Override + public SharedData getSharedData(String key) throws WasmException { + return next().getSharedData(key); + } + + @Override + public WasmResult setSharedData(String key, byte[] value, int cas) { + return next().setSharedData(key, value, cas); + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java index 168f6e1..cd3f1fb 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java @@ -183,11 +183,11 @@ default int getCurrentTimeNanoseconds() throws WasmException { /** * Send an HTTP response. * - * @param responseCode The HTTP response code + * @param responseCode The HTTP response code * @param responseCodeDetails The response code details - * @param responseBody The response body - * @param additionalHeaders Additional headers to include - * @param grpcStatus The gRPC status code (-1 for non-gRPC responses) + * @param responseBody The response body + * @param additionalHeaders Additional headers to include + * @param grpcStatus The gRPC status code (-1 for non-gRPC responses) * @return The result of sending the response */ default WasmResult sendHttpResponse( @@ -273,7 +273,7 @@ default WasmResult setFuncCallData(byte[] data) { * Set a custom buffer. * * @param bufferType The buffer type - * @param buffer The custom buffer as a byte[] + * @param buffer The custom buffer as a byte[] * @return WasmResult indicating success or failure */ default WasmResult setCustomBuffer(int bufferType, byte[] buffer) { @@ -284,7 +284,7 @@ default WasmResult setCustomBuffer(int bufferType, byte[] buffer) { * Set a custom header map. * * @param mapType The type of map to set - * @param map The header map to set + * @param map The header map to set * @return WasmResult indicating success or failure */ default WasmResult setCustomHeaders(int mapType, Map map) { @@ -422,4 +422,22 @@ default WasmResult incrementMetric(int metricId, long value) { default long getMetric(int metricId) throws WasmException { throw new WasmException(WasmResult.UNIMPLEMENTED); } + + class SharedData { + public byte[] data; + public int cas; + + public SharedData(byte[] data, int cas) { + this.data = data; + this.cas = cas; + } + } + + default SharedData getSharedData(String key) throws WasmException { + throw new WasmException(WasmResult.UNIMPLEMENTED); + } + + default WasmResult setSharedData(String key, byte[] value, int cas) { + return WasmResult.UNIMPLEMENTED; + } } diff --git a/src/test/go-examples/shared_data/README.md b/src/test/go-examples/shared_data/README.md new file mode 100644 index 0000000..90d2373 --- /dev/null +++ b/src/test/go-examples/shared_data/README.md @@ -0,0 +1,4 @@ +## Attribution + +This example originally came from: +https://github.com/proxy-wasm/proxy-wasm-go-sdk/blob/ab4161dcf9246a828008b539a82a1556cf0f2e24/examples/shared_data diff --git a/src/test/go-examples/shared_data/go.mod b/src/test/go-examples/shared_data/go.mod new file mode 100644 index 0000000..7a51ab9 --- /dev/null +++ b/src/test/go-examples/shared_data/go.mod @@ -0,0 +1,5 @@ +module github.com/proxy-wasm/proxy-wasm-go-sdk/examples/shared_data + +go 1.24 + +require github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 diff --git a/src/test/go-examples/shared_data/go.sum b/src/test/go-examples/shared_data/go.sum new file mode 100644 index 0000000..3ddb896 --- /dev/null +++ b/src/test/go-examples/shared_data/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 h1:wTcK6gcyTKJMeDka69AMjZYvisdI8CBXzTEfZ+2pOxI= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924/go.mod h1:9mBRvh8I6Td6sg3CwEY+zGFE4DKaIoieCaca1kQnDBE= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/test/go-examples/shared_data/main.go b/src/test/go-examples/shared_data/main.go new file mode 100644 index 0000000..d56f3c2 --- /dev/null +++ b/src/test/go-examples/shared_data/main.go @@ -0,0 +1,106 @@ +// Copyright 2020-2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/binary" + "errors" + + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm" + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm/types" +) + +const ( + sharedDataKey = "shared_data_key" +) + +func main() {} +func init() { + proxywasm.SetVMContext(&vmContext{}) +} + +type ( + // vmContext implements types.VMContext. + vmContext struct{} + // pluginContext implements types.PluginContext. + pluginContext struct { + // Embed the default plugin context here, + // so that we don't need to reimplement all the methods. + types.DefaultPluginContext + } + // httpContext implements types.HttpContext. + httpContext struct { + // Embed the default http context here, + // so that we don't need to reimplement all the methods. + types.DefaultHttpContext + } +) + +// OnVMStart implements types.VMContext. +func (*vmContext) OnVMStart(vmConfigurationSize int) types.OnVMStartStatus { + initialValueBuf := make([]byte, 0) // Empty data to indicate that the data is not initialized. + if err := proxywasm.SetSharedData(sharedDataKey, initialValueBuf, 0); err != nil { + proxywasm.LogWarnf("error setting shared data on OnVMStart: %v", err) + } + return types.OnVMStartStatusOK +} + +// NewPluginContext implements types.VMContext. +func (*vmContext) NewPluginContext(contextID uint32) types.PluginContext { + return &pluginContext{} +} + +// NewHttpContext implements types.PluginContext. +func (*pluginContext) NewHttpContext(contextID uint32) types.HttpContext { + return &httpContext{} +} + +// OnHttpRequestHeaders implements types.HttpContext. +func (ctx *httpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) types.Action { + for { + value, err := ctx.incrementData() + if err == nil { + proxywasm.LogInfof("shared value: %d", value) + } else if errors.Is(err, types.ErrorStatusCasMismatch) { + continue + } + break + } + return types.ActionContinue +} + +// incrementData increments the shared data value by 1. +func (ctx *httpContext) incrementData() (uint64, error) { + data, cas, err := proxywasm.GetSharedData(sharedDataKey) + if err != nil { + proxywasm.LogWarnf("error getting shared data on OnHttpRequestHeaders: %v", err) + return 0, err + } + + var nextValue uint64 + if len(data) > 0 { + nextValue = binary.LittleEndian.Uint64(data) + 1 + } else { + nextValue = 1 + } + + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, nextValue) + if err := proxywasm.SetSharedData(sharedDataKey, buf, cas); err != nil { + proxywasm.LogWarnf("error setting shared data on OnHttpRequestHeaders: %v", err) + return 0, err + } + return nextValue, err +} diff --git a/src/test/go-examples/shared_data/main.wasm b/src/test/go-examples/shared_data/main.wasm new file mode 100644 index 0000000..d794230 Binary files /dev/null and b/src/test/go-examples/shared_data/main.wasm differ diff --git a/src/test/java/io/roastedroot/proxywasm/MockHandler.java b/src/test/java/io/roastedroot/proxywasm/MockHandler.java index 4d64a6a..a63e4c8 100644 --- a/src/test/java/io/roastedroot/proxywasm/MockHandler.java +++ b/src/test/java/io/roastedroot/proxywasm/MockHandler.java @@ -435,4 +435,31 @@ public WasmResult setAction(StreamType streamType, Action action) { public Action getAction() { return action; } + + private final HashMap sharedData = new HashMap<>(); + + @Override + public SharedData getSharedData(String key) throws WasmException { + return sharedData.get(key); + } + + @Override + public WasmResult setSharedData(String key, byte[] value, int cas) { + SharedData prev = sharedData.get(key); + if (prev == null) { + if (cas == 0) { + sharedData.put(key, new SharedData(value, 0)); + return WasmResult.OK; + } else { + return WasmResult.CAS_MISMATCH; + } + } else { + if (cas == 0 || prev.cas == cas) { + sharedData.put(key, new SharedData(value, prev.cas + 1)); + return WasmResult.OK; + } else { + return WasmResult.CAS_MISMATCH; + } + } + } } diff --git a/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java b/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java new file mode 100644 index 0000000..b8f9f38 --- /dev/null +++ b/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java @@ -0,0 +1,48 @@ +package io.roastedroot.proxywasm; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.dylibso.chicory.wasm.Parser; +import io.roastedroot.proxywasm.v1.Action; +import io.roastedroot.proxywasm.v1.ProxyWasm; +import io.roastedroot.proxywasm.v1.StartException; +import java.nio.file.Path; +import org.junit.jupiter.api.Test; + +/** + * Java port of https://github.com/proxy-wasm/proxy-wasm-go-sdk/blob/master/examples/shared_data/main_test.go + */ +public class SharedDataTest { + + @Test + public void testSetEffectiveContext() throws StartException { + + var handler = new MockHandler(); + // Load the WASM module + var module = Parser.parse(Path.of("./src/test/go-examples/shared_data/main.wasm")); + + // Create and configure the ProxyWasm instance + try (var host = ProxyWasm.builder().withPluginHandler(handler).build(module)) { + + // Initialize context + try (var context = host.createHttpContext(handler)) { + + // Call OnHttpRequestHeaders. + Action action = context.callOnRequestHeaders(false); + assertEquals(Action.CONTINUE, action); + + // Check Envoy logs. + handler.assertLogsContain("shared value: 1"); + + // Call OnHttpRequestHeaders again. + action = context.callOnRequestHeaders(false); + assertEquals(Action.CONTINUE, action); + action = context.callOnRequestHeaders(false); + assertEquals(Action.CONTINUE, action); + + // Check Envoy logs. + handler.assertLogsContain("shared value: 3"); + } + } + } +}