Skip to content

Commit

Permalink
Use go context
Browse files Browse the repository at this point in the history
  • Loading branch information
sjwiesman committed Aug 21, 2020
1 parent 5ad28b8 commit aca3da2
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 77 deletions.
10 changes: 6 additions & 4 deletions examples/greeter/greeter.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package main

import (
"context"
"fmt"
"github.com/golang/protobuf/ptypes/any"
"google.golang.org/protobuf/types/known/anypb"
"net/http"

"github.com/sjwiesman/statefun-go/pkg/flink/statefun"
Expand All @@ -16,7 +17,7 @@ var egressId = io.EgressIdentifier{

type Greeter struct{}

func (greeter Greeter) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.Any) error {
func (greeter Greeter) Invoke(ctx context.Context, runtime statefun.StatefulFunctionRuntime, _ *anypb.Any) error {
var seen SeenCount
if err := runtime.Get("seen_count", &seen); err != nil {
return err
Expand All @@ -28,11 +29,12 @@ func (greeter Greeter) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.A
return err
}

response := computeGreeting(runtime.Self().Id, seen.Seen)
self := statefun.Self(ctx)
response := computeGreeting(self.Id, seen.Seen)

record := io.KafkaRecord{
Topic: "greetings",
Key: runtime.Self().Id,
Key: statefun.Self(ctx).Id,
Value: response,
}

Expand Down
23 changes: 23 additions & 0 deletions pkg/flink/statefun/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package statefun

import "context"

type contextKey string

const (
selfKey = contextKey("self")
callerKey = contextKey("caller")
)

// Self returns the address of the current
// function instance under evaluation
func Self(ctx context.Context) *Address {
return ctx.Value(selfKey).(*Address)
}

// Caller returns the address of the caller function.
// The caller may be nil if the message
// was sent directly from an ingress
func Caller(ctx context.Context) *Address {
return ctx.Value(callerKey).(*Address)
}
3 changes: 2 additions & 1 deletion pkg/flink/statefun/function.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package statefun

import (
"context"
any "google.golang.org/protobuf/types/known/anypb"
)

Expand All @@ -25,5 +26,5 @@ import (
type StatefulFunction interface {

// Invoke this function with the given input.
Invoke(runtime StatefulFunctionRuntime, message *any.Any) error
Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *any.Any) error
}
44 changes: 26 additions & 18 deletions pkg/flink/statefun/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"github.com/golang/protobuf/ptypes/any"
"github.com/sjwiesman/statefun-go/pkg/flink/statefun/internal/messages"
"github.com/valyala/bytebufferpool"
"google.golang.org/protobuf/types/known/anypb"
"log"
"net/http"
)

type StatefulFunctionPointer func(ctx StatefulFunctionRuntime, message *any.Any) error
type StatefulFunctionPointer func(ctx context.Context, runtime StatefulFunctionRuntime, message *any.Any) error

// Keeps a mapping from FunctionType to stateful functions.
// Use this together with an http endpoint to serve
Expand All @@ -28,11 +29,11 @@ type FunctionRegistry interface {
}

type pointer struct {
f func(ctx StatefulFunctionRuntime, message *any.Any) error
f func(ctx context.Context, runtime StatefulFunctionRuntime, message *any.Any) error
}

func (pointer *pointer) Invoke(runtime StatefulFunctionRuntime, message *any.Any) error {
return pointer.f(runtime, message)
func (pointer *pointer) Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *anypb.Any) error {
return pointer.f(ctx, runtime, msg)
}

type functions struct {
Expand Down Expand Up @@ -126,6 +127,20 @@ func validRequest(w http.ResponseWriter, req *http.Request) bool {
return true
}

func fromInternal(address *messages.Address) *Address {
if address == nil {
return nil
}

return &Address{
FunctionType: FunctionType{
Namespace: address.Namespace,
Type: address.Type,
},
Id: address.Id,
}
}

func executeBatch(functions functions, ctx context.Context, request *messages.ToFunction) (*messages.FromFunction, error) {
invocations := request.GetInvocation()
if invocations == nil {
Expand All @@ -142,27 +157,20 @@ func executeBatch(functions functions, ctx context.Context, request *messages.To
return nil, errors.New(funcType.String() + " does not exist")
}

runtime := newStateFunIO(invocations.Target, invocations.State)
runtime := newStateFunIO(invocations.State)

self := fromInternal(invocations.Target)
ctx = context.WithValue(ctx, selfKey, self)
for _, invocation := range invocations.Invocations {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
if invocation.Caller == nil {
runtime.caller = nil
} else {
runtime.caller = &Address{
FunctionType: FunctionType{
Namespace: invocation.Caller.Namespace,
Type: invocation.Caller.Type,
},
Id: invocation.Caller.Id,
}
}
err := function.Invoke(runtime, (*invocation).Argument)
caller := fromInternal(invocation.Caller)
ctx = context.WithValue(ctx, callerKey, caller)
err := function.Invoke(ctx, runtime, (*invocation).Argument)
if err != nil {
return nil, fmt.Errorf("failed to execute function %s %w", runtime.self.String(), err)
return nil, fmt.Errorf("failed to execute function %s %w", self.String(), err)
}
}
}
Expand Down
50 changes: 5 additions & 45 deletions pkg/flink/statefun/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ import (
// reading and writing persisted state values with exactly-once guarantees provided
// by the runtime.
type StatefulFunctionRuntime interface {
// Self returns the address of the current
// function instance under evaluation
Self() Address

// Caller returns the address of the caller function.
// The caller may be nil if the message
// was sent directly from an ingress
Caller() *Address

// Get retrieves the state for the given name and
// unmarshalls the encoded value contained into the provided message state.
Expand All @@ -43,20 +35,14 @@ type StatefulFunctionRuntime interface {

// Invokes another function with an input, identified by the target function's Address
// and marshals the given message into an any.Any.
Send(target Address, message proto.Message) error

// Invokes the calling function of the current invocation under execution. This has the same effect
// as calling Send with the address obtained from Caller, and
// will not work if the current function was not invoked by another function.
// This method marshals the given message into an any.Any.
Reply(message proto.Message) error
Send(target *Address, message proto.Message) error

// Invokes another function with an input, identified by the target function's
// FunctionType and unique id after a specified delay. This method is durable
// and as such, the message will not be lost if the system experiences
// downtime between when the message is sent and the end of the duration.
// This method marshals the given message into an any.Any.
SendAfter(target Address, duration time.Duration, message proto.Message) error
SendAfter(target *Address, duration time.Duration, message proto.Message) error

// Sends an output to an EgressIdentifier.
// This method marshals the given message into an any.Any.
Expand All @@ -74,8 +60,6 @@ type state struct {
// It tracks all responses that will be sent back to the
// Flink runtime after the full batch has been executed.
type runtime struct {
self Address
caller *Address
states map[string]*state
invocations []*messages.FromFunction_Invocation
delayedInvocation []*messages.FromFunction_DelayedInvocation
Expand All @@ -84,16 +68,8 @@ type runtime struct {

// Create a new runtime based on the target function
// and set of initial states.
func newStateFunIO(self *messages.Address, persistedValues []*messages.ToFunction_PersistedValue) *runtime {
func newStateFunIO(persistedValues []*messages.ToFunction_PersistedValue) *runtime {
ctx := &runtime{
self: Address{
FunctionType: FunctionType{
Namespace: self.Namespace,
Type: self.Type,
},
Id: self.Id,
},
caller: nil,
states: map[string]*state{},
invocations: []*messages.FromFunction_Invocation{},
delayedInvocation: []*messages.FromFunction_DelayedInvocation{},
Expand All @@ -116,14 +92,6 @@ func newStateFunIO(self *messages.Address, persistedValues []*messages.ToFunctio
return ctx
}

func (tracker *runtime) Self() Address {
return tracker.self
}

func (tracker *runtime) Caller() *Address {
return tracker.caller
}

func (tracker *runtime) Get(name string, state proto.Message) error {
packedState := tracker.states[name]
if packedState == nil {
Expand Down Expand Up @@ -159,7 +127,7 @@ func (tracker *runtime) Clear(name string) {
_ = tracker.Set(name, nil)
}

func (tracker *runtime) Send(target Address, message proto.Message) error {
func (tracker *runtime) Send(target *Address, message proto.Message) error {
if message == nil {
return errors.New("cannot send nil message to function")
}
Expand All @@ -182,15 +150,7 @@ func (tracker *runtime) Send(target Address, message proto.Message) error {
return nil
}

func (tracker *runtime) Reply(message proto.Message) error {
if tracker.caller == nil {
return errors.New("cannot reply to nil caller")
}

return tracker.Send(*tracker.caller, message)
}

func (tracker *runtime) SendAfter(target Address, duration time.Duration, message proto.Message) error {
func (tracker *runtime) SendAfter(target *Address, duration time.Duration, message proto.Message) error {
if message == nil {
return errors.New("cannot send nil message to function")
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/flink/statefun/stateful_functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package statefun

import (
"bytes"
"context"
"errors"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/sjwiesman/statefun-go/pkg/flink/statefun/internal/test"
"github.com/sjwiesman/statefun-go/pkg/flink/statefun/io"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -166,7 +168,7 @@ func TestValidation(t *testing.T) {

type Greeter struct{}

func (f Greeter) Invoke(runtime StatefulFunctionRuntime, msg *any.Any) error {
func (f Greeter) Invoke(ctx context.Context, runtime StatefulFunctionRuntime, msg *anypb.Any) error {
if err := ptypes.UnmarshalAny(msg, &test.Invoke{}); err != nil {
return err
}
Expand All @@ -182,11 +184,12 @@ func (f Greeter) Invoke(runtime StatefulFunctionRuntime, msg *any.Any) error {
Greeting: "Hello",
}

if err := runtime.Reply(greeting); err != nil {
caller := Caller(ctx)
if err := runtime.Send(caller, greeting); err != nil {
return err
}

if err := runtime.SendAfter(*runtime.Caller(), time.Duration(6e+10), greeting); err != nil {
if err := runtime.SendAfter(caller, time.Duration(6e+10), greeting); err != nil {
return err
}

Expand Down
14 changes: 8 additions & 6 deletions test/verification.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package main

import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"github.com/golang/protobuf/ptypes/any"
"net/http"
"github.com/sjwiesman/statefun-go/pkg/flink/statefun"
"github.com/sjwiesman/statefun-go/pkg/flink/statefun/io"
"google.golang.org/protobuf/types/known/anypb"
"net/http"
)

func randToken(n int) string {
Expand All @@ -20,7 +22,7 @@ func randToken(n int) string {

type CounterFunction struct{}

func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any.Any) error {
func (c CounterFunction) Invoke(ctx context.Context, runtime statefun.StatefulFunctionRuntime, msg *anypb.Any) error {
var count InvokeCount
if err := runtime.Get("invoke_count", &count); err != nil {
return fmt.Errorf("unable to deserialize invoke_count %w", err)
Expand All @@ -34,10 +36,10 @@ func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any

response := &InvokeResult{
InvokeCount: count.Count,
Id: runtime.Self().Id,
Id: statefun.Self(ctx).Id,
}

target := statefun.Address{
target := &statefun.Address{
FunctionType: statefun.FunctionType{
Namespace: "org.apache.flink.statefun.e2e.remote",
Type: "forward-function",
Expand All @@ -48,15 +50,15 @@ func (c CounterFunction) Invoke(runtime statefun.StatefulFunctionRuntime, _ *any
return runtime.Send(target, response)
}

func ForwardFunction(runtime statefun.StatefulFunctionRuntime, msg *any.Any) error {
func ForwardFunction(ctx context.Context, runtime statefun.StatefulFunctionRuntime, msg *any.Any) error {
egress := io.EgressIdentifier{
EgressNamespace: "org.apache.flink.statefun.e2e.remote",
EgressType: "invoke-results",
}

record := io.KafkaRecord{
Topic: "invoke-results",
Key: runtime.Self().Id,
Key: statefun.Self(ctx).Id,
Value: msg,
}

Expand Down

0 comments on commit aca3da2

Please sign in to comment.