From 53d61585f46d65f217e9f4bc8b8027bdea82ea11 Mon Sep 17 00:00:00 2001 From: Felicitas Pojtinger Date: Mon, 16 Oct 2023 16:18:22 +0200 Subject: [PATCH] feat: Introduce two separate request/response Frisbee operations --- cmd/dudirekta-example-frisbee-client/main.go | 90 ++++++------ cmd/dudirekta-example-frisbee-server/main.go | 122 ++++++++++------ pkg/rpc/registry.go | 146 ++++++++++++------- 3 files changed, 218 insertions(+), 140 deletions(-) diff --git a/cmd/dudirekta-example-frisbee-client/main.go b/cmd/dudirekta-example-frisbee-client/main.go index b632ff2..f2743bf 100644 --- a/cmd/dudirekta-example-frisbee-client/main.go +++ b/cmd/dudirekta-example-frisbee-client/main.go @@ -1,13 +1,11 @@ package main import ( - "bufio" "context" "flag" "fmt" "log" "net" - "os" "time" "github.com/loopholelabs/frisbee-go" @@ -17,7 +15,8 @@ import ( ) const ( - DUDIREKTA = uint16(10) + DUDIREKTA_REQUESTS = uint16(10) + DUDIREKTA_RESPONSES = uint16(11) ) type local struct{} @@ -66,57 +65,45 @@ func main() { ) go func() { - log.Println(`Enter one of the following letters followed by to run a function on the remote(s): - -- a: Increment remote counter by one -- b: Decrement remote counter by one`) - - stdin := bufio.NewReader(os.Stdin) - for { - line, err := stdin.ReadString('\n') - if err != nil { - panic(err) - } + for _, peer := range registry.Peers() { + new, err := peer.Increment(ctx, 1) + if err != nil { + log.Println("Got error for Increment func:", err) - for peerID, peer := range registry.Peers() { - log.Println("Calling functions for peer with ID", peerID) - - switch line { - case "a\n": - new, err := peer.Increment(ctx, 1) - if err != nil { - log.Println("Got error for Increment func:", err) - - continue - } - - log.Println(new) - case "b\n": - new, err := peer.Increment(ctx, -1) - if err != nil { - log.Println("Got error for Increment func:", err) + continue + } - continue - } + log.Println(new) - log.Println(new) - default: - log.Printf("Unknown letter %v, ignoring input", line) + new, err = peer.Increment(ctx, -1) + if err != nil { + log.Println("Got error for Increment func:", err) continue } + + log.Println(new) } } }() handlers := make(frisbee.HandlerTable) - packets := make(chan []byte) - handlers[DUDIREKTA] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { + requestPackets := make(chan []byte) + handlers[DUDIREKTA_REQUESTS] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { + b := make([]byte, incoming.Metadata.ContentLength) + copy(b, incoming.Content.Bytes()) + requestPackets <- b + + return + } + + responsePackets := make(chan []byte) + handlers[DUDIREKTA_RESPONSES] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { b := make([]byte, incoming.Metadata.ContentLength) copy(b, incoming.Content.Bytes()) - packets <- b + responsePackets <- b return } @@ -138,14 +125,32 @@ func main() { func(b []byte) error { pkg := packet.Get() - pkg.Metadata.Operation = DUDIREKTA + pkg.Metadata.Operation = DUDIREKTA_REQUESTS pkg.Content.Write(b) pkg.Metadata.ContentLength = uint32(pkg.Content.Len()) return client.WritePacket(pkg) }, + func(b []byte) error { + pkg := packet.Get() + + pkg.Metadata.Operation = DUDIREKTA_RESPONSES + pkg.Content.Write(b) + pkg.Metadata.ContentLength = uint32(pkg.Content.Len()) + + return client.WritePacket(pkg) + }, + + func() ([]byte, error) { + b, ok := <-requestPackets + if !ok { + return []byte{}, net.ErrClosed + } + + return b, nil + }, func() ([]byte, error) { - b, ok := <-packets + b, ok := <-responsePackets if !ok { return []byte{}, net.ErrClosed } @@ -159,5 +164,6 @@ func main() { <-client.CloseChannel() - close(packets) + close(requestPackets) + close(responsePackets) } diff --git a/cmd/dudirekta-example-frisbee-server/main.go b/cmd/dudirekta-example-frisbee-server/main.go index 95a0813..e79b470 100644 --- a/cmd/dudirekta-example-frisbee-server/main.go +++ b/cmd/dudirekta-example-frisbee-server/main.go @@ -1,12 +1,10 @@ package main import ( - "bufio" "context" "flag" "log" "net" - "os" "sync" "sync/atomic" "time" @@ -20,7 +18,8 @@ import ( type Key int const ( - DUDIREKTA = uint16(10) + DUDIREKTA_REQUESTS = uint16(10) + DUDIREKTA_RESPONSES = uint16(11) ConnIDKey Key = iota ) @@ -71,30 +70,10 @@ func main() { ) go func() { - log.Println(`Enter one of the following letters followed by to run a function on the remote(s): - -- a: Print "Hello, world!"`) - - stdin := bufio.NewReader(os.Stdin) - for { - line, err := stdin.ReadString('\n') - if err != nil { - panic(err) - } - - for peerID, peer := range registry.Peers() { - log.Println("Calling functions for peer with ID", peerID) - - switch line { - case "a\n": - if err := peer.Println(ctx, "Hello, world!"); err != nil { - log.Println("Got error for Println func:", err) - - continue - } - default: - log.Printf("Unknown letter %v, ignoring input", line) + for _, peer := range registry.Peers() { + if err := peer.Println(ctx, "Hello, world!"); err != nil { + log.Println("Got error for Println func:", err) continue } @@ -104,20 +83,42 @@ func main() { handlers := make(frisbee.HandlerTable) - var packetsLock sync.Mutex - packets := map[string]chan []byte{} + var requestPacketsLock sync.Mutex + requestPackets := map[string]chan []byte{} + + handlers[DUDIREKTA_REQUESTS] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { + connID := ctx.Value(ConnIDKey).(string) + + requestPacketsLock.Lock() + p, ok := requestPackets[connID] + if !ok { + p = make(chan []byte) + + requestPackets[connID] = p + } + requestPacketsLock.Unlock() + + b := make([]byte, incoming.Metadata.ContentLength) + copy(b, incoming.Content.Bytes()) + p <- b - handlers[DUDIREKTA] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { + return + } + + var responsePacketsLock sync.Mutex + responsePackets := map[string]chan []byte{} + + handlers[DUDIREKTA_RESPONSES] = func(ctx context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action frisbee.Action) { connID := ctx.Value(ConnIDKey).(string) - packetsLock.Lock() - p, ok := packets[connID] + responsePacketsLock.Lock() + p, ok := responsePackets[connID] if !ok { p = make(chan []byte) - packets[connID] = p + responsePackets[connID] = p } - packetsLock.Unlock() + responsePacketsLock.Unlock() b := make([]byte, incoming.Metadata.ContentLength) copy(b, incoming.Content.Bytes()) @@ -135,15 +136,25 @@ func main() { server.SetOnClosed(func(a *frisbee.Async, err error) { connID := a.RemoteAddr().String() - packetsLock.Lock() - defer packetsLock.Unlock() + requestPacketsLock.Lock() + defer requestPacketsLock.Unlock() + + rqp, ok := requestPackets[connID] + if !ok { + return + } + + close(rqp) + + responsePacketsLock.Lock() + defer responsePacketsLock.Unlock() - p, ok := packets[connID] + rsp, ok := responsePackets[connID] if !ok { return } - close(p) + close(rsp) }) server.ConnContext = func(ctx context.Context, a *frisbee.Async) context.Context { @@ -154,21 +165,48 @@ func main() { func(b []byte) error { pkg := packet.Get() - pkg.Metadata.Operation = DUDIREKTA + pkg.Metadata.Operation = DUDIREKTA_REQUESTS pkg.Content.Write(b) pkg.Metadata.ContentLength = uint32(pkg.Content.Len()) return a.WritePacket(pkg) }, + func(b []byte) error { + pkg := packet.Get() + + pkg.Metadata.Operation = DUDIREKTA_RESPONSES + pkg.Content.Write(b) + pkg.Metadata.ContentLength = uint32(pkg.Content.Len()) + + return a.WritePacket(pkg) + }, + + func() ([]byte, error) { + requestPacketsLock.Lock() + p, ok := requestPackets[connID] + if !ok { + p = make(chan []byte) + + requestPackets[connID] = p + } + requestPacketsLock.Unlock() + + b, ok := <-p + if !ok { + return []byte{}, net.ErrClosed + } + + return b, nil + }, func() ([]byte, error) { - packetsLock.Lock() - p, ok := packets[connID] + responsePacketsLock.Lock() + p, ok := responsePackets[connID] if !ok { p = make(chan []byte) - packets[connID] = p + responsePackets[connID] = p } - packetsLock.Unlock() + responsePacketsLock.Unlock() b, ok := <-p if !ok { diff --git a/pkg/rpc/registry.go b/pkg/rpc/registry.go index 45a504d..e700fd8 100644 --- a/pkg/rpc/registry.go +++ b/pkg/rpc/registry.go @@ -98,12 +98,12 @@ func (r Registry[R]) makeRPC( errs chan error, responseResolver *broadcast.Relay[response], - write func(b []byte) error, + writeRequest func(b []byte) error, ) reflect.Value { return reflect.MakeFunc(functionType, func(args []reflect.Value) (results []reflect.Value) { callID := uuid.NewString() - cmd := []any{true, callID, name} + cmd := []any{callID, name} cmdArgs := []any{} for i, arg := range args { @@ -166,7 +166,7 @@ func (r Registry[R]) makeRPC( } }() - if err := write(b); err != nil { + if err := writeRequest(b); err != nil { errs <- err return @@ -218,8 +218,10 @@ func (r Registry[R]) makeRPC( } func (r Registry[R]) Link( - write func(b []byte) error, - read func() ([]byte, error), + writeRequest, + writeResponse func(b []byte) error, + readRequest, + readResponse func() ([]byte, error), ) error { responseResolver := broadcast.NewRelay[response]() @@ -267,7 +269,7 @@ func (r Registry[R]) Link( functionType, errs, responseResolver, - write, + writeRequest, )) } @@ -291,52 +293,50 @@ func (r Registry[R]) Link( r.remotesLock.Unlock() }() - for { - b, err := read() - if err != nil { - errs <- err + var wg sync.WaitGroup - return - } + wg.Add(1) + go func() { + defer wg.Done() - var res []json.RawMessage - if err := json.Unmarshal(b, &res); err != nil { - errs <- err + for { + b, err := readRequest() + if err != nil { + errs <- err - return - } + return + } - if len(res) != 4 { - errs <- ErrInvalidRequest + var res []json.RawMessage + if err := json.Unmarshal(b, &res); err != nil { + errs <- err - return - } + return + } - var isCall bool - if err := json.Unmarshal(res[0], &isCall); err != nil { - errs <- err + if len(res) != 3 { + errs <- ErrInvalidRequest - return - } + return + } - var callID string - if err := json.Unmarshal(res[1], &callID); err != nil { - errs <- err + var callID string + if err := json.Unmarshal(res[0], &callID); err != nil { + errs <- err - return - } + return + } - if isCall { go func() { var functionName string - if err := json.Unmarshal(res[2], &functionName); err != nil { + if err := json.Unmarshal(res[1], &functionName); err != nil { errs <- err return } var functionArgs []json.RawMessage - if err := json.Unmarshal(res[3], &functionArgs); err != nil { + if err := json.Unmarshal(res[2], &functionArgs); err != nil { errs <- err return @@ -388,7 +388,7 @@ func (r Registry[R]) Link( reflect.TypeOf(callClosureType(nil)), errs, responseResolver, - write, + writeRequest, ) rpcArgs := []interface{}{} @@ -439,28 +439,28 @@ func (r Registry[R]) Link( switch len(res) { case 0: - b, err := json.Marshal([]interface{}{false, callID, nil, ""}) + b, err := json.Marshal([]interface{}{callID, nil, ""}) if err != nil { errs <- err return } - if err := write(b); err != nil { + if err := writeResponse(b); err != nil { errs <- err return } case 1: if res[0].Type().Implements(errorType) && !res[0].IsNil() { - b, err := json.Marshal([]interface{}{false, callID, nil, res[0].Interface().(error).Error()}) + b, err := json.Marshal([]interface{}{callID, nil, res[0].Interface().(error).Error()}) if err != nil { errs <- err return } - if err := write(b); err != nil { + if err := writeResponse(b); err != nil { errs <- err return @@ -473,14 +473,14 @@ func (r Registry[R]) Link( return } - b, err := json.Marshal([]interface{}{false, callID, json.RawMessage(string(v)), ""}) + b, err := json.Marshal([]interface{}{callID, json.RawMessage(string(v)), ""}) if err != nil { errs <- err return } - if err := write(b); err != nil { + if err := writeResponse(b); err != nil { errs <- err return @@ -495,27 +495,27 @@ func (r Registry[R]) Link( } if res[1].Interface() == nil { - b, err := json.Marshal([]interface{}{false, callID, json.RawMessage(string(v)), ""}) + b, err := json.Marshal([]interface{}{callID, json.RawMessage(string(v)), ""}) if err != nil { errs <- err return } - if err := write(b); err != nil { + if err := writeResponse(b); err != nil { errs <- err return } } else { - b, err := json.Marshal([]interface{}{false, callID, json.RawMessage(string(v)), res[1].Interface().(error).Error()}) + b, err := json.Marshal([]interface{}{callID, json.RawMessage(string(v)), res[1].Interface().(error).Error()}) if err != nil { errs <- err return } - if err := write(b); err != nil { + if err := writeResponse(b); err != nil { errs <- err return @@ -524,23 +524,57 @@ func (r Registry[R]) Link( } }() }() - - continue } + }() - var errMsg string - if err := json.Unmarshal(res[3], &errMsg); err != nil { - errs <- err + wg.Add(1) + go func() { + defer wg.Done() - return - } + for { + b, err := readResponse() + if err != nil { + errs <- err - if strings.TrimSpace(errMsg) != "" { - err = errors.New(errMsg) + return + } + + var res []json.RawMessage + if err := json.Unmarshal(b, &res); err != nil { + errs <- err + + return + } + + if len(res) != 3 { + errs <- ErrInvalidRequest + + return + } + + var callID string + if err := json.Unmarshal(res[0], &callID); err != nil { + errs <- err + + return + } + + var errMsg string + if err := json.Unmarshal(res[2], &errMsg); err != nil { + errs <- err + + return + } + + if strings.TrimSpace(errMsg) != "" { + err = errors.New(errMsg) + } + + responseResolver.Broadcast(response{callID, res[1], err, false}) } + }() - responseResolver.Broadcast(response{callID, res[2], err, false}) - } + wg.Wait() }() return <-errs