From aca3da282d7ac64edea253a21dc7673c41cf6e4d Mon Sep 17 00:00:00 2001 From: Seth Wiesman Date: Thu, 20 Aug 2020 22:19:02 -0500 Subject: [PATCH] Use go context --- examples/greeter/greeter.go | 10 ++-- pkg/flink/statefun/context.go | 23 +++++++++ pkg/flink/statefun/function.go | 3 +- pkg/flink/statefun/registry.go | 44 +++++++++------- pkg/flink/statefun/runtime.go | 50 ++----------------- pkg/flink/statefun/stateful_functions_test.go | 9 ++-- test/verification.go | 14 +++--- 7 files changed, 76 insertions(+), 77 deletions(-) create mode 100644 pkg/flink/statefun/context.go diff --git a/examples/greeter/greeter.go b/examples/greeter/greeter.go index 3f9637e..e742e68 100644 --- a/examples/greeter/greeter.go +++ b/examples/greeter/greeter.go @@ -1,8 +1,9 @@ package main import ( + "context" "fmt" - "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/types/known/anypb" "net/http" "github.com/sjwiesman/statefun-go/pkg/flink/statefun" @@ -16,7 +17,7 @@ var egressId = io.EgressIdentifier{ type Greeter struct{} -func (greeter Greeter) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.Any) error { +func (greeter Greeter) Invoke(ctx context.Context, runtime statefun.StatefulFunctionRuntime, _ *anypb.Any) error { var seen SeenCount if err := runtime.Get("seen_count", &seen); err != nil { return err @@ -28,11 +29,12 @@ func (greeter Greeter) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.A return err } - response := computeGreeting(runtime.Self().Id, seen.Seen) + self := statefun.Self(ctx) + response := computeGreeting(self.Id, seen.Seen) record := io.KafkaRecord{ Topic: "greetings", - Key: runtime.Self().Id, + Key: statefun.Self(ctx).Id, Value: response, } diff --git a/pkg/flink/statefun/context.go b/pkg/flink/statefun/context.go new file mode 100644 index 0000000..23e1c6e --- /dev/null +++ b/pkg/flink/statefun/context.go @@ -0,0 +1,23 @@ +package statefun + +import "context" + +type contextKey string + +const ( + selfKey = contextKey("self") + callerKey = contextKey("caller") +) + +// Self returns the address of the current +// function instance under evaluation +func Self(ctx context.Context) *Address { + return ctx.Value(selfKey).(*Address) +} + +// Caller returns the address of the caller function. +// The caller may be nil if the message +// was sent directly from an ingress +func Caller(ctx context.Context) *Address { + return ctx.Value(callerKey).(*Address) +} diff --git a/pkg/flink/statefun/function.go b/pkg/flink/statefun/function.go index f6ea00c..3f9631c 100644 --- a/pkg/flink/statefun/function.go +++ b/pkg/flink/statefun/function.go @@ -1,6 +1,7 @@ package statefun import ( + "context" any "google.golang.org/protobuf/types/known/anypb" ) @@ -25,5 +26,5 @@ import ( type StatefulFunction interface { // Invoke this function with the given input. - Invoke(runtime StatefulFunctionRuntime, message *any.Any) error + Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *any.Any) error } diff --git a/pkg/flink/statefun/registry.go b/pkg/flink/statefun/registry.go index 9cea712..4404805 100644 --- a/pkg/flink/statefun/registry.go +++ b/pkg/flink/statefun/registry.go @@ -8,11 +8,12 @@ import ( "github.com/golang/protobuf/ptypes/any" "github.com/sjwiesman/statefun-go/pkg/flink/statefun/internal/messages" "github.com/valyala/bytebufferpool" + "google.golang.org/protobuf/types/known/anypb" "log" "net/http" ) -type StatefulFunctionPointer func(ctx StatefulFunctionRuntime, message *any.Any) error +type StatefulFunctionPointer func(ctx context.Context, runtime StatefulFunctionRuntime, message *any.Any) error // Keeps a mapping from FunctionType to stateful functions. // Use this together with an http endpoint to serve @@ -28,11 +29,11 @@ type FunctionRegistry interface { } type pointer struct { - f func(ctx StatefulFunctionRuntime, message *any.Any) error + f func(ctx context.Context, runtime StatefulFunctionRuntime, message *any.Any) error } -func (pointer *pointer) Invoke(runtime StatefulFunctionRuntime, message *any.Any) error { - return pointer.f(runtime, message) +func (pointer *pointer) Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *anypb.Any) error { + return pointer.f(ctx, runtime, msg) } type functions struct { @@ -126,6 +127,20 @@ func validRequest(w http.ResponseWriter, req *http.Request) bool { return true } +func fromInternal(address *messages.Address) *Address { + if address == nil { + return nil + } + + return &Address{ + FunctionType: FunctionType{ + Namespace: address.Namespace, + Type: address.Type, + }, + Id: address.Id, + } +} + func executeBatch(functions functions, ctx context.Context, request *messages.ToFunction) (*messages.FromFunction, error) { invocations := request.GetInvocation() if invocations == nil { @@ -142,27 +157,20 @@ func executeBatch(functions functions, ctx context.Context, request *messages.To return nil, errors.New(funcType.String() + " does not exist") } - runtime := newStateFunIO(invocations.Target, invocations.State) + runtime := newStateFunIO(invocations.State) + self := fromInternal(invocations.Target) + ctx = context.WithValue(ctx, selfKey, self) for _, invocation := range invocations.Invocations { select { case <-ctx.Done(): return nil, ctx.Err() default: - if invocation.Caller == nil { - runtime.caller = nil - } else { - runtime.caller = &Address{ - FunctionType: FunctionType{ - Namespace: invocation.Caller.Namespace, - Type: invocation.Caller.Type, - }, - Id: invocation.Caller.Id, - } - } - err := function.Invoke(runtime, (*invocation).Argument) + caller := fromInternal(invocation.Caller) + ctx = context.WithValue(ctx, callerKey, caller) + err := function.Invoke(ctx, runtime, (*invocation).Argument) if err != nil { - return nil, fmt.Errorf("failed to execute function %s %w", runtime.self.String(), err) + return nil, fmt.Errorf("failed to execute function %s %w", self.String(), err) } } } diff --git a/pkg/flink/statefun/runtime.go b/pkg/flink/statefun/runtime.go index 87c5946..2b843d7 100644 --- a/pkg/flink/statefun/runtime.go +++ b/pkg/flink/statefun/runtime.go @@ -18,14 +18,6 @@ import ( // reading and writing persisted state values with exactly-once guarantees provided // by the runtime. type StatefulFunctionRuntime interface { - // Self returns the address of the current - // function instance under evaluation - Self() Address - - // Caller returns the address of the caller function. - // The caller may be nil if the message - // was sent directly from an ingress - Caller() *Address // Get retrieves the state for the given name and // unmarshalls the encoded value contained into the provided message state. @@ -43,20 +35,14 @@ type StatefulFunctionRuntime interface { // Invokes another function with an input, identified by the target function's Address // and marshals the given message into an any.Any. - Send(target Address, message proto.Message) error - - // Invokes the calling function of the current invocation under execution. This has the same effect - // as calling Send with the address obtained from Caller, and - // will not work if the current function was not invoked by another function. - // This method marshals the given message into an any.Any. - Reply(message proto.Message) error + Send(target *Address, message proto.Message) error // Invokes another function with an input, identified by the target function's // FunctionType and unique id after a specified delay. This method is durable // and as such, the message will not be lost if the system experiences // downtime between when the message is sent and the end of the duration. // This method marshals the given message into an any.Any. - SendAfter(target Address, duration time.Duration, message proto.Message) error + SendAfter(target *Address, duration time.Duration, message proto.Message) error // Sends an output to an EgressIdentifier. // This method marshals the given message into an any.Any. @@ -74,8 +60,6 @@ type state struct { // It tracks all responses that will be sent back to the // Flink runtime after the full batch has been executed. type runtime struct { - self Address - caller *Address states map[string]*state invocations []*messages.FromFunction_Invocation delayedInvocation []*messages.FromFunction_DelayedInvocation @@ -84,16 +68,8 @@ type runtime struct { // Create a new runtime based on the target function // and set of initial states. -func newStateFunIO(self *messages.Address, persistedValues []*messages.ToFunction_PersistedValue) *runtime { +func newStateFunIO(persistedValues []*messages.ToFunction_PersistedValue) *runtime { ctx := &runtime{ - self: Address{ - FunctionType: FunctionType{ - Namespace: self.Namespace, - Type: self.Type, - }, - Id: self.Id, - }, - caller: nil, states: map[string]*state{}, invocations: []*messages.FromFunction_Invocation{}, delayedInvocation: []*messages.FromFunction_DelayedInvocation{}, @@ -116,14 +92,6 @@ func newStateFunIO(self *messages.Address, persistedValues []*messages.ToFunctio return ctx } -func (tracker *runtime) Self() Address { - return tracker.self -} - -func (tracker *runtime) Caller() *Address { - return tracker.caller -} - func (tracker *runtime) Get(name string, state proto.Message) error { packedState := tracker.states[name] if packedState == nil { @@ -159,7 +127,7 @@ func (tracker *runtime) Clear(name string) { _ = tracker.Set(name, nil) } -func (tracker *runtime) Send(target Address, message proto.Message) error { +func (tracker *runtime) Send(target *Address, message proto.Message) error { if message == nil { return errors.New("cannot send nil message to function") } @@ -182,15 +150,7 @@ func (tracker *runtime) Send(target Address, message proto.Message) error { return nil } -func (tracker *runtime) Reply(message proto.Message) error { - if tracker.caller == nil { - return errors.New("cannot reply to nil caller") - } - - return tracker.Send(*tracker.caller, message) -} - -func (tracker *runtime) SendAfter(target Address, duration time.Duration, message proto.Message) error { +func (tracker *runtime) SendAfter(target *Address, duration time.Duration, message proto.Message) error { if message == nil { return errors.New("cannot send nil message to function") } diff --git a/pkg/flink/statefun/stateful_functions_test.go b/pkg/flink/statefun/stateful_functions_test.go index 77da197..bb36b1d 100644 --- a/pkg/flink/statefun/stateful_functions_test.go +++ b/pkg/flink/statefun/stateful_functions_test.go @@ -2,6 +2,7 @@ package statefun import ( "bytes" + "context" "errors" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" @@ -10,6 +11,7 @@ import ( "github.com/sjwiesman/statefun-go/pkg/flink/statefun/internal/test" "github.com/sjwiesman/statefun-go/pkg/flink/statefun/io" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/anypb" "io/ioutil" "net/http" "net/http/httptest" @@ -166,7 +168,7 @@ func TestValidation(t *testing.T) { type Greeter struct{} -func (f Greeter) Invoke(runtime StatefulFunctionRuntime, msg *any.Any) error { +func (f Greeter) Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *anypb.Any) error { if err := ptypes.UnmarshalAny(msg, &test.Invoke{}); err != nil { return err } @@ -182,11 +184,12 @@ func (f Greeter) Invoke(runtime StatefulFunctionRuntime, msg *any.Any) error { Greeting: "Hello", } - if err := runtime.Reply(greeting); err != nil { + caller := Caller(ctx) + if err := runtime.Send(caller, greeting); err != nil { return err } - if err := runtime.SendAfter(*runtime.Caller(), time.Duration(6e+10), greeting); err != nil { + if err := runtime.SendAfter(caller, time.Duration(6e+10), greeting); err != nil { return err } diff --git a/test/verification.go b/test/verification.go index a626595..00b9fdc 100644 --- a/test/verification.go +++ b/test/verification.go @@ -1,13 +1,15 @@ package main import ( + "context" "crypto/rand" "encoding/hex" "fmt" "github.com/golang/protobuf/ptypes/any" - "net/http" "github.com/sjwiesman/statefun-go/pkg/flink/statefun" "github.com/sjwiesman/statefun-go/pkg/flink/statefun/io" + "google.golang.org/protobuf/types/known/anypb" + "net/http" ) func randToken(n int) string { @@ -20,7 +22,7 @@ func randToken(n int) string { type CounterFunction struct{} -func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.Any) error { +func (c CounterFunction) Invoke(ctx context.Context, runtime statefun.StatefulFunctionRuntime, msg *anypb.Any) error { var count InvokeCount if err := runtime.Get("invoke_count", &count); err != nil { return fmt.Errorf("unable to deserialize invoke_count %w", err) @@ -34,10 +36,10 @@ func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any response := &InvokeResult{ InvokeCount: count.Count, - Id: runtime.Self().Id, + Id: statefun.Self(ctx).Id, } - target := statefun.Address{ + target := &statefun.Address{ FunctionType: statefun.FunctionType{ Namespace: "org.apache.flink.statefun.e2e.remote", Type: "forward-function", @@ -48,7 +50,7 @@ func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any return runtime.Send(target, response) } -func ForwardFunction(runtime statefun.StatefulFunctionRuntime, msg *any.Any) error { +func ForwardFunction(ctx context.Context, runtime statefun.StatefulFunctionRuntime, msg *any.Any) error { egress := io.EgressIdentifier{ EgressNamespace: "org.apache.flink.statefun.e2e.remote", EgressType: "invoke-results", @@ -56,7 +58,7 @@ func ForwardFunction(runtime statefun.StatefulFunctionRuntime, msg *any.Any) err record := io.KafkaRecord{ Topic: "invoke-results", - Key: runtime.Self().Id, + Key: statefun.Self(ctx).Id, Value: msg, }