diff --git a/.gitignore b/.gitignore index aa6673e..4055cb2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ suppressions/ .idea cmd/rsocket-cli/rsocket-cli + +coverage.out diff --git a/.travis.yml b/.travis.yml index 4af9046..87e0d0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,14 +3,13 @@ language: go go: - 1.x -env: - - GO111MODULE=on - -before_install: - - go get -u golang.org/x/lint/golint +install: + - go get golang.org/x/lint/golint - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.27.0 + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls script: -# - golint ./... - golangci-lint run ./... - - go test -race -count=1 . -v + - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 ./... + - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/README.md b/README.md index 5822ad6..16db495 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,10 @@ ![logo](./logo.jpg) [![Build Status](https://travis-ci.com/rsocket/rsocket-go.svg?branch=master)](https://travis-ci.com/rsocket/rsocket-go) +[![Coverage Status](https://coveralls.io/repos/github/rsocket/rsocket-go/badge.svg?branch=master)](https://coveralls.io/github/rsocket/rsocket-go?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/rsocket/rsocket-go)](https://goreportcard.com/report/github.com/rsocket/rsocket-go) [![Slack](https://img.shields.io/badge/slack-rsocket--go-blue.svg)](https://rsocket.slack.com/messages/C9VGZ5MV3) [![GoDoc](https://godoc.org/github.com/rsocket/rsocket-go?status.svg)](https://godoc.org/github.com/rsocket/rsocket-go) -[![Go Report Card](https://goreportcard.com/badge/github.com/rsocket/rsocket-go)](https://goreportcard.com/report/github.com/rsocket/rsocket-go) [![License](https://img.shields.io/github/license/rsocket/rsocket-go.svg)](https://github.com/rsocket/rsocket-go/blob/master/LICENSE) [![GitHub Release](https://img.shields.io/github/release-pre/rsocket/rsocket-go.svg)](https://github.com/rsocket/rsocket-go/releases) @@ -28,6 +29,7 @@ package main import ( "context" + "log" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/payload" @@ -46,10 +48,11 @@ func main() { }), ), nil }). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). Serve(context.Background()) - panic(err) + log.Fatalln(err) } + ``` > Connect to echo server @@ -71,7 +74,7 @@ func main() { Resume(). Fragment(1024). SetupPayload(payload.NewString("Hello", "World")). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()). Start(context.Background()) if err != nil { panic(err) @@ -131,12 +134,13 @@ func main() { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnSuccess(func(input payload.Payload) { + DoOnSuccess(func(input payload.Payload) error { // Handle and consume payload. // Do something here... fmt.Println("bingo:", input) + return nil }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) <-done @@ -155,9 +159,9 @@ import ( "context" "fmt" - flxx "github.com/jjeffcaii/reactor-go/flux" "github.com/rsocket/rsocket-go/extension" "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" ) @@ -177,24 +181,27 @@ func main() { // payload.NewString("qux", extension.TextPlain.String()), //) - f. - DoOnNext(func(elem payload.Payload) { + // Block + _, _ = f. + DoOnNext(func(elem payload.Payload) error { // Handle and consume elements // Do something here... fmt.Println("bingo:", elem) + return nil }). - Subscribe(context.Background()) + BlockLast(context.Background()) - // Or you can use Raw reactor-go API. :-D - f2 := flux.Raw(flxx.Range(0, 10).Map(func(i interface{}) interface{} { - return payload.NewString(fmt.Sprintf("Hello@%d", i.(int)), extension.TextPlain.String()) + // Subscribe + f.Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) error { + fmt.Println("bingo:", input) + return nil })) - f2. - DoOnNext(func(input payload.Payload) { - fmt.Println("bingo:", input) - }). - BlockLast(context.Background()) + + // Or implement your own subscriber + var s rx.Subscriber + f.SubscribeWith(context.Background(), s) } + ``` #### Backpressure & RequestN @@ -237,65 +244,18 @@ func main() { su = s su.Request(1) }), - rx.OnNext(func(elem payload.Payload) { + rx.OnNext(func(elem payload.Payload) error { // Consume element, do something... fmt.Println("bingo:", elem) // Request for next one manually. su.Request(1) + return nil }), ) } - -``` - -#### Logging - -We do not use a specific log implementation. You can register your own log implementation. For example: - -```go -package main - -import ( - "log" - - "github.com/rsocket/rsocket-go/logger" -) - -func init() { - logger.SetFunc(logger.LevelDebug|logger.LevelInfo|logger.LevelWarn|logger.LevelError, func(template string, args ...interface{}) { - // Implement your own logger here... - log.Printf(template, args...) - }) - logger.SetLevel(logger.LevelInfo) -} - ``` #### Dependencies - [reactor-go](https://github.com/jjeffcaii/reactor-go) - [testify](https://github.com/stretchr/testify) - [websocket](https://github.com/gorilla/websocket) - -### TODO - -#### Transport - - [x] TCP - - [x] Websocket - -#### Duplex Socket - - [x] MetadataPush - - [x] RequestFNF - - [x] RequestResponse - - [x] RequestStream - - [x] RequestChannel - -##### Others - - [x] Resume - - [x] Keepalive - - [x] Fragmentation - - [x] Thin Reactor - - [x] Cancel - - [x] Error - - [x] Flow Control: RequestN - - [x] Flow Control: Lease - - [x] Load Balance diff --git a/balancer/balancer.go b/balancer/balancer.go index ba20b6c..8c9b83b 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -2,6 +2,7 @@ package balancer import ( + "context" "io" "github.com/rsocket/rsocket-go" @@ -11,11 +12,11 @@ import ( type Balancer interface { io.Closer // Put puts a new client. - Put(client rsocket.Client) + Put(client rsocket.Client) error // PutLabel puts a new client with a label. - PutLabel(label string, client rsocket.Client) + PutLabel(label string, client rsocket.Client) error // Next returns next balanced RSocket client. - Next() rsocket.Client + Next(context.Context) (rsocket.Client, bool) // OnLeave handle events when a client exit. OnLeave(fn func(label string)) } diff --git a/balancer/group.go b/balancer/group.go index dcb5ef8..f447dbb 100644 --- a/balancer/group.go +++ b/balancer/group.go @@ -13,7 +13,8 @@ var errGroupClosed = errors.New("balancer group has been closed") // Group can be used to create a simple RSocket Broker. type Group struct { g func() Balancer - m *sync.Map + l sync.Mutex + m map[string]Balancer } // Close close current RSocket group. @@ -33,10 +34,12 @@ func (p *Group) Close() (err error) { } } }(all, done) - p.m.Range(func(key, value interface{}) bool { - all <- value.(Balancer) - return true - }) + + p.l.Lock() + defer p.l.Unlock() + for _, b := range p.m { + all <- b + } p.m = nil close(all) <-done @@ -45,24 +48,23 @@ func (p *Group) Close() (err error) { // Get returns a Balancer with custom id. func (p *Group) Get(id string) Balancer { + p.l.Lock() + defer p.l.Unlock() if p.m == nil { panic(errGroupClosed) } - if actual, ok := p.m.Load(id); ok { - return actual.(Balancer) + if actual, ok := p.m[id]; ok { + return actual } newborn := p.g() - actual, loaded := p.m.LoadOrStore(id, newborn) - if loaded { - _ = newborn.Close() - } - return actual.(Balancer) + p.m[id] = newborn + return newborn } // NewGroup returns a new Group. func NewGroup(gen func() Balancer) *Group { return &Group{ g: gen, - m: &sync.Map{}, + m: make(map[string]Balancer), } } diff --git a/balancer/group_example_test.go b/balancer/group_example_test.go new file mode 100644 index 0000000..01b4950 --- /dev/null +++ b/balancer/group_example_test.go @@ -0,0 +1,46 @@ +package balancer + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx/mono" +) + +func ExampleNewGroup() { + group := NewGroup(func() Balancer { + return NewRoundRobinBalancer() + }) + defer func() { + _ = group.Close() + }() + // Create a broker with resume. + err := rsocket.Receive(). + Resume(rsocket.WithServerResumeSessionDuration(10 * time.Second)). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { + // Register service using Setup Metadata as service ID. + if serviceID, ok := setup.MetadataUTF8(); ok { + _ = group.Get(serviceID).Put(sendingSocket) + } + // Proxy requests by group. + return rsocket.NewAbstractSocket(rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + requestServiceID, ok := msg.MetadataUTF8() + if !ok { + panic(errors.New("missing service ID in metadata")) + } + fmt.Println("[broker] redirect request to service", requestServiceID) + upstream, _ := group.Get(requestServiceID).Next(context.Background()) + fmt.Println("[broker] choose upstream:", upstream) + return upstream.RequestResponse(msg) + })), nil + }). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). + Serve(context.Background()) + if err != nil { + panic(err) + } +} diff --git a/balancer/group_test.go b/balancer/group_test.go index 9099900..6b45f98 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -1,103 +1,23 @@ package balancer_test import ( - "context" - "crypto/md5" - "errors" - "fmt" - "log" "testing" - "time" - . "github.com/rsocket/rsocket-go" - . "github.com/rsocket/rsocket-go/balancer" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx/mono" - "github.com/stretchr/testify/require" + "github.com/rsocket/rsocket-go/balancer" + "github.com/stretchr/testify/assert" ) -const uri = "tcp://127.0.0.1:7878" +var fakeGroupId = "fakeGroupId" -func ExampleNewGroup() { - group := NewGroup(func() Balancer { - return NewRoundRobinBalancer() +func TestGroup_Get(t *testing.T) { + called := 0 + g := balancer.NewGroup(func() balancer.Balancer { + called++ + return balancer.NewRoundRobinBalancer() }) - defer func() { - _ = group.Close() - }() - // Create a broker with resume. - err := Receive(). - Resume(WithServerResumeSessionDuration(10 * time.Second)). - Acceptor(func(setup payload.SetupPayload, sendingSocket CloseableRSocket) (RSocket, error) { - // Register service using Setup Metadata as service ID. - if serviceID, ok := setup.MetadataUTF8(); ok { - group.Get(serviceID).Put(sendingSocket) - } - // Proxy requests by group. - return NewAbstractSocket(RequestResponse(func(msg payload.Payload) mono.Mono { - requestServiceID, ok := msg.MetadataUTF8() - if !ok { - panic(errors.New("missing service ID in metadata")) - } - log.Println("[broker] redirect request to service", requestServiceID) - return group.Get(requestServiceID).Next().RequestResponse(msg) - })), nil - }). - Transport(uri). - Serve(context.Background()) - if err != nil { - panic(err) + for range [2]struct{}{} { + b := g.Get(fakeGroupId) + assert.NotNil(t, b) + assert.Equal(t, 1, called) } } - -func TestServiceSubscribe(t *testing.T) { - // Init broker and service. - go ExampleNewGroup() - - // Waiting broker up by sleeping 200 ms. - time.Sleep(200 * time.Millisecond) - - // Deploy MD5 service. - go func() { - done := make(chan struct{}) - cli, err := Connect(). - OnClose(func(err error) { - close(done) - }). - SetupPayload(payload.NewString("This is a Service Publisher!", "md5")). - Acceptor(func(socket RSocket) RSocket { - return NewAbstractSocket(RequestResponse(func(msg payload.Payload) mono.Mono { - result := payload.NewString(fmt.Sprintf("%02x", md5.Sum(msg.Data())), "MD5 RESULT") - log.Println("[publisher] accept MD5 request:", msg.DataUTF8()) - return mono.Just(result) - })) - }). - Transport(uri). - Start(context.Background()) - if err != nil { - panic(err) - } - defer func() { - _ = cli.Close() - }() - <-done - }() - - // Create a client and request md5 service. - cli, err := Connect(). - SetupPayload(payload.NewString("This is a Subscriber", "")). - Transport(uri). - Start(context.Background()) - require.NoError(t, err, "create client failed") - defer func() { - _ = cli.Close() - time.Sleep(200 * time.Millisecond) - }() - _, err = cli.RequestResponse(payload.NewString("Hello World!", "md5")). - DoOnSuccess(func(elem payload.Payload) { - log.Println("[subscriber] receive MD5 response:", elem.DataUTF8()) - require.Equal(t, "ed076287532e86365e841e92bfc50d8c", elem.DataUTF8(), "bad md5") - }). - Block(context.Background()) - require.NoError(t, err, "request failed") -} diff --git a/balancer/round_robin.go b/balancer/round_robin.go index fd46193..8c9bf55 100644 --- a/balancer/round_robin.go +++ b/balancer/round_robin.go @@ -1,115 +1,121 @@ package balancer import ( + "context" + "runtime" "sync" "github.com/google/uuid" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/logger" + "go.uber.org/atomic" ) -type labelClient struct { - l string - c rsocket.Client -} +var errConflictSocket = errors.New("socket exists already") type balancerRoundRobin struct { - cond *sync.Cond - seq int - clients []*labelClient + seq *atomic.Uint32 + keys []string + sockets []rsocket.Client done chan struct{} once sync.Once onLeave []func(string) + c *common.Cond } -func (p *balancerRoundRobin) OnLeave(fn func(label string)) { +func (b *balancerRoundRobin) OnLeave(fn func(label string)) { if fn != nil { - p.onLeave = append(p.onLeave, fn) + b.onLeave = append(b.onLeave, fn) } } -func (p *balancerRoundRobin) Put(client rsocket.Client) { - label := uuid.New().String() - p.PutLabel(label, client) +func (b *balancerRoundRobin) Put(client rsocket.Client) error { + return b.PutLabel(uuid.New().String(), client) } -func (p *balancerRoundRobin) PutLabel(label string, client rsocket.Client) { - p.cond.L.Lock() - p.clients = append(p.clients, &labelClient{ - l: label, - c: client, - }) - client.OnClose(func(error) { - p.remove(client) - }) - if len(p.clients) == 1 { - p.cond.Broadcast() +func (b *balancerRoundRobin) PutLabel(label string, client rsocket.Client) error { + b.c.L.Lock() + defer b.c.L.Unlock() + for _, k := range b.keys { + if k == label { + return errConflictSocket + } + } + b.keys = append(b.keys, label) + b.sockets = append(b.sockets, client) + if n := len(b.sockets); n == 1 { + b.c.Broadcast() } - p.cond.L.Unlock() + client.OnClose(func(err error) { + b.remove(client) + }) + return nil } -func (p *balancerRoundRobin) Next() (c rsocket.Client) { - p.cond.L.Lock() - for len(p.clients) < 1 { - select { - case <-p.done: - goto L - default: - p.cond.Wait() +func (b *balancerRoundRobin) Next(ctx context.Context) (client rsocket.Client, ok bool) { + b.c.L.Lock() + for { + n := len(b.keys) + if n > 0 { + idx := int(b.seq.Inc() % uint32(n)) + client = b.sockets[idx] + ok = true + break + } + if b.c.Wait(ctx) { + break } + b.c.L.Unlock() + runtime.Gosched() + b.c.L.Lock() } - c = p.choose() -L: - p.cond.L.Unlock() - return -} - -func (p *balancerRoundRobin) choose() (cli rsocket.Client) { - p.seq = (p.seq + 1) % len(p.clients) - cli = p.clients[p.seq].c + b.c.L.Unlock() return } -func (p *balancerRoundRobin) Close() (err error) { - p.once.Do(func() { - if len(p.clients) < 1 { +func (b *balancerRoundRobin) Close() (err error) { + b.once.Do(func() { + if len(b.sockets) < 1 { return } - clone := append([]*labelClient(nil), p.clients...) - close(p.done) + clone := append([]rsocket.Client(nil), b.sockets...) + close(b.done) wg := &sync.WaitGroup{} wg.Add(len(clone)) - for _, value := range clone { + for i := 0; i < len(clone); i++ { go func(c rsocket.Client, wg *sync.WaitGroup) { defer wg.Done() if err := c.Close(); err != nil { logger.Warnf("close client failed: %s\n", err) } - }(value.c, wg) + }(clone[i], wg) } wg.Wait() }) return } -func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok bool) { - p.cond.L.Lock() +func (b *balancerRoundRobin) remove(client rsocket.Client) (label string, ok bool) { + b.c.L.Lock() j := -1 - for i, l := 0, len(p.clients); i < l; i++ { - if p.clients[i].c == client { + for i, l := 0, len(b.sockets); i < l; i++ { + if b.sockets[i] == client { j = i break } } ok = j > -1 if ok { - label = p.clients[j].l - p.clients = append(p.clients[:j], p.clients[j+1:]...) + label = b.keys[j] + b.keys = append(b.keys[:j], b.keys[j+1:]...) + b.sockets = append(b.sockets[:j], b.sockets[j+1:]...) } - p.cond.L.Unlock() - if ok && len(p.onLeave) > 0 { + b.c.L.Unlock() + if ok && len(b.onLeave) > 0 { go func(label string) { - for _, fn := range p.onLeave { + for _, fn := range b.onLeave { fn(label) } }(label) @@ -120,8 +126,8 @@ func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok boo // NewRoundRobinBalancer returns a new Round-Robin Balancer. func NewRoundRobinBalancer() Balancer { return &balancerRoundRobin{ - cond: sync.NewCond(&sync.Mutex{}), - seq: -1, + seq: atomic.NewUint32(0), done: make(chan struct{}), + c: common.NewCond(&sync.Mutex{}), } } diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index 32c9832..f4a2d03 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -9,52 +9,144 @@ import ( "time" "github.com/jjeffcaii/reactor-go/scheduler" - . "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go" . "github.com/rsocket/rsocket-go/balancer" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/mono" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func startServer(ctx context.Context, port int, counter *sync.Map) { + _ = rsocket.Receive(). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { + return rsocket.NewAbstractSocket( + rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + cur, _ := counter.LoadOrStore(port, atomic.NewInt32(0)) + cur.(*atomic.Int32).Inc() + return mono.Just(msg) + }), + ), nil + }). + Transport(rsocket.TcpServer().SetHostAndPort("127.0.0.1", port).Build()). + Serve(ctx) +} + func TestRoundRobin(t *testing.T) { + ports := [3]int{7000, 7001, 7002} + counter := &sync.Map{} + ctx0, cancel0 := context.WithCancel(context.Background()) + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx0, ports[0]) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx1, ports[1]) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx2, ports[2]) + + time.Sleep(1 * time.Second) + b := NewRoundRobinBalancer() b.OnLeave(func(label string) { log.Println("client leave:", label) }) - defer func() { - _ = b.Close() - }() - - const x, y = 3, 1000 - wg := &sync.WaitGroup{} - wg.Add(x * y) - for i := 0; i < x; i++ { - go func(n int) { - for j := 0; j < y; j++ { - b.Next().RequestResponse(payload.NewString(fmt.Sprintf("GO_%04d_%04d", n, j), "go")). - DoOnSuccess(func(elem payload.Payload) { - m, _ := elem.MetadataUTF8() - log.Println("elem:", elem.DataUTF8(), m) - }). - DoFinally(func(st rx.SignalType) { - wg.Done() - }). - SubscribeOn(scheduler.Elastic()). - Subscribe(context.Background()) - time.Sleep(1 * time.Second) - } - }(i) + defer b.Close() + + for i := 0; i < len(ports); i++ { + client, err := rsocket.Connect(). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", ports[i]).Build()). + Start(context.Background()) + assert.NoError(t, err) + _ = b.PutLabel(fmt.Sprintf("test-client-%d", ports[i]), client) } - for _, port := range []int{17878, 8000, 8001, 8002} { - go func(uri string) { - c, err := Connect(). - SetupPayload(payload.NewString(uri, "hello")). - Transport(uri). - Start(context.Background()) - if err == nil { - b.PutLabel(uri, c) - } - }(fmt.Sprintf("tcp://127.0.0.1:%d", port)) + + req := payload.NewString("foo", "bar") + + const n = 3 + wg := sync.WaitGroup{} + wg.Add(n * len(ports)) + for i := 0; i < n*len(ports); i++ { + c, ok := b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + c.RequestResponse(req). + DoFinally(func(s rx.SignalType) { + wg.Done() + }). + DoOnError(func(e error) { + assert.Fail(t, "should never run here") + }). + SubscribeOn(scheduler.Parallel()). + Subscribe(context.Background()) } wg.Wait() + + counter.Range(func(key, value interface{}) bool { + v := int(value.(*atomic.Int32).Load()) + assert.Equal(t, n, v) + return true + }) + + ac0, ok := counter.Load(ports[0]) + assert.True(t, ok) + + amount0 := ac0.(*atomic.Int32).Load() + + // shutdown server 1 + cancel0() + time.Sleep(100 * time.Millisecond) + + // then send a request + c, ok := b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err := c.RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + + var total int + counter.Range(func(key, value interface{}) bool { + total += int(value.(*atomic.Int32).Load()) + return true + }) + assert.Equal(t, n*len(ports)+1, total) + assert.Equal(t, int32(0), ac0.(*atomic.Int32).Load()-amount0) + + c, ok = b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err = c.RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + total++ + + // shutdown server 2 + cancel1() + time.Sleep(100 * time.Millisecond) + + const extra = 10 + + for i := 0; i < extra; i++ { + c, ok = b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err = c.RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + } + total += 10 + + requestedActual := 0 + counter.Range(func(key, value interface{}) bool { + requestedActual += int(value.(*atomic.Int32).Load()) + return true + }) + assert.Equal(t, total, requestedActual) + + cancel2() + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, ok = b.Next(ctx) + assert.False(t, ok) } diff --git a/client.go b/client.go index b9bb5ec..d08866d 100644 --- a/client.go +++ b/client.go @@ -2,14 +2,14 @@ package rsocket import ( "context" - "crypto/tls" "time" "github.com/google/uuid" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/internal/fragmentation" "github.com/rsocket/rsocket-go/internal/socket" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/payload" ) @@ -21,113 +21,78 @@ var ( type ( // ClientResumeOptions represents resume options for client. ClientResumeOptions func(opts *resumeOpts) - - // Client is Client Side of a RSocket socket. Sends Frames to a RSocket Server. - Client interface { - CloseableRSocket - } - - // ClientSocketAcceptor is alias for RSocket handler function. - ClientSocketAcceptor = func(socket RSocket) RSocket - - // ClientStarter can be used to start a client. - ClientStarter interface { - // Start start a client socket. - Start(ctx context.Context) (Client, error) - // Start start a client socket with TLS. - // Here's an example: - // tc := &tls.Config { - // InsecureSkipVerify: true, - // } - StartTLS(ctx context.Context, tc *tls.Config) (Client, error) - } - - // ClientBuilder can be used to build a RSocket client. - ClientBuilder interface { - ClientTransportBuilder - // Fragment set fragmentation size which default is 16_777_215(16MB). - Fragment(mtu int) ClientBuilder - // KeepAlive defines current client keepalive settings. - KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder - // Resume enable the functionality of resume. - Resume(opts ...ClientResumeOptions) ClientBuilder - // Lease enable the functionality of lease. - Lease() ClientBuilder - // DataMimeType is used to set payload data MIME type. - // Default MIME type is `application/binary`. - DataMimeType(mime string) ClientBuilder - // MetadataMimeType is used to set payload metadata MIME type. - // Default MIME type is `application/binary`. - MetadataMimeType(mime string) ClientBuilder - // SetupPayload set the setup payload. - SetupPayload(setup payload.Payload) ClientBuilder - // OnClose register handler when client socket closed. - OnClose(fn func(error)) ClientBuilder - // Acceptor set acceptor for RSocket client. - Acceptor(acceptor ClientSocketAcceptor) ClientTransportBuilder - } - - // ClientTransportBuilder is used to build a RSocket client with custom Transport string. - ClientTransportBuilder interface { - // Transport set Transport for current RSocket client. - // URI is used to create RSocket Transport: - // Example: - // "tcp://127.0.0.1:7878" means a TCP RSocket transport. - // "ws://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport. - // "wss://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport with HTTPS. - Transport(uri string, opts ...TransportOpts) ClientStarter - } - - setupClientSocket interface { - Client - Setup(ctx context.Context, setup *socket.SetupInfo) error - } ) -// Connect create a new RSocket client builder with default settings. -func Connect() ClientBuilder { - return &implClientBuilder{ - fragment: fragmentation.MaxFragment, - setup: &socket.SetupInfo{ - Version: common.DefaultVersion, - KeepaliveInterval: common.DefaultKeepaliveInterval, - KeepaliveLifetime: common.DefaultKeepaliveMaxLifetime, - DataMimeType: _defaultMimeType, - MetadataMimeType: _defaultMimeType, - }, - } -} - -type transportOpts struct { - addr string - headers map[string][]string -} - -// WithWebsocketHeaders attach headers for websocket transport. -func WithWebsocketHeaders(headers map[string][]string) TransportOpts { - return func(opts *transportOpts) { - opts.headers = headers - } -} - -// TransportOpts represents options of transport. -type TransportOpts = func(*transportOpts) - -type implClientBuilder struct { +// Client is Client Side of a RSocket socket. Sends Frames to a RSocket Server. +type Client interface { + CloseableRSocket +} + +// ClientSocketAcceptor is alias for RSocket handler function. +type ClientSocketAcceptor = func(socket RSocket) RSocket + +// ClientStarter can be used to start a client. +type ClientStarter interface { + // Start start a client socket. + Start(ctx context.Context) (Client, error) +} + +// ClientBuilder can be used to build a RSocket client. +type ClientBuilder interface { + ToClientStarter + // Fragment set fragmentation size which default is 16_777_215(16MB). + // Also zero mtu means using default fragmentation size. + Fragment(mtu int) ClientBuilder + // KeepAlive defines current client keepalive settings. + KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder + // Resume enable the functionality of resume. + Resume(opts ...ClientResumeOptions) ClientBuilder + // Lease enable the functionality of lease. + Lease() ClientBuilder + // DataMimeType is used to set payload data MIME type. + // Default MIME type is `application/binary`. + DataMimeType(mime string) ClientBuilder + // MetadataMimeType is used to set payload metadata MIME type. + // Default MIME type is `application/binary`. + MetadataMimeType(mime string) ClientBuilder + // SetupPayload set the setup payload. + SetupPayload(setup payload.Payload) ClientBuilder + // OnClose register handler when client socket closed. + OnClose(fn func(error)) ClientBuilder + // Acceptor set acceptor for RSocket client. + Acceptor(acceptor ClientSocketAcceptor) ToClientStarter +} + +type ToClientStarter interface { + // Transport set generator func for current RSocket client. + // Example: + // "tcp://127.0.0.1:7878" means a TCP RSocket transport. + // "ws://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport. + // "wss://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport with HTTPS. + Transport(transport.ClientTransportFunc) ClientStarter +} + +// ToClientStarter is used to build a RSocket client with custom Transport string. +type setupClientSocket interface { + Client + Setup(ctx context.Context, setup *socket.SetupInfo) error +} + +type clientBuilder struct { resume *resumeOpts fragment int - tpOpts *transportOpts + tpGen transport.ClientTransportFunc setup *socket.SetupInfo acceptor ClientSocketAcceptor onCloses []func(error) } -func (p *implClientBuilder) Lease() ClientBuilder { +func (p *clientBuilder) Lease() ClientBuilder { p.setup.Lease = true return p } -func (p *implClientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { +func (p *clientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { if p.resume == nil { p.resume = newResumeOpts() } @@ -137,33 +102,37 @@ func (p *implClientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { return p } -func (p *implClientBuilder) Fragment(mtu int) ClientBuilder { - p.fragment = mtu +func (p *clientBuilder) Fragment(mtu int) ClientBuilder { + if mtu == 0 { + p.fragment = fragmentation.MaxFragment + } else { + p.fragment = mtu + } return p } -func (p *implClientBuilder) OnClose(fn func(error)) ClientBuilder { +func (p *clientBuilder) OnClose(fn func(error)) ClientBuilder { p.onCloses = append(p.onCloses, fn) return p } -func (p *implClientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { +func (p *clientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { p.setup.KeepaliveInterval = tickPeriod p.setup.KeepaliveLifetime = time.Duration(missedAcks) * ackTimeout return p } -func (p *implClientBuilder) DataMimeType(mime string) ClientBuilder { +func (p *clientBuilder) DataMimeType(mime string) ClientBuilder { p.setup.DataMimeType = []byte(mime) return p } -func (p *implClientBuilder) MetadataMimeType(mime string) ClientBuilder { +func (p *clientBuilder) MetadataMimeType(mime string) ClientBuilder { p.setup.MetadataMimeType = []byte(mime) return p } -func (p *implClientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { +func (p *clientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { p.setup.Data = nil p.setup.Metadata = nil @@ -178,57 +147,34 @@ func (p *implClientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { return p } -func (p *implClientBuilder) Acceptor(acceptor ClientSocketAcceptor) ClientTransportBuilder { +func (p *clientBuilder) Acceptor(acceptor ClientSocketAcceptor) ToClientStarter { p.acceptor = acceptor return p } -func (p *implClientBuilder) Transport(transport string, opts ...TransportOpts) ClientStarter { - p.tpOpts = &transportOpts{ - addr: transport, - } - for i := 0; i < len(opts); i++ { - opts[i](p.tpOpts) - } +func (p *clientBuilder) Transport(t transport.ClientTransportFunc) ClientStarter { + p.tpGen = t return p } -func (p *implClientBuilder) StartTLS(ctx context.Context, tc *tls.Config) (Client, error) { - return p.start(ctx, tc) -} - -func (p *implClientBuilder) Start(ctx context.Context) (client Client, err error) { - return p.start(ctx, nil) -} - -func (p *implClientBuilder) start(ctx context.Context, tc *tls.Config) (client Client, err error) { - var uri *transport.URI - uri, err = transport.ParseURI(p.tpOpts.addr) - if err != nil { - return - } - +func (p *clientBuilder) Start(ctx context.Context) (client Client, err error) { // create a blank socket. err = fragmentation.IsValidFragment(p.fragment) if err != nil { return nil, err } - sk := socket.NewClientDuplexRSocket( + sk := socket.NewClientDuplexConnection( p.fragment, p.setup.KeepaliveInterval, ) - var headers map[string][]string - if uri.IsWebsocket() { - headers = p.tpOpts.headers - } // create a client. var cs setupClientSocket if p.resume != nil { p.setup.Token = p.resume.tokenGen() - cs = socket.NewClientResume(uri, sk, tc, headers) + cs = socket.NewResumableClientSocket(p.tpGen, sk) } else { - cs = socket.NewClient(uri, sk, tc, headers) + cs = socket.NewClient(p.tpGen, sk) } if p.acceptor != nil { sk.SetResponder(p.acceptor(cs)) @@ -272,3 +218,17 @@ func WithClientResumeToken(gen func() []byte) ClientResumeOptions { opts.tokenGen = gen } } + +// Connect create a new RSocket client builder with default settings. +func Connect() ClientBuilder { + return &clientBuilder{ + fragment: fragmentation.MaxFragment, + setup: &socket.SetupInfo{ + Version: core.DefaultVersion, + KeepaliveInterval: common.DefaultKeepaliveInterval, + KeepaliveLifetime: common.DefaultKeepaliveMaxLifetime, + DataMimeType: _defaultMimeType, + MetadataMimeType: _defaultMimeType, + }, + } +} diff --git a/cmd/rsocket-cli/rsocket-cli.go b/cmd/rsocket-cli/rsocket-cli.go index e5fade4..f4fb39c 100644 --- a/cmd/rsocket-cli/rsocket-cli.go +++ b/cmd/rsocket-cli/rsocket-cli.go @@ -11,16 +11,27 @@ import ( "github.com/urfave/cli/v2" ) +type fmtLogger struct { +} + +func (f fmtLogger) Debugf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Infof(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Warnf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Errorf(format string, args ...interface{}) { + _, _ = os.Stderr.WriteString(fmt.Sprintf(format, args...)) +} + func init() { - logger.DisablePrefix() - fn := func(s string, i ...interface{}) { - fmt.Printf(s, i...) - } - logger.SetFunc(logger.LevelDebug, fn) - logger.SetFunc(logger.LevelInfo, fn) - logger.SetFunc(logger.LevelError, func(s string, i ...interface{}) { - _, _ = os.Stderr.WriteString(fmt.Sprintf(s, i...)) - }) + logger.SetLogger(fmtLogger{}) } func main() { @@ -30,7 +41,7 @@ func main() { app.UsageText = "rsocket-cli [global options] [URI]" app.Name = "rsocket-cli" app.Usage = "CLI for RSocket." - app.Version = "v0.5" + app.Version = "v0.6" app.Flags = newFlags(conf) app.ArgsUsage = "[URI]" app.Action = func(c *cli.Context) (err error) { diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index eb73e16..ab1d276 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -4,14 +4,18 @@ import ( "bufio" "context" "encoding/json" - "errors" "fmt" "io/ioutil" + "net/http" + "net/url" "os" + "strconv" "strings" "time" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -43,17 +47,17 @@ type Runner struct { N int Resume bool URI string - wsHeaders map[string][]string + wsHeaders http.Header } -func (p *Runner) preflight() (err error) { - if p.Debug { +func (r *Runner) preflight() (err error) { + if r.Debug { logger.SetLevel(logger.LevelDebug) } - headers := p.Headers.Value() + headers := r.Headers.Value() - if len(headers) > 0 && len(p.Metadata) > 0 { + if len(headers) > 0 && len(r.Metadata) > 0 { return errConflictHeadersAndMetadata } if len(headers) > 0 { @@ -68,55 +72,57 @@ func (p *Runner) preflight() (err error) { headers[strings.TrimSpace(k)] = strings.TrimSpace(v) } bs, _ := json.Marshal(headers) - p.Metadata = string(bs) + r.Metadata = string(bs) } - - tpHeaders := p.TransportHeaders.Value() - if len(tpHeaders) > 0 { - headers := make(map[string][]string) - for _, it := range tpHeaders { + if v := r.TransportHeaders.Value(); len(v) > 0 { + r.wsHeaders = make(http.Header) + for _, it := range v { idx := strings.Index(it, ":") if idx < 0 { return fmt.Errorf("invalid transport header: %s", it) } k := strings.TrimSpace(it[:idx]) v := strings.TrimSpace(it[idx+1:]) - headers[k] = append(headers[k], v) + r.wsHeaders.Add(k, v) } - p.wsHeaders = headers } return } -func (p *Runner) Run() error { - if err := p.preflight(); err != nil { +func (r *Runner) Run() error { + if err := r.preflight(); err != nil { return err } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if p.ServerMode { - return p.runServerMode(ctx) + if r.ServerMode { + return r.runServerMode(ctx) } - return p.runClientMode(ctx) + return r.runClientMode(ctx) } -func (p *Runner) runClientMode(ctx context.Context) (err error) { +func (r *Runner) runClientMode(ctx context.Context) (err error) { cb := rsocket.Connect() - if p.Resume { + if r.Resume { cb = cb.Resume() } - setupData, err := p.readData(p.Setup) + setupData, err := r.readData(r.Setup) if err != nil { return } setupPayload := payload.New(setupData, nil) - sendingPayloads := p.createPayload() + sendingPayloads := r.createPayload() + + tp, err := r.newClientTransport() + if err != nil { + return + } c, err := cb. - DataMimeType(p.DataFormat). - MetadataMimeType(p.MetadataFormat). + DataMimeType(r.DataFormat). + MetadataMimeType(r.MetadataFormat). SetupPayload(setupPayload). - Transport(p.URI, rsocket.WithWebsocketHeaders(p.wsHeaders)). + Transport(tp). Start(ctx) if err != nil { return @@ -125,30 +131,30 @@ func (p *Runner) runClientMode(ctx context.Context) (err error) { _ = c.Close() }() - for i := 0; i < p.Ops; i++ { + for i := 0; i < r.Ops; i++ { if i > 0 { logger.Infof("\n") } var first payload.Payload - if !p.Channel { + if !r.Channel { first, err = sendingPayloads.BlockFirst(ctx) if err != nil { return } } - if p.Request { - err = p.execRequestResponse(ctx, c, first) - } else if p.FNF { - err = p.execFireAndForget(ctx, c, first) - } else if p.Stream { - err = p.execRequestStream(ctx, c, first) - } else if p.Channel { - err = p.execRequestChannel(ctx, c, sendingPayloads) - } else if p.MetadataPush { - err = p.execMetadataPush(ctx, c, first) + if r.Request { + err = r.execRequestResponse(ctx, c, first) + } else if r.FNF { + err = r.execFireAndForget(ctx, c, first) + } else if r.Stream { + err = r.execRequestStream(ctx, c, first) + } else if r.Channel { + err = r.execRequestChannel(ctx, c, sendingPayloads) + } else if r.MetadataPush { + err = r.execMetadataPush(ctx, c, first) } else { - err = p.execRequestResponse(ctx, c, first) + err = r.execRequestResponse(ctx, c, first) } if err != nil { break @@ -157,31 +163,38 @@ func (p *Runner) runClientMode(ctx context.Context) (err error) { return } -func (p *Runner) runServerMode(ctx context.Context) error { +func (r *Runner) runServerMode(ctx context.Context) error { var sb rsocket.ServerBuilder - if p.Resume { + if r.Resume { sb = rsocket.Receive().Resume() } else { sb = rsocket.Receive() } ch := make(chan error) + + tp, err := r.newServerTransport() + if err != nil { + return err + } + go func() { - sendingPayloads := p.createPayload() + sendingPayloads := r.createPayload() ch <- sb. Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { var options []rsocket.OptAbstractSocket options = append(options, rsocket.RequestStream(func(message payload.Payload) flux.Flux { - p.showPayload(message) + r.showPayload(message) return sendingPayloads })) options = append(options, rsocket.RequestChannel(func(messages rx.Publisher) flux.Flux { - messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) { - p.showPayload(input) + messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) error { + r.showPayload(input) + return nil })) return sendingPayloads })) options = append(options, rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - p.showPayload(msg) + r.showPayload(msg) return mono.Create(func(i context.Context, sink mono.Sink) { first, err := sendingPayloads.BlockFirst(i) if err != nil { @@ -192,7 +205,7 @@ func (p *Runner) runServerMode(ctx context.Context) error { }) })) options = append(options, rsocket.FireAndForget(func(msg payload.Payload) { - p.showPayload(msg) + r.showPayload(msg) })) options = append(options, rsocket.MetadataPush(func(msg payload.Payload) { metadata, _ := msg.MetadataUTF8() @@ -200,92 +213,93 @@ func (p *Runner) runServerMode(ctx context.Context) error { })) return rsocket.NewAbstractSocket(options...), nil }). - Transport(p.URI). + Transport(tp). Serve(ctx) close(ch) }() return <-ch } -func (p *Runner) execMetadataPush(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execMetadataPush(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { c.MetadataPush(send) m, _ := send.MetadataUTF8() logger.Infof("%s\n", m) return } -func (p *Runner) execFireAndForget(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execFireAndForget(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { c.FireAndForget(send) return } -func (p *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send payload.Payload) (err error) { res, err := c.RequestResponse(send).Block(ctx) if err != nil { return } - p.showPayload(res) + r.showPayload(res) return } -func (p *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error { +func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error { var f flux.Flux - if p.N < rx.RequestMax { - f = c.RequestChannel(send).Take(p.N) + if r.N < rx.RequestMax { + f = c.RequestChannel(send).Take(r.N) } else { f = c.RequestChannel(send) } - return p.printFlux(ctx, f) + return r.printFlux(ctx, f) } -func (p *Runner) execRequestStream(ctx context.Context, c rsocket.Client, send payload.Payload) error { +func (r *Runner) execRequestStream(ctx context.Context, c rsocket.Client, send payload.Payload) error { var f flux.Flux - if p.N < rx.RequestMax { - f = c.RequestStream(send).Take(p.N) + if r.N < rx.RequestMax { + f = c.RequestStream(send).Take(r.N) } else { f = c.RequestStream(send) } - return p.printFlux(ctx, f) + return r.printFlux(ctx, f) } -func (p *Runner) printFlux(ctx context.Context, f flux.Flux) (err error) { +func (r *Runner) printFlux(ctx context.Context, f flux.Flux) (err error) { _, err = f. - DoOnNext(func(input payload.Payload) { - p.showPayload(input) + DoOnNext(func(input payload.Payload) error { + r.showPayload(input) + return nil }). BlockLast(ctx) return } -func (p *Runner) showPayload(pa payload.Payload) { +func (r *Runner) showPayload(pa payload.Payload) { logger.Infof("%s\n", pa.DataUTF8()) } -func (p *Runner) createPayload() flux.Flux { +func (r *Runner) createPayload() flux.Flux { var md []byte - if strings.HasPrefix(p.Metadata, "@") { + if strings.HasPrefix(r.Metadata, "@") { var err error - md, err = ioutil.ReadFile(p.Metadata[1:]) + md, err = ioutil.ReadFile(r.Metadata[1:]) if err != nil { return flux.Error(err) } } else { - md = []byte(p.Metadata) + md = []byte(r.Metadata) } - if p.Input == "-" { + if r.Input == "-" { fmt.Println("Type commands to send to the server......") reader := bufio.NewReader(os.Stdin) text, _ := reader.ReadString('\n') return flux.Just(payload.New([]byte(strings.Trim(text, "\n")), md)) } - if !strings.HasPrefix(p.Input, "@") { - return flux.Just(payload.New([]byte(p.Input), md)) + if !strings.HasPrefix(r.Input, "@") { + return flux.Just(payload.New([]byte(r.Input), md)) } return flux.Create(func(ctx context.Context, s flux.Sink) { - f, err := os.Open(p.Input[1:]) + f, err := os.Open(r.Input[1:]) if err != nil { fmt.Println("error:", err) s.Error(err) @@ -309,7 +323,7 @@ func (p *Runner) createPayload() flux.Flux { }) } -func (p *Runner) readData(input string) (data []byte, err error) { +func (r *Runner) readData(input string) (data []byte, err error) { switch { case strings.HasPrefix(input, "@"): data, err = ioutil.ReadFile(input[1:]) @@ -318,3 +332,51 @@ func (p *Runner) readData(input string) (data []byte, err error) { } return } + +func (r *Runner) newClientTransport() (transport.ClientTransportFunc, error) { + u, err := url.Parse(r.URI) + if err != nil { + return nil, err + } + switch u.Scheme { + case "tcp": + port, err := strconv.Atoi(u.Port()) + if err != nil { + return nil, err + } + return rsocket.TcpClient().SetHostAndPort(u.Hostname(), port).Build(), nil + case "unix": + return rsocket.UnixClient().SetPath(u.Hostname()).Build(), nil + case "ws", "wss": + return rsocket.WebsocketClient().SetUrl(r.URI).SetHeader(r.wsHeaders).Build(), nil + default: + return nil, fmt.Errorf("invalid transport %s", u.Scheme) + } +} + +func (r *Runner) newServerTransport() (transport.ServerTransportFunc, error) { + u, err := url.Parse(r.URI) + if err != nil { + return nil, err + } + switch u.Scheme { + case "tcp": + port, err := strconv.Atoi(u.Port()) + if err != nil { + return nil, err + } + return rsocket.TcpServer().SetHostAndPort(u.Hostname(), port).Build(), nil + case "unix": + return rsocket.UnixServer().SetPath(u.Hostname()).Build(), nil + case "ws", "wss": + var addr string + if port := u.Port(); port != "" { + addr = fmt.Sprintf("%s:%s", u.Hostname(), port) + } else { + addr = fmt.Sprintf("%s:%d", u.Hostname(), rsocket.DefaultPort) + } + return rsocket.WebsocketServer().SetAddr(addr).SetPath(u.EscapedPath()).Build(), nil + default: + return nil, errors.Errorf("invalid transport %s", u.Scheme) + } +} diff --git a/internal/common/errors.go b/core/errors.go similarity index 97% rename from internal/common/errors.go rename to core/errors.go index e84487b..023435e 100644 --- a/internal/common/errors.go +++ b/core/errors.go @@ -1,21 +1,12 @@ -package common +package core import "errors" // ErrorCode is code for RSocket error. type ErrorCode uint32 -// CustomError provides a method of accessing code and data. -type CustomError interface { - error - // ErrorCode returns error code. - ErrorCode() ErrorCode - // ErrorData returns error data bytes. - ErrorData() []byte -} - -func (p ErrorCode) String() string { - switch p { +func (e ErrorCode) String() string { + switch e { case ErrorCodeInvalidSetup: return "INVALID_SETUP" case ErrorCodeUnsupportedSetup: @@ -64,6 +55,15 @@ const ( ErrorCodeInvalid ErrorCode = 0x00000204 ) +// CustomError provides a method of accessing code and data. +type CustomError interface { + error + // ErrorCode returns error code. + ErrorCode() ErrorCode + // ErrorData returns error data bytes. + ErrorData() []byte +} + // Error defines. var ( ErrFrameLengthExceed = errors.New("rsocket: frame length is greater than 24bits") diff --git a/core/errors_test.go b/core/errors_test.go new file mode 100644 index 0000000..dc8b11e --- /dev/null +++ b/core/errors_test.go @@ -0,0 +1,28 @@ +package core_test + +import ( + "math" + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestErrorCode_String(t *testing.T) { + all := []core.ErrorCode{ + core.ErrorCodeInvalidSetup, + core.ErrorCodeUnsupportedSetup, + core.ErrorCodeRejectedSetup, + core.ErrorCodeRejectedResume, + core.ErrorCodeConnectionError, + core.ErrorCodeConnectionClose, + core.ErrorCodeApplicationError, + core.ErrorCodeRejected, + core.ErrorCodeCanceled, + core.ErrorCodeInvalid, + } + for _, code := range all { + assert.NotEqual(t, "UNKNOWN", code.String()) + } + assert.Equal(t, "UNKNOWN", core.ErrorCode(math.MaxUint32).String()) +} diff --git a/core/framing/frame.go b/core/framing/frame.go new file mode 100644 index 0000000..32b7214 --- /dev/null +++ b/core/framing/frame.go @@ -0,0 +1,145 @@ +package framing + +import ( + "errors" + "fmt" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +var errIncompleteFrame = errors.New("incomplete frame") + +type tinyFrame struct { + header core.FrameHeader + done chan struct{} +} + +func (t *tinyFrame) Header() core.FrameHeader { + return t.header +} + +// Done can be invoked when a frame has been been processed. +func (t *tinyFrame) Done() (closed bool) { + defer func() { + if e := recover(); e != nil { + closed = true + } + }() + close(t.done) + return +} + +// DoneNotify notify when frame has been done. +func (t *tinyFrame) DoneNotify() <-chan struct{} { + return t.done +} + +// RawFrame is basic frame implementation. +type RawFrame struct { + *tinyFrame + body *common.ByteBuff +} + +// Body returns frame body. +func (f *RawFrame) Body() *common.ByteBuff { + return f.body +} + +// Len returns length of frame. +func (f *RawFrame) Len() int { + if f.body == nil { + return core.FrameHeaderLen + } + return core.FrameHeaderLen + f.body.Len() +} + +// WriteTo write frame to writer. +func (f *RawFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = f.header.WriteTo(w) + if err != nil { + return + } + n += wrote + if f.body != nil { + wrote, err = f.body.WriteTo(w) + if err != nil { + return + } + n += wrote + } + return +} + +func (f *RawFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { + raw := f.body.Bytes() + if offset > 0 { + raw = raw[offset:] + } + hasMetadata = f.header.Flag().Check(core.FlagMetadata) + if !hasMetadata { + return + } + if len(raw) < 3 { + n = -1 + } else { + n = common.NewUint24Bytes(raw).AsInt() + } + return +} + +func (f *RawFrame) trySliceMetadata(offset int) ([]byte, bool) { + n, ok := f.trySeekMetadataLen(offset) + if !ok || n < 0 { + return nil, false + } + return f.body.Bytes()[offset+3 : offset+3+n], true +} + +func (f *RawFrame) trySliceData(offset int) []byte { + n, ok := f.trySeekMetadataLen(offset) + if !ok { + return f.body.Bytes()[offset:] + } + if n < 0 { + return nil + } + return f.body.Bytes()[offset+n+3:] +} + +func newTinyFrame(header core.FrameHeader) *tinyFrame { + return &tinyFrame{ + header: header, + done: make(chan struct{}), + } +} + +// NewRawFrame returns a new RawFrame. +func NewRawFrame(header core.FrameHeader, body *common.ByteBuff) *RawFrame { + return &RawFrame{ + tinyFrame: newTinyFrame(header), + body: body, + } +} + +// FromBytes creates frame from a byte slice. +func FromBytes(b []byte) (core.Frame, error) { + if len(b) < core.FrameHeaderLen { + return nil, errIncompleteFrame + } + header := core.ParseFrameHeader(b[:core.FrameHeaderLen]) + bb := common.NewByteBuff() + _, err := bb.Write(b[core.FrameHeaderLen:]) + if err != nil { + return nil, err + } + raw := NewRawFrame(header, bb) + return FromRawFrame(raw) +} + +func PrintFrame(f core.WriteableFrame) string { + // TODO: print frame + return fmt.Sprintf("%+v", f) +} diff --git a/core/framing/frame_cancel.go b/core/framing/frame_cancel.go new file mode 100644 index 0000000..cbef6fa --- /dev/null +++ b/core/framing/frame_cancel.go @@ -0,0 +1,53 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" +) + +// CancelFrame is frame of cancel. +type CancelFrame struct { + *RawFrame +} + +type WriteableCancelFrame struct { + *tinyFrame +} + +func (c WriteableCancelFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = c.header.WriteTo(w) + if err != nil { + return + } + n += wrote + return +} + +func (c WriteableCancelFrame) Len() int { + return core.FrameHeaderLen +} + +// Validate returns error if frame is invalid. +func (f *CancelFrame) Validate() (err error) { + // Cancel frame doesn't need any binary body. + if f.body != nil && f.body.Len() > 0 { + err = errIncompleteFrame + } + return +} + +func NewWriteableCancelFrame(id uint32) *WriteableCancelFrame { + h := core.NewFrameHeader(id, core.FrameTypeCancel, 0) + return &WriteableCancelFrame{ + tinyFrame: newTinyFrame(h), + } +} + +// NewCancelFrame creates cancel frame. +func NewCancelFrame(sid uint32) *CancelFrame { + return &CancelFrame{ + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeCancel, 0), nil), + } +} diff --git a/core/framing/frame_error.go b/core/framing/frame_error.go new file mode 100644 index 0000000..c5f8ae5 --- /dev/null +++ b/core/framing/frame_error.go @@ -0,0 +1,115 @@ +package framing + +import ( + "encoding/binary" + "io" + "strings" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + errCodeLen = 4 + errDataOff = errCodeLen + minErrorFrameLen = errCodeLen +) + +// ErrorFrame is error frame. +type ErrorFrame struct { + *RawFrame +} + +type WriteableErrorFrame struct { + *tinyFrame + code core.ErrorCode + data []byte +} + +func (e WriteableErrorFrame) Error() string { + return makeErrorString(e.code, e.data) +} + +func (e WriteableErrorFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = e.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + err = binary.Write(w, binary.BigEndian, uint32(e.code)) + if err != nil { + return + } + n += 4 + + l, err := w.Write(e.data) + if err != nil { + return + } + n += int64(l) + return +} + +func (e WriteableErrorFrame) Len() int { + return core.FrameHeaderLen + 4 + len(e.data) +} + +// Validate returns error if frame is invalid. +func (p *ErrorFrame) Validate() (err error) { + if p.body.Len() < minErrorFrameLen { + err = errIncompleteFrame + } + return +} + +func (p *ErrorFrame) Error() string { + return makeErrorString(p.ErrorCode(), p.ErrorData()) +} + +// ErrorCode returns error code. +func (p *ErrorFrame) ErrorCode() core.ErrorCode { + v := binary.BigEndian.Uint32(p.body.Bytes()) + return core.ErrorCode(v) +} + +// ErrorData returns error data bytes. +func (p *ErrorFrame) ErrorData() []byte { + return p.body.Bytes()[errDataOff:] +} + +func NewWriteableErrorFrame(id uint32, code core.ErrorCode, data []byte) *WriteableErrorFrame { + h := core.NewFrameHeader(id, core.FrameTypeError, 0) + t := newTinyFrame(h) + return &WriteableErrorFrame{ + tinyFrame: t, + code: code, + data: data, + } +} + +// NewErrorFrame returns a new error frame. +func NewErrorFrame(streamID uint32, code core.ErrorCode, data []byte) *ErrorFrame { + bf := common.NewByteBuff() + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], uint32(code)) + if _, err := bf.Write(b4[:]); err != nil { + panic(err) + } + if _, err := bf.Write(data); err != nil { + panic(err) + } + return &ErrorFrame{ + NewRawFrame(core.NewFrameHeader(streamID, core.FrameTypeError, 0), bf), + } +} + +func makeErrorString(code core.ErrorCode, data []byte) string { + bu := strings.Builder{} + bu.WriteString(code.String()) + bu.WriteByte(':') + bu.WriteByte(' ') + bu.Write(data) + return bu.String() +} diff --git a/core/framing/frame_fnf.go b/core/framing/frame_fnf.go new file mode 100644 index 0000000..f432cd8 --- /dev/null +++ b/core/framing/frame_fnf.go @@ -0,0 +1,104 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// FireAndForgetFrame is fire and forget frame. +type FireAndForgetFrame struct { + *RawFrame +} + +type WriteableFireAndForgetFrame struct { + *tinyFrame + metadata []byte + data []byte +} + +func (f WriteableFireAndForgetFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = f.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + wrote, err = writePayload(w, f.data, f.metadata) + if err != nil { + return + } + n += wrote + return +} + +func (f WriteableFireAndForgetFrame) Len() int { + return CalcPayloadFrameSize(f.data, f.metadata) +} + +// Validate returns error if frame is invalid. +func (f *FireAndForgetFrame) Validate() (err error) { + if f.header.Flag().Check(core.FlagMetadata) && f.body.Len() < 3 { + err = errIncompleteFrame + } + return +} + +// Metadata returns metadata bytes. +func (f *FireAndForgetFrame) Metadata() ([]byte, bool) { + return f.trySliceMetadata(0) +} + +// Data returns data bytes. +func (f *FireAndForgetFrame) Data() []byte { + return f.trySliceData(0) +} + +// MetadataUTF8 returns metadata as UTF8 string. +func (f *FireAndForgetFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := f.Metadata() + if ok { + metadata = string(raw) + } + return +} + +// DataUTF8 returns data as UTF8 string. +func (f *FireAndForgetFrame) DataUTF8() string { + return string(f.Data()) +} + +func NewWriteableFireAndForgetFrame(sid uint32, data, metadata []byte, flag core.FrameFlag) *WriteableFireAndForgetFrame { + if len(metadata) > 0 { + flag |= core.FlagMetadata + } + h := core.NewFrameHeader(sid, core.FrameTypeRequestFNF, flag) + t := newTinyFrame(h) + return &WriteableFireAndForgetFrame{ + tinyFrame: t, + metadata: metadata, + data: data, + } +} + +// NewFireAndForgetFrame returns a new fire and forget frame. +func NewFireAndForgetFrame(sid uint32, data, metadata []byte, flag core.FrameFlag) *FireAndForgetFrame { + bf := common.NewByteBuff() + if len(metadata) > 0 { + flag |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if _, err := bf.Write(data); err != nil { + panic(err) + } + return &FireAndForgetFrame{ + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestFNF, flag), bf), + } +} diff --git a/core/framing/frame_keepalive.go b/core/framing/frame_keepalive.go new file mode 100644 index 0000000..c53ca3d --- /dev/null +++ b/core/framing/frame_keepalive.go @@ -0,0 +1,112 @@ +package framing + +import ( + "encoding/binary" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + lastRecvPosLen = 8 + minKeepaliveFrameLen = lastRecvPosLen +) + +// KeepaliveFrame is keepalive frame. +type KeepaliveFrame struct { + *RawFrame +} + +type WriteableKeepaliveFrame struct { + *tinyFrame + pos [8]byte + data []byte +} + +func (k WriteableKeepaliveFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = k.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(k.pos[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(k.data) + if err != nil { + return + } + n += int64(v) + + return +} + +func (k WriteableKeepaliveFrame) Len() int { + return core.FrameHeaderLen + 8 + len(k.data) +} + +// Validate returns error if frame is invalid. +func (k *KeepaliveFrame) Validate() (err error) { + if k.body.Len() < minKeepaliveFrameLen { + err = errIncompleteFrame + } + return +} + +// LastReceivedPosition returns last received position. +func (k *KeepaliveFrame) LastReceivedPosition() uint64 { + return binary.BigEndian.Uint64(k.body.Bytes()) +} + +// Data returns data bytes. +func (k *KeepaliveFrame) Data() []byte { + return k.body.Bytes()[lastRecvPosLen:] +} + +func NewWriteableKeepaliveFrame(position uint64, data []byte, respond bool) *WriteableKeepaliveFrame { + var flag core.FrameFlag + if respond { + flag |= core.FlagRespond + } + + var b [8]byte + binary.BigEndian.PutUint64(b[:], position) + + h := core.NewFrameHeader(0, core.FrameTypeKeepalive, flag) + t := newTinyFrame(h) + + return &WriteableKeepaliveFrame{ + tinyFrame: t, + pos: b, + data: data, + } +} + +// NewKeepaliveFrame returns a new keepalive frame. +func NewKeepaliveFrame(position uint64, data []byte, respond bool) *KeepaliveFrame { + var fg core.FrameFlag + if respond { + fg |= core.FlagRespond + } + bf := common.NewByteBuff() + var b8 [8]byte + binary.BigEndian.PutUint64(b8[:], position) + if _, err := bf.Write(b8[:]); err != nil { + panic(err) + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &KeepaliveFrame{ + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeKeepalive, fg), bf), + } +} diff --git a/core/framing/frame_lease.go b/core/framing/frame_lease.go new file mode 100644 index 0000000..a9aec95 --- /dev/null +++ b/core/framing/frame_lease.go @@ -0,0 +1,133 @@ +package framing + +import ( + "encoding/binary" + "io" + "time" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + ttlLen = 4 + reqOff = ttlLen + reqLen = 4 + minLeaseFrame = ttlLen + reqLen +) + +// LeaseFrame is lease frame. +type LeaseFrame struct { + *RawFrame +} + +type WriteableLeaseFrame struct { + *tinyFrame + ttl [4]byte + n [4]byte + metadata []byte +} + +// Validate returns error if frame is invalid. +func (l *LeaseFrame) Validate() (err error) { + if l.body.Len() < minLeaseFrame { + err = errIncompleteFrame + } + return +} + +// TimeToLive returns time to live duration. +func (l *LeaseFrame) TimeToLive() time.Duration { + v := binary.BigEndian.Uint32(l.body.Bytes()) + return time.Millisecond * time.Duration(v) +} + +// NumberOfRequests returns number of requests. +func (l *LeaseFrame) NumberOfRequests() uint32 { + return binary.BigEndian.Uint32(l.body.Bytes()[reqOff:]) +} + +// Metadata returns metadata bytes. +func (l *LeaseFrame) Metadata() []byte { + if !l.header.Flag().Check(core.FlagMetadata) { + return nil + } + return l.body.Bytes()[8:] +} + +func (l WriteableLeaseFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = l.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(l.ttl[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(l.n[:]) + if err != nil { + return + } + n += int64(v) + + if l.header.Flag().Check(core.FlagMetadata) { + v, err = w.Write(l.metadata) + if err != nil { + return + } + n += int64(v) + } + + return +} + +func (l WriteableLeaseFrame) Len() int { + n := core.FrameHeaderLen + 8 + if l.header.Flag().Check(core.FlagMetadata) { + n += len(l.metadata) + } + return n +} + +func NewWriteableLeaseFrame(ttl time.Duration, n uint32, metadata []byte) *WriteableLeaseFrame { + var a, b [4]byte + binary.BigEndian.PutUint32(a[:], uint32(ttl.Milliseconds())) + binary.BigEndian.PutUint32(b[:], n) + + var flag core.FrameFlag + if len(metadata) > 0 { + flag |= core.FlagMetadata + } + h := core.NewFrameHeader(0, core.FrameTypeLease, flag) + t := newTinyFrame(h) + return &WriteableLeaseFrame{ + tinyFrame: t, + ttl: a, + n: b, + metadata: metadata, + } +} + +func NewLeaseFrame(ttl time.Duration, n uint32, metadata []byte) *LeaseFrame { + bf := common.NewByteBuff() + if err := binary.Write(bf, binary.BigEndian, uint32(ttl.Milliseconds())); err != nil { + panic(err) + } + if err := binary.Write(bf, binary.BigEndian, n); err != nil { + panic(err) + } + var fg core.FrameFlag + if len(metadata) > 0 { + fg |= core.FlagMetadata + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + return &LeaseFrame{NewRawFrame(core.NewFrameHeader(0, core.FrameTypeLease, fg), bf)} +} diff --git a/core/framing/frame_metadata_push.go b/core/framing/frame_metadata_push.go new file mode 100644 index 0000000..f83f05d --- /dev/null +++ b/core/framing/frame_metadata_push.go @@ -0,0 +1,88 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +var _metadataPushHeader = core.NewFrameHeader(0, core.FrameTypeMetadataPush, core.FlagMetadata) + +// MetadataPushFrame is metadata push frame. +type MetadataPushFrame struct { + *RawFrame +} +type WriteableMetadataPushFrame struct { + *tinyFrame + metadata []byte +} + +// Validate returns error if frame is invalid. +func (m *MetadataPushFrame) Validate() (err error) { + return +} + +// Metadata returns metadata bytes. +func (m *MetadataPushFrame) Metadata() ([]byte, bool) { + return m.body.Bytes(), true +} + +// Data returns data bytes. +func (m *MetadataPushFrame) Data() []byte { + return nil +} + +// MetadataUTF8 returns metadata as UTF8 string. +func (m *MetadataPushFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := m.Metadata() + if ok { + metadata = string(raw) + } + return +} + +func (m WriteableMetadataPushFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = m.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(m.metadata) + if err != nil { + return + } + n += int64(v) + return +} + +func (m WriteableMetadataPushFrame) Len() int { + return core.FrameHeaderLen + len(m.metadata) +} + +// DataUTF8 returns data as UTF8 string. +func (m *MetadataPushFrame) DataUTF8() (data string) { + return +} + +func NewWriteableMetadataPushFrame(metadata []byte) *WriteableMetadataPushFrame { + t := newTinyFrame(_metadataPushHeader) + return &WriteableMetadataPushFrame{ + tinyFrame: t, + metadata: metadata, + } +} + +// NewMetadataPushFrame returns a new metadata push frame. +func NewMetadataPushFrame(metadata []byte) *MetadataPushFrame { + bf := common.NewByteBuff() + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + return &MetadataPushFrame{ + NewRawFrame(_metadataPushHeader, bf), + } +} diff --git a/core/framing/frame_payload.go b/core/framing/frame_payload.go new file mode 100644 index 0000000..2eb83cd --- /dev/null +++ b/core/framing/frame_payload.go @@ -0,0 +1,131 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// PayloadFrame is payload frame. +type PayloadFrame struct { + *RawFrame +} + +type WriteablePayloadFrame struct { + *tinyFrame + metadata []byte + data []byte +} + +// Validate returns error if frame is invalid. +func (p *PayloadFrame) Validate() (err error) { + // Minimal length should be 3 if metadata exists. + if p.header.Flag().Check(core.FlagMetadata) && p.body.Len() < 3 { + err = errIncompleteFrame + } + return +} + +// Metadata returns metadata bytes. +func (p *PayloadFrame) Metadata() ([]byte, bool) { + return p.trySliceMetadata(0) +} + +// Data returns data bytes. +func (p *PayloadFrame) Data() []byte { + return p.trySliceData(0) +} + +func (p *PayloadFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := p.Metadata() + if ok { + metadata = string(raw) + } + return +} + +func (p *PayloadFrame) DataUTF8() string { + return string(p.Data()) +} + +func (p WriteablePayloadFrame) Data() []byte { + return p.data +} + +func (p WriteablePayloadFrame) Metadata() (metadata []byte, ok bool) { + ok = p.header.Flag().Check(core.FlagMetadata) + if ok { + metadata = p.metadata + } + return +} + +func (p WriteablePayloadFrame) DataUTF8() (data string) { + if p.data != nil { + data = string(p.data) + } + return +} + +func (p WriteablePayloadFrame) MetadataUTF8() (metadata string, ok bool) { + ok = p.header.Flag().Check(core.FlagMetadata) + if ok { + metadata = string(p.metadata) + } + return +} + +func (p WriteablePayloadFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = p.header.WriteTo(w) + if err != nil { + return + } + n += wrote + wrote, err = writePayload(w, p.data, p.metadata) + if err == nil { + n += wrote + } + return +} + +func (p WriteablePayloadFrame) Len() int { + return CalcPayloadFrameSize(p.data, p.metadata) +} + +// NewWriteablePayloadFrame returns a new payload frame. +func NewWriteablePayloadFrame(id uint32, data, metadata []byte, flag core.FrameFlag) *WriteablePayloadFrame { + if len(metadata) > 0 { + flag |= core.FlagMetadata + } + h := core.NewFrameHeader(id, core.FrameTypePayload, flag) + t := newTinyFrame(h) + return &WriteablePayloadFrame{ + tinyFrame: t, + metadata: metadata, + data: data, + } +} + +// NewPayloadFrame returns a new payload frame. +func NewPayloadFrame(id uint32, data, metadata []byte, flag core.FrameFlag) *PayloadFrame { + bf := common.NewByteBuff() + if len(metadata) > 0 { + flag |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &PayloadFrame{ + NewRawFrame(core.NewFrameHeader(id, core.FrameTypePayload, flag), bf), + } +} diff --git a/core/framing/frame_request_channel.go b/core/framing/frame_request_channel.go new file mode 100644 index 0000000..f344ee7 --- /dev/null +++ b/core/framing/frame_request_channel.go @@ -0,0 +1,138 @@ +package framing + +import ( + "encoding/binary" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + initReqLen = 4 + minRequestChannelFrameLen = initReqLen +) + +// RequestChannelFrame is frame for RequestChannel. +type RequestChannelFrame struct { + *RawFrame +} + +type WriteableRequestChannelFrame struct { + *tinyFrame + n [4]byte + metadata []byte + data []byte +} + +// Validate returns error if frame is invalid. +func (r *RequestChannelFrame) Validate() error { + l := r.body.Len() + if l < minRequestChannelFrameLen { + return errIncompleteFrame + } + if r.header.Flag().Check(core.FlagMetadata) && l < minRequestChannelFrameLen+3 { + return errIncompleteFrame + } + return nil +} + +// InitialRequestN returns initial N. +func (r *RequestChannelFrame) InitialRequestN() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) +} + +// Metadata returns metadata bytes. +func (r *RequestChannelFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(initReqLen) +} + +// Data returns data bytes. +func (r *RequestChannelFrame) Data() []byte { + return r.trySliceData(initReqLen) +} + +// MetadataUTF8 returns metadata as UTF8 string. +func (r *RequestChannelFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() + if ok { + metadata = string(raw) + } + return +} + +// DataUTF8 returns data as UTF8 string. +func (r *RequestChannelFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r WriteableRequestChannelFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(r.n[:]) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, r.data, r.metadata) + if err != nil { + return + } + n += wrote + + return +} + +func (r WriteableRequestChannelFrame) Len() int { + return CalcPayloadFrameSize(r.data, r.metadata) + 4 +} + +func NewWriteableRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *WriteableRequestChannelFrame { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + if len(metadata) > 0 { + flag |= core.FlagMetadata + } + h := core.NewFrameHeader(sid, core.FrameTypeRequestChannel, flag) + t := newTinyFrame(h) + return &WriteableRequestChannelFrame{ + tinyFrame: t, + n: b, + metadata: metadata, + data: data, + } +} + +// NewRequestChannelFrame returns a new RequestChannel frame. +func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestChannelFrame { + bf := common.NewByteBuff() + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], n) + if _, err := bf.Write(b4[:]); err != nil { + panic(err) + } + if len(metadata) > 0 { + flag |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &RequestChannelFrame{ + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestChannel, flag), bf), + } +} diff --git a/core/framing/frame_request_n.go b/core/framing/frame_request_n.go new file mode 100644 index 0000000..6a54f96 --- /dev/null +++ b/core/framing/frame_request_n.go @@ -0,0 +1,72 @@ +package framing + +import ( + "encoding/binary" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// RequestNFrame is RequestN frame. +type RequestNFrame struct { + *RawFrame +} + +type WriteableRequestNFrame struct { + *tinyFrame + n [4]byte +} + +// Validate returns error if frame is invalid. +func (r *RequestNFrame) Validate() (err error) { + if r.body.Len() != 4 { + err = errIncompleteFrame + } + return +} + +// N returns N in RequestN. +func (r *RequestNFrame) N() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) +} + +func (r WriteableRequestNFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + v, err := w.Write(r.n[:]) + if err == nil { + n += int64(v) + } + return +} + +func (r WriteableRequestNFrame) Len() int { + return core.FrameHeaderLen + 4 +} + +func NewWriteableRequestNFrame(id uint32, n uint32, fg core.FrameFlag) *WriteableRequestNFrame { + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], n) + return &WriteableRequestNFrame{ + tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestN, fg)), + n: b4, + } +} + +// NewRequestNFrame returns a new RequestN frame. +func NewRequestNFrame(sid, n uint32, fg core.FrameFlag) *RequestNFrame { + bf := common.NewByteBuff() + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], n) + if _, err := bf.Write(b4[:]); err != nil { + panic(err) + } + return &RequestNFrame{ + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestN, fg), bf), + } +} diff --git a/core/framing/frame_request_response.go b/core/framing/frame_request_response.go new file mode 100644 index 0000000..2e74e40 --- /dev/null +++ b/core/framing/frame_request_response.go @@ -0,0 +1,103 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// RequestResponseFrame is frame for requesting single response. +type RequestResponseFrame struct { + *RawFrame +} + +type WriteableRequestResponseFrame struct { + *tinyFrame + metadata []byte + data []byte +} + +// Validate returns error if frame is invalid. +func (r *RequestResponseFrame) Validate() (err error) { + if r.header.Flag().Check(core.FlagMetadata) && r.body.Len() < 3 { + err = errIncompleteFrame + } + return +} + +// Metadata returns metadata bytes. +func (r *RequestResponseFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(0) +} + +// Data returns data bytes. +func (r *RequestResponseFrame) Data() []byte { + return r.trySliceData(0) +} + +// MetadataUTF8 returns metadata as UTF8 string. +func (r *RequestResponseFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() + if ok { + metadata = string(raw) + } + return +} + +// DataUTF8 returns data as UTF8 string. +func (r *RequestResponseFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r WriteableRequestResponseFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + wrote, err = writePayload(w, r.data, r.metadata) + if err == nil { + n += wrote + } + return +} + +func (r WriteableRequestResponseFrame) Len() int { + return CalcPayloadFrameSize(r.data, r.metadata) +} + +// NewWriteableRequestResponseFrame returns a new RequestResponse frame support. +func NewWriteableRequestResponseFrame(id uint32, data, metadata []byte, fg core.FrameFlag) core.WriteableFrame { + if len(metadata) > 0 { + fg |= core.FlagMetadata + } + return &WriteableRequestResponseFrame{ + tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestResponse, fg)), + metadata: metadata, + data: data, + } +} + +// NewRequestResponseFrame returns a new RequestResponse frame. +func NewRequestResponseFrame(id uint32, data, metadata []byte, fg core.FrameFlag) *RequestResponseFrame { + bf := common.NewByteBuff() + if len(metadata) > 0 { + fg |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &RequestResponseFrame{ + NewRawFrame(core.NewFrameHeader(id, core.FrameTypeRequestResponse, fg), bf), + } +} diff --git a/core/framing/frame_request_stream.go b/core/framing/frame_request_stream.go new file mode 100644 index 0000000..1221869 --- /dev/null +++ b/core/framing/frame_request_stream.go @@ -0,0 +1,134 @@ +package framing + +import ( + "encoding/binary" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + minRequestStreamFrameLen = initReqLen +) + +// RequestStreamFrame is frame for requesting a completable stream. +type RequestStreamFrame struct { + *RawFrame +} + +type WriteableRequestStreamFrame struct { + *tinyFrame + n [4]byte + metadata []byte + data []byte +} + +// Validate returns error if frame is invalid. +func (r *RequestStreamFrame) Validate() error { + l := r.body.Len() + if l < minRequestStreamFrameLen { + return errIncompleteFrame + } + if r.header.Flag().Check(core.FlagMetadata) && l < minRequestStreamFrameLen+3 { + return errIncompleteFrame + } + return nil +} + +// InitialRequestN returns initial request N. +func (r *RequestStreamFrame) InitialRequestN() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) +} + +// Metadata returns metadata bytes. +func (r *RequestStreamFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(4) +} + +// Data returns data bytes. +func (r *RequestStreamFrame) Data() []byte { + return r.trySliceData(4) +} + +// MetadataUTF8 returns metadata as UTF8 string. +func (r *RequestStreamFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() + if ok { + metadata = string(raw) + } + return +} + +// DataUTF8 returns data as UTF8 string. +func (r *RequestStreamFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r WriteableRequestStreamFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(r.n[:]) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, r.data, r.metadata) + if err != nil { + return + } + n += wrote + return +} + +func (r WriteableRequestStreamFrame) Len() int { + return 4 + CalcPayloadFrameSize(r.data, r.metadata) +} + +func NewWriteableRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) core.WriteableFrame { + if len(metadata) > 0 { + flag |= core.FlagMetadata + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + h := core.NewFrameHeader(id, core.FrameTypeRequestStream, flag) + t := newTinyFrame(h) + return &WriteableRequestStreamFrame{ + tinyFrame: t, + n: b, + metadata: metadata, + data: data, + } +} + +// NewRequestStreamFrame returns a new request stream frame. +func NewRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestStreamFrame { + bf := common.NewByteBuff() + if err := binary.Write(bf, binary.BigEndian, n); err != nil { + panic(err) + } + if len(metadata) > 0 { + flag |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &RequestStreamFrame{ + NewRawFrame(core.NewFrameHeader(id, core.FrameTypeRequestStream, flag), bf), + } +} diff --git a/core/framing/frame_resume.go b/core/framing/frame_resume.go new file mode 100644 index 0000000..67eb4ec --- /dev/null +++ b/core/framing/frame_resume.go @@ -0,0 +1,165 @@ +package framing + +import ( + "encoding/binary" + "errors" + "io" + "math" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +var errResumeTokenTooLarge = errors.New("max length of resume token is 65535") + +const ( + _lenVersion = 4 + _lenTokenLength = 2 + _lenLastRecvPos = 8 + _lenFirstPos = 8 + _minResumeLength = _lenVersion + _lenTokenLength + _lenLastRecvPos + _lenFirstPos +) + +// ResumeFrame represents a frame of Resume. +type ResumeFrame struct { + *RawFrame +} + +type WriteableResumeFrame struct { + *tinyFrame + version core.Version + token []byte + posFirst [8]byte + posLast [8]byte +} + +// Validate validate current frame. +func (r *ResumeFrame) Validate() (err error) { + if r.body.Len() < _minResumeLength { + err = errIncompleteFrame + } + return +} + +// Version returns version. +func (r *ResumeFrame) Version() core.Version { + raw := r.body.Bytes() + major := binary.BigEndian.Uint16(raw) + minor := binary.BigEndian.Uint16(raw[2:]) + return [2]uint16{major, minor} +} + +// Token returns resume token in bytes. +func (r *ResumeFrame) Token() []byte { + raw := r.body.Bytes() + tokenLen := binary.BigEndian.Uint16(raw[4:6]) + return raw[6 : 6+tokenLen] +} + +// LastReceivedServerPosition returns last received server position. +func (r *ResumeFrame) LastReceivedServerPosition() uint64 { + raw := r.body.Bytes() + offset := 6 + binary.BigEndian.Uint16(raw[4:6]) + return binary.BigEndian.Uint64(raw[offset:]) +} + +// FirstAvailableClientPosition returns first available client position. +func (r *ResumeFrame) FirstAvailableClientPosition() uint64 { + raw := r.body.Bytes() + offset := 6 + binary.BigEndian.Uint16(raw[4:6]) + 8 + return binary.BigEndian.Uint64(raw[offset:]) +} + +func (r WriteableResumeFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + + v, err = w.Write(r.version.Bytes()) + if err != nil { + return + } + n += int64(v) + + lenToken := uint16(len(r.token)) + err = binary.Write(w, binary.BigEndian, lenToken) + if err != nil { + return + } + n += 2 + + v, err = w.Write(r.token) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(r.posLast[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(r.posFirst[:]) + if err != nil { + return + } + n += int64(v) + + return +} + +func (r WriteableResumeFrame) Len() int { + return core.FrameHeaderLen + _lenTokenLength + _lenFirstPos + _lenLastRecvPos + _lenVersion + len(r.token) +} + +// NewWriteableResumeFrame creates a new frame support of Resume. +func NewWriteableResumeFrame(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *WriteableResumeFrame { + h := core.NewFrameHeader(0, core.FrameTypeResume, 0) + t := newTinyFrame(h) + var a, b [8]byte + binary.BigEndian.PutUint64(a[:], firstAvailableClientPosition) + binary.BigEndian.PutUint64(b[:], lastReceivedServerPosition) + + return &WriteableResumeFrame{ + tinyFrame: t, + version: version, + token: token, + posFirst: a, + posLast: b, + } +} + +// NewResumeFrame creates a new frame of Resume. +func NewResumeFrame(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrame { + n := len(token) + if n > math.MaxUint16 { + panic(errResumeTokenTooLarge) + } + bf := common.NewByteBuff() + if _, err := bf.Write(version.Bytes()); err != nil { + panic(err) + } + if err := binary.Write(bf, binary.BigEndian, uint16(n)); err != nil { + panic(err) + } + if n > 0 { + if _, err := bf.Write(token); err != nil { + panic(err) + } + } + if err := binary.Write(bf, binary.BigEndian, lastReceivedServerPosition); err != nil { + panic(err) + } + if err := binary.Write(bf, binary.BigEndian, firstAvailableClientPosition); err != nil { + panic(err) + } + return &ResumeFrame{ + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeResume, 0), bf), + } +} diff --git a/core/framing/frame_resume_ok.go b/core/framing/frame_resume_ok.go new file mode 100644 index 0000000..21903f8 --- /dev/null +++ b/core/framing/frame_resume_ok.go @@ -0,0 +1,79 @@ +package framing + +import ( + "encoding/binary" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// ResumeOKFrame represents a frame of ResumeOK. +type ResumeOKFrame struct { + *RawFrame +} + +type WriteableResumeOKFrame struct { + *tinyFrame + pos [8]byte +} + +// Validate validate current frame. +func (r *ResumeOKFrame) Validate() (err error) { + // Length of frame body should be 8 + if r.body.Len() != 8 { + err = errIncompleteFrame + } + return +} + +// LastReceivedClientPosition returns last received client position. +func (r *ResumeOKFrame) LastReceivedClientPosition() uint64 { + raw := r.body.Bytes() + return binary.BigEndian.Uint64(raw) +} + +func (r WriteableResumeOKFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + var v int + v, err = w.Write(r.pos[:]) + if err != nil { + return + } + n += int64(v) + return +} + +func (r WriteableResumeOKFrame) Len() int { + return core.FrameHeaderLen + 8 +} + +func NewWriteableResumeOKFrame(position uint64) *WriteableResumeOKFrame { + h := core.NewFrameHeader(0, core.FrameTypeResumeOK, 0) + t := newTinyFrame(h) + var b [8]byte + binary.BigEndian.PutUint64(b[:], position) + return &WriteableResumeOKFrame{ + tinyFrame: t, + pos: b, + } +} + +// NewResumeOKFrame creates a new frame of ResumeOK. +func NewResumeOKFrame(position uint64) *ResumeOKFrame { + var b8 [8]byte + binary.BigEndian.PutUint64(b8[:], position) + bf := common.NewByteBuff() + _, err := bf.Write(b8[:]) + if err != nil { + panic(err) + } + return &ResumeOKFrame{ + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeResumeOK, 0), bf), + } +} diff --git a/core/framing/frame_setup.go b/core/framing/frame_setup.go new file mode 100644 index 0000000..8f84941 --- /dev/null +++ b/core/framing/frame_setup.go @@ -0,0 +1,331 @@ +package framing + +import ( + "encoding/binary" + "io" + "time" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +const ( + _versionLen = 4 + _timeLen = 4 + _metadataLen = 1 + _dataLen = 1 + _minSetupFrameLen = _versionLen + _timeLen*2 + _metadataLen + _dataLen +) + +// SetupFrame is sent by client to initiate protocol processing. +type SetupFrame struct { + *RawFrame +} + +type WriteableSetupFrame struct { + *tinyFrame + version core.Version + keepalive [4]byte + lifetime [4]byte + token []byte + mimeMetadata []byte + mimeData []byte + metadata []byte + data []byte +} + +// Validate returns error if frame is invalid. +func (p *SetupFrame) Validate() (err error) { + if p.Len() < _minSetupFrameLen { + err = errIncompleteFrame + } + return +} + +// Version returns version. +func (p *SetupFrame) Version() core.Version { + major := binary.BigEndian.Uint16(p.body.Bytes()) + minor := binary.BigEndian.Uint16(p.body.Bytes()[2:]) + return [2]uint16{major, minor} +} + +// TimeBetweenKeepalive returns keepalive interval duration. +func (p *SetupFrame) TimeBetweenKeepalive() time.Duration { + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[4:])) +} + +// MaxLifetime returns keepalive max lifetime. +func (p *SetupFrame) MaxLifetime() time.Duration { + return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[8:])) +} + +// Token returns token of setup. +func (p *SetupFrame) Token() []byte { + if !p.header.Flag().Check(core.FlagResume) { + return nil + } + raw := p.body.Bytes() + tokenLength := binary.BigEndian.Uint16(raw[12:]) + return raw[14 : 14+tokenLength] +} + +// DataMimeType returns MIME of data. +func (p *SetupFrame) DataMimeType() (mime string) { + _, b := p.mime() + return string(b) +} + +// MetadataMimeType returns MIME of metadata. +func (p *SetupFrame) MetadataMimeType() string { + a, _ := p.mime() + return string(a) +} + +// Metadata returns metadata bytes. +func (p *SetupFrame) Metadata() ([]byte, bool) { + if !p.header.Flag().Check(core.FlagMetadata) { + return nil, false + } + offset := p.seekMIME() + m1, m2 := p.mime() + offset += 2 + len(m1) + len(m2) + return p.trySliceMetadata(offset) +} + +// Data returns data bytes. +func (p *SetupFrame) Data() []byte { + offset := p.seekMIME() + m1, m2 := p.mime() + offset += 2 + len(m1) + len(m2) + if !p.header.Flag().Check(core.FlagMetadata) { + return p.Body().Bytes()[offset:] + } + return p.trySliceData(offset) +} + +// MetadataUTF8 returns metadata as UTF8 string +func (p *SetupFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := p.Metadata() + if ok { + metadata = string(raw) + } + return +} + +// DataUTF8 returns data as UTF8 string. +func (p *SetupFrame) DataUTF8() string { + return string(p.Data()) +} + +func (p *SetupFrame) mime() (metadata []byte, data []byte) { + offset := p.seekMIME() + raw := p.body.Bytes() + l1 := int(raw[offset]) + offset++ + m1 := raw[offset : offset+l1] + offset += l1 + l2 := int(raw[offset]) + offset++ + m2 := raw[offset : offset+l2] + return m1, m2 +} + +func (p *SetupFrame) seekMIME() int { + if !p.header.Flag().Check(core.FlagResume) { + return 12 + } + l := binary.BigEndian.Uint16(p.body.Bytes()[12:]) + return 14 + int(l) +} + +func (s WriteableSetupFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = s.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + wrote, err = s.version.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(s.keepalive[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(s.lifetime[:]) + if err != nil { + return + } + n += int64(v) + + if s.header.Flag().Check(core.FlagResume) { + tokenLen := len(s.token) + err = binary.Write(w, binary.BigEndian, uint16(tokenLen)) + if err != nil { + return + } + n += 2 + v, err = w.Write(s.token) + if err != nil { + return + } + n += int64(v) + } + + lenMimeMetadata := len(s.mimeMetadata) + v, err = w.Write([]byte{byte(lenMimeMetadata)}) + if err != nil { + return + } + n += int64(v) + v, err = w.Write(s.mimeMetadata) + if err != nil { + return + } + n += int64(v) + + lenMimeData := len(s.mimeData) + v, err = w.Write([]byte{byte(lenMimeData)}) + if err != nil { + return + } + n += int64(v) + v, err = w.Write(s.mimeData) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, s.data, s.metadata) + if err != nil { + return + } + n += wrote + return +} + +func (s WriteableSetupFrame) Len() int { + n := _minSetupFrameLen + CalcPayloadFrameSize(s.data, s.metadata) + n += len(s.mimeData) + len(s.mimeMetadata) + if l := len(s.token); l > 0 { + n += 2 + len(s.token) + } + return n +} + +func NewWriteableSetupFrame( + version core.Version, + timeBetweenKeepalive, + maxLifetime time.Duration, + token []byte, + mimeMetadata []byte, + mimeData []byte, + data []byte, + metadata []byte, + lease bool, +) *WriteableSetupFrame { + var flag core.FrameFlag + if l := len(token); l > 0 { + flag |= core.FlagResume + } + if lease { + flag |= core.FlagLease + } + if l := len(metadata); l > 0 { + flag |= core.FlagMetadata + } + h := core.NewFrameHeader(0, core.FrameTypeSetup, flag) + t := newTinyFrame(h) + + var a, b [4]byte + binary.BigEndian.PutUint32(a[:], uint32(timeBetweenKeepalive.Nanoseconds()/1e6)) + binary.BigEndian.PutUint32(b[:], uint32(maxLifetime.Nanoseconds()/1e6)) + return &WriteableSetupFrame{ + tinyFrame: t, + version: version, + keepalive: a, + lifetime: b, + token: token, + mimeMetadata: mimeMetadata, + mimeData: mimeData, + metadata: metadata, + data: data, + } +} + +// NewSetupFrame returns a new setup frame. +func NewSetupFrame( + version core.Version, + timeBetweenKeepalive, + maxLifetime time.Duration, + token []byte, + mimeMetadata []byte, + mimeData []byte, + data []byte, + metadata []byte, + lease bool, +) *SetupFrame { + var fg core.FrameFlag + bf := common.NewByteBuff() + if _, err := bf.Write(version.Bytes()); err != nil { + panic(err) + } + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], uint32(timeBetweenKeepalive.Nanoseconds()/1e6)) + if _, err := bf.Write(b4[:]); err != nil { + panic(err) + } + binary.BigEndian.PutUint32(b4[:], uint32(maxLifetime.Nanoseconds()/1e6)) + if _, err := bf.Write(b4[:]); err != nil { + panic(err) + } + if lease { + fg |= core.FlagLease + } + if len(token) > 0 { + fg |= core.FlagResume + binary.BigEndian.PutUint16(b4[:2], uint16(len(token))) + if _, err := bf.Write(b4[:2]); err != nil { + panic(err) + } + if _, err := bf.Write(token); err != nil { + panic(err) + } + } + if err := bf.WriteByte(byte(len(mimeMetadata))); err != nil { + panic(err) + } + if _, err := bf.Write(mimeMetadata); err != nil { + panic(err) + } + if err := bf.WriteByte(byte(len(mimeData))); err != nil { + panic(err) + } + if _, err := bf.Write(mimeData); err != nil { + panic(err) + } + if len(metadata) > 0 { + fg |= core.FlagMetadata + if err := bf.WriteUint24(len(metadata)); err != nil { + panic(err) + } + if _, err := bf.Write(metadata); err != nil { + panic(err) + } + } + if len(data) > 0 { + if _, err := bf.Write(data); err != nil { + panic(err) + } + } + return &SetupFrame{ + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeSetup, fg), bf), + } +} diff --git a/core/framing/frame_test.go b/core/framing/frame_test.go new file mode 100644 index 0000000..a9320d2 --- /dev/null +++ b/core/framing/frame_test.go @@ -0,0 +1,284 @@ +package framing_test + +import ( + "bytes" + "math" + "testing" + "time" + + "github.com/rsocket/rsocket-go/core" + . "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +const _sid uint32 = 1 + +func TestFromBytes(t *testing.T) { + // empty + _, err := FromBytes([]byte{}) + assert.Error(t, err, "should be error") + + b := &bytes.Buffer{} + frame := NewWriteableRequestResponseFrame(42, []byte("fake-data"), []byte("fake-metadata"), 0) + _, _ = frame.WriteTo(b) + frameActual, err := FromBytes(b.Bytes()) + assert.NoError(t, err, "should not be error") + assert.Equal(t, frame.Header(), frameActual.Header(), "header does not match") + assert.Equal(t, frame.Len(), frameActual.Len()) +} + +func TestFrameCancel(t *testing.T) { + f := NewCancelFrame(_sid) + checkBasic(t, f, core.FrameTypeCancel) + f2 := NewWriteableCancelFrame(_sid) + checkBytes(t, f, f2) +} + +func TestFrameError(t *testing.T) { + errData := []byte(common.RandAlphanumeric(10)) + f := NewErrorFrame(_sid, core.ErrorCodeApplicationError, errData) + checkBasic(t, f, core.FrameTypeError) + assert.Equal(t, core.ErrorCodeApplicationError, f.ErrorCode()) + assert.Equal(t, errData, f.ErrorData()) + assert.NotEmpty(t, f.Error()) + f2 := NewWriteableErrorFrame(_sid, core.ErrorCodeApplicationError, errData) + checkBytes(t, f, f2) +} + +func TestFrameFNF(t *testing.T) { + b := []byte(common.RandAlphanumeric(100)) + // Without Metadata + f := NewFireAndForgetFrame(_sid, b, nil, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestFNF) + assert.Equal(t, b, f.Data()) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + metadata, ok := f.Metadata() + assert.False(t, ok) + assert.Nil(t, metadata) + assert.True(t, f.Header().Flag().Check(core.FlagNext)) + assert.False(t, f.Header().Flag().Check(core.FlagMetadata)) + f2 := NewWriteableFireAndForgetFrame(_sid, b, nil, core.FlagNext) + checkBytes(t, f, f2) + + // With Metadata + f = NewFireAndForgetFrame(_sid, nil, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestFNF) + assert.Empty(t, f.Data()) + metadata, ok = f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, metadata) + assert.True(t, f.Header().Flag().Check(core.FlagNext)) + assert.True(t, f.Header().Flag().Check(core.FlagMetadata)) + f2 = NewWriteableFireAndForgetFrame(_sid, nil, b, core.FlagNext) + checkBytes(t, f, f2) +} + +func TestFrameKeepalive(t *testing.T) { + pos := uint64(common.RandIntn(math.MaxInt32)) + d := []byte(common.RandAlphanumeric(100)) + f := NewKeepaliveFrame(pos, d, true) + checkBasic(t, f, core.FrameTypeKeepalive) + assert.Equal(t, d, f.Data()) + assert.Equal(t, pos, f.LastReceivedPosition()) + assert.True(t, f.Header().Flag().Check(core.FlagRespond)) + f2 := NewWriteableKeepaliveFrame(pos, d, true) + checkBytes(t, f, f2) +} + +func TestFrameLease(t *testing.T) { + metadata := []byte("foobar") + n := uint32(4444) + f := NewLeaseFrame(time.Second, n, metadata) + checkBasic(t, f, core.FrameTypeLease) + assert.Equal(t, time.Second, f.TimeToLive()) + assert.Equal(t, n, f.NumberOfRequests()) + assert.Equal(t, metadata, f.Metadata()) + f2 := NewWriteableLeaseFrame(time.Second, n, metadata) + checkBytes(t, f, f2) +} + +func TestFrameMetadataPush(t *testing.T) { + metadata := []byte("foobar") + f := NewMetadataPushFrame(metadata) + assert.Nil(t, f.Data(), "should not be nil") + assert.Equal(t, "", f.DataUTF8(), "should be zero string") + checkBasic(t, f, core.FrameTypeMetadataPush) + metadata2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, metadata, metadata2) + _, _ = f.MetadataUTF8() + + f2 := NewWriteableMetadataPushFrame(metadata) + checkBytes(t, f, f2) +} + +func TestPayloadFrame(t *testing.T) { + b := []byte("foobar") + f := NewPayloadFrame(_sid, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypePayload) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, f.Data()) + assert.Equal(t, b, m) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) + f2 := NewWriteablePayloadFrame(_sid, b, b, core.FlagNext) + checkBytes(t, f, f2) + + assert.Equal(t, b, f2.Data()) + _ = f2.DataUTF8() + m2, ok := f2.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m2) + _, _ = f2.MetadataUTF8() +} + +func TestFrameRequestChannel(t *testing.T) { + b := []byte("foobar") + n := uint32(1) + f := NewRequestChannelFrame(_sid, n, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestChannel) + assert.Equal(t, n, f.InitialRequestN()) + assert.Equal(t, b, f.Data()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) + + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + + f2 := NewWriteableRequestChannelFrame(_sid, n, b, b, core.FlagNext) + checkBytes(t, f, f2) +} + +func TestFrameRequestN(t *testing.T) { + n := uint32(1234) + f := NewRequestNFrame(_sid, n, 0) + checkBasic(t, f, core.FrameTypeRequestN) + assert.Equal(t, n, f.N()) + f2 := NewWriteableRequestNFrame(_sid, n, 0) + checkBytes(t, f, f2) +} + +func TestFrameRequestResponse(t *testing.T) { + b := []byte("foobar") + f := NewRequestResponseFrame(_sid, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestResponse) + assert.Equal(t, b, f.Data()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) + assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + f2 := NewWriteableRequestResponseFrame(_sid, b, b, core.FlagNext) + checkBytes(t, f, f2) +} + +func TestFrameRequestStream(t *testing.T) { + b := []byte("foobar") + n := uint32(1234) + f := NewRequestStreamFrame(_sid, n, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestStream) + assert.Equal(t, b, f.Data()) + assert.Equal(t, n, f.InitialRequestN()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) + _, _ = f.MetadataUTF8() + _ = f.DataUTF8() + f2 := NewWriteableRequestStreamFrame(_sid, n, b, b, core.FlagNext) + checkBytes(t, f, f2) +} + +func TestFrameResume(t *testing.T) { + v := core.NewVersion(3, 1) + token := []byte("hello") + p1 := uint64(333) + p2 := uint64(444) + f := NewResumeFrame(v, token, p1, p2) + checkBasic(t, f, core.FrameTypeResume) + assert.Equal(t, token, f.Token()) + assert.Equal(t, p1, f.FirstAvailableClientPosition()) + assert.Equal(t, p2, f.LastReceivedServerPosition()) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) + f2 := NewWriteableResumeFrame(v, token, p1, p2) + checkBytes(t, f, f2) +} + +func TestFrameResumeOK(t *testing.T) { + pos := uint64(1234) + f := NewResumeOKFrame(pos) + checkBasic(t, f, core.FrameTypeResumeOK) + assert.Equal(t, pos, f.LastReceivedClientPosition()) + f2 := NewWriteableResumeOKFrame(pos) + checkBytes(t, f, f2) +} + +func TestFrameSetup(t *testing.T) { + v := core.NewVersion(3, 1) + timeKeepalive := 20 * time.Second + maxLifetime := time.Minute + 30*time.Second + var token []byte + mimeData := []byte("application/binary") + mimeMetadata := []byte("application/binary") + d := []byte("你好") + m := []byte("世界") + f := NewSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + checkBasic(t, f, core.FrameTypeSetup) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) + assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) + assert.Equal(t, maxLifetime, f.MaxLifetime()) + assert.Equal(t, token, f.Token()) + assert.Equal(t, string(mimeData), f.DataMimeType()) + assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) + assert.Equal(t, d, f.Data()) + m2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, m, m2) + + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + + fs := NewWriteableSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + + checkBytes(t, f, fs) +} + +func checkBasic(t *testing.T, f core.Frame, typ core.FrameType) { + sid := _sid + switch typ { + case core.FrameTypeKeepalive, core.FrameTypeSetup, core.FrameTypeLease, core.FrameTypeResume, core.FrameTypeResumeOK, core.FrameTypeMetadataPush: + sid = 0 + } + assert.Equal(t, sid, f.Header().StreamID(), "wrong frame stream id") + assert.NoError(t, f.Validate(), "validate frame type failed") + assert.Equal(t, typ, f.Header().Type(), "frame type doesn't match") + assert.True(t, f.Header().Type().String() != "UNKNOWN") + go func() { + f.Done() + }() + <-f.DoneNotify() +} + +func checkBytes(t *testing.T, a core.Frame, b core.WriteableFrame) { + assert.NoError(t, a.Validate()) + assert.Equal(t, a.Len(), b.Len()) + bf1, bf2 := &bytes.Buffer{}, &bytes.Buffer{} + _, err := a.WriteTo(bf1) + assert.NoError(t, err, "write failed") + _, err = b.WriteTo(bf2) + assert.NoError(t, err, "write failed") + b1, b2 := bf1.Bytes(), bf2.Bytes() + assert.Equal(t, b1, b2, "bytes doesn't match") + bf := common.NewByteBuff() + _, _ = bf.Write(b1[core.FrameHeaderLen:]) + raw := NewRawFrame(core.ParseFrameHeader(b1[:core.FrameHeaderLen]), bf) + _, err = FromRawFrame(raw) + assert.NoError(t, err, "create from raw failed") +} diff --git a/core/framing/misc.go b/core/framing/misc.go new file mode 100644 index 0000000..81b930f --- /dev/null +++ b/core/framing/misc.go @@ -0,0 +1,83 @@ +package framing + +import ( + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +// CalcPayloadFrameSize returns payload frame size. +func CalcPayloadFrameSize(data, metadata []byte) int { + size := core.FrameHeaderLen + len(data) + if n := len(metadata); n > 0 { + size += 3 + n + } + return size +} + +// FromRawFrame creates a frame from a RawFrame. +func FromRawFrame(f *RawFrame) (frame core.Frame, err error) { + switch f.header.Type() { + case core.FrameTypeSetup: + frame = &SetupFrame{RawFrame: f} + case core.FrameTypeKeepalive: + frame = &KeepaliveFrame{RawFrame: f} + case core.FrameTypeRequestResponse: + frame = &RequestResponseFrame{RawFrame: f} + case core.FrameTypeRequestFNF: + frame = &FireAndForgetFrame{RawFrame: f} + case core.FrameTypeRequestStream: + frame = &RequestStreamFrame{RawFrame: f} + case core.FrameTypeRequestChannel: + frame = &RequestChannelFrame{RawFrame: f} + case core.FrameTypeCancel: + frame = &CancelFrame{RawFrame: f} + case core.FrameTypePayload: + frame = &PayloadFrame{RawFrame: f} + case core.FrameTypeMetadataPush: + frame = &MetadataPushFrame{RawFrame: f} + case core.FrameTypeError: + frame = &ErrorFrame{RawFrame: f} + case core.FrameTypeRequestN: + frame = &RequestNFrame{RawFrame: f} + case core.FrameTypeLease: + frame = &LeaseFrame{RawFrame: f} + case core.FrameTypeResume: + frame = &ResumeFrame{RawFrame: f} + case core.FrameTypeResumeOK: + frame = &ResumeOKFrame{RawFrame: f} + default: + err = core.ErrInvalidFrame + } + return +} + +func writePayload(w io.Writer, data []byte, metadata []byte) (n int64, err error) { + if l := len(metadata); l > 0 { + var wrote int64 + u := common.MustNewUint24(l) + wrote, err = u.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(metadata) + if err != nil { + return + } + n += int64(v) + } + + if l := len(data); l > 0 { + var v int + v, err = w.Write(data) + if err != nil { + return + } + n += int64(v) + } + return +} diff --git a/core/header.go b/core/header.go new file mode 100644 index 0000000..81713b1 --- /dev/null +++ b/core/header.go @@ -0,0 +1,86 @@ +package core + +import ( + "encoding/binary" + "io" + "strconv" + "strings" +) + +const ( + // FrameHeaderLen is len of header. + FrameHeaderLen = 6 +) + +// FrameHeader is the header fo a RSocket frame. +// RSocket frames begin with a RSocket Frame FrameHeader. +// It includes StreamID, FrameType and Flags. +type FrameHeader [FrameHeaderLen]byte + +func (h FrameHeader) String() string { + bu := strings.Builder{} + bu.WriteString("FrameHeader{id=") + bu.WriteString(strconv.FormatUint(uint64(h.StreamID()), 10)) + bu.WriteString(",type=") + bu.WriteString(h.Type().String()) + bu.WriteString(",flag=") + bu.WriteString(h.Flag().String()) + bu.WriteByte('}') + return bu.String() +} + +// Resumable returns true if frame supports resume. +func (h FrameHeader) Resumable() bool { + switch h.Type() { + case FrameTypeRequestChannel, FrameTypeRequestStream, FrameTypeRequestResponse, FrameTypeRequestFNF, FrameTypeRequestN, FrameTypeCancel, FrameTypeError, FrameTypePayload: + return true + default: + return false + } +} + +// WriteTo writes frame header to a writer. +func (h FrameHeader) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(h[:]) + return int64(n), err +} + +// StreamID returns StreamID. +func (h FrameHeader) StreamID() uint32 { + return binary.BigEndian.Uint32(h[:4]) +} + +// Type returns frame type. +func (h FrameHeader) Type() FrameType { + return FrameType((h.n() & 0xFC00) >> 10) +} + +// Flag returns flag of a frame. +func (h FrameHeader) Flag() FrameFlag { + return FrameFlag(h.n() & 0x03FF) +} + +func (h FrameHeader) Bytes() []byte { + return h[:] +} + +func (h FrameHeader) n() uint16 { + return binary.BigEndian.Uint16(h[4:]) +} + +// NewFrameHeader returns a new frame header. +func NewFrameHeader(streamID uint32, frameType FrameType, fg FrameFlag) FrameHeader { + var h [FrameHeaderLen]byte + binary.BigEndian.PutUint32(h[:], streamID) + binary.BigEndian.PutUint16(h[4:], uint16(frameType)<<10|uint16(fg)) + return h + +} + +// ParseFrameHeader parse a header from bytes. +func ParseFrameHeader(bs []byte) FrameHeader { + _ = bs[FrameHeaderLen-1] + var bb [FrameHeaderLen]byte + copy(bb[:], bs[:FrameHeaderLen]) + return bb +} diff --git a/core/header_test.go b/core/header_test.go new file mode 100644 index 0000000..bc438c7 --- /dev/null +++ b/core/header_test.go @@ -0,0 +1,29 @@ +package core_test + +import ( + "bytes" + "math" + "testing" + + . "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestHeader_All(t *testing.T) { + id := uint32(common.RandIntn(math.MaxUint32)) + h1 := NewFrameHeader(id, FrameTypePayload, FlagMetadata|FlagComplete|FlagNext) + assert.NotEmpty(t, h1.String(), "header string is blank") + assert.True(t, h1.Resumable()) + h2 := ParseFrameHeader(h1[:]) + assert.Equal(t, h1[:], h2.Bytes()) + assert.Equal(t, h1.StreamID(), h2.StreamID()) + assert.Equal(t, h1.Type(), h2.Type()) + assert.Equal(t, h1.Flag(), h2.Flag()) + assert.Equal(t, FrameTypePayload, h1.Type()) + assert.Equal(t, FlagMetadata|FlagComplete|FlagNext, h1.Flag()) + bf := &bytes.Buffer{} + n, err := h2.WriteTo(bf) + assert.NoError(t, err) + assert.Equal(t, int64(FrameHeaderLen), n) +} diff --git a/core/traffic_counter.go b/core/traffic_counter.go new file mode 100644 index 0000000..89e7791 --- /dev/null +++ b/core/traffic_counter.go @@ -0,0 +1,36 @@ +package core + +import ( + "go.uber.org/atomic" +) + +// TrafficCounter represents a counter of read/write bytes. +type TrafficCounter struct { + r, w *atomic.Uint64 +} + +// ReadBytes returns the number of bytes that have been read. +func (p TrafficCounter) ReadBytes() uint64 { + return p.r.Load() +} + +// WriteBytes returns the number of bytes that have been written. +func (p TrafficCounter) WriteBytes() uint64 { + return p.w.Load() +} + +func (p TrafficCounter) IncWriteBytes(n int) { + p.w.Add(uint64(n)) +} + +func (p TrafficCounter) IncReadBytes(n int) { + p.r.Add(uint64(n)) +} + +// NewTrafficCounter returns a new counter. +func NewTrafficCounter() *TrafficCounter { + return &TrafficCounter{ + r: atomic.NewUint64(0), + w: atomic.NewUint64(0), + } +} diff --git a/core/traffic_counter_test.go b/core/traffic_counter_test.go new file mode 100644 index 0000000..2cd27b8 --- /dev/null +++ b/core/traffic_counter_test.go @@ -0,0 +1,29 @@ +package core_test + +import ( + "sync" + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestTrafficCounter(t *testing.T) { + const cycle = 1000 + const amount = 1000 + wg := sync.WaitGroup{} + wg.Add(amount) + c := core.NewTrafficCounter() + for range [amount]struct{}{} { + go func() { + for range [cycle]struct{}{} { + c.IncWriteBytes(1) + c.IncReadBytes(1) + } + wg.Done() + }() + } + wg.Wait() + assert.Equal(t, uint64(cycle*amount), c.WriteBytes()) + assert.Equal(t, uint64(cycle*amount), c.ReadBytes()) +} diff --git a/core/transport/base_test.go b/core/transport/base_test.go new file mode 100644 index 0000000..3c1a24c --- /dev/null +++ b/core/transport/base_test.go @@ -0,0 +1,9 @@ +package transport_test + +import "github.com/pkg/errors" + +var ( + fakeErr = errors.New("fake-error") + fakeData = []byte("fake-data") + fakeMetadata = []byte("fake-metadata") +) diff --git a/internal/transport/decoder.go b/core/transport/decoder.go similarity index 86% rename from internal/transport/decoder.go rename to core/transport/decoder.go index 05d0655..8630e7c 100644 --- a/internal/transport/decoder.go +++ b/core/transport/decoder.go @@ -5,8 +5,8 @@ import ( "errors" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) const ( @@ -32,14 +32,14 @@ func (p *LengthBasedFrameDecoder) Read() (raw []byte, err error) { return } raw = scanner.Bytes()[lengthFieldSize:] - if len(raw) < framing.HeaderLen { + if len(raw) < core.FrameHeaderLen { err = ErrIncompleteHeader } return } -func doSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF { +func doSplit(data []byte, eof bool) (advance int, token []byte, err error) { + if eof { return } if len(data) < lengthFieldSize { @@ -47,7 +47,7 @@ func doSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { } frameLength := common.NewUint24Bytes(data).AsInt() if frameLength < 1 { - err = common.ErrInvalidFrameLength + err = core.ErrInvalidFrameLength return } frameSize := frameLength + lengthFieldSize diff --git a/core/transport/decoder_test.go b/core/transport/decoder_test.go new file mode 100644 index 0000000..8203b6d --- /dev/null +++ b/core/transport/decoder_test.go @@ -0,0 +1,90 @@ +package transport_test + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestLengthBasedFrameDecoder_ReadBroken(t *testing.T) { + b := &bytes.Buffer{} + + _, _ = common.MustNewUint24(5).WriteTo(b) + _, _ = b.Write([]byte{'_', 'f', 'a', 'k', 'e'}) + decoder := transport.NewLengthBasedFrameDecoder(b) + _, err := decoder.Read() + assert.Equal(t, transport.ErrIncompleteHeader, err, "should be incomplete header error") + + b.Reset() + _, _ = b.Write([]byte{0, 0, 0, 'f', 'a', 'k', 'e'}) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, core.ErrInvalidFrameLength, err, "should be invalid length error") + + b.Reset() + b.Write([]byte{0, 0}) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, io.EOF, err, "should read nothing") + + b.Reset() + _, _ = common.MustNewUint24(10).WriteTo(b) + var notEnough [7]byte + _, _ = b.Write(notEnough[:]) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, io.EOF, err, "should read nothing") +} + +func TestLengthBasedFrameDecoder_Read(t *testing.T) { + b := &bytes.Buffer{} + frames := []core.Frame{ + framing.NewSetupFrame( + core.DefaultVersion, + 30*time.Second, + 90*time.Second, + nil, + []byte("text/plain"), + []byte("text/plain"), + []byte("fake-data"), + []byte("fake-metadata"), + false, + ), + framing.NewKeepaliveFrame(0, []byte("fake-data"), true), + framing.NewRequestResponseFrame(1, []byte("fake-data"), []byte("fake-metadata"), 0), + } + + for _, it := range frames { + n, err := common.NewUint24(it.Len()) + assert.NoError(t, err, "convert to uint24 failed") + _, err = n.WriteTo(b) + assert.NoError(t, err, "write length failed") + _, err = it.WriteTo(b) + assert.NoError(t, err, "write frame failed") + } + + var results []core.Frame + + decoder := transport.NewLengthBasedFrameDecoder(b) + for { + next, err := decoder.Read() + if err == io.EOF { + break + } + assert.NoError(t, err, "decode next frame failed") + frame, err := framing.FromBytes(next) + assert.NoError(t, err, "read next frame failed") + results = append(results, frame) + } + + for i := 0; i < len(results); i++ { + assert.Equal(t, frames[i].Header(), results[i].Header()) + } +} diff --git a/internal/transport/misc.go b/core/transport/misc.go similarity index 99% rename from internal/transport/misc.go rename to core/transport/misc.go index 1c779aa..cd47eb9 100644 --- a/internal/transport/misc.go +++ b/core/transport/misc.go @@ -17,3 +17,4 @@ func isClosedErr(err error) bool { } return false } + diff --git a/core/transport/misc_test.go b/core/transport/misc_test.go new file mode 100644 index 0000000..b0a1991 --- /dev/null +++ b/core/transport/misc_test.go @@ -0,0 +1,16 @@ +package transport + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsClosedErr(t *testing.T) { + assert.False(t, isClosedErr(nil)) + assert.False(t, isClosedErr(errors.New("fake error"))) + assert.True(t, isClosedErr(http.ErrServerClosed)) + assert.True(t, isClosedErr(errors.New("use of closed network connection"))) +} diff --git a/core/transport/mock_conn_test.go b/core/transport/mock_conn_test.go new file mode 100644 index 0000000..14198f9 --- /dev/null +++ b/core/transport/mock_conn_test.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/types.go + +// Package transport is a generated GoMock package. +package transport_test + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + core "github.com/rsocket/rsocket-go/core" +) + +// MockConn is a mock of Conn interface +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder +} + +// MockConnMockRecorder is the mock recorder for MockConn +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// SetDeadline mocks base method +func (m *MockConn) SetDeadline(deadline time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", deadline) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), deadline) +} + +// SetCounter mocks base method +func (m *MockConn) SetCounter(c *core.TrafficCounter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCounter", c) +} + +// SetCounter indicates an expected call of SetCounter +func (mr *MockConnMockRecorder) SetCounter(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCounter", reflect.TypeOf((*MockConn)(nil).SetCounter), c) +} + +// Read mocks base method +func (m *MockConn) Read() (core.Frame, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read") + ret0, _ := ret[0].(core.Frame) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockConnMockRecorder) Read() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read)) +} + +// Write mocks base method +func (m *MockConn) Write(arg0 core.WriteableFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write +func (mr *MockConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), arg0) +} + +// Flush mocks base method +func (m *MockConn) Flush() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +// Flush indicates an expected call of Flush +func (mr *MockConnMockRecorder) Flush() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockConn)(nil).Flush)) +} diff --git a/core/transport/net_conn_mock_test.go b/core/transport/net_conn_mock_test.go new file mode 100644 index 0000000..b73a6f9 --- /dev/null +++ b/core/transport/net_conn_mock_test.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Conn) + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + net "net" + reflect "reflect" + time "time" +) + +// mockNetConn is a mock of Conn interface +type mockNetConn struct { + ctrl *gomock.Controller + recorder *mockNetConnMockRecorder +} + +// mockNetConnMockRecorder is the mock recorder for mockNetConn +type mockNetConnMockRecorder struct { + mock *mockNetConn +} + +// newMockNetConn creates a new mock instance +func newMockNetConn(ctrl *gomock.Controller) *mockNetConn { + mock := &mockNetConn{ctrl: ctrl} + mock.recorder = &mockNetConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockNetConn) EXPECT() *mockNetConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *mockNetConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockNetConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockNetConn)(nil).Close)) +} + +// LocalAddr mocks base method +func (m *mockNetConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *mockNetConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*mockNetConn)(nil).LocalAddr)) +} + +// Read mocks base method +func (m *mockNetConn) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *mockNetConnMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*mockNetConn)(nil).Read), arg0) +} + +// RemoteAddr mocks base method +func (m *mockNetConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr +func (mr *mockNetConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*mockNetConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method +func (m *mockNetConn) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *mockNetConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*mockNetConn)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method +func (m *mockNetConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *mockNetConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*mockNetConn)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method +func (m *mockNetConn) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *mockNetConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*mockNetConn)(nil).SetWriteDeadline), arg0) +} + +// Write mocks base method +func (m *mockNetConn) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *mockNetConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*mockNetConn)(nil).Write), arg0) +} diff --git a/core/transport/net_listener_mock_test.go b/core/transport/net_listener_mock_test.go new file mode 100644 index 0000000..3ac1eff --- /dev/null +++ b/core/transport/net_listener_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Listener) + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + net "net" + reflect "reflect" +) + +// mockNetListener is a mock of Listener interface +type mockNetListener struct { + ctrl *gomock.Controller + recorder *mockNetListenerMockRecorder +} + +// mockNetListenerMockRecorder is the mock recorder for mockNetListener +type mockNetListenerMockRecorder struct { + mock *mockNetListener +} + +// newMockNetListener creates a new mock instance +func newMockNetListener(ctrl *gomock.Controller) *mockNetListener { + mock := &mockNetListener{ctrl: ctrl} + mock.recorder = &mockNetListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockNetListener) EXPECT() *mockNetListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method +func (m *mockNetListener) Accept() (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept") + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept +func (mr *mockNetListenerMockRecorder) Accept() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*mockNetListener)(nil).Accept)) +} + +// Addr mocks base method +func (m *mockNetListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr +func (mr *mockNetListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*mockNetListener)(nil).Addr)) +} + +// Close mocks base method +func (m *mockNetListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockNetListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockNetListener)(nil).Close)) +} diff --git a/core/transport/tcp_conn.go b/core/transport/tcp_conn.go new file mode 100644 index 0000000..ac4cfda --- /dev/null +++ b/core/transport/tcp_conn.go @@ -0,0 +1,102 @@ +package transport + +import ( + "bufio" + "io" + "net" + "time" + + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/logger" +) + +type TcpConn struct { + conn net.Conn + writer *bufio.Writer + decoder *LengthBasedFrameDecoder + counter *core.TrafficCounter +} + +func (p *TcpConn) SetCounter(c *core.TrafficCounter) { + p.counter = c +} + +func (p *TcpConn) SetDeadline(deadline time.Time) error { + return p.conn.SetReadDeadline(deadline) +} + +func (p *TcpConn) Read() (f core.Frame, err error) { + raw, err := p.decoder.Read() + if err == io.EOF { + return + } + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + f, err = framing.FromBytes(raw) + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + if p.counter != nil && f.Header().Resumable() { + p.counter.IncReadBytes(f.Len()) + } + err = f.Validate() + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + if logger.IsDebugEnabled() { + logger.Debugf("<--- rcv: %s\n", f) + } + return +} + +func (p *TcpConn) Flush() (err error) { + err = p.writer.Flush() + if err != nil { + err = errors.Wrap(err, "flush failed") + } + return +} + +func (p *TcpConn) Write(frame core.WriteableFrame) (err error) { + size := frame.Len() + if p.counter != nil && frame.Header().Resumable() { + p.counter.IncWriteBytes(size) + } + _, err = common.MustNewUint24(size).WriteTo(p.writer) + if err != nil { + err = errors.Wrap(err, "write frame failed") + return + } + var debugStr string + if logger.IsDebugEnabled() { + debugStr = framing.PrintFrame(frame) + } + _, err = frame.WriteTo(p.writer) + if err != nil { + err = errors.Wrap(err, "write frame failed") + return + } + if logger.IsDebugEnabled() { + logger.Debugf("---> snd: %s\n", debugStr) + } + return +} + +func (p *TcpConn) Close() error { + return p.conn.Close() +} + +func NewTcpConn(conn net.Conn) *TcpConn { + return &TcpConn{ + conn: conn, + writer: bufio.NewWriter(conn), + decoder: NewLengthBasedFrameDecoder(conn), + } +} diff --git a/core/transport/tcp_conn_test.go b/core/transport/tcp_conn_test.go new file mode 100644 index 0000000..71eff6a --- /dev/null +++ b/core/transport/tcp_conn_test.go @@ -0,0 +1,158 @@ +package transport_test + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +func InitMockTcpConn(t *testing.T) (*gomock.Controller, *mockNetConn, *transport.TcpConn) { + ctrl := gomock.NewController(t) + nc := newMockNetConn(ctrl) + tc := transport.NewTcpConn(nc) + return ctrl, nc, tc +} + +func TestTcpConn_Read_Empty(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT().Read(gomock.Any()).Return(0, fakeErr).AnyTimes() + _, err := tc.Read() + assert.Error(t, err, "should read failed") +} + +func TestTcpConn_Read(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + bf := &bytes.Buffer{} + c := core.NewTrafficCounter() + tc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var writtenBytes int + + for _, frame := range toBeWritten { + n := frame.Len() + if frame.Header().Resumable() { + writtenBytes += n + } + _, _ = common.MustNewUint24(n).WriteTo(bf) + _, _ = frame.WriteTo(bf) + } + + nc.EXPECT(). + Read(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Read(b) + }). + AnyTimes() + var results []core.Frame + for { + next, err := tc.Read() + if err == io.EOF { + break + } + assert.NoError(t, err, "read next frame failed") + results = append(results, next) + } + assert.Equal(t, len(toBeWritten), len(results), "result amount does not match") + for i := 0; i < len(results); i++ { + assert.Equal(t, toBeWritten[i].Header(), results[i].Header(), "header does not match") + } + assert.Equal(t, writtenBytes, int(c.ReadBytes()), "read bytes doesn't match") +} + +func TestTcpConn_SetDeadline(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + nc.EXPECT().SetReadDeadline(gomock.Any()).Times(1) + err := tc.SetDeadline(time.Now()) + assert.NoError(t, err, "call setDeadline failed") +} + +func TestTcpConn_Flush_Nothing(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + c := core.NewTrafficCounter() + tc.SetCounter(c) + + nc.EXPECT().Write(gomock.Any()).Times(0) + + err := tc.Flush() + assert.NoError(t, err, "flush failed") + assert.Equal(t, 0, int(c.WriteBytes()), "bytes written should be zero") +} + +func TestTcpConn_WriteWithBrokenConn(t *testing.T) { + logger.SetLogger(nil) + logger.SetLevel(logger.LevelDebug) + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT(). + Write(gomock.Any()). + Return(0, fakeErr). + AnyTimes() + _ = tc.Write(framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0)) + err := tc.Flush() + assert.Equal(t, fakeErr, errors.Cause(err), "should be fake error") +} + +func TestTcpConn_WriteAndFlush(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + c := core.NewTrafficCounter() + tc.SetCounter(c) + + nc.EXPECT(). + Write(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return len(b), nil + }). + Times(1) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var bytesWritten uint64 + for _, frame := range toBeWritten { + n := frame.Len() + err := tc.Write(frame) + assert.NoError(t, err, "write failed") + if frame.Header().Resumable() { + bytesWritten += uint64(n) + } + } + err := tc.Flush() + assert.NoError(t, err) + assert.Equal(t, bytesWritten, c.WriteBytes(), "write bytes doesn't match") +} + +func TestTcpConn_Close(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT().Close().Return(fakeErr).Times(1) + err := tc.Close() + assert.Equal(t, fakeErr, err, "should return fake error") +} diff --git a/core/transport/tcp_transport.go b/core/transport/tcp_transport.go new file mode 100644 index 0000000..5ef0a07 --- /dev/null +++ b/core/transport/tcp_transport.go @@ -0,0 +1,128 @@ +package transport + +import ( + "context" + "crypto/tls" + "io" + "net" + "sync" + + "github.com/pkg/errors" +) + +type tcpServerTransport struct { + listenerFn func() (net.Listener, error) + acceptor ServerTransportAcceptor + listener net.Listener + onceClose sync.Once + transports *sync.Map +} + +func (p *tcpServerTransport) Accept(acceptor ServerTransportAcceptor) { + p.acceptor = acceptor +} + +func (p *tcpServerTransport) Close() (err error) { + if p.listener == nil { + return + } + p.onceClose.Do(func() { + err = p.listener.Close() + + p.transports.Range(func(key, value interface{}) bool { + _ = key.(*Transport).Close() + return true + }) + + }) + return +} + +func (p *tcpServerTransport) Listen(ctx context.Context, notifier chan<- struct{}) (err error) { + p.listener, err = p.listenerFn() + if err != nil { + close(notifier) + return + } + notifier <- struct{}{} + close(notifier) + return p.listen(ctx) +} + +func (p *tcpServerTransport) listen(ctx context.Context) (err error) { + done := make(chan struct{}) + + defer func() { + close(done) + _ = p.Close() + }() + + go func() { + for { + select { + case <-ctx.Done(): + _ = p.Close() + return + case <-done: + return + } + } + }() + + // Start loop of accepting connections. + var c net.Conn + for { + c, err = p.listener.Accept() + if err == io.EOF || isClosedErr(err) { + err = nil + break + } + if err != nil { + err = errors.Wrap(err, "accept next conn failed") + break + } + // Dispatch raw conn. + tp := NewTransport(NewTcpConn(c)) + p.transports.Store(tp, struct{}{}) + go p.acceptor(ctx, tp, func(t *Transport) { + p.transports.Delete(t) + }) + } + return +} + +func NewTcpServerTransport(gen func() (net.Listener, error)) ServerTransport { + return &tcpServerTransport{ + listenerFn: gen, + transports: &sync.Map{}, + } +} + +func NewTcpClientTransport(c net.Conn) *Transport { + return NewTransport(NewTcpConn(c)) +} + +func NewTcpServerTransportWithAddr(network, addr string, tlsConfig *tls.Config) ServerTransport { + gen := func() (net.Listener, error) { + if tlsConfig == nil { + return net.Listen(network, addr) + } else { + return tls.Listen(network, addr, tlsConfig) + } + } + return NewTcpServerTransport(gen) +} + +func NewTcpClientTransportWithAddr(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { + var rawConn net.Conn + if tlsConfig == nil { + rawConn, err = net.Dial(network, addr) + } else { + rawConn, err = tls.Dial(network, addr, tlsConfig) + } + if err != nil { + return + } + tp = NewTcpClientTransport(rawConn) + return +} diff --git a/core/transport/tcp_transport_test.go b/core/transport/tcp_transport_test.go new file mode 100644 index 0000000..2b76faa --- /dev/null +++ b/core/transport/tcp_transport_test.go @@ -0,0 +1,165 @@ +package transport_test + +import ( + "context" + "crypto/tls" + "io" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" +) + +func InitTcpServerTransport(t *testing.T) (*gomock.Controller, *mockNetListener, transport.ServerTransport) { + ctrl := gomock.NewController(t) + listener := newMockNetListener(ctrl) + tp := transport.NewTcpServerTransport(func() (net.Listener, error) { + return listener, nil + }) + return ctrl, listener, tp +} + +func TestTcpServerTransport_ListenBroken(t *testing.T) { + tp := transport.NewTcpServerTransport(func() (net.Listener, error) { + return nil, fakeErr + }) + + defer tp.Close() + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fake error") + }() + _, ok := <-notifier + assert.False(t, ok) + + <-done +} + +func TestTcpServerTransport_Listen(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + + listener.EXPECT().Accept().Return(nil, io.EOF).AnyTimes() + listener.EXPECT().Close().Times(1) + + done := make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(ctx, notifier) + assert.True(t, err == nil || err == io.EOF) + }() + _, ok := <-notifier + assert.True(t, ok) + + time.Sleep(100 * time.Millisecond) + cancel() + + <-done +} + +func TestTcpServerTransport_Accept(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + defer tp.Close() + + connChan := make(chan net.Conn, 1) + listener.EXPECT(). + Accept(). + DoAndReturn(func() (net.Conn, error) { + c, ok := <-connChan + if !ok { + return nil, io.EOF + } + return c, nil + }). + AnyTimes() + listener.EXPECT().Close().Times(1) + + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) + err := tp.Start(ctx) + assert.True(t, err == nil || err == io.EOF) + }) + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.True(t, err == nil || err == io.EOF) + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier failed") + + c := newMockNetConn(ctrl) + + c.EXPECT().Read(gomock.Any()).Return(0, io.EOF).AnyTimes() + c.EXPECT().Close().Times(1) + + connChan <- c + + time.Sleep(100 * time.Millisecond) + close(connChan) + + <-done +} + +func TestTcpServerTransport_AcceptBroken(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + + listener.EXPECT(). + Accept(). + Return(nil, fakeErr). + AnyTimes() + listener.EXPECT().Close().Times(1) + + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) + err := tp.Start(ctx) + assert.True(t, err == nil || err == io.EOF) + }) + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Error(t, err, "should be error") + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fake error") + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier failed") + + <-done +} + +func TestNewTcpServerTransportWithAddr(t *testing.T) { + assert.NotPanics(t, func() { + tp := transport.NewTcpServerTransportWithAddr("tcp", ":9999", nil) + assert.NotNil(t, tp) + }) + assert.NotPanics(t, func() { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tp := transport.NewTcpServerTransportWithAddr("tcp", ":9999", tlsConfig) + assert.NotNil(t, tp) + }) +} diff --git a/core/transport/transport.go b/core/transport/transport.go new file mode 100644 index 0000000..36126bb --- /dev/null +++ b/core/transport/transport.go @@ -0,0 +1,252 @@ +package transport + +import ( + "context" + "io" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/logger" +) + +var ( + errTransportClosed = errors.New("transport closed") + errNoHandler = errors.New("you must register a handler") +) + +// FrameHandler is an alias of frame handler. +type FrameHandler = func(frame core.Frame) (err error) + +// ServerTransportAcceptor is an alias of server transport handler. +type ServerTransportAcceptor = func(ctx context.Context, tp *Transport, onClose func(*Transport)) + +// ServerTransport is server-side RSocket transport. +type ServerTransport interface { + io.Closer + // Accept register incoming connection handler. + Accept(acceptor ServerTransportAcceptor) + // Listen listens on the network address addr and handles requests on incoming connections. + // You can specify onReady handler, it'll be invoked when server begin listening. + // It always returns a non-nil error. + Listen(ctx context.Context, notifier chan<- struct{}) error +} + +type EventType int + +const ( + OnSetup EventType = iota + OnResume + OnLease + OnResumeOK + OnFireAndForget + OnMetadataPush + OnRequestResponse + OnRequestStream + OnRequestChannel + OnPayload + OnRequestN + OnError + OnErrorWithZeroStreamID + OnCancel + OnKeepalive + + handlerLen = int(OnKeepalive) + 1 +) + +// Transport is RSocket transport which is used to carry RSocket frames. +type Transport struct { + conn Conn + maxLifetime time.Duration + lastRcvPos uint64 + once sync.Once + handlers [handlerLen]FrameHandler +} + +func (p *Transport) RegisterHandler(event EventType, handler FrameHandler) { + p.handlers[int(event)] = handler +} + +// Connection returns current connection. +func (p *Transport) Connection() Conn { + return p.conn +} + +// SetLifetime set max lifetime for current transport. +func (p *Transport) SetLifetime(lifetime time.Duration) { + if lifetime < 1 { + return + } + p.maxLifetime = lifetime +} + +// Send send a frame. +func (p *Transport) Send(frame core.WriteableFrame, flush bool) (err error) { + defer func() { + // ensure frame done when send success. + if err == nil { + frame.Done() + } + }() + if p == nil || p.conn == nil { + err = errTransportClosed + return + } + err = p.conn.Write(frame) + if err != nil { + return + } + if !flush { + return + } + err = p.conn.Flush() + return +} + +// Flush flush all bytes in current connection. +func (p *Transport) Flush() (err error) { + if p == nil || p.conn == nil { + err = errTransportClosed + return + } + err = p.conn.Flush() + return +} + +// Close close current transport. +func (p *Transport) Close() (err error) { + p.once.Do(func() { + err = p.conn.Close() + }) + return +} + +// ReadFirst reads first frame. +func (p *Transport) ReadFirst(ctx context.Context) (frame core.Frame, err error) { + select { + case <-ctx.Done(): + err = ctx.Err() + default: + frame, err = p.conn.Read() + if err != nil { + err = errors.Wrap(err, "read first frame failed") + } + } + if err != nil { + _ = p.Close() + } + return +} + +// Start start transport. +func (p *Transport) Start(ctx context.Context) error { + defer p.Close() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + f, err := p.conn.Read() + if err == nil { + err = p.DispatchFrame(ctx, f) + } + if err == nil { + continue + } + if errors.Is(err, io.EOF) { + return nil + } + return errors.Wrap(err, "read and delivery frame failed") + } + } +} + +// DispatchFrame delivery incoming frames. +func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err error) { + header := frame.Header() + t := header.Type() + sid := header.StreamID() + + var handler FrameHandler + + switch t { + case core.FrameTypeSetup: + p.maxLifetime = frame.(*framing.SetupFrame).MaxLifetime() + handler = p.handlers[OnSetup] + case core.FrameTypeResume: + handler = p.handlers[OnResume] + case core.FrameTypeResumeOK: + p.lastRcvPos = frame.(*framing.ResumeOKFrame).LastReceivedClientPosition() + handler = p.handlers[OnResumeOK] + case core.FrameTypeRequestFNF: + handler = p.handlers[OnFireAndForget] + case core.FrameTypeMetadataPush: + if sid != 0 { + // skip invalid metadata push + logger.Warnf("rsocket.Transport: omit MetadataPush with non-zero stream id %d\n", sid) + return + } + handler = p.handlers[OnMetadataPush] + case core.FrameTypeRequestResponse: + handler = p.handlers[OnRequestResponse] + case core.FrameTypeRequestStream: + handler = p.handlers[OnRequestStream] + case core.FrameTypeRequestChannel: + handler = p.handlers[OnRequestChannel] + case core.FrameTypePayload: + handler = p.handlers[OnPayload] + case core.FrameTypeRequestN: + handler = p.handlers[OnRequestN] + case core.FrameTypeError: + if sid == 0 { + err = errors.New(frame.(*framing.ErrorFrame).Error()) + if call := p.handlers[OnErrorWithZeroStreamID]; call != nil { + _ = call(frame) + } + return + } + handler = p.handlers[OnError] + case core.FrameTypeCancel: + handler = p.handlers[OnCancel] + case core.FrameTypeKeepalive: + ka := frame.(*framing.KeepaliveFrame) + p.lastRcvPos = ka.LastReceivedPosition() + handler = p.handlers[OnKeepalive] + case core.FrameTypeLease: + handler = p.handlers[OnLease] + } + + // Set deadline. + deadline := time.Now().Add(p.maxLifetime) + err = p.conn.SetDeadline(deadline) + if err != nil { + return + } + + // missing handler + if handler == nil { + err = errNoHandler + return + } + + // trigger handler + err = handler(frame) + if err != nil { + err = errors.Wrap(err, "exec frame handler failed") + } + return +} + +func NewTransport(c Conn) *Transport { + return &Transport{ + conn: c, + maxLifetime: common.DefaultKeepaliveMaxLifetime, + } +} + +func IsNoHandlerError(err error) bool { + return err == errNoHandler +} diff --git a/core/transport/transport_test.go b/core/transport/transport_test.go new file mode 100644 index 0000000..a34e77a --- /dev/null +++ b/core/transport/transport_test.go @@ -0,0 +1,239 @@ +package transport_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func Init(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { + ctrl := gomock.NewController(t) + conn := NewMockConn(ctrl) + tp := transport.NewTransport(conn) + return ctrl, conn, tp +} + +func TestTransport_Start(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().Return(nil).Times(1) + + conn.EXPECT().Read().Return(nil, fakeErr).Times(1) + err := tp.Start(context.Background()) + assert.Error(t, err, "should be an error") + assert.True(t, errors.Cause(err) == fakeErr, "should be the fake error") + + conn.EXPECT().Read().Return(nil, io.EOF).Times(1) + err = tp.Start(context.Background()) + assert.NoError(t, err, "there should be no error here if io.EOF occurred") +} + +func TestTransport_RegisterHandler(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + var cursor int + fakeToken := []byte("fake-token") + fakeDada := []byte("fake-data") + fakeMetadata := []byte("fake-metadata") + fakeMimeType := []byte("fake-mime-type") + fakeFlag := core.FrameFlag(0) + frames := []core.Frame{ + framing.NewSetupFrame( + core.NewVersion(1, 0), + 30*time.Second, + 90*time.Second, + nil, + fakeMimeType, + fakeMimeType, + fakeDada, + fakeMetadata, + false), + framing.NewMetadataPushFrame(fakeDada), + framing.NewFireAndForgetFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestResponseFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestStreamFrame(1, 1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestChannelFrame(1, 1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestNFrame(1, 1, fakeFlag), + framing.NewKeepaliveFrame(1, fakeDada, true), + framing.NewCancelFrame(1), + framing.NewErrorFrame(1, core.ErrorCodeApplicationError, fakeDada), + framing.NewLeaseFrame(30*time.Second, 1, fakeMetadata), + framing.NewPayloadFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewResumeFrame(core.DefaultVersion, fakeToken, 1, 1), + framing.NewResumeOKFrame(1), + framing.NewErrorFrame(0, core.ErrorCodeRejected, fakeDada), + } + conn.EXPECT(). + Read(). + DoAndReturn(func() (core.Frame, error) { + defer func() { + cursor++ + }() + if cursor >= len(frames) { + return nil, io.EOF + } + return frames[cursor], nil + }). + AnyTimes() + + calls := make(map[core.FrameType]int) + fakeHandler := func(frame core.Frame) (err error) { + typ := frame.Header().Type() + calls[typ] = calls[typ] + 1 + return nil + } + + callsErrorWithZeroStreamID := atomic.NewInt32(0) + + tp.RegisterHandler(transport.OnSetup, fakeHandler) + tp.RegisterHandler(transport.OnRequestResponse, fakeHandler) + tp.RegisterHandler(transport.OnFireAndForget, fakeHandler) + tp.RegisterHandler(transport.OnMetadataPush, fakeHandler) + tp.RegisterHandler(transport.OnRequestStream, fakeHandler) + tp.RegisterHandler(transport.OnRequestChannel, fakeHandler) + tp.RegisterHandler(transport.OnKeepalive, fakeHandler) + tp.RegisterHandler(transport.OnRequestN, fakeHandler) + tp.RegisterHandler(transport.OnPayload, fakeHandler) + tp.RegisterHandler(transport.OnError, fakeHandler) + tp.RegisterHandler(transport.OnCancel, fakeHandler) + tp.RegisterHandler(transport.OnResumeOK, fakeHandler) + tp.RegisterHandler(transport.OnResume, fakeHandler) + tp.RegisterHandler(transport.OnLease, fakeHandler) + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { + callsErrorWithZeroStreamID.Inc() + return + }) + + toHaveBeenCalled := func(typ core.FrameType, n int) { + called, ok := calls[typ] + assert.True(t, ok, "%s have not been called", typ) + assert.Equal(t, n, called, "%s have not been called %d times", n) + } + + err := tp.Start(context.Background()) + assert.Error(t, err, "should be no error") + + for _, typ := range []core.FrameType{ + core.FrameTypeSetup, + core.FrameTypeRequestResponse, + core.FrameTypeRequestFNF, + core.FrameTypeMetadataPush, + core.FrameTypeRequestStream, + core.FrameTypeRequestChannel, + core.FrameTypeKeepalive, + core.FrameTypeRequestN, + core.FrameTypePayload, + core.FrameTypeError, + core.FrameTypeCancel, + core.FrameTypeResume, + core.FrameTypeResumeOK, + core.FrameTypeLease, + } { + toHaveBeenCalled(typ, 1) + } + assert.Equal(t, int32(1), callsErrorWithZeroStreamID.Load(), "error frame with zero stream id has not been called") +} + +func TestTransport_ReadFirst(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().AnyTimes() + + conn.EXPECT().Read().Return(nil, fakeErr).Times(1) + _, err := tp.ReadFirst(context.Background()) + assert.Error(t, err, "should be error") + + expect := framing.NewCancelFrame(1) + conn.EXPECT().Read().Return(expect, nil).Times(1) + actual, err := tp.ReadFirst(context.Background()) + assert.NoError(t, err, "should not be error") + assert.Equal(t, expect, actual, "not match") +} + +func TestTransport_Send(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Write(gomock.Any()).Times(2) + conn.EXPECT().Flush().Times(1) + + var err error + + err = tp.Send(framing.NewCancelFrame(1), false) + assert.NoError(t, err, "send failed") + + err = tp.Send(framing.NewCancelFrame(1), true) + assert.NoError(t, err, "send failed") +} + +func TestTransport_Connection(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + c := tp.Connection() + assert.Equal(t, conn, c) +} + +func TestTransport_Flush(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Flush().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + + err := tp.Flush() + assert.NoError(t, err, "flush failed") + conn.SetCounter(core.NewTrafficCounter()) +} + +func TestTransport_Close(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().Times(1) + + err := tp.Close() + assert.NoError(t, err, "close transport failed") +} + +func TestTransport_HandlerReturnsError(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + conn.EXPECT().Close().Times(1) + conn.EXPECT().Read().Return(framing.NewCancelFrame(1), nil).Times(1) + + tp.RegisterHandler(transport.OnCancel, func(_ core.Frame) error { + return fakeErr + }) + err := tp.Start(context.Background()) + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fakeError") +} + +func TestTransport_EmptyHandler(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + conn.EXPECT().Close().Times(1) + conn.EXPECT().Read().Return(framing.NewCancelFrame(1), nil).Times(1) + + err := tp.Start(context.Background()) + assert.True(t, transport.IsNoHandlerError(errors.Cause(err)), "should be no handler error") +} diff --git a/internal/transport/connection.go b/core/transport/types.go similarity index 58% rename from internal/transport/connection.go rename to core/transport/types.go index 767cda0..4d76c43 100644 --- a/internal/transport/connection.go +++ b/core/transport/types.go @@ -1,10 +1,16 @@ package transport import ( + "context" "io" "time" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" +) + +type ( + ClientTransportFunc func(context.Context) (*Transport, error) + ServerTransportFunc func(context.Context) (ServerTransport, error) ) // Conn is connection for RSocket. @@ -14,11 +20,11 @@ type Conn interface { // After this deadline, connection will be closed. SetDeadline(deadline time.Time) error // SetCounter bind a counter which can count r/w bytes. - SetCounter(c *Counter) + SetCounter(c *core.TrafficCounter) // Read reads next frame from Conn. - Read() (framing.Frame, error) + Read() (core.Frame, error) // Write writes a frame to Conn. - Write(frames framing.Frame) error + Write(core.WriteableFrame) error // Flush. Flush() error } diff --git a/core/transport/websocket_conn.go b/core/transport/websocket_conn.go new file mode 100644 index 0000000..e0d28db --- /dev/null +++ b/core/transport/websocket_conn.go @@ -0,0 +1,123 @@ +package transport + +import ( + "bytes" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/logger" +) + +var _buffPool = sync.Pool{ + New: func() interface{} { return &bytes.Buffer{} }, +} + +type RawWsConn interface { + io.Closer + SetReadDeadline(time.Time) error + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error +} + +type WsConn struct { + c RawWsConn + counter *core.TrafficCounter +} + +func (p *WsConn) SetCounter(c *core.TrafficCounter) { + p.counter = c +} + +func (p *WsConn) SetDeadline(deadline time.Time) error { + return p.c.SetReadDeadline(deadline) +} + +func (p *WsConn) Read() (f core.Frame, err error) { + t, raw, err := p.c.ReadMessage() + + if err == io.EOF { + return + } + + if websocket.IsCloseError(err) || websocket.IsUnexpectedCloseError(err) || isClosedErr(err) { + err = io.EOF + return + } + + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + + // Skip non-binary message + if t != websocket.BinaryMessage { + return p.Read() + } + + f, err = framing.FromBytes(raw) + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + + if p.counter != nil && f.Header().Resumable() { + p.counter.IncReadBytes(f.Len()) + } + + err = f.Validate() + if err != nil { + err = errors.Wrap(err, "read frame failed") + return + } + if logger.IsDebugEnabled() { + logger.Debugf("<--- rcv: %s\n", f) + } + return +} + +func (p *WsConn) Flush() (err error) { + return +} + +func (p *WsConn) Write(frame core.WriteableFrame) (err error) { + size := frame.Len() + bf := _buffPool.Get().(*bytes.Buffer) + defer func() { + bf.Reset() + _buffPool.Put(bf) + }() + _, err = frame.WriteTo(bf) + if err != nil { + return + } + err = p.c.WriteMessage(websocket.BinaryMessage, bf.Bytes()) + if err == io.EOF { + return + } + if err != nil { + err = errors.Wrap(err, "write frame failed") + return + } + if p.counter != nil && frame.Header().Resumable() { + p.counter.IncWriteBytes(size) + } + if logger.IsDebugEnabled() { + logger.Debugf("---> snd: %s\n", frame) + } + return +} + +func (p *WsConn) Close() error { + return p.c.Close() +} + +func NewWebsocketConnection(rawConn RawWsConn) *WsConn { + return &WsConn{ + c: rawConn, + } +} diff --git a/core/transport/websocket_conn_mock_test.go b/core/transport/websocket_conn_mock_test.go new file mode 100644 index 0000000..7a73349 --- /dev/null +++ b/core/transport/websocket_conn_mock_test.go @@ -0,0 +1,92 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/websocket_conn.go + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" + time "time" +) + +// mockRawWsConn is a mock of RawWsConn interface +type mockRawWsConn struct { + ctrl *gomock.Controller + recorder *mockRawWsConnMockRecorder +} + +// mockRawWsConnMockRecorder is the mock recorder for mockRawWsConn +type mockRawWsConnMockRecorder struct { + mock *mockRawWsConn +} + +// newMockRawWsConn creates a new mock instance +func newMockRawWsConn(ctrl *gomock.Controller) *mockRawWsConn { + mock := &mockRawWsConn{ctrl: ctrl} + mock.recorder = &mockRawWsConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockRawWsConn) EXPECT() *mockRawWsConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *mockRawWsConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockRawWsConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockRawWsConn)(nil).Close)) +} + +// SetReadDeadline mocks base method +func (m *mockRawWsConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *mockRawWsConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*mockRawWsConn)(nil).SetReadDeadline), arg0) +} + +// ReadMessage mocks base method +func (m *mockRawWsConn) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage +func (mr *mockRawWsConnMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*mockRawWsConn)(nil).ReadMessage)) +} + +// WriteMessage mocks base method +func (m *mockRawWsConn) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage +func (mr *mockRawWsConnMockRecorder) WriteMessage(messageType, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*mockRawWsConn)(nil).WriteMessage), messageType, data) +} diff --git a/core/transport/websocket_conn_test.go b/core/transport/websocket_conn_test.go new file mode 100644 index 0000000..703aa6b --- /dev/null +++ b/core/transport/websocket_conn_test.go @@ -0,0 +1,169 @@ +package transport_test + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +func InitMockWsConn(t *testing.T) (*gomock.Controller, *mockRawWsConn, *transport.WsConn) { + ctrl := gomock.NewController(t) + rawConn := newMockRawWsConn(ctrl) + conn := transport.NewWebsocketConnection(rawConn) + return ctrl, rawConn, conn +} + +func TestWsConn_Read_Empty(t *testing.T) { + ctrl, rawConn, conn := InitMockWsConn(t) + defer ctrl.Finish() + + rawConn.EXPECT().ReadMessage().Return(0, nil, fakeErr).AnyTimes() + _, err := conn.Read() + assert.Error(t, err, "should read failed") +} + +func TestWsConn_Read(t *testing.T) { + ctrl, rc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewTrafficCounter() + wc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var ( + writtenBytesSlice [][]byte + writtenBytes int + ) + + for _, frame := range toBeWritten { + n := frame.Len() + if frame.Header().Resumable() { + writtenBytes += n + } + b := &bytes.Buffer{} + _, err := frame.WriteTo(b) + assert.NoError(t, err, "write frame failed") + writtenBytesSlice = append(writtenBytesSlice, b.Bytes()) + } + + cursor := 0 + rc.EXPECT(). + ReadMessage(). + DoAndReturn(func() (int, []byte, error) { + defer func() { + cursor++ + }() + if cursor >= len(writtenBytesSlice) { + return 0, nil, io.EOF + } + return websocket.BinaryMessage, writtenBytesSlice[cursor], nil + }). + AnyTimes() + + var results []core.Frame + for { + next, err := wc.Read() + if err == io.EOF { + break + } + assert.NoError(t, err, "read next frame failed") + results = append(results, next) + } + + assert.Equal(t, len(toBeWritten), len(results), "result amount does not match") + for i := 0; i < len(results); i++ { + assert.Equal(t, toBeWritten[i].Header(), results[i].Header(), "header does not match") + } + assert.Equal(t, writtenBytes, int(c.ReadBytes()), "read bytes doesn't match") +} + +func TestWsConn_SetDeadline(t *testing.T) { + ctrl, mc, c := InitMockWsConn(t) + defer ctrl.Finish() + + mc.EXPECT().SetReadDeadline(gomock.Any()).Times(1) + err := c.SetDeadline(time.Now()) + assert.NoError(t, err, "call setDeadline failed") +} + +func TestWsConn_Flush_Nothing(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewTrafficCounter() + wc.SetCounter(c) + + mc.EXPECT().WriteMessage(websocket.BinaryMessage, gomock.Any()).Times(0) + + err := wc.Flush() + assert.NoError(t, err, "flush failed") + assert.Equal(t, 0, int(c.WriteBytes()), "bytes written should be zero") +} + +func TestWsConn_WriteWithBrokenConn(t *testing.T) { + logger.SetLogger(nil) + logger.SetLevel(logger.LevelDebug) + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + mc.EXPECT(). + WriteMessage(websocket.BinaryMessage, gomock.Any()). + Return(fakeErr). + AnyTimes() + err := wc.Write(framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0)) + assert.Equal(t, fakeErr, errors.Cause(err), "should be fake error") +} + +func TestWsConn_Write(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewTrafficCounter() + wc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + mc.EXPECT(). + WriteMessage(websocket.BinaryMessage, gomock.Any()). + Return(nil). + Times(len(toBeWritten)) + + var bytesWritten uint64 + for _, frame := range toBeWritten { + n := frame.Len() + err := wc.Write(frame) + assert.NoError(t, err, "write failed") + if frame.Header().Resumable() { + bytesWritten += uint64(n) + } + } + err := wc.Flush() + assert.NoError(t, err) + assert.Equal(t, bytesWritten, c.WriteBytes(), "write bytes doesn't match") +} + +func TestWsConn_Close(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + mc.EXPECT().Close().Return(fakeErr).Times(1) + err := wc.Close() + assert.Equal(t, fakeErr, err, "should return fake error") +} diff --git a/internal/transport/transport_ws.go b/core/transport/websocket_transport.go similarity index 66% rename from internal/transport/transport_ws.go rename to core/transport/websocket_transport.go index f22a168..681f1d4 100644 --- a/internal/transport/transport_ws.go +++ b/core/transport/websocket_transport.go @@ -6,8 +6,6 @@ import ( "io" "net" "net/http" - "os" - "strings" "sync" "time" @@ -21,30 +19,22 @@ const defaultWebsocketPath = "/" var upgrader websocket.Upgrader func init() { - // Default allow CORS. - cors := true - if v, ok := os.LookupEnv("RSOCKET_WS_CORS"); ok { - v = strings.TrimSpace(strings.ToLower(v)) - cors = v == "yes" || v == "on" || v == "1" || v == "true" - } upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - } - if cors { - upgrader.CheckOrigin = func(r *http.Request) bool { + CheckOrigin: func(r *http.Request) bool { return true - } + }, } } type wsServerTransport struct { - addr string - path string - acceptor ServerTransportAcceptor - onceClose sync.Once - listener net.Listener - tls *tls.Config + path string + acceptor ServerTransportAcceptor + onceClose sync.Once + listenerFn func() (net.Listener, error) + listener net.Listener + transports *sync.Map } func (p *wsServerTransport) Close() (err error) { @@ -68,21 +58,18 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} logger.Errorf("create websocket conn failed: %s\n", err.Error()) return } - go func(c *websocket.Conn, ctx context.Context) { - conn := newWebsocketConnection(c) - tp := newTransportClient(conn) - p.acceptor(ctx, tp) - }(c, ctx) + tp := NewTransport(NewWebsocketConnection(c)) + p.transports.Store(tp, struct{}{}) + go p.acceptor(ctx, tp, func(tp *Transport) { + p.transports.Delete(tp) + }) }) - if p.tls == nil { - p.listener, err = net.Listen("tcp", p.addr) - } else { - p.listener, err = tls.Listen("tcp", p.addr, p.tls) - } + p.listener, err = p.listenerFn() if err != nil { err = errors.Wrap(err, "server listen failed") + close(notifier) return } @@ -110,18 +97,27 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} return } -func newWebsocketServerTransport(addr string, path string, c *tls.Config) *wsServerTransport { +func NewWebsocketServerTransport(gen func() (net.Listener, error), path string) ServerTransport { if path == "" { path = defaultWebsocketPath } return &wsServerTransport{ - addr: addr, - path: path, - tls: c, + path: path, + listenerFn: gen, + transports: &sync.Map{}, } } -func newWebsocketClientTransport(url string, tc *tls.Config, header http.Header) (*Transport, error) { +func NewWebsocketServerTransportWithAddr(addr string, path string, c *tls.Config) ServerTransport { + return NewWebsocketServerTransport(func() (net.Listener, error) { + if c == nil { + return net.Listen("tcp", addr) + } + return tls.Listen("tcp", addr, c) + }, path) +} + +func NewWebsocketClientTransport(url string, tc *tls.Config, header http.Header) (*Transport, error) { var d *websocket.Dialer if tc == nil { d = websocket.DefaultDialer @@ -136,5 +132,5 @@ func newWebsocketClientTransport(url string, tc *tls.Config, header http.Header) if err != nil { return nil, errors.Wrap(err, "dial websocket failed") } - return newTransportClient(newWebsocketConnection(wsConn)), nil + return NewTransport(NewWebsocketConnection(wsConn)), nil } diff --git a/core/transport/websocket_transport_test.go b/core/transport/websocket_transport_test.go new file mode 100644 index 0000000..6c3711c --- /dev/null +++ b/core/transport/websocket_transport_test.go @@ -0,0 +1,121 @@ +package transport_test + +import ( + "bytes" + "context" + "io" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketServerTransport(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + listener := newMockNetListener(ctrl) + + fakeConnChan := make(chan net.Conn, 1) + + conn := newMockNetConn(ctrl) + + conn.EXPECT(). + RemoteAddr(). + DoAndReturn(func() net.Addr { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") + assert.NoError(t, err, "bad addr") + return addr + }). + AnyTimes() + conn.EXPECT(). + LocalAddr(). + DoAndReturn(func() net.Addr { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:18080") + assert.NoError(t, err, "bad addr") + return addr + }). + AnyTimes() + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Close().Return(nil).Times(1) + bf := &bytes.Buffer{} + conn.EXPECT(). + Read(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Read(b) + }). + AnyTimes() + conn.EXPECT(). + Write(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Write(b) + }). + AnyTimes() + + fakeConnChan <- conn + + listener.EXPECT(). + Accept(). + DoAndReturn(func() (net.Conn, error) { + c, ok := <-fakeConnChan + if !ok { + return nil, io.EOF + } + return c, nil + }). + AnyTimes() + listener.EXPECT().Close().AnyTimes() + + tp := transport.NewWebsocketServerTransport(func() (net.Listener, error) { + return listener, nil + }, "") + + notifier := make(chan struct{}) + + done := make(chan struct{}) + + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.True(t, err == nil || err == io.EOF) + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier should return ok=true") + + time.Sleep(100 * time.Millisecond) + + close(fakeConnChan) + + <-done +} + +func TestNewWebsocketServerTransport_Broken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tp := transport.NewWebsocketServerTransport(func() (net.Listener, error) { + return nil, fakeErr + }, "") + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + }) + + notifier := make(chan struct{}) + + done := make(chan struct{}) + + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Equal(t, fakeErr, errors.Cause(err)) + }() + + _, ok := <-notifier + assert.False(t, ok, "notifier should return ok=false") + + <-done +} diff --git a/core/types.go b/core/types.go new file mode 100644 index 0000000..a646d8b --- /dev/null +++ b/core/types.go @@ -0,0 +1,128 @@ +package core + +import ( + "io" + "strings" +) + +// FrameType is type of frame. +type FrameType uint8 + +// All frame types +const ( + FrameTypeReserved FrameType = 0x00 + FrameTypeSetup FrameType = 0x01 + FrameTypeLease FrameType = 0x02 + FrameTypeKeepalive FrameType = 0x03 + FrameTypeRequestResponse FrameType = 0x04 + FrameTypeRequestFNF FrameType = 0x05 + FrameTypeRequestStream FrameType = 0x06 + FrameTypeRequestChannel FrameType = 0x07 + FrameTypeRequestN FrameType = 0x08 + FrameTypeCancel FrameType = 0x09 + FrameTypePayload FrameType = 0x0A + FrameTypeError FrameType = 0x0B + FrameTypeMetadataPush FrameType = 0x0C + FrameTypeResume FrameType = 0x0D + FrameTypeResumeOK FrameType = 0x0E + FrameTypeExt FrameType = 0x3F +) + +func (f FrameType) String() string { + switch f { + case FrameTypeReserved: + return "RESERVED" + case FrameTypeSetup: + return "SETUP" + case FrameTypeLease: + return "LEASE" + case FrameTypeKeepalive: + return "KEEPALIVE" + case FrameTypeRequestResponse: + return "REQUEST_RESPONSE" + case FrameTypeRequestFNF: + return "REQUEST_FNF" + case FrameTypeRequestStream: + return "REQUEST_STREAM" + case FrameTypeRequestChannel: + return "REQUEST_CHANNEL" + case FrameTypeRequestN: + return "REQUEST_N" + case FrameTypeCancel: + return "CANCEL" + case FrameTypePayload: + return "PAYLOAD" + case FrameTypeError: + return "ERROR" + case FrameTypeMetadataPush: + return "METADATA_PUSH" + case FrameTypeResume: + return "RESUME" + case FrameTypeResumeOK: + return "RESUME_OK" + case FrameTypeExt: + return "EXT" + default: + return "UNKNOWN" + } +} + +// FrameFlag is flag of frame. +type FrameFlag uint16 + +func (f FrameFlag) String() string { + foo := make([]string, 0) + if f.Check(FlagNext) { + foo = append(foo, "N") + } + if f.Check(FlagComplete) { + foo = append(foo, "CL") + } + if f.Check(FlagFollow) { + foo = append(foo, "FRS") + } + if f.Check(FlagMetadata) { + foo = append(foo, "M") + } + if f.Check(FlagIgnore) { + foo = append(foo, "I") + } + return strings.Join(foo, "|") +} + +// All frame flags +const ( + FlagNext FrameFlag = 1 << (5 + iota) + FlagComplete + FlagFollow + FlagMetadata + FlagIgnore + + FlagResume = FlagFollow + FlagLease = FlagComplete + FlagRespond = FlagFollow +) + +// Check returns true if mask exists. +func (f FrameFlag) Check(flag FrameFlag) bool { + return flag&f == flag +} + +type WriteableFrame interface { + io.WriterTo + // FrameHeader returns frame FrameHeader. + Header() FrameHeader + // Len returns length of frame. + Len() int + // Done marks current frame has been sent. + Done() (closed bool) + // DoneNotify notifies when frame done. + DoneNotify() <-chan struct{} +} + +// Frame is a single message containing a request, response, or protocol processing. +type Frame interface { + WriteableFrame + // Validate returns error if frame is invalid. + Validate() error +} diff --git a/core/types_test.go b/core/types_test.go new file mode 100644 index 0000000..f40db3d --- /dev/null +++ b/core/types_test.go @@ -0,0 +1,13 @@ +package core_test + +import ( + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestFrameFlag_String(t *testing.T) { + f := core.FlagNext | core.FlagComplete | core.FlagFollow | core.FlagMetadata | core.FlagIgnore + assert.True(t, f.String() != "") +} diff --git a/core/version.go b/core/version.go new file mode 100644 index 0000000..a847d56 --- /dev/null +++ b/core/version.go @@ -0,0 +1,80 @@ +package core + +import ( + "encoding/binary" + "io" + "strconv" + "strings" +) + +// DefaultVersion is default protocol version. +var DefaultVersion Version = [2]uint16{1, 0} + +// Version define the version of protocol. +// It includes major and minor version. +type Version [2]uint16 + +// Bytes returns raw bytes of current version. +func (p Version) Bytes() []byte { + bs := make([]byte, 4) + binary.BigEndian.PutUint16(bs, p[0]) + binary.BigEndian.PutUint16(bs[2:], p[1]) + return bs +} + +// Major returns major version. +func (p Version) Major() uint16 { + return p[0] +} + +// Minor returns minor version. +func (p Version) Minor() uint16 { + return p[1] +} + +func (p Version) Equals(version Version) bool { + return p.Major() == version.Major() && p.Minor() == version.Minor() +} + +func (p Version) GreaterThan(version Version) bool { + if p[0] == version[0] { + return p[1] > version[1] + } + return p[0] > version[0] +} + +func (p Version) LessThan(version Version) bool { + if p[0] == version[0] { + return p[1] < version[1] + } + return p[0] < version[0] +} + +// WriteTo write raw version bytes to a writer. +func (p Version) WriteTo(w io.Writer) (n int64, err error) { + err = binary.Write(w, binary.BigEndian, p[0]) + if err != nil { + return + } + err = binary.Write(w, binary.BigEndian, p[1]) + if err != nil { + return + } + n = 4 + return +} + +func (p Version) String() string { + b := strings.Builder{} + b.WriteString(strconv.Itoa(int(p[0]))) + b.WriteByte('.') + b.WriteString(strconv.Itoa(int(p[1]))) + return b.String() +} + +// NewVersion creates a new Version from major and minor. +func NewVersion(major, minor uint16) Version { + return Version{ + major, minor, + } +} diff --git a/core/version_test.go b/core/version_test.go new file mode 100644 index 0000000..0cceb8b --- /dev/null +++ b/core/version_test.go @@ -0,0 +1,75 @@ +package core_test + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func BenchmarkVersion_String(b *testing.B) { + v := core.NewVersion(2, 3) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = v.String() + } +} + +func TestVersion(t *testing.T) { + var ( + major uint16 = 2 + minor uint16 = 1 + ) + v := core.NewVersion(major, minor) + assert.Equal(t, "2.1", v.String()) + assert.Equal(t, uint16(2), v.Major(), "wrong major version") + assert.Equal(t, uint16(1), v.Minor(), "wrong minor version") + checkBytes(t, v.Bytes(), 2, 1) + w := bytes.Buffer{} + n, err := v.WriteTo(&w) + assert.NoError(t, err, "write version failed") + assert.Equal(t, int64(4), n, "wrong wrote bytes length") + checkBytes(t, v.Bytes(), 2, 1) +} + +func TestVersion_Equals(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + + assert.True(t, v1.Equals(v2)) + assert.True(t, v2.Equals(v1)) + assert.False(t, v1.Equals(v3)) + assert.False(t, v1.Equals(v4)) +} + +func TestVersion_GreaterThan(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + assert.False(t, v1.GreaterThan(v2)) + assert.False(t, v1.GreaterThan(v3)) + assert.True(t, v1.GreaterThan(v4)) +} + +func TestVersion_LessThan(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + assert.False(t, v1.LessThan(v2)) + assert.True(t, v1.LessThan(v3)) + assert.False(t, v1.LessThan(v4)) +} + +func checkBytes(t *testing.T, b []byte, expectMajor, expectMinor uint16) { + assert.Equal(t, 4, len(b), "wrong version bytes") + major := binary.BigEndian.Uint16(b[:2]) + minor := binary.BigEndian.Uint16(b[2:]) + assert.Equal(t, expectMajor, major, "wrong major version") + assert.Equal(t, expectMinor, minor, "wrong minor version") +} diff --git a/examples/echo/echo.go b/examples/echo/echo.go index 1d8ed6f..823af98 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -11,20 +11,24 @@ import ( "strings" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" ) -const ListenAt = "tcp://127.0.0.1:7878" +var tp transport.ServerTransportFunc -//const ListenAt = "unix:///tmp/rsocket.echo.sock" -//const ListenAt = "ws://127.0.0.1:7878/echo" +func init() { + tp = rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).Build() +} func main() { go func() { + http.Handle("/metrics", promhttp.Handler()) log.Println(http.ListenAndServe(":4444", nil)) }() //logger.SetLevel(logger.LevelDebug) @@ -32,7 +36,7 @@ func main() { //Fragment(65535). //Resume(). OnStart(func() { - log.Println("server is listening:", ListenAt) + log.Println("server start success!") }). Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { //log.Println("SETUP BEGIN:----------------") @@ -61,7 +65,7 @@ func main() { } return responder(), nil }). - Transport(ListenAt). + Transport(tp). Serve(context.Background()) if err != nil { panic(err) @@ -139,9 +143,10 @@ func responder() rsocket.RSocket { //return payloads.(flux.Flux) payloads.(flux.Flux). //LimitRate(1). - SubscribeOn(scheduler.Elastic()). - DoOnNext(func(elem payload.Payload) { + SubscribeOn(scheduler.Parallel()). + DoOnNext(func(elem payload.Payload) error { log.Println("receiving:", elem) + return nil }). Subscribe(context.Background()) return flux.Create(func(i context.Context, sink flux.Sink) { diff --git a/examples/echo/echo_benchmark_test.go b/examples/echo_bench/echo_bench.go similarity index 55% rename from examples/echo/echo_benchmark_test.go rename to examples/echo_bench/echo_bench.go index 3507701..647f7a1 100644 --- a/examples/echo/echo_benchmark_test.go +++ b/examples/echo_bench/echo_bench.go @@ -3,61 +3,74 @@ package main import ( "bytes" "context" - "fmt" + "flag" "log" - _ "net/http/pprof" + "math/rand" "sync" - "testing" "time" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/mono" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestClient_RequestResponse(t *testing.T) { - client, err := createClient(ListenAt) - require.NoError(t, err, "bad client") - defer func() { - _ = client.Close() - }() +var tp transport.ClientTransportFunc + +func init() { + flag.Parse() + rand.Seed(time.Now().UnixNano()) + tp = rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() +} + +func main() { + var ( + n int + payloadSize int + mtu int + ) + flag.IntVar(&n, "n", 100*10000, "request amount.") + flag.IntVar(&payloadSize, "size", 1024, "payload data size.") + flag.IntVar(&mtu, "mtu", 0, "mut size, zero means disabled.") + + client, err := createClient(mtu) + if err != nil { + panic(err) + } + defer client.Close() wg := &sync.WaitGroup{} - n := 100 * 10000 + wg.Add(n) - data := []byte(common.RandAlphanumeric(1024)) + data := make([]byte, payloadSize) + rand.Read(data) now := time.Now() ctx := context.Background() sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { - assert.Equal(t, data, input.Data(), "data doesn't match") + rx.OnNext(func(input payload.Payload) error { //m2, _ := elem.MetadataUTF8() //assert.Equal(t, m1, m2, "metadata doesn't match") wg.Done() + return nil }), ) for i := 0; i < n; i++ { - m1 := []byte(fmt.Sprintf("benchmark_test_%d", i)) - client.RequestResponse(payload.New(data, m1)).SubscribeOn(scheduler.Elastic()).SubscribeWith(ctx, sub) + client.RequestResponse(payload.New(data, nil)).SubscribeOn(scheduler.Parallel()).SubscribeWith(ctx, sub) } wg.Wait() cost := time.Since(now) log.Println(n, "COST:", cost) log.Println(n, "QPS:", float64(n)/cost.Seconds()) - } -func createClient(uri string) (rsocket.Client, error) { +func createClient(mtu int) (rsocket.Client, error) { + return rsocket.Connect(). - //Fragment(1024). - //Resume(). + Fragment(mtu). SetupPayload(payload.NewString("你好", "世界")). Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { return rsocket.NewAbstractSocket( @@ -70,6 +83,6 @@ func createClient(uri string) (rsocket.Client, error) { }), ) }). - Transport(uri). + Transport(tp). Start(context.Background()) } diff --git a/examples/fibonacci/main.go b/examples/fibonacci/main.go index 400e90d..5b32465 100644 --- a/examples/fibonacci/main.go +++ b/examples/fibonacci/main.go @@ -14,7 +14,6 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -const transportString = "tcp://127.0.0.1:7878" const number = 13 func main() { @@ -77,7 +76,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestStreamHandler), nil }). // specify transport - Transport(transportString). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). // serve will block execution unless an error occurred Serve(context.Background()) @@ -86,7 +85,8 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection - client, err := rsocket.Connect().Transport(transportString).Start(context.Background()) + tp := rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() + client, err := rsocket.Connect().Transport(tp).Start(context.Background()) if err != nil { panic(err) } @@ -103,11 +103,12 @@ func client() { wg := sync.WaitGroup{} wg.Add(1) - f.DoOnNext(func(input payload.Payload) { + f.DoOnNext(func(input payload.Payload) error { // print each number in a stream fmt.Println(input.DataUTF8()) + return nil }).DoOnComplete(func() { - // will be called on successfull completion of the stream + // will be called on successful completion of the stream fmt.Println("Fibonacci sequence done") }).DoOnError(func(err error) { // will be called if a error occurs diff --git a/examples/lease/main.go b/examples/lease/main.go deleted file mode 100644 index b7023b6..0000000 --- a/examples/lease/main.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "context" - "log" - "time" - - "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/lease" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx/mono" -) - -func main() { - les, _ := lease.NewSimpleLease(10*time.Second, 7*time.Second, 1*time.Second, 5) - err := rsocket.Receive(). - Lease(les). - Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (socket rsocket.RSocket, e error) { - socket = rsocket.NewAbstractSocket( - rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - return mono.Just(msg) - }), - ) - return - }). - Transport("tcp://127.0.0.1:7878"). - Serve(context.Background()) - - if err != nil { - log.Fatal(err) - } -} diff --git a/examples/lease/main_test.go b/examples/lease/main_test.go deleted file mode 100644 index ec6c228..0000000 --- a/examples/lease/main_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package main_test - -import ( - "context" - "log" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.uber.org/atomic" - - "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/payload" -) - -func TestClientWithLease(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), 20*time.Second) - cli, err := rsocket.Connect(). - Lease(). - Transport("tcp://127.0.0.1:7878"). - Start(ctx) - if err != nil { - panic(err) - } - defer func() { - _ = cli.Close() - }() - - success := atomic.NewUint32(0) - -Loop: - for { - select { - case <-ctx.Done(): - break Loop - default: - time.Sleep(1 * time.Second) - v, err := cli.RequestResponse(payload.NewString("hello world", "go")).Block(context.Background()) - if err != nil { - log.Println("request failed:", err) - } else { - success.Inc() - log.Println("request success:", v) - } - } - } - assert.Equal(t, uint32(10), success.Load(), "bad requests") -} diff --git a/examples/word_counter/main.go b/examples/word_counter/main.go index 1767745..f3c421c 100644 --- a/examples/word_counter/main.go +++ b/examples/word_counter/main.go @@ -13,7 +13,6 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -const transportString = "tcp://127.0.0.1:7878" const number = 13 func main() { @@ -34,9 +33,10 @@ func server(readyCh chan struct{}) { // create a handler that will be called when the server receives the RequestChannel frame (FrameTypeRequestChannel - 0x07) requestChannelHandler := rsocket.RequestChannel(func(msgs rx.Publisher) flux.Flux { return flux.Create(func(ctx context.Context, s flux.Sink) { - msgs.(flux.Flux).DoOnNext(func(elem payload.Payload) { + msgs.(flux.Flux).DoOnNext(func(elem payload.Payload) error { // for each payload in a flux stream respond with a word count s.Next(payload.NewString(fmt.Sprintf("%d", wordCount(elem.DataUTF8())), "")) + return nil }).DoOnComplete(func() { // signal completion of the response stream s.Complete() @@ -54,7 +54,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestChannelHandler), nil }). // specify transport - Transport(transportString). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). // serve will block execution unless an error occurred Serve(context.Background()) @@ -63,21 +63,21 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection - client, err := rsocket.Connect().Transport(transportString).Start(context.Background()) + client, err := rsocket.Connect().Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()).Start(context.Background()) if err != nil { panic(err) } defer client.Close() // strings to count the words - strings := []payload.Payload{ + sentences := []payload.Payload{ payload.NewString("", extension.TextPlain.String()), payload.NewString("qux", extension.TextPlain.String()), payload.NewString("The quick brown fox jumps over the lazy dog", extension.TextPlain.String()), payload.NewString("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", extension.TextPlain.String()), } - f := flux.FromSlice(strings) + f := flux.FromSlice(sentences) // create a wait group so that the function does not return until the stream completes wg := sync.WaitGroup{} @@ -86,12 +86,13 @@ func client() { counter := 0 // register handler for RequestChannel - client.RequestChannel(f).DoOnNext(func(input payload.Payload) { + client.RequestChannel(f).DoOnNext(func(input payload.Payload) error { // print word count - fmt.Println(strings[counter].DataUTF8(), ":", input.DataUTF8()) + fmt.Println(sentences[counter].DataUTF8(), ":", input.DataUTF8()) counter = counter + 1 + return nil }).DoOnComplete(func() { - // will be called on successfull completion of the stream + // will be called on successful completion of the stream fmt.Println("Word counter ended.") }).DoOnError(func(err error) { // will be called if a error occurs diff --git a/extension/authentication.go b/extension/authentication.go index e8dfa0a..98fae6a 100644 --- a/extension/authentication.go +++ b/extension/authentication.go @@ -15,7 +15,10 @@ const ( _authenticationBearer wellKnownAuthenticationType = 0x01 ) -var errInvalidAuthBytes = errors.New("invalid authentication bytes") +var ( + _errInvalidAuthBytes = errors.New("invalid authentication bytes") + _errAuthTypeLengthExceed = errors.New("invalid authType length: exceed 127 bytes") +) type wellKnownAuthenticationType uint8 @@ -55,14 +58,23 @@ func (a Authentication) IsWellKnown() (ok bool) { } // NewAuthentication creates a new Authentication -func NewAuthentication(authType string, payload []byte) Authentication { +func NewAuthentication(authType string, payload []byte) (*Authentication, error) { if len(authType) > 0x7F { - panic("illegal authType length: exceed 127 bytes") + return nil, _errAuthTypeLengthExceed } - return Authentication{ + return &Authentication{ typ: authType, payload: payload, + }, nil +} + +// MustNewAuthentication creates a new Authentication +func MustNewAuthentication(authType string, payload []byte) *Authentication { + auth, err := NewAuthentication(authType, payload) + if err != nil { + panic(err) } + return auth } // Bytes encodes current Authentication to byte slice. @@ -78,28 +90,42 @@ func (a Authentication) Bytes() (raw []byte) { } // ParseAuthentication parse Authentication from raw bytes. -func ParseAuthentication(raw []byte) (auth Authentication, err error) { +func ParseAuthentication(raw []byte) (auth *Authentication, err error) { totals := len(raw) if totals < 1 { - err = errInvalidAuthBytes + err = _errInvalidAuthBytes return } first := raw[0] n := 0x7F & first if first&0x80 != 0 { - auth.typ = wellKnownAuthenticationType(n).String() - auth.payload = raw[1:] + auth = &Authentication{ + typ: wellKnownAuthenticationType(n).String(), + payload: raw[1:], + } return } if totals < int(n+1) { - err = errInvalidAuthBytes + err = _errInvalidAuthBytes return } - auth.typ = string(raw[1 : 1+n]) - auth.payload = raw[n+1:] + auth = &Authentication{ + typ: string(raw[1 : 1+n]), + payload: raw[n+1:], + } return } +// IsInvalidAuthenticationBytes returns true if input error is for invalid bytes. +func IsInvalidAuthenticationBytes(err error) bool { + return err == _errInvalidAuthBytes +} + +// IsAuthTypeLengthExceed returns true if input error is for AuthType length exceed. +func IsAuthTypeLengthExceed(err error) bool { + return err == _errAuthTypeLengthExceed +} + func parseWellKnownAuthenticateType(typ string) (au wellKnownAuthenticationType, ok bool) { switch typ { case _simpleAuth: diff --git a/extension/authentication_test.go b/extension/authentication_test.go index 34d65c4..99a77c6 100644 --- a/extension/authentication_test.go +++ b/extension/authentication_test.go @@ -1,14 +1,46 @@ package extension_test import ( + "math/rand" + "strings" "testing" + "time" "github.com/rsocket/rsocket-go/extension" "github.com/stretchr/testify/assert" ) +func TestNewAuthentication(t *testing.T) { + payload := []byte("foobar") + + for _, authType := range []string{"bearer", "simple"} { + au, err := extension.NewAuthentication(authType, payload) + assert.NoError(t, err, "create Authentication failed!") + assert.Equal(t, authType, au.Type(), "wrong type") + assert.Equal(t, payload, au.Payload(), "wrong payload") + assert.Equal(t, true, au.IsWellKnown(), "well-known should be true") + b := au.Bytes() + au2, err := extension.ParseAuthentication(b) + assert.NoError(t, err, "parse Authentication failed!") + assert.Equal(t, au.Type(), au2.Type(), "authType doesn't match") + assert.Equal(t, au.Payload(), au2.Payload(), "payload doesn't match") + } + + _, err := extension.NewAuthentication(strings.Repeat("0", 128), payload) + assert.True(t, extension.IsAuthTypeLengthExceed(err), "should error") +} + +func TestParseAuthentication(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + input := make([]byte, 2) + rand.Read(input) + input[0] &= ^uint8(0x80) + _, err := extension.ParseAuthentication(input) + assert.True(t, extension.IsInvalidAuthenticationBytes(err), "should error") +} + func TestAuthentication(t *testing.T) { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) raw := au.Bytes() au2, err := extension.ParseAuthentication(raw) assert.NoError(t, err, "bad authentication bytes") @@ -17,13 +49,13 @@ func TestAuthentication(t *testing.T) { func BenchmarkAuthentication_Bytes(b *testing.B) { for i := 0; i < b.N; i++ { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) _ = au.Bytes() } } func BenchmarkParseAuthentication(b *testing.B) { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) raw := au.Bytes() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/extension/composite_metadata.go b/extension/composite_metadata.go index acbbd7c..99cb059 100644 --- a/extension/composite_metadata.go +++ b/extension/composite_metadata.go @@ -120,7 +120,7 @@ func (c *CompositeMetadataBuilder) Build() (CompositeMetadata, error) { } metadata := c.v[i] metadataLen := len(metadata) - bf.Write(common.NewUint24(metadataLen).Bytes()) + bf.Write(common.MustNewUint24(metadataLen).Bytes()) if metadataLen > 0 { bf.Write(metadata) } diff --git a/fuzz.go b/fuzz.go index 9e70722..d04aa17 100644 --- a/fuzz.go +++ b/fuzz.go @@ -6,11 +6,11 @@ package rsocket import ( "bytes" "errors" - "fmt" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" ) func Fuzz(data []byte) int { @@ -20,17 +20,21 @@ func Fuzz(data []byte) int { if err == nil { err = handleRaw(raw) } - if err == nil || err == common.ErrInvalidFrame || err == transport.ErrIncompleteHeader { + if err == nil || isExpectedError(err) { return 0 } return 1 } +func isExpectedError(err error) bool { + return err == core.ErrInvalidFrame || err == transport.ErrIncompleteHeader +} + func handleRaw(raw []byte) (err error) { - h := framing.ParseFrameHeader(raw) + h := core.ParseFrameHeader(raw) bf := common.NewByteBuff() - var frame framing.Frame - frame, err = framing.NewFromBase(framing.NewBaseFrame(h, bf)) + var frame core.Frame + frame, err = framing.FromRawFrame(framing.NewRawFrame(h, bf)) if err != nil { return } @@ -38,21 +42,9 @@ func handleRaw(raw []byte) (err error) { if err != nil { return } - - switch f := frame.(type) { - case fmt.Stringer: - s := f.String() - if len(s) > 0 { - return - } - case error: - e := f.Error() - if len(e) > 0 { - return - } - default: - panic("unreachable") + if frame.Len() >= core.FrameHeaderLen { + return } - - return errors.New("???") + err = errors.New("broken frame") + return } diff --git a/go.mod b/go.mod index 8a1bf19..6446eae 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,12 @@ module github.com/rsocket/rsocket-go go 1.12 require ( + github.com/golang/mock v1.4.3 github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 - github.com/jjeffcaii/reactor-go v0.1.3 + github.com/jjeffcaii/reactor-go v0.2.1 github.com/pkg/errors v0.9.1 + github.com/prometheus/client_golang v1.7.1 github.com/stretchr/testify v1.4.0 github.com/urfave/cli/v2 v2.1.1 go.uber.org/atomic v1.5.1 diff --git a/go.sum b/go.sum index 7a23336..8538a77 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,97 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jjeffcaii/reactor-go v0.1.3 h1:HPvOkeoH1Z11t0TlWIyYuQkbSG/9/e3LgTN4QuLvPFs= -github.com/jjeffcaii/reactor-go v0.1.3/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/jjeffcaii/reactor-go v0.2.1 h1:Wb33QstbwZdgFFS3BqWeBweGp2KQJigfts7NQLHTxqE= +github.com/jjeffcaii/reactor-go v0.2.1/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/panjf2000/ants/v2 v2.4.1 h1:7RtUqj5lGOw0WnZhSKDZ2zzJhaX5490ZW1sUolRXCxY= github.com/panjf2000/ants/v2 v2.4.1/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1 h1:NTGy1Ja9pByO+xAeH/qiWnLrKtr3hJPNjaVUwnjpdpA= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -28,20 +99,50 @@ github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/common/bytebuffer.go b/internal/common/bytebuffer.go index dcf9500..a328f18 100644 --- a/internal/common/bytebuffer.go +++ b/internal/common/bytebuffer.go @@ -8,40 +8,49 @@ import ( // ByteBuff provides byte buffer, which can be used for minimizing. type ByteBuff bytes.Buffer -func (p *ByteBuff) pp() *bytes.Buffer { - return (*bytes.Buffer)(p) +func (b *ByteBuff) pp() *bytes.Buffer { + return (*bytes.Buffer)(b) } // Len returns size of ByteBuff. -func (p *ByteBuff) Len() (n int) { - return p.pp().Len() +func (b *ByteBuff) Len() (n int) { + return b.pp().Len() } // WriteTo write bytes to writer. -func (p *ByteBuff) WriteTo(w io.Writer) (int64, error) { - return p.pp().WriteTo(w) +func (b *ByteBuff) WriteTo(w io.Writer) (int64, error) { + return b.pp().WriteTo(w) } // Writer write bytes to current ByteBuff. -func (p *ByteBuff) Write(bs []byte) (int, error) { - return p.pp().Write(bs) +func (b *ByteBuff) Write(bs []byte) (int, error) { + return b.pp().Write(bs) } // WriteUint24 encode and write Uint24 to current ByteBuff. -func (p *ByteBuff) WriteUint24(n int) (err error) { - v := NewUint24(n) - _, err = p.Write(v[:]) +func (b *ByteBuff) WriteUint24(n int) (err error) { + if n > MaxUint24 { + return errExceedMaxUint24 + } + v := MustNewUint24(n) + _, err = b.Write(v[:]) return } -// WriteByte write a byte to current ByteBuff. -func (p *ByteBuff) WriteByte(b byte) error { - return p.pp().WriteByte(b) +// WriteByte writes a byte to current ByteBuff. +func (b *ByteBuff) WriteByte(c byte) error { + return b.pp().WriteByte(c) +} + +// WriteString writes a string to current ByteBuff. +func (b *ByteBuff) WriteString(s string) (err error) { + _, err = b.pp().Write([]byte(s)) + return } // Bytes returns all bytes in ByteBuff. -func (p *ByteBuff) Bytes() []byte { - return p.pp().Bytes() +func (b *ByteBuff) Bytes() []byte { + return b.pp().Bytes() } // NewByteBuff creates a new ByteBuff. diff --git a/internal/common/bytebuffer_test.go b/internal/common/bytebuffer_test.go new file mode 100644 index 0000000..ea0ce94 --- /dev/null +++ b/internal/common/bytebuffer_test.go @@ -0,0 +1,52 @@ +package common_test + +import ( + "os" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestByteBuff_Bytes(t *testing.T) { + data := []byte("foobar") + b := common.NewByteBuff() + wrote, err := b.Write(data) + assert.NoError(t, err, "write failed") + assert.Equal(t, len(data), wrote, "wrong wrote size") + assert.Equal(t, data, b.Bytes(), "wrong data") +} + +func TestByteBuff_WriteUint24(t *testing.T) { + b := common.NewByteBuff() + var err error + err = b.WriteUint24(0) + assert.NoError(t, err, "write uint24 failed") + err = b.WriteUint24(common.MaxUint24) + assert.NoError(t, err, "write maximum uint24 failed") + err = b.WriteUint24(0x01FFFFFF) + assert.Error(t, err, "should write failed") +} + +func TestByteBuff_Len(t *testing.T) { + b := common.NewByteBuff() + // 3+1+6 + _ = b.WriteUint24(1) + _ = b.WriteByte('c') + _, _ = b.Write([]byte("foobar")) + assert.Equal(t, 10, b.Len(), "wrong length") +} + +func TestByteBuff_WriteTo(t *testing.T) { + b := common.NewByteBuff() + f, err := os.OpenFile("/dev/null", os.O_WRONLY, os.ModeAppend) + assert.NoError(t, err, "open /dev/null failed") + defer f.Close() + // 16MB + s := common.RandAlphanumeric(16 * 1024 * 1024) + err = b.WriteString(s) + assert.NoError(t, err) + n, err := b.WriteTo(f) + assert.NoError(t, err, "WriteTo failed") + assert.Equal(t, len(s), int(n), "wrong length") +} diff --git a/internal/common/misc.go b/internal/common/common.go similarity index 100% rename from internal/common/misc.go rename to internal/common/common.go diff --git a/internal/common/cond.go b/internal/common/cond.go new file mode 100644 index 0000000..9c8bc73 --- /dev/null +++ b/internal/common/cond.go @@ -0,0 +1,47 @@ +package common + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" +) + +// Cond +// see https://gist.github.com/zviadm/c234426882bfc8acba88f3503edaaa36#file-cond2-go +type Cond struct { + L sync.Locker + n unsafe.Pointer +} + +func NewCond(l sync.Locker) *Cond { + c := &Cond{ + L: l, + } + n := make(chan struct{}) + c.n = unsafe.Pointer(&n) + return c +} + +func (c *Cond) NotifyChan() <-chan struct{} { + ptr := atomic.LoadPointer(&c.n) + return *((*chan struct{})(ptr)) +} + +func (c *Cond) Broadcast() { + n := make(chan struct{}) + ptrOld := atomic.SwapPointer(&c.n, unsafe.Pointer(&n)) + close(*(*chan struct{})(ptrOld)) +} + +func (c *Cond) Wait(ctx context.Context) (isCtx bool) { + n := c.NotifyChan() + c.L.Unlock() + select { + case <-n: + case <-ctx.Done(): + isCtx = true + } + c.L.Lock() + return +} diff --git a/internal/common/cond_test.go b/internal/common/cond_test.go new file mode 100644 index 0000000..84422a3 --- /dev/null +++ b/internal/common/cond_test.go @@ -0,0 +1,67 @@ +package common_test + +import ( + "context" + "log" + "runtime" + "sync" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" +) + +func TestNewCond(t *testing.T) { + x := 0 + c := common.NewCond(&sync.Mutex{}) + done := make(chan bool) + go func() { + c.L.Lock() + x = 1 + c.Wait(context.Background()) + if x != 2 { + log.Fatal("want 2") + } + x = 3 + c.Broadcast() + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 1 { + x = 2 + c.Broadcast() + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 2 { + c.Wait(context.Background()) + if x != 3 { + log.Fatal("want 3") + } + break + } + if x == 3 { + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + <-done + <-done + <-done +} diff --git a/internal/common/rand_test.go b/internal/common/rand_test.go new file mode 100644 index 0000000..402ea60 --- /dev/null +++ b/internal/common/rand_test.go @@ -0,0 +1,35 @@ +package common_test + +import ( + "regexp" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestRandAlphanumeric(t *testing.T) { + s := common.RandAlphanumeric(10) + r := regexp.MustCompile("^[a-zA-Z0-9]{10}$") + assert.True(t, r.MatchString(s)) + s = common.RandAlphanumeric(0) + assert.Empty(t, s) +} + +func TestRandAlphabetic(t *testing.T) { + s := common.RandAlphabetic(10) + r := regexp.MustCompile("^[a-zA-Z]{10}$") + assert.True(t, r.MatchString(s)) + s = common.RandAlphabetic(0) + assert.Empty(t, s) +} + +func TestRandFloat64(t *testing.T) { + f := common.RandFloat64() + assert.True(t, f < 1 && f > 0) +} + +func TestRandIntn(t *testing.T) { + n := common.RandIntn(10) + assert.True(t, n >= 0 && n < 10) +} diff --git a/internal/common/u24.go b/internal/common/u24.go index f2ca3a3..daf2860 100644 --- a/internal/common/u24.go +++ b/internal/common/u24.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "io" ) @@ -8,7 +9,20 @@ import ( // MaxUint24 is the max value of Uint24. const MaxUint24 = 16777215 -var errMaxUint24 = fmt.Errorf("uint24 exceed max value: %d", MaxUint24) +var ( + errExceedMaxUint24 = fmt.Errorf("uint24 exceed max value: %d", MaxUint24) + errNegativeNumber = errors.New("negative number is illegal") +) + +// IsExceedMaximumUint24Error returns true if exceed maximum Uint24. (16777215) +func IsExceedMaximumUint24Error(err error) bool { + return err == errExceedMaxUint24 +} + +// IsNegativeUint24Error returns true if number is negative. +func IsNegativeUint24Error(err error) bool { + return err == errNegativeNumber +} // Uint24 is 3 bytes unsigned integer. type Uint24 [3]byte @@ -29,14 +43,27 @@ func (p Uint24) AsInt() int { return int(p[0])<<16 + int(p[1])<<8 + int(p[2]) } +// MustNewUint24 returns a new uint24. +func MustNewUint24(n int) Uint24 { + v, err := NewUint24(n) + if err != nil { + panic(err) + } + return v +} + // NewUint24 returns a new uint24. -func NewUint24(n int) (v Uint24) { - if n > MaxUint24 { - panic(errMaxUint24) +func NewUint24(v int) (n Uint24, err error) { + if v < 0 { + err = errNegativeNumber + return + } + if v > MaxUint24 { + err = errExceedMaxUint24 } - v[0] = byte(n >> 16) - v[1] = byte(n >> 8) - v[2] = byte(n) + n[0] = byte(v >> 16) + n[1] = byte(v >> 8) + n[2] = byte(v) return } diff --git a/internal/common/u24_benchmark_test.go b/internal/common/u24_benchmark_test.go deleted file mode 100644 index 39df2fd..0000000 --- a/internal/common/u24_benchmark_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package common - -import "testing" - -func BenchmarkNewUint24(b *testing.B) { - n := RandIntn(MaxUint24) - for i := 0; i < b.N; i++ { - _ = NewUint24(n) - } -} - -func BenchmarkUint24_AsInt(b *testing.B) { - n := NewUint24(RandIntn(MaxUint24)) - for i := 0; i < b.N; i++ { - _ = n.AsInt() - } -} diff --git a/internal/common/u24_test.go b/internal/common/u24_test.go index ff59106..fdfb9fb 100644 --- a/internal/common/u24_test.go +++ b/internal/common/u24_test.go @@ -1,14 +1,71 @@ -package common +package common_test import ( "testing" + . "github.com/rsocket/rsocket-go/internal/common" "github.com/stretchr/testify/assert" ) -func TestUint24(t *testing.T) { +func BenchmarkNewUint24(b *testing.B) { n := RandIntn(MaxUint24) - x := NewUint24(n) + for i := 0; i < b.N; i++ { + _ = MustNewUint24(n) + } +} + +func BenchmarkUint24_AsInt(b *testing.B) { + n := MustNewUint24(RandIntn(MaxUint24)) + for i := 0; i < b.N; i++ { + _ = n.AsInt() + } +} + +func TestMustNewUint24(t *testing.T) { + func() { + defer func() { + e := recover() + assert.True(t, IsExceedMaximumUint24Error(e.(error)), "should failed") + }() + _ = MustNewUint24(MaxUint24 + 1) + }() + func() { + defer func() { + e := recover() + assert.True(t, IsNegativeUint24Error(e.(error)), "should failed") + }() + _ = MustNewUint24(-1) + }() +} + +func TestUint24(t *testing.T) { + testSingle(t, 0) + for range [1_000_000]struct{}{} { + testSingle(t, RandIntn(MaxUint24)) + } + testSingle(t, MaxUint24) + // negative + _, err := NewUint24(-1) + assert.Error(t, err, "negative number should failed") + + // over maximum number + _, err = NewUint24(MaxUint24 + 1) + assert.Error(t, err, "over maximum number should failed") +} + +func TestUint24_WriteTo(t *testing.T) { + for _, n := range []int{0, 1, RandIntn(MaxUint24), MaxUint24} { + v := MustNewUint24(n) + b := NewByteBuff() + wrote, err := v.WriteTo(b) + assert.NoError(t, err, "write uint24 failed") + assert.Equal(t, int64(3), wrote, "wrote bytes length should be 3") + assert.Equal(t, n, NewUint24Bytes(b.Bytes()).AsInt(), "bad uint24 result") + } +} + +func testSingle(t *testing.T, n int) { + x := MustNewUint24(n) assert.Equal(t, n, x.AsInt(), "bad new from int") y := NewUint24Bytes(x.Bytes()) assert.Equal(t, n, y.AsInt(), "bad new from bytes") diff --git a/internal/common/version.go b/internal/common/version.go deleted file mode 100644 index 9f83d81..0000000 --- a/internal/common/version.go +++ /dev/null @@ -1,46 +0,0 @@ -package common - -import ( - "encoding/binary" - "fmt" - "io" -) - -// DefaultVersion is default protocol version. -var DefaultVersion Version = [2]uint16{1, 0} - -// Version define the version of protocol. -// It includes major and minor version. -type Version [2]uint16 - -// Bytes returns raw bytes of current version. -func (p Version) Bytes() []byte { - bs := make([]byte, 4) - binary.BigEndian.PutUint16(bs, p[0]) - binary.BigEndian.PutUint16(bs[2:], p[1]) - return bs -} - -// Major returns major version. -func (p Version) Major() uint16 { - return p[0] -} - -// Minor returns minor version. -func (p Version) Minor() uint16 { - return p[1] -} - -// WriteTo write raw version bytes to a writer. -func (p Version) WriteTo(w io.Writer) (n int64, err error) { - var wrote int - wrote, err = w.Write(p.Bytes()) - if err == nil { - n += int64(wrote) - } - return -} - -func (p Version) String() string { - return fmt.Sprintf("%d.%d", p[0], p[1]) -} diff --git a/internal/fragmentation/joiner.go b/internal/fragmentation/joiner.go index 92cf778..0043ca2 100644 --- a/internal/fragmentation/joiner.go +++ b/internal/fragmentation/joiner.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" ) var errNoFrameInJoiner = errors.New("no frames in current joiner") @@ -14,15 +14,15 @@ type implJoiner struct { root *list.List // list of HeaderAndPayload } -func (p *implJoiner) First() framing.Frame { +func (p *implJoiner) First() core.Frame { first := p.root.Front() if first == nil { panic(errNoFrameInJoiner) } - return first.Value.(framing.Frame) + return first.Value.(core.Frame) } -func (p *implJoiner) Header() framing.FrameHeader { +func (p *implJoiner) Header() core.FrameHeader { return p.First().Header() } @@ -34,7 +34,7 @@ func (p *implJoiner) String() string { func (p *implJoiner) Metadata() (metadata []byte, ok bool) { for cur := p.root.Front(); cur != nil; cur = cur.Next() { f := cur.Value.(HeaderAndPayload) - if !f.Header().Flag().Check(framing.FlagMetadata) { + if !f.Header().Flag().Check(core.FlagMetadata) { break } if m, has := f.Metadata(); has { @@ -74,6 +74,6 @@ func (p *implJoiner) DataUTF8() (data string) { func (p *implJoiner) Push(elem HeaderAndPayload) (end bool) { p.root.PushBack(elem) h := elem.Header() - end = !h.Flag().Check(framing.FlagFollow) + end = !h.Flag().Check(core.FlagFollow) return } diff --git a/internal/fragmentation/joiner_test.go b/internal/fragmentation/joiner_test.go index ffb121f..1ea12c5 100644 --- a/internal/fragmentation/joiner_test.go +++ b/internal/fragmentation/joiner_test.go @@ -5,23 +5,24 @@ import ( "log" "testing" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" ) func TestFragmentPayload(t *testing.T) { const totals = 10 const sid = uint32(1) - fr := NewJoiner(framing.NewFramePayload(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow, framing.FlagMetadata)) + fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), core.FlagFollow|core.FlagMetadata)) for i := 0; i < totals; i++ { data := fmt.Sprintf("(data%04d)", i) - var frame *framing.FramePayload + var frame *framing.PayloadFrame if i < 3 { meta := fmt.Sprintf("(meta%04d)", i) - frame = framing.NewFramePayload(sid, []byte(data), []byte(meta), framing.FlagFollow, framing.FlagMetadata) + frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), core.FlagFollow|core.FlagMetadata) } else if i != totals-1 { - frame = framing.NewFramePayload(sid, []byte(data), nil, framing.FlagFollow) + frame = framing.NewPayloadFrame(sid, []byte(data), nil, core.FlagFollow) } else { - frame = framing.NewFramePayload(sid, []byte(data), nil) + frame = framing.NewPayloadFrame(sid, []byte(data), nil, 0) } fr.Push(frame) } diff --git a/internal/fragmentation/splitter.go b/internal/fragmentation/splitter.go index 299612e..a22b0a4 100644 --- a/internal/fragmentation/splitter.go +++ b/internal/fragmentation/splitter.go @@ -1,24 +1,29 @@ package fragmentation import ( + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) -type splitResult struct { - f framing.FrameFlag - b *common.ByteBuff +// HandleSplitResult is callback for fragmentation result. +type HandleSplitResult = func(index int, result SplitResult) + +// SplitResult defines fragmentation result struct. +type SplitResult struct { + Flag core.FrameFlag + Metadata []byte + Data []byte } // Split split data and metadata in frame. -func Split(mtu int, data []byte, metadata []byte, onFrame func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { +func Split(mtu int, data []byte, metadata []byte, onFrame HandleSplitResult) { SplitSkip(mtu, 0, data, metadata, onFrame) } // SplitSkip skip some bytes and split data and metadata in frame. -func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { - ch := make(chan splitResult, 3) - go func(mtu int, skip int, data []byte, metadata []byte, ch chan splitResult) { +func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame HandleSplitResult) { + ch := make(chan SplitResult, 3) + go func(mtu int, skip int, data []byte, metadata []byte, ch chan SplitResult) { defer func() { close(ch) }() @@ -28,8 +33,7 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx var follow bool for { bf = common.NewByteBuff() - var wroteM int - left := mtu - framing.HeaderLen + left := mtu - core.FrameHeaderLen if idx == 0 && skip > 0 { left -= skip for i := 0; i < skip; i++ { @@ -41,51 +45,35 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx hasMetadata := cursor1 < lenM if hasMetadata { left -= 3 - // write metadata length placeholder - if err := bf.WriteUint24(0); err != nil { - panic(err) - } } begin1, begin2 := cursor1, cursor2 for wrote := 0; wrote < left; wrote++ { if cursor1 < lenM { - wroteM++ cursor1++ } else if cursor2 < lenD { cursor2++ } } - if _, err := bf.Write(metadata[begin1:cursor1]); err != nil { - panic(err) - } - if _, err := bf.Write(data[begin2:cursor2]); err != nil { - panic(err) - } + curMetadata := metadata[begin1:cursor1] + curData := data[begin2:cursor2] follow = cursor1+cursor2 < lenM+lenD - var fg framing.FrameFlag + var flag core.FrameFlag if follow { - fg |= framing.FlagFollow + flag |= core.FlagFollow } else { - fg &= ^framing.FlagFollow + flag &= ^core.FlagFollow } - if wroteM > 0 { - // set metadata length - x := common.NewUint24(wroteM) - for i := 0; i < len(x); i++ { - if idx == 0 { - bf.Bytes()[i+skip] = x[i] - } else { - bf.Bytes()[i] = x[i] - } - } - fg |= framing.FlagMetadata + if hasMetadata { + // metadata + flag |= core.FlagMetadata } else { // non-metadata - fg &= ^framing.FlagMetadata + flag &= ^core.FlagMetadata } - ch <- splitResult{ - f: fg, - b: bf, + ch <- SplitResult{ + Flag: flag, + Metadata: curMetadata, + Data: curData, } if !follow { break @@ -96,7 +84,7 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx var idx int for v := range ch { if onFrame != nil { - onFrame(idx, v.f, v.b) + onFrame(idx, v) } idx++ } diff --git a/internal/fragmentation/splitter_benchmark_test.go b/internal/fragmentation/splitter_benchmark_test.go index 8d1d67d..2828d98 100644 --- a/internal/fragmentation/splitter_benchmark_test.go +++ b/internal/fragmentation/splitter_benchmark_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) func BenchmarkToFragments(b *testing.B) { @@ -13,7 +12,7 @@ func BenchmarkToFragments(b *testing.B) { data := []byte(common.RandAlphanumeric(4 * 1024 * 1024)) metadata := []byte(common.RandAlphanumeric(1024 * 1024)) - fn := func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { + fn := func(idx int, result SplitResult) { } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { diff --git a/internal/fragmentation/splitter_test.go b/internal/fragmentation/splitter_test.go index e45725b..783fc5f 100644 --- a/internal/fragmentation/splitter_test.go +++ b/internal/fragmentation/splitter_test.go @@ -3,8 +3,9 @@ package fragmentation import ( "testing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) @@ -23,17 +24,14 @@ func TestSplitter_Split(t *testing.T) { } func split2joiner(mtu int, data, metadata []byte) (joiner Joiner, err error) { - fn := func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { + fn := func(idx int, result SplitResult) { + sid := uint32(77778888) if idx == 0 { - h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, framing.FlagComplete|fg) - joiner = NewJoiner(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - }) + f := framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, core.FlagComplete|result.Flag) + joiner = NewJoiner(f) } else { - h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, fg) - joiner.Push(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - }) + f := framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag) + joiner.Push(f) } } Split(mtu, data, metadata, fn) diff --git a/internal/fragmentation/fragmentation.go b/internal/fragmentation/types.go similarity index 86% rename from internal/fragmentation/fragmentation.go rename to internal/fragmentation/types.go index 058a1b7..3110185 100644 --- a/internal/fragmentation/fragmentation.go +++ b/internal/fragmentation/types.go @@ -4,14 +4,14 @@ import ( "container/list" "fmt" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/payload" ) const ( // MinFragment is minimum fragment size in bytes. - MinFragment = framing.HeaderLen + 4 + MinFragment = core.FrameHeaderLen + 4 // MaxFragment is minimum fragment size in bytes. MaxFragment = common.MaxUint24 - 3 ) @@ -21,15 +21,15 @@ var errInvalidFragmentLen = fmt.Errorf("invalid fragment: [%d,%d]", MinFragment, // HeaderAndPayload is Payload which having a FrameHeader. type HeaderAndPayload interface { payload.Payload - // Header returns a header of frame. - Header() framing.FrameHeader + // FrameHeader returns a header of frame. + Header() core.FrameHeader } // Joiner is used to join frames to a payload. type Joiner interface { HeaderAndPayload // First returns the first frame. - First() framing.Frame + First() core.Frame // Push append a new frame and returns true if joiner is end. Push(elem HeaderAndPayload) (end bool) } diff --git a/internal/framing/frame.go b/internal/framing/frame.go deleted file mode 100644 index 3cb297e..0000000 --- a/internal/framing/frame.go +++ /dev/null @@ -1,277 +0,0 @@ -package framing - -import ( - "errors" - "fmt" - "io" - "strings" - - "github.com/rsocket/rsocket-go/internal/common" -) - -var errIncompleteFrame = errors.New("incomplete frame") - -// FrameType is type of frame. -type FrameType uint8 - -// All frame types -const ( - FrameTypeReserved FrameType = 0x00 - FrameTypeSetup FrameType = 0x01 - FrameTypeLease FrameType = 0x02 - FrameTypeKeepalive FrameType = 0x03 - FrameTypeRequestResponse FrameType = 0x04 - FrameTypeRequestFNF FrameType = 0x05 - FrameTypeRequestStream FrameType = 0x06 - FrameTypeRequestChannel FrameType = 0x07 - FrameTypeRequestN FrameType = 0x08 - FrameTypeCancel FrameType = 0x09 - FrameTypePayload FrameType = 0x0A - FrameTypeError FrameType = 0x0B - FrameTypeMetadataPush FrameType = 0x0C - FrameTypeResume FrameType = 0x0D - FrameTypeResumeOK FrameType = 0x0E - FrameTypeExt FrameType = 0x3F -) - -func (f FrameType) String() string { - switch f { - case FrameTypeReserved: - return "RESERVED" - case FrameTypeSetup: - return "SETUP" - case FrameTypeLease: - return "LEASE" - case FrameTypeKeepalive: - return "KEEPALIVE" - case FrameTypeRequestResponse: - return "REQUEST_RESPONSE" - case FrameTypeRequestFNF: - return "REQUEST_FNF" - case FrameTypeRequestStream: - return "REQUEST_STREAM" - case FrameTypeRequestChannel: - return "REQUEST_CHANNEL" - case FrameTypeRequestN: - return "REQUEST_N" - case FrameTypeCancel: - return "CANCEL" - case FrameTypePayload: - return "PAYLOAD" - case FrameTypeError: - return "ERROR" - case FrameTypeMetadataPush: - return "METADATA_PUSH" - case FrameTypeResume: - return "RESUME" - case FrameTypeResumeOK: - return "RESUME_OK" - case FrameTypeExt: - return "EXT" - default: - return "UNKNOWN" - } -} - -// FrameFlag is flag of frame. -type FrameFlag uint16 - -func (f FrameFlag) String() string { - foo := make([]string, 0) - if f.Check(FlagNext) { - foo = append(foo, "N") - } - if f.Check(FlagComplete) { - foo = append(foo, "CL") - } - if f.Check(FlagFollow) { - foo = append(foo, "FRS") - } - if f.Check(FlagMetadata) { - foo = append(foo, "M") - } - if f.Check(FlagIgnore) { - foo = append(foo, "I") - } - return strings.Join(foo, "|") -} - -// All frame flags -const ( - FlagNext FrameFlag = 1 << (5 + iota) - FlagComplete - FlagFollow - FlagMetadata - FlagIgnore - - FlagResume = FlagFollow - FlagLease = FlagComplete - FlagRespond = FlagFollow -) - -// Check returns true if mask exists. -func (f FrameFlag) Check(mask FrameFlag) bool { - return mask&f == mask -} - -func newFlags(flags ...FrameFlag) FrameFlag { - var fg FrameFlag - for _, it := range flags { - fg |= it - } - return fg -} - -// Frame is a single message containing a request, response, or protocol processing. -type Frame interface { - fmt.Stringer - io.WriterTo - // Header returns frame FrameHeader. - Header() FrameHeader - // Body returns body of frame. - Body() *common.ByteBuff - // Len returns length of frame. - Len() int - // Validate returns error if frame is invalid. - Validate() error - // SetHeader set frame header. - SetHeader(h FrameHeader) - // SetBody set frame body. - SetBody(body *common.ByteBuff) - // Bytes encodes and returns frame in bytes. - Bytes() []byte - // CanResume returns true if frame supports resume. - CanResume() bool - // Done marks current frame has been sent. - Done() (closed bool) - // DoneNotify notifies when frame done. - DoneNotify() <-chan struct{} -} - -// BaseFrame is basic frame implementation. -type BaseFrame struct { - header FrameHeader - body *common.ByteBuff - done chan struct{} -} - -// Done can be invoked when a frame has been been processed. -func (p *BaseFrame) Done() (closed bool) { - defer func() { - if e := recover(); e != nil { - closed = true - } - }() - close(p.done) - return -} - -// DoneNotify notify when frame has been done. -func (p *BaseFrame) DoneNotify() <-chan struct{} { - return p.done -} - -// CanResume returns true if frame supports resume. -func (p *BaseFrame) CanResume() bool { - switch p.header.Type() { - case FrameTypeRequestChannel, FrameTypeRequestStream, FrameTypeRequestResponse, FrameTypeRequestFNF, FrameTypeRequestN, FrameTypeCancel, FrameTypeError, FrameTypePayload: - return true - default: - return false - } -} - -// SetBody set frame body. -func (p *BaseFrame) SetBody(body *common.ByteBuff) { - p.body = body -} - -// Body returns frame body. -func (p *BaseFrame) Body() *common.ByteBuff { - return p.body -} - -// Header returns frame header. -func (p *BaseFrame) Header() FrameHeader { - return p.header -} - -// SetHeader set frame header. -func (p *BaseFrame) SetHeader(h FrameHeader) { - p.header = h -} - -// Len returns length of frame. -func (p *BaseFrame) Len() int { - return HeaderLen + p.body.Len() -} - -// WriteTo write frame to writer. -func (p *BaseFrame) WriteTo(w io.Writer) (n int64, err error) { - var wrote int64 - wrote, err = p.header.WriteTo(w) - if err != nil { - return - } - n += wrote - wrote, err = p.body.WriteTo(w) - if err != nil { - return - } - n += wrote - return -} - -// Bytes returns frame in bytes. -func (p *BaseFrame) Bytes() []byte { - ret := make([]byte, HeaderLen+p.body.Len()) - copy(ret[:HeaderLen], p.header[:]) - copy(ret[HeaderLen:], p.body.Bytes()) - return ret -} - -// NewBaseFrame returns a new BaseFrame. -func NewBaseFrame(h FrameHeader, body *common.ByteBuff) (f *BaseFrame) { - f = &BaseFrame{ - header: h, - body: body, - done: make(chan struct{}), - } - return -} - -func (p *BaseFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { - raw := p.body.Bytes() - if offset > 0 { - raw = raw[offset:] - } - hasMetadata = p.header.Flag().Check(FlagMetadata) - if !hasMetadata { - return - } - if len(raw) < 3 { - n = -1 - } else { - n = common.NewUint24Bytes(raw).AsInt() - } - return -} - -func (p *BaseFrame) trySliceMetadata(offset int) ([]byte, bool) { - n, ok := p.trySeekMetadataLen(offset) - if !ok || n < 0 { - return nil, false - } - return p.body.Bytes()[offset+3 : offset+3+n], true -} - -func (p *BaseFrame) trySliceData(offset int) []byte { - n, ok := p.trySeekMetadataLen(offset) - if !ok { - return p.body.Bytes()[offset:] - } - if n < 0 { - return nil - } - return p.body.Bytes()[offset+n+3:] -} diff --git a/internal/framing/frame_cancel.go b/internal/framing/frame_cancel.go deleted file mode 100644 index 1849440..0000000 --- a/internal/framing/frame_cancel.go +++ /dev/null @@ -1,28 +0,0 @@ -package framing - -import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -// FrameCancel is frame of cancel. -type FrameCancel struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameCancel) Validate() (err error) { - return -} - -func (p *FrameCancel) String() string { - return fmt.Sprintf("FrameCancel{%s}", p.header) -} - -// NewFrameCancel returns a new cancel frame. -func NewFrameCancel(sid uint32) *FrameCancel { - return &FrameCancel{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeCancel), common.NewByteBuff()), - } -} diff --git a/internal/framing/frame_error.go b/internal/framing/frame_error.go deleted file mode 100644 index 19812bc..0000000 --- a/internal/framing/frame_error.go +++ /dev/null @@ -1,64 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - errCodeLen = 4 - errDataOff = errCodeLen - minErrorFrameLen = errCodeLen -) - -// FrameError is error frame. -type FrameError struct { - *BaseFrame -} - -func (p *FrameError) String() string { - return fmt.Sprintf("FrameError{%s,code=%s,data=%s}", p.header, p.ErrorCode(), p.ErrorData()) -} - -// Validate returns error if frame is invalid. -func (p *FrameError) Validate() (err error) { - if p.Len() < minErrorFrameLen { - err = errIncompleteFrame - } - return -} - -func (p *FrameError) Error() string { - return fmt.Sprintf("%s: %s", p.ErrorCode(), string(p.ErrorData())) -} - -// ErrorCode returns error code. -func (p *FrameError) ErrorCode() common.ErrorCode { - v := binary.BigEndian.Uint32(p.body.Bytes()) - return common.ErrorCode(v) -} - -// ErrorData returns error data bytes. -func (p *FrameError) ErrorData() []byte { - return p.body.Bytes()[errDataOff:] -} - -// NewFrameError returns a new error frame. -func NewFrameError(streamID uint32, code common.ErrorCode, data []byte) *FrameError { - bf := common.NewByteBuff() - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], uint32(code)) - if _, err := bf.Write(b4[:]); err != nil { - - panic(err) - } - if _, err := bf.Write(data); err != nil { - - panic(err) - } - return &FrameError{ - NewBaseFrame(NewFrameHeader(streamID, FrameTypeError), bf), - } -} diff --git a/internal/framing/frame_fnf.go b/internal/framing/frame_fnf.go deleted file mode 100644 index 9b68584..0000000 --- a/internal/framing/frame_fnf.go +++ /dev/null @@ -1,67 +0,0 @@ -package framing - -import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -// FrameFNF is fire and forget frame. -type FrameFNF struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameFNF) Validate() (err error) { - return -} - -func (p *FrameFNF) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameFNF{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) -} - -// Metadata returns metadata bytes. -func (p *FrameFNF) Metadata() ([]byte, bool) { - return p.trySliceMetadata(0) -} - -// Data returns data bytes. -func (p *FrameFNF) Data() []byte { - return p.trySliceData(0) -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameFNF) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameFNF) DataUTF8() string { - return string(p.Data()) -} - -// NewFrameFNF returns a new fire and forget frame. -func NewFrameFNF(sid uint32, data, metadata []byte, flags ...FrameFlag) *FrameFNF { - fg := newFlags(flags...) - bf := common.NewByteBuff() - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if _, err := bf.Write(data); err != nil { - panic(err) - } - return &FrameFNF{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestFNF, fg), bf), - } -} diff --git a/internal/framing/frame_keepalive.go b/internal/framing/frame_keepalive.go deleted file mode 100644 index ab7192c..0000000 --- a/internal/framing/frame_keepalive.go +++ /dev/null @@ -1,62 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - lastRecvPosLen = 8 - minKeepaliveFrameLen = lastRecvPosLen -) - -// FrameKeepalive is keepalive frame. -type FrameKeepalive struct { - *BaseFrame -} - -func (p *FrameKeepalive) String() string { - return fmt.Sprintf("FrameKeepalive{%s,lastReceivedPosition=%d,data=%s}", p.header, p.LastReceivedPosition(), string(p.Data())) -} - -// Validate returns error if frame is invalid. -func (p *FrameKeepalive) Validate() (err error) { - if p.body.Len() < minKeepaliveFrameLen { - err = errIncompleteFrame - } - return -} - -// LastReceivedPosition returns last received position. -func (p *FrameKeepalive) LastReceivedPosition() uint64 { - return binary.BigEndian.Uint64(p.body.Bytes()) -} - -// Data returns data bytes. -func (p *FrameKeepalive) Data() []byte { - return p.body.Bytes()[lastRecvPosLen:] -} - -// NewFrameKeepalive returns a new keepalive frame. -func NewFrameKeepalive(position uint64, data []byte, respond bool) *FrameKeepalive { - var fg FrameFlag - if respond { - fg |= FlagRespond - } - bf := common.NewByteBuff() - var b8 [8]byte - binary.BigEndian.PutUint64(b8[:], position) - if _, err := bf.Write(b8[:]); err != nil { - panic(err) - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FrameKeepalive{ - NewBaseFrame(NewFrameHeader(0, FrameTypeKeepalive, fg), bf), - } -} diff --git a/internal/framing/frame_lease.go b/internal/framing/frame_lease.go deleted file mode 100644 index 9b1c240..0000000 --- a/internal/framing/frame_lease.go +++ /dev/null @@ -1,68 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - "time" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - ttlLen = 4 - reqOff = ttlLen - reqLen = 4 - minLeaseFrame = ttlLen + reqLen -) - -// FrameLease is lease frame. -type FrameLease struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameLease) Validate() (err error) { - if p.body.Len() < minLeaseFrame { - err = errIncompleteFrame - } - return -} - -func (p *FrameLease) String() string { - return fmt.Sprintf("FrameLease{%s,timeToLive=%s,numberOfRequests=%d,metadata=%s}", - p.header, p.TimeToLive(), p.NumberOfRequests(), string(p.Metadata())) -} - -// TimeToLive returns time to live duration. -func (p *FrameLease) TimeToLive() time.Duration { - v := binary.BigEndian.Uint32(p.body.Bytes()) - return time.Millisecond * time.Duration(v) -} - -// NumberOfRequests returns number of requests. -func (p *FrameLease) NumberOfRequests() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()[reqOff:]) -} - -// Metadata returns metadata bytes. -func (p *FrameLease) Metadata() []byte { - if !p.header.Flag().Check(FlagMetadata) { - return nil - } - return p.body.Bytes()[8:] -} - -func NewFrameLease(ttl time.Duration, n uint32, metadata []byte) *FrameLease { - bf := common.NewByteBuff() - if err := binary.Write(bf, binary.BigEndian, uint32(ttl.Milliseconds())); err != nil { - panic(err) - } - if err := binary.Write(bf, binary.BigEndian, n); err != nil { - panic(err) - } - var fg FrameFlag - if len(metadata) > 0 { - fg |= FlagMetadata - } - return &FrameLease{NewBaseFrame(NewFrameHeader(0, FrameTypeLease, fg), bf)} -} diff --git a/internal/framing/frame_metadata_push.go b/internal/framing/frame_metadata_push.go deleted file mode 100644 index 495cf6c..0000000 --- a/internal/framing/frame_metadata_push.go +++ /dev/null @@ -1,59 +0,0 @@ -package framing - -import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -var defaultFrameMetadataPushHeader = NewFrameHeader(0, FrameTypeMetadataPush, FlagMetadata) - -// FrameMetadataPush is metadata push frame. -type FrameMetadataPush struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameMetadataPush) Validate() (err error) { - return -} - -func (p *FrameMetadataPush) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameMetadataPush{%s,metadata=%s}", p.header, m) -} - -// Metadata returns metadata bytes. -func (p *FrameMetadataPush) Metadata() ([]byte, bool) { - return p.body.Bytes(), true -} - -// Data returns data bytes. -func (p *FrameMetadataPush) Data() []byte { - return nil -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameMetadataPush) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameMetadataPush) DataUTF8() (data string) { - return -} - -// NewFrameMetadataPush returns a new metadata push frame. -func NewFrameMetadataPush(metadata []byte) *FrameMetadataPush { - bf := common.NewByteBuff() - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - return &FrameMetadataPush{ - NewBaseFrame(defaultFrameMetadataPushHeader, bf), - } -} diff --git a/internal/framing/frame_payload.go b/internal/framing/frame_payload.go deleted file mode 100644 index 531b0c9..0000000 --- a/internal/framing/frame_payload.go +++ /dev/null @@ -1,69 +0,0 @@ -package framing - -import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -// FramePayload is payload frame. -type FramePayload struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FramePayload) Validate() (err error) { - return -} - -func (p *FramePayload) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FramePayload{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) -} - -// Metadata returns metadata bytes. -func (p *FramePayload) Metadata() ([]byte, bool) { - return p.trySliceMetadata(0) -} - -// Data returns data bytes. -func (p *FramePayload) Data() []byte { - return p.trySliceData(0) -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FramePayload) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FramePayload) DataUTF8() string { - return string(p.Data()) -} - -// NewFramePayload returns a new payload frame. -func NewFramePayload(id uint32, data, metadata []byte, flags ...FrameFlag) *FramePayload { - fg := newFlags(flags...) - bf := common.NewByteBuff() - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FramePayload{ - NewBaseFrame(NewFrameHeader(id, FrameTypePayload, fg), bf), - } -} diff --git a/internal/framing/frame_payload_test.go b/internal/framing/frame_payload_test.go deleted file mode 100644 index 4346a34..0000000 --- a/internal/framing/frame_payload_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package framing - -import ( - "bytes" - "testing" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestFramePayload_Basic(t *testing.T) { - metadata := []byte("hello") - data := []byte("world") - f := NewFramePayload(123, data, metadata) - - metadata1, _ := f.Metadata() - assert.Equal(t, metadata, metadata1, "metadata failed") - assert.Equal(t, data, f.Data(), "data failed") - - bf := &bytes.Buffer{} - _, _ = f.WriteTo(bf) - bs := bf.Bytes() - - bb := common.NewByteBuff() - _, _ = bb.Write(bs[HeaderLen:]) - f2 := &FramePayload{ - NewBaseFrame(ParseFrameHeader(bs[:HeaderLen]), bb), - } - - metadata2, _ := f2.Metadata() - assert.Equal(t, metadata, metadata2, "metadata failed 2") - assert.Equal(t, data, f2.Data(), "data failed 2") - assert.Equal(t, f2.header[:], f2.header[:]) -} diff --git a/internal/framing/frame_request_channel.go b/internal/framing/frame_request_channel.go deleted file mode 100644 index 9e25cea..0000000 --- a/internal/framing/frame_request_channel.go +++ /dev/null @@ -1,89 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - initReqLen = 4 - minRequestChannelFrameLen = initReqLen -) - -// FrameRequestChannel is frame for RequestChannel. -type FrameRequestChannel struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameRequestChannel) Validate() (err error) { - if p.body.Len() < minRequestChannelFrameLen { - err = errIncompleteFrame - } - return -} - -func (p *FrameRequestChannel) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestChannel{%s,data=%s,metadata=%s,initialRequestN=%d}", - p.header, p.DataUTF8(), m, p.InitialRequestN()) -} - -// InitialRequestN returns initial N. -func (p *FrameRequestChannel) InitialRequestN() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) -} - -// Metadata returns metadata bytes. -func (p *FrameRequestChannel) Metadata() ([]byte, bool) { - return p.trySliceMetadata(initReqLen) -} - -// Data returns data bytes. -func (p *FrameRequestChannel) Data() []byte { - return p.trySliceData(initReqLen) -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestChannel) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameRequestChannel) DataUTF8() string { - return string(p.Data()) -} - -// NewFrameRequestChannel returns a new RequestChannel frame. -func NewFrameRequestChannel(sid uint32, n uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestChannel { - fg := newFlags(flags...) - bf := common.NewByteBuff() - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], n) - if _, err := bf.Write(b4[:]); err != nil { - panic(err) - } - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FrameRequestChannel{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestChannel, fg), bf), - } -} diff --git a/internal/framing/frame_request_n.go b/internal/framing/frame_request_n.go deleted file mode 100644 index 819c2b8..0000000 --- a/internal/framing/frame_request_n.go +++ /dev/null @@ -1,50 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - reqNLen = 4 - minRequestNFrameLen = reqNLen -) - -// FrameRequestN is RequestN frame. -type FrameRequestN struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameRequestN) Validate() (err error) { - if p.body.Len() < minRequestNFrameLen { - err = errIncompleteFrame - } - return -} - -func (p *FrameRequestN) String() string { - return fmt.Sprintf("FrameRequestN{%s,n=%d}", p.header, p.N()) -} - -// N returns N in RequestN. -func (p *FrameRequestN) N() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) -} - -// NewFrameRequestN returns a new RequestN frame. -func NewFrameRequestN(sid, n uint32, flags ...FrameFlag) *FrameRequestN { - fg := newFlags(flags...) - bf := common.NewByteBuff() - - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], n) - if _, err := bf.Write(b4[:]); err != nil { - panic(err) - } - return &FrameRequestN{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestN, fg), bf), - } -} diff --git a/internal/framing/frame_request_response.go b/internal/framing/frame_request_response.go deleted file mode 100644 index fbfb508..0000000 --- a/internal/framing/frame_request_response.go +++ /dev/null @@ -1,69 +0,0 @@ -package framing - -import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -// FrameRequestResponse is frame for requesting single response. -type FrameRequestResponse struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameRequestResponse) Validate() (err error) { - return -} - -func (p *FrameRequestResponse) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestResponse{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) -} - -// Metadata returns metadata bytes. -func (p *FrameRequestResponse) Metadata() ([]byte, bool) { - return p.trySliceMetadata(0) -} - -// Data returns data bytes. -func (p *FrameRequestResponse) Data() []byte { - return p.trySliceData(0) -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestResponse) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameRequestResponse) DataUTF8() string { - return string(p.Data()) -} - -// NewFrameRequestResponse returns a new RequestResponse frame. -func NewFrameRequestResponse(id uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestResponse { - fg := newFlags(flags...) - bf := common.NewByteBuff() - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FrameRequestResponse{ - NewBaseFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg), bf), - } -} diff --git a/internal/framing/frame_request_stream.go b/internal/framing/frame_request_stream.go deleted file mode 100644 index a274d81..0000000 --- a/internal/framing/frame_request_stream.go +++ /dev/null @@ -1,88 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - minRequestStreamFrameLen = initReqLen -) - -// FrameRequestStream is frame for requesting a completable stream. -type FrameRequestStream struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameRequestStream) Validate() (err error) { - if p.body.Len() < minRequestStreamFrameLen { - err = errIncompleteFrame - } - return -} - -func (p *FrameRequestStream) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestStream{%s,data=%s,metadata=%s,initialRequestN=%d}", - p.header, p.DataUTF8(), m, p.InitialRequestN()) -} - -// InitialRequestN returns initial request N. -func (p *FrameRequestStream) InitialRequestN() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) -} - -// Metadata returns metadata bytes. -func (p *FrameRequestStream) Metadata() ([]byte, bool) { - return p.trySliceMetadata(4) -} - -// Data returns data bytes. -func (p *FrameRequestStream) Data() []byte { - return p.trySliceData(4) -} - -// MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestStream) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameRequestStream) DataUTF8() string { - return string(p.Data()) -} - -// NewFrameRequestStream returns a new request stream frame. -func NewFrameRequestStream(id uint32, n uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestStream { - fg := newFlags(flags...) - bf := common.NewByteBuff() - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], n) - if _, err := bf.Write(b4[:]); err != nil { - panic(err) - } - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FrameRequestStream{ - NewBaseFrame(NewFrameHeader(id, FrameTypeRequestStream, fg), bf), - } -} diff --git a/internal/framing/frame_resume.go b/internal/framing/frame_resume.go deleted file mode 100644 index 1b2fe77..0000000 --- a/internal/framing/frame_resume.go +++ /dev/null @@ -1,87 +0,0 @@ -package framing - -import ( - "encoding/binary" - "errors" - "fmt" - "math" - - "github.com/rsocket/rsocket-go/internal/common" -) - -var errResumeTokenTooLarge = errors.New("max length of resume token is 65535") - -// FrameResume represents a frame of Resume. -type FrameResume struct { - *BaseFrame -} - -func (p *FrameResume) String() string { - return fmt.Sprintf( - "FrameResume{%s,version=%s,token=0x%02x,lastReceivedServerPosition=%d,firstAvailableClientPosition=%d}", - p.header, p.Version(), p.Token(), p.LastReceivedServerPosition(), p.FirstAvailableClientPosition(), - ) -} - -// Validate validate current frame. -func (p *FrameResume) Validate() (err error) { - return -} - -// Version returns version. -func (p *FrameResume) Version() common.Version { - raw := p.body.Bytes() - major := binary.BigEndian.Uint16(raw) - minor := binary.BigEndian.Uint16(raw[2:]) - return [2]uint16{major, minor} -} - -// Token returns resume token in bytes. -func (p *FrameResume) Token() []byte { - raw := p.body.Bytes() - tokenLen := binary.BigEndian.Uint16(raw[4:6]) - return raw[6 : 6+tokenLen] -} - -// LastReceivedServerPosition returns last received server position. -func (p *FrameResume) LastReceivedServerPosition() uint64 { - raw := p.body.Bytes() - offset := 6 + binary.BigEndian.Uint16(raw[4:6]) - return binary.BigEndian.Uint64(raw[offset:]) -} - -// FirstAvailableClientPosition returns first available client position. -func (p *FrameResume) FirstAvailableClientPosition() uint64 { - raw := p.body.Bytes() - offset := 6 + binary.BigEndian.Uint16(raw[4:6]) + 8 - return binary.BigEndian.Uint64(raw[offset:]) -} - -// NewFrameResume creates a new frame of Resume. -func NewFrameResume(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *FrameResume { - n := len(token) - if n > math.MaxUint16 { - panic(errResumeTokenTooLarge) - } - bf := common.NewByteBuff() - if _, err := bf.Write(version.Bytes()); err != nil { - panic(err) - } - if err := binary.Write(bf, binary.BigEndian, uint16(n)); err != nil { - panic(err) - } - if n > 0 { - if _, err := bf.Write(token); err != nil { - panic(err) - } - } - if err := binary.Write(bf, binary.BigEndian, lastReceivedServerPosition); err != nil { - panic(err) - } - if err := binary.Write(bf, binary.BigEndian, firstAvailableClientPosition); err != nil { - panic(err) - } - return &FrameResume{ - NewBaseFrame(NewFrameHeader(0, FrameTypeResume), bf), - } -} diff --git a/internal/framing/frame_resume_ok.go b/internal/framing/frame_resume_ok.go deleted file mode 100644 index 32125d6..0000000 --- a/internal/framing/frame_resume_ok.go +++ /dev/null @@ -1,42 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" -) - -// FrameResumeOK represents a frame of ResumeOK. -type FrameResumeOK struct { - *BaseFrame -} - -func (p *FrameResumeOK) String() string { - return fmt.Sprintf("FrameResumeOK{%s,lastReceivedClientPosition=%d}", p.header, p.LastReceivedClientPosition()) -} - -// Validate validate current frame. -func (p *FrameResumeOK) Validate() (err error) { - return -} - -// LastReceivedClientPosition returns last received client position. -func (p *FrameResumeOK) LastReceivedClientPosition() uint64 { - raw := p.body.Bytes() - return binary.BigEndian.Uint64(raw) -} - -// NewResumeOK creates a new frame of ResumeOK. -func NewResumeOK(position uint64) *FrameResumeOK { - var b8 [8]byte - binary.BigEndian.PutUint64(b8[:], position) - bf := common.NewByteBuff() - _, err := bf.Write(b8[:]) - if err != nil { - panic(err) - } - return &FrameResumeOK{ - NewBaseFrame(NewFrameHeader(0, FrameTypeResumeOK), bf), - } -} diff --git a/internal/framing/frame_setup.go b/internal/framing/frame_setup.go deleted file mode 100644 index bcac659..0000000 --- a/internal/framing/frame_setup.go +++ /dev/null @@ -1,213 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - "time" - - "github.com/rsocket/rsocket-go/internal/common" -) - -const ( - versionLen = 4 - timeLen = 4 - tokenLen = 2 - metadataLen = 1 - dataLen = 1 - minSetupFrameLen = versionLen + timeLen*2 + tokenLen + metadataLen + dataLen -) - -// FrameSetup is sent by client to initiate protocol processing. -type FrameSetup struct { - *BaseFrame -} - -// Validate returns error if frame is invalid. -func (p *FrameSetup) Validate() (err error) { - if p.Len() < minSetupFrameLen { - err = errIncompleteFrame - } - return -} - -func (p *FrameSetup) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf( - "FrameSetup{%s,version=%s,keepaliveInterval=%s,keepaliveMaxLifetime=%s,token=0x%02x,dataMimeType=%s,metadataMimeType=%s,data=%s,metadata=%s}", - p.header, - p.Version(), - p.TimeBetweenKeepalive(), - p.MaxLifetime(), - p.Token(), - p.DataMimeType(), - p.MetadataMimeType(), - p.DataUTF8(), - m, - ) -} - -// Version returns version. -func (p *FrameSetup) Version() common.Version { - major := binary.BigEndian.Uint16(p.body.Bytes()) - minor := binary.BigEndian.Uint16(p.body.Bytes()[2:]) - return [2]uint16{major, minor} -} - -// TimeBetweenKeepalive returns keepalive interval duration. -func (p *FrameSetup) TimeBetweenKeepalive() time.Duration { - return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[4:])) -} - -// MaxLifetime returns keepalive max lifetime. -func (p *FrameSetup) MaxLifetime() time.Duration { - return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[8:])) -} - -// Token returns token of setup. -func (p *FrameSetup) Token() []byte { - if !p.header.Flag().Check(FlagResume) { - return nil - } - raw := p.body.Bytes() - tokenLength := binary.BigEndian.Uint16(raw[12:]) - return raw[14 : 14+tokenLength] -} - -// DataMimeType returns MIME of data. -func (p *FrameSetup) DataMimeType() (mime string) { - _, b := p.mime() - return string(b) -} - -// MetadataMimeType returns MIME of metadata. -func (p *FrameSetup) MetadataMimeType() string { - a, _ := p.mime() - return string(a) -} - -// Metadata returns metadata bytes. -func (p *FrameSetup) Metadata() ([]byte, bool) { - if !p.header.Flag().Check(FlagMetadata) { - return nil, false - } - offset := p.seekMIME() - m1, m2 := p.mime() - offset += 2 + len(m1) + len(m2) - return p.trySliceMetadata(offset) -} - -// Data returns data bytes. -func (p *FrameSetup) Data() []byte { - offset := p.seekMIME() - m1, m2 := p.mime() - offset += 2 + len(m1) + len(m2) - if !p.header.Flag().Check(FlagMetadata) { - return p.Body().Bytes()[offset:] - } - return p.trySliceData(offset) -} - -// MetadataUTF8 returns metadata as UTF8 string -func (p *FrameSetup) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() - if ok { - metadata = string(raw) - } - return -} - -// DataUTF8 returns data as UTF8 string. -func (p *FrameSetup) DataUTF8() string { - return string(p.Data()) -} - -func (p *FrameSetup) mime() (metadata []byte, data []byte) { - offset := p.seekMIME() - raw := p.body.Bytes() - l1 := int(raw[offset]) - offset++ - m1 := raw[offset : offset+l1] - offset += l1 - l2 := int(raw[offset]) - offset++ - m2 := raw[offset : offset+l2] - return m1, m2 -} - -func (p *FrameSetup) seekMIME() int { - if !p.header.Flag().Check(FlagResume) { - return 12 - } - l := binary.BigEndian.Uint16(p.body.Bytes()[12:]) - return 14 + int(l) -} - -// NewFrameSetup returns a new setup frame. -func NewFrameSetup( - version common.Version, - timeBetweenKeepalive, - maxLifetime time.Duration, - token []byte, - mimeMetadata []byte, - mimeData []byte, - data []byte, - metadata []byte, - lease bool, -) *FrameSetup { - var fg FrameFlag - bf := common.NewByteBuff() - if _, err := bf.Write(version.Bytes()); err != nil { - panic(err) - } - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], uint32(timeBetweenKeepalive.Nanoseconds()/1e6)) - if _, err := bf.Write(b4[:]); err != nil { - panic(err) - } - binary.BigEndian.PutUint32(b4[:], uint32(maxLifetime.Nanoseconds()/1e6)) - if _, err := bf.Write(b4[:]); err != nil { - panic(err) - } - if lease { - fg |= FlagLease - } - if len(token) > 0 { - fg |= FlagResume - binary.BigEndian.PutUint16(b4[:2], uint16(len(token))) - if _, err := bf.Write(b4[:2]); err != nil { - panic(err) - } - if _, err := bf.Write(token); err != nil { - panic(err) - } - } - if err := bf.WriteByte(byte(len(mimeMetadata))); err != nil { - panic(err) - } - if _, err := bf.Write(mimeMetadata); err != nil { - panic(err) - } - if err := bf.WriteByte(byte(len(mimeData))); err != nil { - panic(err) - } - if _, err := bf.Write(mimeData); err != nil { - panic(err) - } - if len(metadata) > 0 { - fg |= FlagMetadata - if err := bf.WriteUint24(len(metadata)); err != nil { - panic(err) - } - if _, err := bf.Write(metadata); err != nil { - panic(err) - } - } - if len(data) > 0 { - if _, err := bf.Write(data); err != nil { - panic(err) - } - } - return &FrameSetup{ - NewBaseFrame(NewFrameHeader(0, FrameTypeSetup, fg), bf), - } -} diff --git a/internal/framing/frame_setup_test.go b/internal/framing/frame_setup_test.go deleted file mode 100644 index 3c553d0..0000000 --- a/internal/framing/frame_setup_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package framing - -import ( - "bytes" - "log" - "testing" - "time" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestDecodeFrameSetup(t *testing.T) { - metadata := []byte("hello") - data := []byte("world") - mimeMetadata, mimeData := []byte("text/plain"), []byte("application/json") - token := []byte(common.RandAlphanumeric(16)) - setup := NewFrameSetup( - common.DefaultVersion, - 30*time.Second, - 90*time.Second, - token, - mimeMetadata, - mimeData, - data, - metadata, - false, - ) - log.Println(setup) - assert.Equal(t, "1.0", setup.Version().String()) - assert.True(t, bytes.Equal(token, setup.Token()), "bad token") - assert.Equal(t, string(mimeMetadata), setup.MetadataMimeType(), "bad mime metadata") - assert.Equal(t, string(mimeData), setup.DataMimeType(), "bad mime data") - m, _ := setup.Metadata() - assert.Equal(t, metadata, m, "bad metadata") - assert.Equal(t, data, setup.Data(), "bad data") -} diff --git a/internal/framing/frame_test.go b/internal/framing/frame_test.go deleted file mode 100644 index c5806ac..0000000 --- a/internal/framing/frame_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package framing - -import ( - "encoding/hex" - "log" - "testing" - "time" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestDecode_Payload(t *testing.T) { - //s := "000000012940000005776f726c6468656c6c6f" // go - //s := "00000001296000000966726f6d5f6a617661706f6e67" //java - - var all []string - all = append(all, "0000000004400001000000004e2000015f90126170706c69636174696f6e2f62696e617279126170706c69636174696f6e2f62696e617279") - all = append(all, "00000000090000000bb800000005") - all = append(all, "00000000090000001b5800000005") - all = append(all, "000000011100000000436c69656e74207265717565737420547565204f63742032322032303a31373a3333204353542032303139") - all = append(all, "00000001286053657276657220526573706f6e736520547565204f63742032322032303a31373a3333204353542032303139") - - for _, s := range all { - bs, err := hex.DecodeString(s) - assert.NoError(t, err, "bad bytes") - h := ParseFrameHeader(bs[:HeaderLen]) - //log.Println(h) - bf := common.NewByteBuff() - _, _ = bf.Write(bs[HeaderLen:]) - f, err := NewFromBase(NewBaseFrame(h, bf)) - assert.NoError(t, err, "decode failed") - log.Println(f) - } - - lease := NewFrameLease(3*time.Second, 5, nil) - log.Println("actual:", hex.EncodeToString(lease.Bytes())) - log.Println("should: 00000000090000000bb800000005") -} diff --git a/internal/framing/header.go b/internal/framing/header.go deleted file mode 100644 index 925f0e1..0000000 --- a/internal/framing/header.go +++ /dev/null @@ -1,64 +0,0 @@ -package framing - -import ( - "encoding/binary" - "fmt" - "io" -) - -const ( - // HeaderLen is len of header. - HeaderLen = 6 -) - -// FrameHeader is the header fo a RSocket frame. -// RSocket frames begin with a RSocket Frame Header. -// It includes StreamID, FrameType and Flags. -type FrameHeader [HeaderLen]byte - -func (p FrameHeader) String() string { - return fmt.Sprintf("FrameHeader{id=%d,type=%s,flag=%s}", p.StreamID(), p.Type(), p.Flag()) -} - -// WriteTo writes frame header to a writer. -func (p FrameHeader) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(p[:]) - return int64(n), err -} - -// StreamID returns StreamID. -func (p FrameHeader) StreamID() uint32 { - return binary.BigEndian.Uint32(p[:4]) -} - -// Type returns frame type. -func (p FrameHeader) Type() FrameType { - return FrameType((p.n() & 0xFC00) >> 10) -} - -// Flag returns flag of a frame. -func (p FrameHeader) Flag() FrameFlag { - return FrameFlag(p.n() & 0x03FF) -} - -func (p FrameHeader) n() uint16 { - return binary.BigEndian.Uint16(p[4:]) -} - -// NewFrameHeader returns a new frame header. -func NewFrameHeader(streamID uint32, frameType FrameType, flags ...FrameFlag) FrameHeader { - fg := newFlags(flags...) - var h [HeaderLen]byte - binary.BigEndian.PutUint32(h[:], streamID) - binary.BigEndian.PutUint16(h[4:], uint16(frameType)<<10|uint16(fg)) - return h - -} - -// ParseFrameHeader parse a header from bytes. -func ParseFrameHeader(bs []byte) FrameHeader { - _ = bs[HeaderLen-1] - var bb [HeaderLen]byte - copy(bb[:], bs[:HeaderLen]) - return bb -} diff --git a/internal/framing/header_test.go b/internal/framing/header_test.go deleted file mode 100644 index 44ee867..0000000 --- a/internal/framing/header_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package framing - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHeader_All(t *testing.T) { - h1 := NewFrameHeader(134, FrameTypePayload, FlagMetadata|FlagComplete|FlagNext) - h := ParseFrameHeader(h1[:]) - assert.Equal(t, h1.StreamID(), h.StreamID()) - assert.Equal(t, h1.Type(), h.Type()) - assert.Equal(t, h1.Flag(), h.Flag()) - - fmt.Println("streamID:", h.StreamID()) - fmt.Println("type:", h.Type()) - fmt.Println("flag:", h.Flag()) -} diff --git a/internal/framing/misc.go b/internal/framing/misc.go deleted file mode 100644 index 3a53c71..0000000 --- a/internal/framing/misc.go +++ /dev/null @@ -1,51 +0,0 @@ -package framing - -import ( - "github.com/rsocket/rsocket-go/internal/common" -) - -// CalcPayloadFrameSize returns payload frame size. -func CalcPayloadFrameSize(data, metadata []byte) int { - size := HeaderLen + len(data) - if n := len(metadata); n > 0 { - size += 3 + n - } - return size -} - -// NewFromBase creates a frame from a BaseFrame. -func NewFromBase(f *BaseFrame) (frame Frame, err error) { - switch f.header.Type() { - case FrameTypeSetup: - frame = &FrameSetup{BaseFrame: f} - case FrameTypeKeepalive: - frame = &FrameKeepalive{BaseFrame: f} - case FrameTypeRequestResponse: - frame = &FrameRequestResponse{BaseFrame: f} - case FrameTypeRequestFNF: - frame = &FrameFNF{BaseFrame: f} - case FrameTypeRequestStream: - frame = &FrameRequestStream{BaseFrame: f} - case FrameTypeRequestChannel: - frame = &FrameRequestChannel{BaseFrame: f} - case FrameTypeCancel: - frame = &FrameCancel{BaseFrame: f} - case FrameTypePayload: - frame = &FramePayload{BaseFrame: f} - case FrameTypeMetadataPush: - frame = &FrameMetadataPush{BaseFrame: f} - case FrameTypeError: - frame = &FrameError{BaseFrame: f} - case FrameTypeRequestN: - frame = &FrameRequestN{BaseFrame: f} - case FrameTypeLease: - frame = &FrameLease{BaseFrame: f} - case FrameTypeResume: - frame = &FrameResume{BaseFrame: f} - case FrameTypeResumeOK: - frame = &FrameResumeOK{BaseFrame: f} - default: - err = common.ErrInvalidFrame - } - return -} diff --git a/internal/session/manager.go b/internal/session/manager.go index efcfa6d..74575d5 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -15,8 +15,8 @@ type Manager struct { // Len returns size of session in current manager. func (p *Manager) Len() (n int) { p.locker.RLock() + defer p.locker.RUnlock() n = len(*p.h) - p.locker.RUnlock() return } @@ -31,32 +31,32 @@ func (p *Manager) Push(session *Session) { // Load returns session with custom token. func (p *Manager) Load(token []byte) (session *Session, ok bool) { p.locker.RLock() + defer p.locker.RUnlock() session, ok = p.m[(string)(token)] - p.locker.RUnlock() return } // Remove remove a session with custom token. func (p *Manager) Remove(token []byte) (session *Session, ok bool) { p.locker.Lock() + defer p.locker.Unlock() session, ok = p.m[(string)(token)] if ok && session.index > -1 { heap.Remove(p.h, session.index) delete(p.m, (string)(token)) session.index = -1 } - p.locker.Unlock() return } // Pop pop earliest session. func (p *Manager) Pop() (session *Session) { p.locker.Lock() + defer p.locker.Unlock() session = heap.Pop(p.h).(*Session) if session != nil { delete(p.m, (string)(session.Token())) } - p.locker.Unlock() return } diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 371bcba..300471e 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -1,28 +1,40 @@ -package session +package session_test import ( - "log" + "fmt" "testing" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/internal/session" "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" ) func TestSession(t *testing.T) { - manager := NewManager() - for i := 0; i < 3; i++ { - deadline := time.Now().Add(time.Duration(common.RandIntn(30)) * time.Second) - token := common.RandAlphanumeric(32) - manager.Push(NewSession(deadline, socket.NewServerResume(nil, []byte(token)))) + const total = 100 + var tokens []string + manager := session.NewManager() + for i := 0; i < total; i++ { + deadline := time.Now().Add(time.Duration(i+1) * time.Second) + token := fmt.Sprintf("token_%d", i) + tokens = append(tokens, token) + manager.Push(session.NewSession(deadline, socket.NewServerResume(nil, []byte(token)))) } - for _, value := range *(manager.h) { - session, ok := manager.Load(value.Token()) - log.Printf("session=%s,ok=%t\n", session, ok) + for _, token := range tokens { + s, ok := manager.Load([]byte(token)) + assert.True(t, ok) + assert.NotNil(t, s, "session is nil") } - for manager.Len() > 0 { - log.Println("session:", manager.Pop()) + firstToken := []byte(tokens[0]) + s, ok := manager.Remove(firstToken) + assert.True(t, ok) + assert.NotNil(t, s) + assert.Equal(t, firstToken, s.Token()) + assert.Equal(t, len(tokens)-1, manager.Len()) + for i := 1; i < len(tokens); i++ { + manager.Pop() } + assert.Equal(t, 0, manager.Len()) } diff --git a/internal/socket/abstract_socket.go b/internal/socket/abstract_socket.go new file mode 100644 index 0000000..b31749a --- /dev/null +++ b/internal/socket/abstract_socket.go @@ -0,0 +1,69 @@ +package socket + +import ( + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/logger" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +var ( + errUnimplementedMetadataPush = errors.New("METADATA_PUSH is unimplemented") + errUnimplementedFireAndForget = errors.New("FIRE_AND_FORGET is unimplemented") + errUnimplementedRequestResponse = errors.New("REQUEST_RESPONSE is unimplemented") + errUnimplementedRequestStream = errors.New("REQUEST_STREAM is unimplemented") + errUnimplementedRequestChannel = errors.New("REQUEST_CHANNEL is unimplemented") +) + +// AbstractRSocket represents an abstract RSocket. +type AbstractRSocket struct { + FF func(payload.Payload) + MP func(payload.Payload) + RR func(payload.Payload) mono.Mono + RS func(payload.Payload) flux.Flux + RC func(rx.Publisher) flux.Flux +} + +// MetadataPush starts a request of MetadataPush. +func (p AbstractRSocket) MetadataPush(message payload.Payload) { + if p.MP == nil { + logger.Errorf("%s\n", errUnimplementedMetadataPush) + return + } + p.MP(message) +} + +// FireAndForget starts a request of FireAndForget. +func (p AbstractRSocket) FireAndForget(message payload.Payload) { + if p.FF == nil { + logger.Errorf("%s\n", errUnimplementedFireAndForget) + return + } + p.FF(message) +} + +// RequestResponse starts a request of RequestResponse. +func (p AbstractRSocket) RequestResponse(message payload.Payload) mono.Mono { + if p.RR == nil { + return mono.Error(errUnimplementedRequestResponse) + } + return p.RR(message) +} + +// RequestStream starts a request of RequestStream. +func (p AbstractRSocket) RequestStream(message payload.Payload) flux.Flux { + if p.RS == nil { + return flux.Error(errUnimplementedRequestStream) + } + return p.RS(message) +} + +// RequestChannel starts a request of RequestChannel. +func (p AbstractRSocket) RequestChannel(messages rx.Publisher) flux.Flux { + if p.RC == nil { + return flux.Error(errUnimplementedRequestChannel) + } + return p.RC(messages) +} diff --git a/internal/socket/abstract_socket_test.go b/internal/socket/abstract_socket_test.go new file mode 100644 index 0000000..f042281 --- /dev/null +++ b/internal/socket/abstract_socket_test.go @@ -0,0 +1,107 @@ +package socket_test + +import ( + "context" + "testing" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +var emptyAbstractRSocket = &socket.AbstractRSocket{} +var fakeRequest = payload.New(fakeData, fakeMetadata) + +func TestAbstractRSocket_FireAndForget(t *testing.T) { + called := atomic.NewBool(false) + s := &socket.AbstractRSocket{ + FF: func(payload payload.Payload) { + called.CAS(false, true) + }, + } + s.FireAndForget(fakeRequest) + assert.True(t, called.Load()) + + assert.NotPanics(t, func() { + emptyAbstractRSocket.FireAndForget(fakeRequest) + }) +} + +func TestAbstractRSocket_RequestResponse(t *testing.T) { + s := &socket.AbstractRSocket{ + RR: func(p payload.Payload) mono.Mono { + return mono.Just(p) + }, + } + res, err := s.RequestResponse(fakeRequest).Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fakeData, res.Data()) + assert.Equal(t, fakeMetadata, extractMetadata(res)) + + _, err = emptyAbstractRSocket.RequestResponse(fakeRequest).Block(context.Background()) + assert.Error(t, err, "should return an error") +} + +func TestAbstractRSocket_MetadataPush(t *testing.T) { + called := atomic.NewBool(false) + s := &socket.AbstractRSocket{ + MP: func(p payload.Payload) { + called.CAS(false, true) + }, + } + + assert.NotPanics(t, func() { + s.MetadataPush(fakeRequest) + }) + assert.NotPanics(t, func() { + emptyAbstractRSocket.MetadataPush(fakeRequest) + }) +} + +func TestAbstractRSocket_RequestStream(t *testing.T) { + s := &socket.AbstractRSocket{ + RS: func(p payload.Payload) flux.Flux { + return flux.Just(p) + }, + } + + var res []payload.Payload + + _, err := s.RequestStream(fakeRequest). + DoOnNext(func(input payload.Payload) error { + res = append(res, input) + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, fakeRequest, res[0]) + + _, err = emptyAbstractRSocket.RequestStream(fakeRequest).BlockLast(context.Background()) + assert.Error(t, err, "should return an error") +} + +func TestAbstractRSocket_RequestChannel(t *testing.T) { + s := &socket.AbstractRSocket{ + RC: func(publisher rx.Publisher) flux.Flux { + return flux.Clone(publisher) + }, + } + var res []payload.Payload + _, err := s.RequestChannel(flux.Just(fakeRequest)). + DoOnNext(func(input payload.Payload) error { + res = append(res, input) + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, fakeRequest, res[0]) + + _, err = emptyAbstractRSocket.RequestChannel(flux.Just(fakeRequest)).BlockFirst(context.Background()) + assert.Error(t, err, "should return an error") +} diff --git a/internal/socket/base_socket.go b/internal/socket/base_socket.go new file mode 100644 index 0000000..126a989 --- /dev/null +++ b/internal/socket/base_socket.go @@ -0,0 +1,89 @@ +package socket + +import ( + "sync" + "time" + + "github.com/rsocket/rsocket-go/logger" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +type BaseSocket struct { + socket *DuplexConnection + closers []func(error) + once sync.Once + reqLease *leaser +} + +func (p *BaseSocket) refreshLease(ttl time.Duration, n int64) { + deadline := time.Now().Add(ttl) + if p.reqLease == nil { + p.reqLease = newLeaser(deadline, n) + } else { + p.reqLease.refresh(deadline, n) + } +} + +func (p *BaseSocket) FireAndForget(message payload.Payload) { + if err := p.reqLease.allow(); err != nil { + logger.Warnf("request FireAndForget failed: %v\n", err) + } + p.socket.FireAndForget(message) +} + +func (p *BaseSocket) MetadataPush(message payload.Payload) { + p.socket.MetadataPush(message) +} + +func (p *BaseSocket) RequestResponse(message payload.Payload) mono.Mono { + if err := p.reqLease.allow(); err != nil { + return mono.Error(err) + } + return p.socket.RequestResponse(message) +} + +func (p *BaseSocket) RequestStream(message payload.Payload) flux.Flux { + if err := p.reqLease.allow(); err != nil { + return flux.Error(err) + } + return p.socket.RequestStream(message) +} + +func (p *BaseSocket) RequestChannel(messages rx.Publisher) flux.Flux { + if err := p.reqLease.allow(); err != nil { + return flux.Error(err) + } + return p.socket.RequestChannel(messages) +} + +func (p *BaseSocket) OnClose(fn func(error)) { + if fn != nil { + p.closers = append(p.closers, fn) + } +} + +func (p *BaseSocket) Close() (err error) { + p.once.Do(func() { + err = p.socket.Close() + for i, l := 0, len(p.closers); i < l; i++ { + func(fn func(error)) { + defer func() { + if e := tryRecover(recover()); e != nil { + logger.Errorf("handle socket closer failed: %s\n", e) + } + }() + fn(err) + }(p.closers[l-i-1]) + } + }) + return +} + +func NewBaseSocket(rawSocket *DuplexConnection) *BaseSocket { + return &BaseSocket{ + socket: rawSocket, + } +} diff --git a/internal/socket/base_socket_test.go b/internal/socket/base_socket_test.go new file mode 100644 index 0000000..d427495 --- /dev/null +++ b/internal/socket/base_socket_test.go @@ -0,0 +1,60 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestBaseSocket(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + conn.EXPECT().Close().Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().Return(nil, io.EOF).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + duplex := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + duplex.SetTransport(tp) + + go func() { + _ = duplex.LoopWrite(context.Background()) + }() + + s := socket.NewBaseSocket(duplex) + + onClosedCalled := atomic.NewBool(false) + + s.OnClose(func(err error) { + onClosedCalled.CAS(false, true) + }) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = tp.Start(context.Background()) + }() + + assert.NotPanics(t, func() { + s.MetadataPush(fakeRequest) + s.FireAndForget(fakeRequest) + s.RequestResponse(fakeRequest) + s.RequestStream(fakeRequest) + s.RequestChannel(flux.Just(fakeRequest)) + }) + + <-done + + _ = s.Close() + assert.Equal(t, true, onClosedCalled.Load()) +} diff --git a/internal/socket/callback.go b/internal/socket/callback.go new file mode 100644 index 0000000..6afea4b --- /dev/null +++ b/internal/socket/callback.go @@ -0,0 +1,66 @@ +package socket + +import ( + "github.com/jjeffcaii/reactor-go" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +type callback interface { + Close(error) +} + +type requestStreamCallback struct { + pc flux.Processor +} + +func (s requestStreamCallback) Close(err error) { + s.pc.Error(err) +} + +type requestResponseCallback struct { + pc mono.Processor +} + +func (s requestResponseCallback) Close(err error) { + s.pc.Error(err) +} + +type requestChannelCallback struct { + snd rx.Subscription + rcv flux.Processor +} + +func (s requestChannelCallback) Close(err error) { + s.snd.Cancel() + s.rcv.Error(err) +} + +type requestResponseCallbackReverse struct { + su reactor.Subscription +} + +func (s requestResponseCallbackReverse) Close(err error) { + s.su.Cancel() + // TODO: fill err +} + +type requestStreamCallbackReverse struct { + su rx.Subscription +} + +func (s requestStreamCallbackReverse) Close(err error) { + s.su.Cancel() + // TODO: fill error +} + +type requestChannelCallbackReverse struct { + snd rx.Subscription + rcv flux.Processor +} + +func (s requestChannelCallbackReverse) Close(err error) { + s.rcv.Error(err) + s.snd.Cancel() +} diff --git a/internal/socket/client_default.go b/internal/socket/client_default.go deleted file mode 100644 index 7cb05bb..0000000 --- a/internal/socket/client_default.go +++ /dev/null @@ -1,67 +0,0 @@ -package socket - -import ( - "context" - "crypto/tls" - - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" - "github.com/rsocket/rsocket-go/logger" -) - -type defaultClientSocket struct { - *baseSocket - uri *transport.URI - headers map[string][]string - tls *tls.Config -} - -func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { - tp, err := p.uri.MakeClientTransport(p.tls, p.headers) - if err != nil { - return - } - tp.Connection().SetCounter(p.socket.counter) - tp.SetLifetime(setup.KeepaliveLifetime) - - p.socket.SetTransport(tp) - - if setup.Lease { - p.refreshLease(0, 0) - tp.HandleLease(func(frame framing.Frame) (err error) { - lease := frame.(*framing.FrameLease) - p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) - logger.Infof(">>>>> refresh lease: %v\n", lease) - return - }) - } - - tp.HandleDisaster(func(frame framing.Frame) (err error) { - p.socket.SetError(frame.(*framing.FrameError)) - return - }) - - go func(ctx context.Context, tp *transport.Transport) { - if err := tp.Start(ctx); err != nil { - logger.Warnf("client exit failed: %+v\n", err) - } - _ = p.Close() - }(ctx, tp) - - go func(ctx context.Context) { - _ = p.socket.loopWrite(ctx) - }(ctx) - setupFrame := setup.toFrame() - err = p.socket.tp.Send(setupFrame, true) - return -} - -// NewClient create a simple client-side socket. -func NewClient(uri *transport.URI, socket *DuplexRSocket, tc *tls.Config, headers map[string][]string) ClientSocket { - return &defaultClientSocket{ - baseSocket: newBaseSocket(socket), - uri: uri, - headers: headers, - tls: tc, - } -} diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index c738482..46a4ebf 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -2,70 +2,83 @@ package socket import ( "context" - "encoding/binary" - "errors" "fmt" - "io" "sync" "time" "github.com/jjeffcaii/reactor-go/scheduler" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/fragmentation" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/lease" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" - "go.uber.org/atomic" ) -const outsSize = 64 +const _outChanSize = 64 + +var errSocketClosed = errors.New("socket closed already") var ( - errSocketClosed = errors.New("socket closed already") unsupportedRequestStream = []byte("Request-Stream not implemented.") unsupportedRequestResponse = []byte("Request-Response not implemented.") unsupportedRequestChannel = []byte("Request-Channel not implemented.") ) -// DuplexRSocket represents a socket of RSocket which can be a requester or a responder. -type DuplexRSocket struct { - counter *transport.Counter +// IsSocketClosedError returns true if input error is for socket closed. +func IsSocketClosedError(err error) bool { + return err == errSocketClosed +} + +// DuplexConnection represents a socket of RSocket which can be a requester or a responder. +type DuplexConnection struct { + l *sync.RWMutex + counter *core.TrafficCounter tp *transport.Transport - outs chan framing.Frame - outsPriority []framing.Frame + outs chan core.WriteableFrame + outsPriority []core.WriteableFrame responder Responder - messages *u32map - sids streamIDs + messages *sync.Map + sids StreamID mtu int - fragments *u32map // key=streamID, value=Joiner - closed *atomic.Bool - done chan struct{} - keepaliver *keepaliver + fragments *sync.Map // common.U32Map // key=streamID, value=Joiner + writeDone chan struct{} + keepaliver *Keepaliver cond *sync.Cond singleScheduler scheduler.Scheduler e error leases lease.Leases + closeOnce sync.Once } // SetError sets error for current socket. -func (p *DuplexRSocket) SetError(e error) { - p.e = e +func (dc *DuplexConnection) SetError(err error) { + dc.l.Lock() + defer dc.l.Unlock() + dc.e = err } -func (p *DuplexRSocket) nextStreamID() (sid uint32) { - var lap1st bool +// GetError get the error set. +func (dc *DuplexConnection) GetError() error { + dc.l.RLock() + defer dc.l.RUnlock() + return dc.e +} + +func (dc *DuplexConnection) nextStreamID() (sid uint32) { + var firstLap bool for { - // There's no necessery to check StreamID conflicts. - sid, lap1st = p.sids.next() - if lap1st { + // There's no required to check StreamID conflicts. + sid, firstLap = dc.sids.Next() + if firstLap { return } - _, ok := p.messages.Load(sid) + _, ok := dc.messages.Load(sid) if !ok { return } @@ -73,150 +86,130 @@ func (p *DuplexRSocket) nextStreamID() (sid uint32) { } // Close close current socket. -func (p *DuplexRSocket) Close() error { - if !p.closed.CAS(false, true) { - return nil - } - if p.keepaliver != nil { - p.keepaliver.Stop() +func (dc *DuplexConnection) Close() (err error) { + dc.closeOnce.Do(func() { + err = dc.innerClose() + }) + return +} + +func (dc *DuplexConnection) innerClose() error { + if dc.keepaliver != nil { + dc.keepaliver.Stop() } - _ = p.singleScheduler.(io.Closer).Close() - close(p.outs) - p.cond.L.Lock() - p.cond.Broadcast() - p.cond.L.Unlock() - - _ = p.fragments.Close() - <-p.done - - if p.tp != nil { - if p.e == nil { - p.e = p.tp.Close() + _ = dc.singleScheduler.Close() + close(dc.outs) + dc.cond.L.Lock() + dc.cond.Broadcast() + dc.cond.L.Unlock() + + <-dc.writeDone + + if dc.tp != nil { + if dc.e == nil { + dc.e = dc.tp.Close() } else { - _ = p.tp.Close() + _ = dc.tp.Close() } } - - p.fragments.Range(func(key uint32, value interface{}) bool { - return true - }) - _ = p.fragments.Close() - - p.messages.Range(func(key uint32, value interface{}) bool { - if cc, ok := value.(closerWithError); ok { - if p.e == nil { - go func() { - cc.Close(errSocketClosed) - }() - } else { - go func(e error) { - cc.Close(e) - }(p.e) + dc.messages.Range(func(_, v interface{}) bool { + if cb, ok := v.(callback); ok { + err := dc.e + if err == nil { + err = errSocketClosed } + go cb.Close(err) } return true }) - _ = p.messages.Close() - return p.e + return dc.e } // FireAndForget start a request of FireAndForget. -func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { +func (dc *DuplexConnection) FireAndForget(sending payload.Payload) { data := sending.Data() - size := framing.HeaderLen + len(sending.Data()) + size := core.FrameHeaderLen + len(sending.Data()) m, ok := sending.Metadata() if ok { size += 3 + len(m) } - sid := p.nextStreamID() - if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameFNF(sid, data, m)) + sid := dc.nextStreamID() + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableFireAndForgetFrame(sid, data, m, 0)) return } - p.doSplit(data, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestFNF, fg) - f = &framing.FrameFNF{ - BaseFrame: framing.NewBaseFrame(h, body), - } + dc.doSplit(data, m, func(index int, result fragmentation.SplitResult) { + var f core.WriteableFrame + if index == 0 { + f = framing.NewWriteableFireAndForgetFrame(sid, result.Data, result.Metadata, result.Flag) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - } + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) } // MetadataPush start a request of MetadataPush. -func (p *DuplexRSocket) MetadataPush(payload payload.Payload) { +func (dc *DuplexConnection) MetadataPush(payload payload.Payload) { metadata, _ := payload.Metadata() - p.sendFrame(framing.NewFrameMetadataPush(metadata)) + dc.sendFrame(framing.NewWriteableMetadataPushFrame(metadata)) } // RequestResponse start a request of RequestResponse. -func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestResponse(pl payload.Payload) (mo mono.Mono) { + sid := dc.nextStreamID() resp := mono.CreateProcessor() - p.register(sid, reqRR{pc: resp}) + dc.register(sid, requestResponseCallback{pc: resp}) data := pl.Data() metadata, _ := pl.Metadata() + mo = resp. DoFinally(func(s rx.SignalType) { if s == rx.SignalCancel { - p.sendFrame(framing.NewFrameCancel(sid)) + dc.sendFrame(framing.NewWriteableCancelFrame(sid)) } - p.unregister(sid) + dc.unregister(sid) }) - p.singleScheduler.Worker().Do(func() { - // sending... - size := framing.CalcPayloadFrameSize(data, metadata) - if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameRequestResponse(sid, data, metadata)) - return + // sending... + size := framing.CalcPayloadFrameSize(data, metadata) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableRequestResponseFrame(sid, data, metadata, 0)) + return + } + dc.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { + var f core.WriteableFrame + if index == 0 { + f = framing.NewWriteableRequestResponseFrame(sid, result.Data, result.Metadata, result.Flag) + } else { + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.doSplit(data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestResponse, fg) - f = &framing.FrameRequestResponse{ - BaseFrame: framing.NewBaseFrame(h, body), - } - } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - } - } - p.sendFrame(f) - }) + dc.sendFrame(f) }) + return } // RequestStream start a request of RequestStream. -func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux) { + sid := dc.nextStreamID() pc := flux.CreateProcessor() - p.register(sid, reqRS{pc: pc}) + dc.register(sid, requestStreamCallback{pc: pc}) requested := make(chan struct{}) ret = pc. DoFinally(func(sig rx.SignalType) { if sig == rx.SignalCancel { - p.sendFrame(framing.NewFrameCancel(sid)) + dc.sendFrame(framing.NewWriteableCancelFrame(sid)) } - p.unregister(sid) + dc.unregister(sid) }). DoOnRequest(func(n int) { - n32 := toU32N(n) + n32 := ToUint32RequestN(n) var newborn bool select { @@ -227,8 +220,8 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { } if !newborn { - frameN := framing.NewFrameRequestN(sid, n32) - p.sendFrame(frameN) + frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) + dc.sendFrame(frameN) <-frameN.DoneNotify() return } @@ -237,34 +230,27 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { metadata, _ := sending.Metadata() size := framing.CalcPayloadFrameSize(data, metadata) + 4 - if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameRequestStream(sid, n32, data, metadata)) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableRequestStreamFrame(sid, n32, data, metadata, 0)) return } - p.doSplitSkip(4, data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestStream, fg) - // write init RequestN - binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.FrameRequestStream{ - BaseFrame: framing.NewBaseFrame(h, body), - } + + dc.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { + var f core.WriteableFrame + if index == 0 { + f = framing.NewWriteableRequestStreamFrame(sid, n32, result.Data, result.Metadata, result.Flag) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - } + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) }) return } // RequestChannel start a request of RequestChannel. -func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { + sid := dc.nextStreamID() sending := publisher.(flux.Flux) receiving := flux.CreateProcessor() @@ -273,10 +259,10 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { ret = receiving. DoFinally(func(sig rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). DoOnRequest(func(n int) { - n32 := toU32N(n) + n32 := ToUint32RequestN(n) var newborn bool select { case <-rcvRequested: @@ -285,15 +271,15 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { close(rcvRequested) } if !newborn { - frameN := framing.NewFrameRequestN(sid, n32) - p.sendFrame(frameN) + frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) + dc.sendFrame(frameN) <-frameN.DoneNotify() return } sndRequested := make(chan struct{}) sub := rx.NewSubscriber( - rx.OnNext(func(item payload.Payload) { + rx.OnNext(func(item payload.Payload) (err error) { var newborn bool select { case <-sndRequested: @@ -302,38 +288,32 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { close(sndRequested) } if !newborn { - p.sendPayload(sid, item, framing.FlagNext) + dc.sendPayload(sid, item, core.FlagNext) return } d := item.Data() m, _ := item.Metadata() size := framing.CalcPayloadFrameSize(d, m) + 4 - if !p.shouldSplit(size) { + if !dc.shouldSplit(size) { metadata, _ := item.Metadata() - p.sendFrame(framing.NewFrameRequestChannel(sid, n32, item.Data(), metadata, framing.FlagNext)) + dc.sendFrame(framing.NewWriteableRequestChannelFrame(sid, n32, item.Data(), metadata, core.FlagNext)) return } - p.doSplitSkip(4, d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestChannel, fg|framing.FlagNext) - // write init RequestN - binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.FrameRequestChannel{ - BaseFrame: framing.NewBaseFrame(h, body), - } + + dc.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { + var f core.WriteableFrame + if index == 0 { + f = framing.NewWriteableRequestChannelFrame(sid, n32, result.Data, result.Metadata, result.Flag|core.FlagNext) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - } + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) + return }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, reqRC{rcv: receiving, snd: s}) + dc.register(sid, requestChannelCallback{rcv: receiving, snd: s}) s.Request(1) }), ) @@ -342,29 +322,29 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { // TODO: handle cancel or error switch sig { case rx.SignalComplete: - complete := framing.NewFramePayload(sid, nil, nil, framing.FlagComplete) - p.sendFrame(complete) + complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) + dc.sendFrame(complete) <-complete.DoneNotify() default: panic(fmt.Errorf("unsupported sending channel signal: %s", sig)) } }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) }) return ret } -func (p *DuplexRSocket) onFrameRequestResponse(frame framing.Frame) error { +func (dc *DuplexConnection) onFrameRequestResponse(frame core.Frame) error { // fragment - receiving, ok := p.doFragment(frame.(*framing.FrameRequestResponse)) + receiving, ok := dc.doFragment(frame.(*framing.RequestResponseFrame)) if !ok { return nil } - return p.respondRequestResponse(receiving) + return dc.respondRequestResponse(receiving) } -func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // 1. execute socket handler @@ -372,58 +352,59 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd defer func() { err = tryRecover(recover()) }() - mono = p.responder.RequestResponse(receiving) + mono = dc.responder.RequestResponse(receiving) return }() // 2. sending error with panic if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } // 3. sending error with unsupported handler if sending == nil { - p.writeError(sid, framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestResponse)) + dc.writeError(sid, framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) return nil } // 4. async subscribe publisher sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { - p.sendPayload(sid, input, framing.FlagNext|framing.FlagComplete) + rx.OnNext(func(input payload.Payload) error { + dc.sendPayload(sid, input, core.FlagNext|core.FlagComplete) + return nil }), rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRR{su: s}) + dc.register(sid, requestResponseCallbackReverse{su: s}) s.Request(rx.RequestMax) }), ) sending. DoFinally(func(sig rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } -func (p *DuplexRSocket) onFrameRequestChannel(input framing.Frame) error { - receiving, ok := p.doFragment(input.(*framing.FrameRequestChannel)) +func (dc *DuplexConnection) onFrameRequestChannel(input core.Frame) error { + receiving, ok := dc.doFragment(input.(*framing.RequestChannelFrame)) if !ok { return nil } - return p.respondRequestChannel(receiving) + return dc.respondRequestChannel(receiving) } -func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { // seek initRequestN var initRequestN int switch v := pl.(type) { - case *framing.FrameRequestChannel: - initRequestN = toIntN(v.InitialRequestN()) + case *framing.RequestChannelFrame: + initRequestN = ToIntRequestN(v.InitialRequestN()) case fragmentation.Joiner: - initRequestN = toIntN(v.First().(*framing.FrameRequestChannel).InitialRequestN()) + initRequestN = ToIntRequestN(v.First().(*framing.RequestChannelFrame).InitialRequestN()) default: panic("unreachable") } @@ -441,18 +422,18 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) case _, ok := <-ch: if ok { close(ch) - p.unregister(sid) + dc.unregister(sid) } default: } }). DoOnRequest(func(n int) { - frameN := framing.NewFrameRequestN(sid, toU32N(n)) - p.sendFrame(frameN) + frameN := framing.NewWriteableRequestNFrame(sid, ToUint32RequestN(n), 0) + dc.sendFrame(frameN) <-frameN.DoneNotify() }) - p.singleScheduler.Worker().Do(func() { + _ = dc.singleScheduler.Worker().Do(func() { receivingProcessor.Next(pl) }) @@ -461,15 +442,15 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) defer func() { err = tryRecover(recover()) }() - flux = p.responder.RequestChannel(receiving) + flux = dc.responder.RequestChannel(receiving) if flux == nil { - err = framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestChannel) + err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) } return }() if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } @@ -478,20 +459,21 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) sub := rx.NewSubscriber( rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnComplete(func() { - complete := framing.NewFramePayload(sid, nil, nil, framing.FlagComplete) - p.sendFrame(complete) + complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) + dc.sendFrame(complete) <-complete.DoneNotify() }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRC{rcv: receivingProcessor, snd: s}) + dc.register(sid, requestChannelCallbackReverse{rcv: receivingProcessor, snd: s}) close(mustSub) s.Request(initRequestN) }), - rx.OnNext(func(elem payload.Payload) { - p.sendPayload(sid, elem, framing.FlagNext) + rx.OnNext(func(elem payload.Payload) error { + dc.sendPayload(sid, elem, core.FlagNext) + return nil }), ) @@ -503,55 +485,55 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) case _, ok := <-ch: if ok { close(ch) - p.unregister(sid) + dc.unregister(sid) } default: } }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) <-mustSub return nil } -func (p *DuplexRSocket) respondMetadataPush(input framing.Frame) (err error) { +func (dc *DuplexConnection) respondMetadataPush(input core.Frame) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond METADATA_PUSH failed: %s\n", e) } }() - p.responder.MetadataPush(input.(*framing.FrameMetadataPush)) + dc.responder.MetadataPush(input.(*framing.MetadataPushFrame)) return } -func (p *DuplexRSocket) onFrameFNF(frame framing.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.FrameFNF)) +func (dc *DuplexConnection) onFrameFNF(frame core.Frame) error { + receiving, ok := dc.doFragment(frame.(*framing.FireAndForgetFrame)) if !ok { return nil } - return p.respondFNF(receiving) + return dc.respondFNF(receiving) } -func (p *DuplexRSocket) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { +func (dc *DuplexConnection) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond FireAndForget failed: %s\n", e) } }() - p.responder.FireAndForget(receiving) + dc.responder.FireAndForget(receiving) return } -func (p *DuplexRSocket) onFrameRequestStream(frame framing.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.FrameRequestStream)) +func (dc *DuplexConnection) onFrameRequestStream(frame core.Frame) error { + receiving, ok := dc.doFragment(frame.(*framing.RequestStreamFrame)) if !ok { return nil } - return p.respondRequestStream(receiving) + return dc.respondRequestStream(receiving) } -func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // execute request stream handler @@ -559,126 +541,128 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa defer func() { err = tryRecover(recover()) }() - resp = p.responder.RequestStream(receiving) + resp = dc.responder.RequestStream(receiving) if resp == nil { - err = framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestStream) + err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestStream) } return }() // send error with panic if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } // seek n32 var n32 int switch v := receiving.(type) { - case *framing.FrameRequestStream: + case *framing.RequestStreamFrame: n32 = int(v.InitialRequestN()) case fragmentation.Joiner: - n32 = int(v.First().(*framing.FrameRequestStream).InitialRequestN()) + n32 = int(v.First().(*framing.RequestStreamFrame).InitialRequestN()) default: panic("unreachable") } sub := rx.NewSubscriber( - rx.OnNext(func(elem payload.Payload) { - p.sendPayload(sid, elem, framing.FlagNext) + rx.OnNext(func(elem payload.Payload) error { + dc.sendPayload(sid, elem, core.FlagNext) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRS{su: s}) + dc.register(sid, requestStreamCallbackReverse{su: s}) s.Request(n32) }), rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnComplete(func() { - p.sendFrame(framing.NewFramePayload(sid, nil, nil, framing.FlagComplete)) + dc.sendFrame(framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete)) }), ) // async subscribe publisher sending. DoFinally(func(s rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } -func (p *DuplexRSocket) writeError(sid uint32, e error) { +func (dc *DuplexConnection) writeError(sid uint32, e error) { // ignore sending error because current socket has been closed. - if e == errSocketClosed { + if IsSocketClosedError(e) { return } switch err := e.(type) { - case *framing.FrameError: - p.sendFrame(err) - case common.CustomError: - p.sendFrame(framing.NewFrameError(sid, err.ErrorCode(), err.ErrorData())) + case *framing.ErrorFrame: + dc.sendFrame(err) + case core.CustomError: + dc.sendFrame(framing.NewWriteableErrorFrame(sid, err.ErrorCode(), err.ErrorData())) default: - p.sendFrame(framing.NewFrameError(sid, common.ErrorCodeApplicationError, []byte(e.Error()))) + dc.sendFrame(framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) } } // SetResponder sets a responder for current socket. -func (p *DuplexRSocket) SetResponder(responder Responder) { - p.responder = responder +func (dc *DuplexConnection) SetResponder(responder Responder) { + dc.responder = responder } -func (p *DuplexRSocket) onFrameKeepalive(frame framing.Frame) (err error) { - f := frame.(*framing.FrameKeepalive) - if f.Header().Flag().Check(framing.FlagRespond) { - f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) - p.sendFrame(f) +func (dc *DuplexConnection) onFrameKeepalive(frame core.Frame) (err error) { + f := frame.(*framing.KeepaliveFrame) + if f.Header().Flag().Check(core.FlagRespond) { + k := framing.NewKeepaliveFrame(f.LastReceivedPosition(), f.Data(), false) + //f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) + dc.sendFrame(k) } return } -func (p *DuplexRSocket) onFrameCancel(frame framing.Frame) (err error) { +func (dc *DuplexConnection) onFrameCancel(frame core.Frame) (err error) { sid := frame.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { logger.Warnf("nothing cancelled: sid=%d\n", sid) return } switch vv := v.(type) { - case resRR: + case requestResponseCallbackReverse: vv.su.Cancel() - case resRS: + case requestStreamCallbackReverse: vv.su.Cancel() default: panic(fmt.Errorf("illegal cancel target: %v", vv)) } - if _, ok := p.fragments.Load(sid); ok { - p.fragments.Delete(sid) + if _, ok := dc.fragments.Load(sid); ok { + dc.fragments.Delete(sid) } return } -func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { - f := input.(*framing.FrameError) +func (dc *DuplexConnection) onFrameError(input core.Frame) (err error) { + f := input.(*framing.ErrorFrame) logger.Errorf("handle error frame: %s\n", f) sid := f.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { err = fmt.Errorf("invalid stream id: %d", sid) return } switch vv := v.(type) { - case reqRR: + case requestResponseCallback: vv.pc.Error(f) - case reqRS: + case requestStreamCallback: vv.pc.Error(f) - case reqRC: + case requestChannelCallback: vv.rcv.Error(f) default: panic(fmt.Errorf("illegal value for error: %v", vv)) @@ -686,23 +670,23 @@ func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameRequestN(input framing.Frame) (err error) { - f := input.(*framing.FrameRequestN) +func (dc *DuplexConnection) onFrameRequestN(input core.Frame) (err error) { + f := input.(*framing.RequestNFrame) sid := f.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { if logger.IsDebugEnabled() { logger.Debugf("ignore non-exists RequestN: id=%d\n", sid) } return } - n := toIntN(f.N()) + n := ToIntRequestN(f.N()) switch vv := v.(type) { - case resRS: + case requestStreamCallbackReverse: vv.su.Request(n) - case reqRC: + case requestChannelCallback: vv.snd.Request(n) - case resRC: + case requestChannelCallbackReverse: vv.snd.Request(n) default: panic(fmt.Errorf("illegal requestN for %+v", vv)) @@ -710,84 +694,84 @@ func (p *DuplexRSocket) onFrameRequestN(input framing.Frame) (err error) { return } -func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { +func (dc *DuplexConnection) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { h := input.Header() sid := h.StreamID() - v, exist := p.fragments.Load(sid) + v, exist := dc.fragments.Load(sid) if exist { joiner := v.(fragmentation.Joiner) ok = joiner.Push(input) if ok { - p.fragments.Delete(sid) + dc.fragments.Delete(sid) out = joiner } return } - ok = !h.Flag().Check(framing.FlagFollow) + ok = !h.Flag().Check(core.FlagFollow) if ok { out = input return } - p.fragments.Store(sid, fragmentation.NewJoiner(input)) + dc.fragments.Store(sid, fragmentation.NewJoiner(input)) return } -func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { - pl, ok := p.doFragment(frame.(*framing.FramePayload)) +func (dc *DuplexConnection) onFramePayload(frame core.Frame) error { + pl, ok := dc.doFragment(frame.(*framing.PayloadFrame)) if !ok { return nil } h := pl.Header() t := h.Type() - if t == framing.FrameTypeRequestFNF { - return p.respondFNF(pl) + if t == core.FrameTypeRequestFNF { + return dc.respondFNF(pl) } - if t == framing.FrameTypeRequestResponse { - return p.respondRequestResponse(pl) + if t == core.FrameTypeRequestResponse { + return dc.respondRequestResponse(pl) } - if t == framing.FrameTypeRequestStream { - return p.respondRequestStream(pl) + if t == core.FrameTypeRequestStream { + return dc.respondRequestStream(pl) } - if t == framing.FrameTypeRequestChannel { - return p.respondRequestChannel(pl) + if t == core.FrameTypeRequestChannel { + return dc.respondRequestChannel(pl) } sid := h.StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { - logger.Warnf("unoccupied Payload(id=%d), maybe it has been canceled(server=%T)\n", sid, p.sids) + logger.Warnf("unoccupied Payload(id=%d), maybe it has been canceled(server=%T)\n", sid, dc.sids) return nil } switch vv := v.(type) { - case reqRR: + case requestResponseCallback: vv.pc.Success(pl) - case reqRS: + case requestStreamCallback: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.pc.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { // Release pure complete payload vv.pc.Complete() } - case reqRC: + case requestChannelCallback: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.rcv.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { vv.rcv.Complete() } - case resRC: + case requestChannelCallbackReverse: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.rcv.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { vv.rcv.Complete() } default: @@ -796,80 +780,78 @@ func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { return nil } -func (p *DuplexRSocket) clearTransport() { - p.cond.L.Lock() - p.tp = nil - p.cond.L.Unlock() +func (dc *DuplexConnection) clearTransport() { + dc.cond.L.Lock() + dc.tp = nil + dc.cond.L.Unlock() } // SetTransport sets a transport for current socket. -func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { - tp.HandleCancel(p.onFrameCancel) - tp.HandleError(p.onFrameError) - tp.HandleRequestN(p.onFrameRequestN) - tp.HandlePayload(p.onFramePayload) - tp.HandleKeepalive(p.onFrameKeepalive) - - if p.responder != nil { - tp.HandleRequestResponse(p.onFrameRequestResponse) - tp.HandleMetadataPush(p.respondMetadataPush) - tp.HandleFNF(p.onFrameFNF) - tp.HandleRequestStream(p.onFrameRequestStream) - tp.HandleRequestChannel(p.onFrameRequestChannel) +func (dc *DuplexConnection) SetTransport(tp *transport.Transport) { + tp.RegisterHandler(transport.OnCancel, dc.onFrameCancel) + tp.RegisterHandler(transport.OnError, dc.onFrameError) + tp.RegisterHandler(transport.OnRequestN, dc.onFrameRequestN) + tp.RegisterHandler(transport.OnPayload, dc.onFramePayload) + tp.RegisterHandler(transport.OnKeepalive, dc.onFrameKeepalive) + + if dc.responder != nil { + tp.RegisterHandler(transport.OnRequestResponse, dc.onFrameRequestResponse) + tp.RegisterHandler(transport.OnMetadataPush, dc.respondMetadataPush) + tp.RegisterHandler(transport.OnFireAndForget, dc.onFrameFNF) + tp.RegisterHandler(transport.OnRequestStream, dc.onFrameRequestStream) + tp.RegisterHandler(transport.OnRequestChannel, dc.onFrameRequestChannel) } - p.cond.L.Lock() - p.tp = tp - p.cond.Signal() - p.cond.L.Unlock() + dc.cond.L.Lock() + dc.tp = tp + dc.cond.Signal() + dc.cond.L.Unlock() } -func (p *DuplexRSocket) sendFrame(f framing.Frame) { +func (dc *DuplexConnection) sendFrame(f core.WriteableFrame) { defer func() { if e := recover(); e != nil { logger.Warnf("send frame failed: %s\n", e) } }() - p.outs <- f + dc.outs <- f } -func (p *DuplexRSocket) sendPayload( +func (dc *DuplexConnection) sendPayload( sid uint32, sending payload.Payload, - frameFlag framing.FrameFlag, + frameFlag core.FrameFlag, ) { d := sending.Data() m, _ := sending.Metadata() size := framing.CalcPayloadFrameSize(d, m) - if !p.shouldSplit(size) { - p.sendFrame(framing.NewFramePayload(sid, d, m, frameFlag)) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteablePayloadFrame(sid, d, m, frameFlag)) return } - p.doSplit(d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var h framing.FrameHeader - if idx == 0 { - h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|frameFlag) + dc.doSplit(d, m, func(index int, result fragmentation.SplitResult) { + flag := result.Flag + if index == 0 { + flag |= frameFlag } else { - h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) + flag |= core.FlagNext } - p.sendFrame(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), - }) + dc.sendFrame(framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, flag)) }) } -func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { - if len(p.outs) > 0 { - p.drain(nil) +func (dc *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { + if len(dc.outs) > 0 { + dc.drain(nil) } - var out framing.Frame + var out core.WriteableFrame select { - case <-p.keepaliver.C(): + case <-dc.keepaliver.C(): ok = true - out = framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(out, true) + out = framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(out, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } @@ -879,60 +861,60 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) if !ok { return } - out = framing.NewFrameLease(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + out = framing.NewWriteableLeaseFrame(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } - case out, ok = <-p.outs: + case out, ok = <-dc.outs: if !ok { return } - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } } return } -func (p *DuplexRSocket) drainWithKeepalive() (ok bool) { - if len(p.outs) > 0 { - p.drain(nil) +func (dc *DuplexConnection) drainWithKeepalive() (ok bool) { + if len(dc.outs) > 0 { + dc.drain(nil) } - var out framing.Frame + var out core.WriteableFrame select { - case <-p.keepaliver.C(): + case <-dc.keepaliver.C(): ok = true - out = framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(out, true) + out = framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(out, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } } - case out, ok = <-p.outs: + case out, ok = <-dc.outs: if !ok { return } - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } } return } -func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { +func (dc *DuplexConnection) drain(leaseChan <-chan lease.Lease) bool { var flush bool - cycle := len(p.outs) + cycle := len(dc.outs) if cycle < 1 { cycle = 1 } @@ -942,34 +924,34 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { if !ok { return false } - if p.drainOne(framing.NewFrameLease(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { + if dc.drainOne(framing.NewWriteableLeaseFrame(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { flush = true } - case out, ok := <-p.outs: + case out, ok := <-dc.outs: if !ok { return false } - if p.drainOne(out) { + if dc.drainOne(out) { flush = true } } } if flush { - if err := p.tp.Flush(); err != nil { + if err := dc.tp.Flush(); err != nil { logger.Errorf("flush failed: %v\n", err) } } return true } -func (p *DuplexRSocket) drainOne(out framing.Frame) (wrote bool) { - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) +func (dc *DuplexConnection) drainOne(out core.WriteableFrame) (wrote bool) { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) return } - err := p.tp.Send(out, false) + err := dc.tp.Send(out, false) if err != nil { - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) logger.Errorf("send frame failed: %s\n", err.Error()) return } @@ -977,50 +959,50 @@ func (p *DuplexRSocket) drainOne(out framing.Frame) (wrote bool) { return } -func (p *DuplexRSocket) drainOutBack() { - if len(p.outsPriority) < 1 { +func (dc *DuplexConnection) drainOutBack() { + if len(dc.outsPriority) < 1 { return } defer func() { - p.outsPriority = p.outsPriority[:0] + dc.outsPriority = dc.outsPriority[:0] }() - if p.tp == nil { + if dc.tp == nil { return } - var out framing.Frame - for i := range p.outsPriority { - out = p.outsPriority[i] - if err := p.tp.Send(out, false); err != nil { + var out core.WriteableFrame + for i := range dc.outsPriority { + out = dc.outsPriority[i] + if err := dc.tp.Send(out, false); err != nil { out.Done() logger.Errorf("send frame failed: %v\n", err) } } - if err := p.tp.Flush(); err != nil { + if err := dc.tp.Flush(); err != nil { logger.Errorf("flush failed: %v\n", err) } } -func (p *DuplexRSocket) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { +func (dc *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { for { - if p.tp == nil { - p.cond.L.Lock() - p.cond.Wait() - p.cond.L.Unlock() + if dc.tp == nil { + dc.cond.L.Lock() + dc.cond.Wait() + dc.cond.L.Unlock() } select { case <-ctx.Done(): - p.cleanOuts() + dc.cleanOuts() return ctx.Err() default: // ignore } select { - case <-p.keepaliver.C(): - kf := framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(kf, true) + case <-dc.keepaliver.C(): + kf := framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(kf, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } @@ -1028,117 +1010,106 @@ func (p *DuplexRSocket) loopWriteWithKeepaliver(ctx context.Context, leaseChan < default: } - p.drainOutBack() - if leaseChan == nil && !p.drainWithKeepalive() { + dc.drainOutBack() + if leaseChan == nil && !dc.drainWithKeepalive() { break } - if leaseChan != nil && !p.drainWithKeepaliveAndLease(leaseChan) { + if leaseChan != nil && !dc.drainWithKeepaliveAndLease(leaseChan) { break } } return nil } -func (p *DuplexRSocket) cleanOuts() { - p.outsPriority = nil +func (dc *DuplexConnection) cleanOuts() { + dc.outsPriority = nil } -func (p *DuplexRSocket) loopWrite(ctx context.Context) error { - defer close(p.done) +func (dc *DuplexConnection) LoopWrite(ctx context.Context) error { + defer close(dc.writeDone) var leaseChan chan lease.Lease - if p.leases != nil { + if dc.leases != nil { leaseCtx, cancel := context.WithCancel(ctx) defer func() { cancel() }() - if c, ok := p.leases.Next(leaseCtx); ok { + if c, ok := dc.leases.Next(leaseCtx); ok { leaseChan = c } } - if p.keepaliver != nil { - defer p.keepaliver.Stop() - return p.loopWriteWithKeepaliver(ctx, leaseChan) + if dc.keepaliver != nil { + defer dc.keepaliver.Stop() + return dc.loopWriteWithKeepaliver(ctx, leaseChan) } for { - if p.tp == nil { - p.cond.L.Lock() - p.cond.Wait() - p.cond.L.Unlock() + if dc.tp == nil { + dc.cond.L.Lock() + dc.cond.Wait() + dc.cond.L.Unlock() } select { case <-ctx.Done(): - p.cleanOuts() + dc.cleanOuts() return ctx.Err() default: } - p.drainOutBack() - if !p.drain(leaseChan) { + dc.drainOutBack() + if !dc.drain(leaseChan) { break } } return nil } -func (p *DuplexRSocket) doSplit(data, metadata []byte, handler func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { - fragmentation.Split(p.mtu, data, metadata, handler) +func (dc *DuplexConnection) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { + fragmentation.Split(dc.mtu, data, metadata, handler) } -func (p *DuplexRSocket) doSplitSkip(skip int, data, metadata []byte, handler func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { - fragmentation.SplitSkip(p.mtu, skip, data, metadata, handler) +func (dc *DuplexConnection) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { + fragmentation.SplitSkip(dc.mtu, skip, data, metadata, handler) } -func (p *DuplexRSocket) shouldSplit(size int) bool { - return size > p.mtu +func (dc *DuplexConnection) shouldSplit(size int) bool { + return size > dc.mtu } -func (p *DuplexRSocket) register(sid uint32, msg interface{}) { - p.messages.Store(sid, msg) +func (dc *DuplexConnection) register(sid uint32, msg interface{}) { + dc.messages.Store(sid, msg) } -func (p *DuplexRSocket) unregister(sid uint32) { - p.messages.Delete(sid) - p.fragments.Delete(sid) +func (dc *DuplexConnection) unregister(sid uint32) { + dc.messages.Delete(sid) + dc.fragments.Delete(sid) } -// NewServerDuplexRSocket creates a new server-side DuplexRSocket. -func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { - return &DuplexRSocket{ - closed: atomic.NewBool(false), - leases: leases, - outs: make(chan framing.Frame, outsSize), - mtu: mtu, - messages: newU32Map(), - sids: &serverStreamIDs{}, - fragments: newU32Map(), - done: make(chan struct{}), - cond: sync.NewCond(&sync.Mutex{}), - counter: transport.NewCounter(), - singleScheduler: scheduler.NewSingle(64), - } +// NewServerDuplexConnection creates a new server-side DuplexConnection. +func NewServerDuplexConnection(mtu int, leases lease.Leases) *DuplexConnection { + return newDuplexConnection(mtu, nil, &serverStreamIDs{}, leases) +} + +// NewClientDuplexConnection creates a new client-side DuplexConnection. +func NewClientDuplexConnection(mtu int, keepaliveInterval time.Duration) *DuplexConnection { + return newDuplexConnection(mtu, NewKeepaliver(keepaliveInterval), &clientStreamIDs{}, nil) } -// NewClientDuplexRSocket creates a new client-side DuplexRSocket. -func NewClientDuplexRSocket( - mtu int, - keepaliveInterval time.Duration, -) (s *DuplexRSocket) { - ka := newKeepaliver(keepaliveInterval) - s = &DuplexRSocket{ - closed: atomic.NewBool(false), - outs: make(chan framing.Frame, outsSize), +func newDuplexConnection(mtu int, ka *Keepaliver, sids StreamID, leases lease.Leases) *DuplexConnection { + l := &sync.RWMutex{} + return &DuplexConnection{ + l: l, + leases: leases, + outs: make(chan core.WriteableFrame, _outChanSize), mtu: mtu, - messages: newU32Map(), - sids: &clientStreamIDs{}, - fragments: newU32Map(), - done: make(chan struct{}), - cond: sync.NewCond(&sync.Mutex{}), - counter: transport.NewCounter(), + messages: &sync.Map{}, + sids: sids, + fragments: &sync.Map{}, + writeDone: make(chan struct{}), + cond: sync.NewCond(l), + counter: core.NewTrafficCounter(), keepaliver: ka, singleScheduler: scheduler.NewSingle(64), } - return } diff --git a/internal/socket/keepaliver.go b/internal/socket/keepaliver.go index a1a5f94..d7ade0c 100644 --- a/internal/socket/keepaliver.go +++ b/internal/socket/keepaliver.go @@ -4,31 +4,30 @@ import ( "time" ) -type keepaliver struct { - ticker *time.Ticker - interval time.Duration +type Keepaliver struct { + ticker *time.Ticker + done chan struct{} } -func (p *keepaliver) C() <-chan time.Time { +func (p Keepaliver) C() <-chan time.Time { return p.ticker.C } -func (p *keepaliver) Stop() { - if p.ticker != nil { - p.ticker.Stop() - } +func (p Keepaliver) Done() <-chan struct{} { + return p.done } -func (p *keepaliver) Reset(interval time.Duration) { - if interval != p.interval { - p.ticker.Stop() - p.ticker = time.NewTicker(interval) - } +func (p Keepaliver) Stop() { + defer func() { + _ = recover() + }() + p.ticker.Stop() + close(p.done) } -func newKeepaliver(interval time.Duration) *keepaliver { - return &keepaliver{ - interval: interval, - ticker: time.NewTicker(interval), +func NewKeepaliver(interval time.Duration) *Keepaliver { + return &Keepaliver{ + ticker: time.NewTicker(interval), + done: make(chan struct{}), } } diff --git a/internal/socket/keepaliver_test.go b/internal/socket/keepaliver_test.go new file mode 100644 index 0000000..3cb01eb --- /dev/null +++ b/internal/socket/keepaliver_test.go @@ -0,0 +1,33 @@ +package socket_test + +import ( + "fmt" + "testing" + "time" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +func TestKeepaliver(t *testing.T) { + k := socket.NewKeepaliver(100 * time.Millisecond) + + time.AfterFunc(time.Second+50*time.Millisecond, func() { + k.Stop() + // stop again + k.Stop() + }) + + beats := 0 +L: + for { + select { + case v := <-k.C(): + fmt.Println(v) + beats++ + case <-k.Done(): + break L + } + } + assert.Equal(t, 10, beats, "beats should be 10") +} diff --git a/internal/socket/misc.go b/internal/socket/misc.go index 1d634ef..d387396 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -1,68 +1,18 @@ package socket import ( - "sync" "time" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/rx" ) -type u32map struct { - k sync.RWMutex - m map[uint32]interface{} -} - -func (p *u32map) Close() error { - p.k.Lock() - p.m = nil - p.k.Unlock() - return nil -} - -func (p *u32map) Range(fn func(uint32, interface{}) bool) { - p.k.RLock() - for key, value := range p.m { - if !fn(key, value) { - break - } - } - p.k.RUnlock() -} - -func (p *u32map) Load(key uint32) (v interface{}, ok bool) { - p.k.RLock() - v, ok = p.m[key] - p.k.RUnlock() - return -} - -func (p *u32map) Store(key uint32, value interface{}) { - p.k.Lock() - if p.m != nil { - p.m[key] = value - } - p.k.Unlock() -} - -func (p *u32map) Delete(key uint32) { - p.k.Lock() - delete(p.m, key) - p.k.Unlock() -} - -func newU32Map() *u32map { - return &u32map{ - m: make(map[uint32]interface{}), - } -} - // SetupInfo represents basic info of setup. type SetupInfo struct { Lease bool - Version common.Version + Version core.Version KeepaliveInterval time.Duration KeepaliveLifetime time.Duration Token []byte @@ -72,8 +22,8 @@ type SetupInfo struct { Metadata []byte } -func (p *SetupInfo) toFrame() *framing.FrameSetup { - return framing.NewFrameSetup( +func (p *SetupInfo) toFrame() core.WriteableFrame { + return framing.NewWriteableSetupFrame( p.Version, p.KeepaliveInterval, p.KeepaliveLifetime, @@ -101,14 +51,17 @@ func tryRecover(e interface{}) (err error) { return } -func toIntN(n uint32) int { +func ToIntRequestN(n uint32) int { if n > rx.RequestMax { return rx.RequestMax } return int(n) } -func toU32N(n int) uint32 { +func ToUint32RequestN(n int) uint32 { + if n < 0 { + panic("invalid negative int") + } if n > rx.RequestMax { return rx.RequestMax } diff --git a/internal/socket/misc_test.go b/internal/socket/misc_test.go new file mode 100644 index 0000000..f448896 --- /dev/null +++ b/internal/socket/misc_test.go @@ -0,0 +1,23 @@ +package socket_test + +import ( + "math" + "testing" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/rx" + "github.com/stretchr/testify/assert" +) + +func TestToUint32RequestN(t *testing.T) { + assert.Equal(t, uint32(1), socket.ToUint32RequestN(1)) + assert.Panics(t, func() { + socket.ToUint32RequestN(-1) + }, "should panic") + assert.Equal(t, uint32(rx.RequestMax), socket.ToUint32RequestN(math.MaxInt64)) +} + +func TestToIntRequestN(t *testing.T) { + assert.Equal(t, 1, socket.ToIntRequestN(1)) + assert.Equal(t, rx.RequestMax, socket.ToIntRequestN(math.MaxUint32)) +} diff --git a/internal/socket/mock_conn_test.go b/internal/socket/mock_conn_test.go new file mode 100644 index 0000000..8442e68 --- /dev/null +++ b/internal/socket/mock_conn_test.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/types.go + +// Package socket_test is a generated GoMock package. +package socket_test + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + core "github.com/rsocket/rsocket-go/core" +) + +// MockConn is a mock of Conn interface +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder +} + +// MockConnMockRecorder is the mock recorder for MockConn +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// SetDeadline mocks base method +func (m *MockConn) SetDeadline(deadline time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", deadline) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), deadline) +} + +// SetCounter mocks base method +func (m *MockConn) SetCounter(c *core.TrafficCounter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCounter", c) +} + +// SetCounter indicates an expected call of SetCounter +func (mr *MockConnMockRecorder) SetCounter(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCounter", reflect.TypeOf((*MockConn)(nil).SetCounter), c) +} + +// Read mocks base method +func (m *MockConn) Read() (core.Frame, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read") + ret0, _ := ret[0].(core.Frame) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockConnMockRecorder) Read() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read)) +} + +// Write mocks base method +func (m *MockConn) Write(arg0 core.WriteableFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write +func (mr *MockConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), arg0) +} + +// Flush mocks base method +func (m *MockConn) Flush() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +// Flush indicates an expected call of Flush +func (mr *MockConnMockRecorder) Flush() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockConn)(nil).Flush)) +} diff --git a/internal/socket/msg.go b/internal/socket/msg.go deleted file mode 100644 index d088e97..0000000 --- a/internal/socket/msg.go +++ /dev/null @@ -1,66 +0,0 @@ -package socket - -import ( - rs "github.com/jjeffcaii/reactor-go" - "github.com/rsocket/rsocket-go/rx" - "github.com/rsocket/rsocket-go/rx/flux" - "github.com/rsocket/rsocket-go/rx/mono" -) - -type closerWithError interface { - Close(error) -} - -type reqRS struct { - pc flux.Processor -} - -func (s reqRS) Close(err error) { - s.pc.Error(err) -} - -type reqRR struct { - pc mono.Processor -} - -func (s reqRR) Close(err error) { - s.pc.Error(err) -} - -type reqRC struct { - snd rx.Subscription - rcv flux.Processor -} - -func (s reqRC) Close(err error) { - s.snd.Cancel() - s.rcv.Error(err) -} - -type resRR struct { - su rs.Subscription -} - -func (s resRR) Close(err error) { - s.su.Cancel() - // TODO: fill err -} - -type resRS struct { - su rx.Subscription -} - -func (s resRS) Close(err error) { - s.su.Cancel() - // TODO: fill error -} - -type resRC struct { - snd rx.Subscription - rcv flux.Processor -} - -func (s resRC) Close(err error) { - s.rcv.Error(err) - s.snd.Cancel() -} diff --git a/internal/socket/client_resume.go b/internal/socket/resumable_client_socket.go similarity index 69% rename from internal/socket/client_resume.go rename to internal/socket/resumable_client_socket.go index cac1323..60d08b7 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/resumable_client_socket.go @@ -2,14 +2,13 @@ package socket import ( "context" - "crypto/tls" "errors" "math" "time" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/logger" "go.uber.org/atomic" ) @@ -17,18 +16,16 @@ import ( const reconnectDelay = 1 * time.Second type resumeClientSocket struct { - *baseSocket + *BaseSocket connects *atomic.Int32 - uri *transport.URI - headers map[string][]string setup *SetupInfo - tc *tls.Config + tp transport.ClientTransportFunc } func (p *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { p.setup = setup go func(ctx context.Context) { - _ = p.socket.loopWrite(ctx) + _ = p.socket.LoopWrite(ctx) }(ctx) return p.connect(ctx) } @@ -50,7 +47,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { _ = p.Close() return } - tp, err := p.uri.MakeClientTransport(p.tc, p.headers) + tp, err := p.tp(ctx) if err != nil { if connects == 1 { return @@ -78,24 +75,23 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { } }(ctx, tp) - var f framing.Frame + var f core.WriteableFrame // connect first time. if len(p.setup.Token) < 1 || connects == 1 { - tp.HandleDisaster(func(frame framing.Frame) (err error) { - p.socket.SetError(frame.(*framing.FrameError)) + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { + p.socket.SetError(frame.(*framing.ErrorFrame)) p.markClosing() return }) - f = p.setup.toFrame() err = tp.Send(f, true) p.socket.SetTransport(tp) return } - f = framing.NewFrameResume( - common.DefaultVersion, + f = framing.NewWriteableResumeFrame( + core.DefaultVersion, p.setup.Token, p.socket.counter.WriteBytes(), p.socket.counter.ReadBytes(), @@ -103,15 +99,15 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { resumeErr := make(chan string) - tp.HandleResumeOK(func(frame framing.Frame) (err error) { + tp.RegisterHandler(transport.OnResumeOK, func(frame core.Frame) (err error) { close(resumeErr) return }) - tp.HandleDisaster(func(frame framing.Frame) (err error) { + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { // TODO: process other error with zero StreamID - f := frame.(*framing.FrameError) - if f.ErrorCode() == common.ErrorCodeRejectedResume { + f := frame.(*framing.ErrorFrame) + if f.ErrorCode() == core.ErrorCodeRejectedResume { resumeErr <- f.Error() close(resumeErr) } @@ -148,13 +144,11 @@ func (p *resumeClientSocket) isClosed() bool { return p.connects.Load() < 0 } -// NewClientResume creates a client-side socket with resume support. -func NewClientResume(uri *transport.URI, socket *DuplexRSocket, tc *tls.Config, headers map[string][]string) ClientSocket { +// NewResumableClientSocket creates a client-side socket with resume support. +func NewResumableClientSocket(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { return &resumeClientSocket{ - baseSocket: newBaseSocket(socket), - uri: uri, - tc: tc, - headers: headers, + BaseSocket: NewBaseSocket(socket), connects: atomic.NewInt32(0), + tp: tp, } } diff --git a/internal/socket/resumable_client_socket_test.go b/internal/socket/resumable_client_socket_test.go new file mode 100644 index 0000000..565230a --- /dev/null +++ b/internal/socket/resumable_client_socket_test.go @@ -0,0 +1,47 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +func TestNewResumableClientSocket(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + + rcs := socket.NewResumableClientSocket(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer rcs.Close() + + err := rcs.Setup(context.Background(), fakeSetup) + assert.NoError(t, err) +} diff --git a/internal/socket/server_resume.go b/internal/socket/resumable_server_socket.go similarity index 77% rename from internal/socket/server_resume.go rename to internal/socket/resumable_server_socket.go index decf673..e9a580c 100644 --- a/internal/socket/server_resume.go +++ b/internal/socket/resumable_server_socket.go @@ -3,11 +3,11 @@ package socket import ( "context" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core/transport" ) type resumeServerSocket struct { - *baseSocket + *BaseSocket token []byte } @@ -33,13 +33,13 @@ func (p *resumeServerSocket) Start(ctx context.Context) error { defer func() { _ = p.Close() }() - return p.socket.loopWrite(ctx) + return p.socket.LoopWrite(ctx) } // NewServerResume creates a new server-side socket with resume support. -func NewServerResume(socket *DuplexRSocket, token []byte) ServerSocket { +func NewServerResume(socket *DuplexConnection, token []byte) ServerSocket { return &resumeServerSocket{ - baseSocket: newBaseSocket(socket), + BaseSocket: NewBaseSocket(socket), token: token, } } diff --git a/internal/socket/server_default.go b/internal/socket/server_default.go deleted file mode 100644 index 74de7f8..0000000 --- a/internal/socket/server_default.go +++ /dev/null @@ -1,41 +0,0 @@ -package socket - -import ( - "context" - - "github.com/rsocket/rsocket-go/internal/transport" -) - -type serverSocket struct { - *baseSocket -} - -func (p *serverSocket) Pause() bool { - return false -} - -func (p *serverSocket) SetResponder(responder Responder) { - p.socket.SetResponder(responder) -} - -func (p *serverSocket) SetTransport(tp *transport.Transport) { - p.socket.SetTransport(tp) -} - -func (p *serverSocket) Token() (token []byte, ok bool) { - return -} - -func (p *serverSocket) Start(ctx context.Context) error { - defer func() { - _ = p.Close() - }() - return p.socket.loopWrite(ctx) -} - -// NewServer creates a new server-side socket. -func NewServer(socket *DuplexRSocket) ServerSocket { - return &serverSocket{ - baseSocket: newBaseSocket(socket), - } -} diff --git a/internal/socket/simple_client_socket.go b/internal/socket/simple_client_socket.go new file mode 100644 index 0000000..e1770f3 --- /dev/null +++ b/internal/socket/simple_client_socket.go @@ -0,0 +1,62 @@ +package socket + +import ( + "context" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/logger" +) + +type simpleClientSocket struct { + *BaseSocket + tp transport.ClientTransportFunc +} + +func (p *simpleClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { + tp, err := p.tp(ctx) + if err != nil { + return + } + tp.Connection().SetCounter(p.socket.counter) + tp.SetLifetime(setup.KeepaliveLifetime) + + p.socket.SetTransport(tp) + + if setup.Lease { + p.refreshLease(0, 0) + tp.RegisterHandler(transport.OnLease, func(frame core.Frame) (err error) { + lease := frame.(*framing.LeaseFrame) + p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) + return + }) + } + + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { + p.socket.SetError(frame.(*framing.ErrorFrame)) + return + }) + + go func(ctx context.Context, tp *transport.Transport) { + if err := tp.Start(ctx); err != nil { + logger.Warnf("client exit failed: %+v\n", err) + } + _ = p.Close() + }(ctx, tp) + + go func() { + _ = p.socket.LoopWrite(ctx) + }() + setupFrame := setup.toFrame() + err = p.socket.tp.Send(setupFrame, true) + return +} + +// NewClient create a simple client-side socket. +func NewClient(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { + return &simpleClientSocket{ + BaseSocket: NewBaseSocket(socket), + tp: tp, + } +} diff --git a/internal/socket/simple_client_socket_test.go b/internal/socket/simple_client_socket_test.go new file mode 100644 index 0000000..43fb872 --- /dev/null +++ b/internal/socket/simple_client_socket_test.go @@ -0,0 +1,144 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestNewClientWithBrokenTransporter(t *testing.T) { + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + // Must failed transporter + transporter := func(ctx context.Context) (*transport.Transport, error) { + return nil, fakeErr + } + cli := socket.NewClient(transporter, ds) + err := cli.Setup(context.Background(), fakeSetup) + assert.Equal(t, fakeErr, err, "should be fake error") +} + +func TestNewClient(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + cli := socket.NewClient(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer func() { + err := cli.Close() + assert.NoError(t, err, "close client failed") + }() + + err := cli.Setup(context.Background(), fakeSetup) + assert.NoError(t, err, "setup client failed") + + requestId := atomic.NewUint32(1) + nextRequestId := func() uint32 { + return requestId.Add(2) - 2 + } + + result, err := cli.RequestResponse(payload.New(fakeData, fakeMetadata)). + DoOnSubscribe(func(s rx.Subscription) { + readChan <- framing.NewPayloadFrame(nextRequestId(), fakeData, fakeMetadata, core.FlagComplete) + }). + Block(context.Background()) + assert.NoError(t, err, "request response failed") + assert.Equal(t, fakeData, result.Data(), "response data doesn't match") + assert.Equal(t, fakeMetadata, extractMetadata(result), "response metadata doesn't match") + + var stream []payload.Payload + _, err = cli.RequestStream(payload.New(fakeData, fakeMetadata)). + DoOnNext(func(input payload.Payload) error { + stream = append(stream, input) + return nil + }). + DoOnSubscribe(func(s rx.Subscription) { + nextId := nextRequestId() + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext|core.FlagComplete) + }). + BlockLast(context.Background()) + assert.NoError(t, err, "request stream failed") + + // When a fatal error occurred, client should be stopped immediately. + fatalErr := []byte("fatal error") + readChan <- framing.NewErrorFrame(0, core.ErrorCodeRejected, fatalErr) + time.Sleep(100 * time.Millisecond) + err = ds.GetError() + assert.Error(t, err, "should get error") + assert.Equal(t, fatalErr, err.(core.CustomError).ErrorData()) +} + +func TestLease(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + cli := socket.NewClient(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer func() { + err := cli.Close() + assert.NoError(t, err, "close client failed") + }() + + setup := *fakeSetup + setup.Lease = true + err := cli.Setup(context.Background(), &setup) + assert.NoError(t, err, "setup client failed") + readChan <- framing.NewLeaseFrame(10*time.Second, 10, fakeMetadata) + time.Sleep(3 * time.Second) +} + +func extractMetadata(p payload.Payload) []byte { + m, _ := p.Metadata() + return m +} diff --git a/internal/socket/simple_server_socket.go b/internal/socket/simple_server_socket.go new file mode 100644 index 0000000..6247b07 --- /dev/null +++ b/internal/socket/simple_server_socket.go @@ -0,0 +1,41 @@ +package socket + +import ( + "context" + + "github.com/rsocket/rsocket-go/core/transport" +) + +type simpleServerSocket struct { + *BaseSocket +} + +func (p *simpleServerSocket) Pause() bool { + return false +} + +func (p *simpleServerSocket) SetResponder(responder Responder) { + p.socket.SetResponder(responder) +} + +func (p *simpleServerSocket) SetTransport(tp *transport.Transport) { + p.socket.SetTransport(tp) +} + +func (p *simpleServerSocket) Token() (token []byte, ok bool) { + return +} + +func (p *simpleServerSocket) Start(ctx context.Context) error { + defer func() { + _ = p.Close() + }() + return p.socket.LoopWrite(ctx) +} + +// NewSimpleServerSocket creates a new server-side socket. +func NewSimpleServerSocket(socket *DuplexConnection) ServerSocket { + return &simpleServerSocket{ + BaseSocket: NewBaseSocket(socket), + } +} diff --git a/internal/socket/simple_server_socket_test.go b/internal/socket/simple_server_socket_test.go new file mode 100644 index 0000000..74551e2 --- /dev/null +++ b/internal/socket/simple_server_socket_test.go @@ -0,0 +1,81 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +var fakeResponder = rsocket.NewAbstractSocket() + +func TestSimpleServerSocket_Start(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + setupFrame := framing.NewSetupFrame( + core.DefaultVersion, + 30*time.Second, + 90*time.Second, + nil, + fakeMimeType, + fakeMimeType, + fakeData, + fakeMetadata, + false, + ) + readChan <- setupFrame + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).AnyTimes() + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + firstFrame, err := tp.ReadFirst(context.Background()) + assert.NoError(t, err, "read first frame failed") + assert.Equal(t, setupFrame, firstFrame, "first should be setup frame") + + close(readChan) + + ds := socket.NewServerDuplexConnection(fragmentation.MaxFragment, nil) + ss := socket.NewSimpleServerSocket(ds) + ss.SetResponder(fakeResponder) + ss.SetTransport(tp) + + assert.Equal(t, false, ss.Pause(), "should always returns false") + token, ok := ss.Token() + assert.False(t, ok) + assert.Nil(t, token, "token should be nil") + + done := make(chan struct{}) + go func() { + defer close(done) + err := ss.Start(context.Background()) + assert.NoError(t, err, "start server socket failed") + }() + + err = tp.Start(context.Background()) + assert.NoError(t, err, "start transport failed") + + _ = ds.Close() + + <-done +} diff --git a/internal/socket/smap_test.go b/internal/socket/smap_test.go deleted file mode 100644 index 6a698b0..0000000 --- a/internal/socket/smap_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package socket - -import ( - "crypto/rand" - "encoding/binary" - "sync" - "testing" -) - -func nextID() uint32 { - b4 := make([]byte, 4) - _, _ = rand.Read(b4) - return uint32(maskStreamID) & binary.BigEndian.Uint32(b4) -} - -func BenchmarkLock(b *testing.B) { - var v reqRC - m := make(map[uint32]interface{}) - var lk sync.RWMutex - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - id := nextID() - switch id % 3 { - case 1: - lk.Lock() - m[id] = v - lk.Unlock() - case 2: - lk.RLock() - _ = m[id] - lk.RUnlock() - default: - lk.Lock() - delete(m, id) - lk.Unlock() - } - } - }) -} - -func BenchmarkSync(b *testing.B) { - var v reqRC - m := &sync.Map{} - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - id := nextID() - switch id % 3 { - case 1: - m.Store(id, v) - case 2: - m.Load(id) - default: - m.Delete(id) - } - } - }) - -} diff --git a/internal/socket/socket.go b/internal/socket/socket.go deleted file mode 100644 index f4fee83..0000000 --- a/internal/socket/socket.go +++ /dev/null @@ -1,197 +0,0 @@ -package socket - -import ( - "context" - "io" - "sync" - "time" - - "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/transport" - "github.com/rsocket/rsocket-go/logger" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx" - "github.com/rsocket/rsocket-go/rx/flux" - "github.com/rsocket/rsocket-go/rx/mono" -) - -var ( - errUnimplementedMetadataPush = errors.New("METADATA_PUSH is unimplemented") - errUnimplementedFireAndForget = errors.New("FIRE_AND_FORGET is unimplemented") - errUnimplementedRequestResponse = errors.New("REQUEST_RESPONSE is unimplemented") - errUnimplementedRequestStream = errors.New("REQUEST_STREAM is unimplemented") - errUnimplementedRequestChannel = errors.New("REQUEST_CHANNEL is unimplemented") -) - -// Closeable represents a closeable target. -type Closeable interface { - io.Closer - // OnClose bind a handler when closing. - OnClose(closer func(error)) -} - -// Responder is a contract providing different interaction models for RSocket protocol. -type Responder interface { - // FireAndForget is a single one-way message. - FireAndForget(message payload.Payload) - // MetadataPush sends asynchronous Metadata frame. - MetadataPush(message payload.Payload) - // RequestResponse request single response. - RequestResponse(message payload.Payload) mono.Mono - // RequestStream request a completable stream. - RequestStream(message payload.Payload) flux.Flux - // RequestChannel request a completable stream in both directions. - RequestChannel(messages rx.Publisher) flux.Flux -} - -// ClientSocket represents a client-side socket. -type ClientSocket interface { - Closeable - Responder - // Setup setups current socket. - Setup(ctx context.Context, setup *SetupInfo) (err error) -} - -// ServerSocket represents a server-side socket. -type ServerSocket interface { - Closeable - Responder - // SetResponder sets a responder for current socket. - SetResponder(responder Responder) - // SetTransport sets a transport for current socket. - SetTransport(tp *transport.Transport) - // Pause pause current socket. - Pause() bool - // Start starts current socket. - Start(ctx context.Context) error - // Token returns token of socket. - Token() (token []byte, ok bool) -} - -// AbstractRSocket represents an abstract RSocket. -type AbstractRSocket struct { - FF func(payload.Payload) - MP func(payload.Payload) - RR func(payload.Payload) mono.Mono - RS func(payload.Payload) flux.Flux - RC func(rx.Publisher) flux.Flux -} - -// MetadataPush starts a request of MetadataPush. -func (p AbstractRSocket) MetadataPush(message payload.Payload) { - if p.MP == nil { - logger.Errorf("%s\n", errUnimplementedMetadataPush) - return - } - p.MP(message) -} - -// FireAndForget starts a request of FireAndForget. -func (p AbstractRSocket) FireAndForget(message payload.Payload) { - if p.FF == nil { - logger.Errorf("%s\n", errUnimplementedFireAndForget) - return - } - p.FF(message) -} - -// RequestResponse starts a request of RequestResponse. -func (p AbstractRSocket) RequestResponse(message payload.Payload) mono.Mono { - if p.RR == nil { - return mono.Error(errUnimplementedRequestResponse) - } - return p.RR(message) -} - -// RequestStream starts a request of RequestStream. -func (p AbstractRSocket) RequestStream(message payload.Payload) flux.Flux { - if p.RS == nil { - return flux.Error(errUnimplementedRequestStream) - } - return p.RS(message) -} - -// RequestChannel starts a request of RequestChannel. -func (p AbstractRSocket) RequestChannel(messages rx.Publisher) flux.Flux { - if p.RC == nil { - return flux.Error(errUnimplementedRequestChannel) - } - return p.RC(messages) -} - -type baseSocket struct { - socket *DuplexRSocket - closers []func(error) - once sync.Once - reqLease *leaser -} - -func (p *baseSocket) refreshLease(ttl time.Duration, n int64) { - deadline := time.Now().Add(ttl) - if p.reqLease == nil { - p.reqLease = newLeaser(deadline, n) - } else { - p.reqLease.refresh(deadline, n) - } -} - -func (p *baseSocket) FireAndForget(message payload.Payload) { - if err := p.reqLease.allow(); err != nil { - logger.Warnf("request FireAndForget failed: %v\n", err) - } - p.socket.FireAndForget(message) -} - -func (p *baseSocket) MetadataPush(message payload.Payload) { - p.socket.MetadataPush(message) -} - -func (p *baseSocket) RequestResponse(message payload.Payload) mono.Mono { - if err := p.reqLease.allow(); err != nil { - return mono.Error(err) - } - return p.socket.RequestResponse(message) -} - -func (p *baseSocket) RequestStream(message payload.Payload) flux.Flux { - if err := p.reqLease.allow(); err != nil { - return flux.Error(err) - } - return p.socket.RequestStream(message) -} - -func (p *baseSocket) RequestChannel(messages rx.Publisher) flux.Flux { - if err := p.reqLease.allow(); err != nil { - return flux.Error(err) - } - return p.socket.RequestChannel(messages) -} - -func (p *baseSocket) OnClose(fn func(error)) { - if fn != nil { - p.closers = append(p.closers, fn) - } -} - -func (p *baseSocket) Close() (err error) { - p.once.Do(func() { - err = p.socket.Close() - for i, l := 0, len(p.closers); i < l; i++ { - func(fn func(error)) { - defer func() { - if e := tryRecover(recover()); e != nil { - logger.Errorf("handle socket closer failed: %s\n", e) - } - }() - fn(err) - }(p.closers[l-i-1]) - } - }) - return -} - -func newBaseSocket(rawSocket *DuplexRSocket) *baseSocket { - return &baseSocket{ - socket: rawSocket, - } -} diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go new file mode 100644 index 0000000..aee5693 --- /dev/null +++ b/internal/socket/socket_test.go @@ -0,0 +1,42 @@ +package socket_test + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/logger" +) + +var ( + fakeErr = errors.New("fake error") + + fakeMetadata = []byte("fake-metadata") + fakeData = []byte("fake-data") + fakeMimeType = []byte("fake-mime-type") + + fakeSetup = &socket.SetupInfo{ + Version: core.DefaultVersion, + MetadataMimeType: fakeMimeType, + DataMimeType: fakeMimeType, + Metadata: fakeMetadata, + Data: fakeData, + KeepaliveLifetime: 90 * time.Second, + KeepaliveInterval: 30 * time.Second, + } +) + +func InitTransport(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { + ctrl := gomock.NewController(t) + conn := NewMockConn(ctrl) + tp := transport.NewTransport(conn) + return ctrl, conn, tp +} + +func init() { + logger.SetLevel(logger.LevelError) +} diff --git a/internal/socket/stream_id.go b/internal/socket/stream_id.go index 3bec2da..771e2a8 100644 --- a/internal/socket/stream_id.go +++ b/internal/socket/stream_id.go @@ -5,38 +5,40 @@ import ( ) const ( - maskStreamID uint64 = 0x7FFFFFFF - halfSeed uint64 = 0x40000000 + _maskStreamID uint64 = 0x7FFFFFFF + _halfSeed uint64 = 0x40000000 ) -type streamIDs interface { - next() (id uint32, lap1st bool) +// StreamID can be used to generate stream ids. +type StreamID interface { + // Next returns next stream id. + Next() (id uint32, firstLoop bool) } type serverStreamIDs struct { cur uint64 } -func (p *serverStreamIDs) next() (uint32, bool) { +func (p *serverStreamIDs) Next() (uint32, bool) { // 2,4,6,8... seed := atomic.AddUint64(&p.cur, 1) v := 2 * seed if v != 0 { - return uint32(maskStreamID & v), seed <= halfSeed + return uint32(_maskStreamID & v), seed <= _halfSeed } - return p.next() + return p.Next() } type clientStreamIDs struct { cur uint64 } -func (p *clientStreamIDs) next() (uint32, bool) { +func (p *clientStreamIDs) Next() (uint32, bool) { // 1,3,5,7 seed := atomic.AddUint64(&p.cur, 1) v := 2*(seed-1) + 1 if v != 0 { - return uint32(maskStreamID & v), seed <= halfSeed + return uint32(_maskStreamID & v), seed <= _halfSeed } - return p.next() + return p.Next() } diff --git a/internal/socket/stream_id_test.go b/internal/socket/stream_id_test.go new file mode 100644 index 0000000..75f455a --- /dev/null +++ b/internal/socket/stream_id_test.go @@ -0,0 +1,39 @@ +package socket + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClientStreamIDs_Next(t *testing.T) { + ids := clientStreamIDs{} + id, firstLap := ids.Next() + assert.Equal(t, uint32(1), id) + assert.True(t, firstLap) +} + +func TestServerStreamIDs_Next(t *testing.T) { + ids := serverStreamIDs{} + id, firstLap := ids.Next() + assert.Equal(t, uint32(2), id) + assert.True(t, firstLap) +} + +func BenchmarkServerStreamIDs_Next(b *testing.B) { + ids := serverStreamIDs{} + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ids.Next() + } + }) +} + +func BenchmarkClientStreamIDs_Next(b *testing.B) { + ids := clientStreamIDs{} + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ids.Next() + } + }) +} diff --git a/internal/socket/types.go b/internal/socket/types.go new file mode 100644 index 0000000..6c109c2 --- /dev/null +++ b/internal/socket/types.go @@ -0,0 +1,57 @@ +package socket + +import ( + "context" + "io" + + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +// Closeable represents a closeable target. +type Closeable interface { + io.Closer + // OnClose bind a handler when closing. + OnClose(closer func(error)) +} + +// Responder is a contract providing different interaction models for RSocket protocol. +type Responder interface { + // FireAndForget is a single one-way message. + FireAndForget(message payload.Payload) + // MetadataPush sends asynchronous Metadata frame. + MetadataPush(message payload.Payload) + // RequestResponse request single response. + RequestResponse(message payload.Payload) mono.Mono + // RequestStream request a completable stream. + RequestStream(message payload.Payload) flux.Flux + // RequestChannel request a completable stream in both directions. + RequestChannel(messages rx.Publisher) flux.Flux +} + +// ClientSocket represents a client-side socket. +type ClientSocket interface { + Closeable + Responder + // Setup setups current socket. + Setup(ctx context.Context, setup *SetupInfo) error +} + +// ServerSocket represents a server-side socket. +type ServerSocket interface { + Closeable + Responder + // SetResponder sets a responder for current socket. + SetResponder(responder Responder) + // SetTransport sets a transport for current socket. + SetTransport(tp *transport.Transport) + // Pause pause current socket. + Pause() bool + // Start starts current socket. + Start(ctx context.Context) error + // Token returns token of socket. + Token() (token []byte, ok bool) +} diff --git a/internal/transport/connection_tcp.go b/internal/transport/connection_tcp.go deleted file mode 100644 index e045278..0000000 --- a/internal/transport/connection_tcp.go +++ /dev/null @@ -1,109 +0,0 @@ -package transport - -import ( - "bufio" - "io" - "net" - "time" - - "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/logger" -) - -type tcpConn struct { - rawConn net.Conn - writer *bufio.Writer - decoder *LengthBasedFrameDecoder - counter *Counter -} - -func (p *tcpConn) SetCounter(c *Counter) { - p.counter = c -} - -func (p *tcpConn) SetDeadline(deadline time.Time) error { - return p.rawConn.SetReadDeadline(deadline) -} - -func (p *tcpConn) Read() (f framing.Frame, err error) { - raw, err := p.decoder.Read() - if err == io.EOF { - return - } - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - h := framing.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, err = bf.Write(raw[framing.HeaderLen:]) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - base := framing.NewBaseFrame(h, bf) - if p.counter != nil && base.CanResume() { - p.counter.incrReadBytes(base.Len()) - } - f, err = framing.NewFromBase(base) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - err = f.Validate() - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - if logger.IsDebugEnabled() { - logger.Debugf("<--- rcv: %s\n", f) - } - return -} - -func (p *tcpConn) Flush() (err error) { - err = p.writer.Flush() - if err != nil { - err = errors.Wrap(err, "flush failed") - } - return -} - -func (p *tcpConn) Write(frame framing.Frame) (err error) { - size := frame.Len() - if p.counter != nil && frame.CanResume() { - p.counter.incrWriteBytes(size) - } - _, err = common.NewUint24(size).WriteTo(p.writer) - if err != nil { - err = errors.Wrap(err, "write frame failed") - return - } - var debugStr string - if logger.IsDebugEnabled() { - debugStr = frame.String() - } - _, err = frame.WriteTo(p.writer) - if err != nil { - err = errors.Wrap(err, "write frame failed") - return - } - if logger.IsDebugEnabled() { - logger.Debugf("---> snd: %s\n", debugStr) - } - return -} - -func (p *tcpConn) Close() error { - return p.rawConn.Close() -} - -func newTCPRConnection(rawConn net.Conn) *tcpConn { - return &tcpConn{ - rawConn: rawConn, - writer: bufio.NewWriter(rawConn), - decoder: NewLengthBasedFrameDecoder(rawConn), - } -} diff --git a/internal/transport/connection_ws.go b/internal/transport/connection_ws.go deleted file mode 100644 index 1bbd2d5..0000000 --- a/internal/transport/connection_ws.go +++ /dev/null @@ -1,93 +0,0 @@ -package transport - -import ( - "io" - "time" - - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/logger" -) - -type wsConnection struct { - c *websocket.Conn - counter *Counter -} - -func (p *wsConnection) SetCounter(c *Counter) { - p.counter = c -} - -func (p *wsConnection) SetDeadline(deadline time.Time) error { - return p.c.SetReadDeadline(deadline) -} - -func (p *wsConnection) Read() (f framing.Frame, err error) { - t, raw, err := p.c.ReadMessage() - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - if t != websocket.BinaryMessage { - logger.Warnf("omit non-binary message %d\n", t) - return p.Read() - } - // validate min length - if len(raw) < framing.HeaderLen { - err = errors.Wrap(ErrIncompleteHeader, "read frame failed") - return - } - header := framing.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, err = bf.Write(raw[framing.HeaderLen:]) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - base := framing.NewBaseFrame(header, bf) - f, err = framing.NewFromBase(base) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - err = f.Validate() - if err != nil { - err = errors.Wrap(err, "read frame failed") - return - } - if logger.IsDebugEnabled() { - logger.Debugf("<--- rcv: %s\n", f) - } - return -} - -func (p *wsConnection) Flush() (err error) { - return -} - -func (p *wsConnection) Write(frame framing.Frame) (err error) { - err = p.c.WriteMessage(websocket.BinaryMessage, frame.Bytes()) - if err == io.EOF { - return - } - if err != nil { - err = errors.Wrap(err, "write frame failed") - return - } - if logger.IsDebugEnabled() { - logger.Debugf("---> snd: %s\n", frame) - } - return -} - -func (p *wsConnection) Close() error { - return p.c.Close() -} - -func newWebsocketConnection(rawConn *websocket.Conn) *wsConnection { - return &wsConnection{ - c: rawConn, - } -} diff --git a/internal/transport/counter.go b/internal/transport/counter.go deleted file mode 100644 index a8fd46b..0000000 --- a/internal/transport/counter.go +++ /dev/null @@ -1,36 +0,0 @@ -package transport - -import ( - "go.uber.org/atomic" -) - -// Counter represents a counter of read/write bytes. -type Counter struct { - r, w *atomic.Uint64 -} - -// ReadBytes returns the number of bytes that have been read. -func (p Counter) ReadBytes() uint64 { - return p.r.Load() -} - -// WriteBytes returns the number of bytes that have been written. -func (p Counter) WriteBytes() uint64 { - return p.w.Load() -} - -func (p Counter) incrWriteBytes(n int) { - p.w.Add(uint64(n)) -} - -func (p Counter) incrReadBytes(n int) { - p.r.Add(uint64(n)) -} - -// NewCounter returns a new counter. -func NewCounter() *Counter { - return &Counter{ - r: atomic.NewUint64(0), - w: atomic.NewUint64(0), - } -} diff --git a/internal/transport/decoder_test.go b/internal/transport/decoder_test.go deleted file mode 100644 index 39c362a..0000000 --- a/internal/transport/decoder_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package transport - -import ( - "bytes" - "encoding/hex" - "fmt" - "testing" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" -) - -func TestDecoder(t *testing.T) { - bs, _ := hex.DecodeString("000012000000012920000003797979776f726c6432000006000000012840") - r := bytes.NewBuffer(bs) - - d := NewLengthBasedFrameDecoder(r) - - for { - raw, err := d.Read() - if err != nil { - break - } - h := framing.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, _ = bf.Write(raw[framing.HeaderLen:]) - f, err := framing.NewFromBase(framing.NewBaseFrame(h, bf)) - if err != nil { - panic(err) - } - fmt.Println(f) - } - -} diff --git a/internal/transport/transport.go b/internal/transport/transport.go deleted file mode 100644 index 58690f2..0000000 --- a/internal/transport/transport.go +++ /dev/null @@ -1,317 +0,0 @@ -package transport - -import ( - "context" - "io" - "sync" - "time" - - "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/logger" -) - -type ( - // FrameHandler is an alias of frame handler. - FrameHandler = func(frame framing.Frame) (err error) - // ServerTransportAcceptor is an alias of server transport handler. - ServerTransportAcceptor = func(ctx context.Context, tp *Transport) -) - -var errTransportClosed = errors.New("transport closed") - -// ServerTransport is server-side RSocket transport. -type ServerTransport interface { - io.Closer - // Accept register incoming connection handler. - Accept(acceptor ServerTransportAcceptor) - // Listen listens on the network address addr and handles requests on incoming connections. - // You can specify onReady handler, it'll be invoked when server begin listening. - // It always returns a non-nil error. - Listen(ctx context.Context, notifier chan<- struct{}) error -} - -// Transport is RSocket transport which is used to carry RSocket frames. -type Transport struct { - conn Conn - maxLifetime time.Duration - lastRcvPos uint64 - once sync.Once - - hSetup FrameHandler - hResume FrameHandler - hLease FrameHandler - hResumeOK FrameHandler - hFireAndForget FrameHandler - hMetadataPush FrameHandler - hRequestResponse FrameHandler - hRequestStream FrameHandler - hRequestChannel FrameHandler - hPayload FrameHandler - hRequestN FrameHandler - hError FrameHandler - hError0 FrameHandler - hCancel FrameHandler - hKeepalive FrameHandler -} - -// HandleDisaster registers handler when receiving frame of DISASTER Error with zero StreamID. -func (p *Transport) HandleDisaster(handler FrameHandler) { - p.hError0 = handler -} - -// Connection returns current connection. -func (p *Transport) Connection() Conn { - return p.conn -} - -// SetLifetime set max lifetime for current transport. -func (p *Transport) SetLifetime(lifetime time.Duration) { - if lifetime < 1 { - return - } - p.maxLifetime = lifetime -} - -// Send send a frame. -func (p *Transport) Send(frame framing.Frame, flush bool) (err error) { - defer func() { - // ensure frame done when send success. - if err == nil { - frame.Done() - } - }() - if p == nil || p.conn == nil { - err = errTransportClosed - return - } - err = p.conn.Write(frame) - if err != nil { - return - } - if !flush { - return - } - err = p.conn.Flush() - return -} - -// Flush flush all bytes in current connection. -func (p *Transport) Flush() (err error) { - if p == nil || p.conn == nil { - err = errTransportClosed - return - } - err = p.conn.Flush() - return -} - -// Close close current transport. -func (p *Transport) Close() (err error) { - p.once.Do(func() { - err = p.conn.Close() - }) - return -} - -// ReadFirst reads first frame. -func (p *Transport) ReadFirst(ctx context.Context) (frame framing.Frame, err error) { - select { - case <-ctx.Done(): - err = ctx.Err() - default: - frame, err = p.conn.Read() - if err != nil { - err = errors.Wrap(err, "read first frame failed") - } - } - if err != nil { - _ = p.Close() - } - return -} - -// Start start transport. -func (p *Transport) Start(ctx context.Context) (err error) { - defer func() { - _ = p.Close() - }() -L: - for { - select { - case <-ctx.Done(): - err = ctx.Err() - return - default: - f, err := p.conn.Read() - if err != nil { - break L - } - err = p.DeliveryFrame(ctx, f) - if err != nil { - break L - } - } - } - if err == io.EOF { - err = nil - return - } - if err != nil { - err = errors.Wrap(err, "read and delivery frame failed") - } - return -} - -// HandleSetup registers handler when receiving a frame of Setup. -func (p *Transport) HandleSetup(handler FrameHandler) { - p.hSetup = handler -} - -// HandleResume registers handler when receiving a frame of Resume. -func (p *Transport) HandleResume(handler FrameHandler) { - p.hResume = handler -} - -func (p *Transport) HandleLease(handler FrameHandler) { - p.hLease = handler -} - -// HandleResumeOK registers handler when receiving a frame of ResumeOK. -func (p *Transport) HandleResumeOK(handler FrameHandler) { - p.hResumeOK = handler -} - -// HandleFNF registers handler when receiving a frame of FireAndForget. -func (p *Transport) HandleFNF(handler FrameHandler) { - p.hFireAndForget = handler -} - -// HandleMetadataPush registers handler when receiving a frame of MetadataPush. -func (p *Transport) HandleMetadataPush(handler FrameHandler) { - p.hMetadataPush = handler -} - -// HandleRequestResponse registers handler when receiving a frame of RequestResponse. -func (p *Transport) HandleRequestResponse(handler FrameHandler) { - p.hRequestResponse = handler -} - -// HandleRequestStream registers handler when receiving a frame of RequestStream. -func (p *Transport) HandleRequestStream(handler FrameHandler) { - p.hRequestStream = handler -} - -// HandleRequestChannel registers handler when receiving a frame of RequestChannel. -func (p *Transport) HandleRequestChannel(handler FrameHandler) { - p.hRequestChannel = handler -} - -// HandlePayload registers handler when receiving a frame of Payload. -func (p *Transport) HandlePayload(handler FrameHandler) { - p.hPayload = handler -} - -// HandleRequestN registers handler when receiving a frame of RequestN. -func (p *Transport) HandleRequestN(handler FrameHandler) { - p.hRequestN = handler -} - -// HandleError registers handler when receiving a frame of Error. -func (p *Transport) HandleError(handler FrameHandler) { - p.hError = handler -} - -// HandleCancel registers handler when receiving a frame of Cancel. -func (p *Transport) HandleCancel(handler FrameHandler) { - p.hCancel = handler -} - -// HandleKeepalive registers handler when receiving a frame of Keepalive. -func (p *Transport) HandleKeepalive(handler FrameHandler) { - p.hKeepalive = handler -} - -// DeliveryFrame delivery incoming frames. -func (p *Transport) DeliveryFrame(_ context.Context, frame framing.Frame) (err error) { - header := frame.Header() - t := header.Type() - sid := header.StreamID() - - var handler FrameHandler - - switch t { - case framing.FrameTypeSetup: - p.maxLifetime = frame.(*framing.FrameSetup).MaxLifetime() - handler = p.hSetup - case framing.FrameTypeResume: - handler = p.hResume - case framing.FrameTypeResumeOK: - p.lastRcvPos = frame.(*framing.FrameResumeOK).LastReceivedClientPosition() - handler = p.hResumeOK - case framing.FrameTypeRequestFNF: - handler = p.hFireAndForget - case framing.FrameTypeMetadataPush: - if sid != 0 { - // skip invalid metadata push - logger.Warnf("rsocket.Transport: omit MetadataPush with non-zero stream id %d\n", sid) - return - } - handler = p.hMetadataPush - case framing.FrameTypeRequestResponse: - handler = p.hRequestResponse - case framing.FrameTypeRequestStream: - handler = p.hRequestStream - case framing.FrameTypeRequestChannel: - handler = p.hRequestChannel - case framing.FrameTypePayload: - handler = p.hPayload - case framing.FrameTypeRequestN: - handler = p.hRequestN - case framing.FrameTypeError: - if sid == 0 { - err = errors.New(frame.(*framing.FrameError).Error()) - if p.hError0 != nil { - _ = p.hError0(frame) - } - return - } - handler = p.hError - case framing.FrameTypeCancel: - handler = p.hCancel - case framing.FrameTypeKeepalive: - ka := frame.(*framing.FrameKeepalive) - p.lastRcvPos = ka.LastReceivedPosition() - handler = p.hKeepalive - case framing.FrameTypeLease: - handler = p.hLease - } - - // Set deadline. - deadline := time.Now().Add(p.maxLifetime) - err = p.conn.SetDeadline(deadline) - if err != nil { - return - } - - // missing handler - if handler == nil { - err = errors.Errorf("missing frame handler: type=%s", t) - return - } - - // trigger handler - err = handler(frame) - if err != nil { - err = errors.Wrap(err, "exec frame handler failed") - } - return -} - -func newTransportClient(c Conn) *Transport { - return &Transport{ - conn: c, - maxLifetime: common.DefaultKeepaliveMaxLifetime, - } -} diff --git a/internal/transport/transport_tcp.go b/internal/transport/transport_tcp.go deleted file mode 100644 index 372efc5..0000000 --- a/internal/transport/transport_tcp.go +++ /dev/null @@ -1,122 +0,0 @@ -package transport - -import ( - "context" - "crypto/tls" - "io" - "net" - "os" - "os/signal" - "sync" - "syscall" - - "github.com/pkg/errors" -) - -type tcpServerTransport struct { - network, addr string - acceptor ServerTransportAcceptor - listener net.Listener - onceClose sync.Once - tls *tls.Config -} - -func (p *tcpServerTransport) Accept(acceptor ServerTransportAcceptor) { - p.acceptor = acceptor -} - -func (p *tcpServerTransport) Close() (err error) { - if p.listener == nil { - return - } - p.onceClose.Do(func() { - err = p.listener.Close() - }) - return -} - -func (p *tcpServerTransport) Listen(ctx context.Context, notifier chan<- struct{}) (err error) { - if p.tls == nil { - p.listener, err = net.Listen(p.network, p.addr) - if err != nil { - err = errors.Wrap(err, "server listen failed") - return - } - } else { - p.listener, err = tls.Listen(p.network, p.addr, p.tls) - if err != nil { - err = errors.Wrap(err, "server listen failed") - return - } - } - notifier <- struct{}{} - return p.listen(ctx) -} - -func (p *tcpServerTransport) listen(ctx context.Context) (err error) { - // Remove unix socket file before exit. - if p.network == schemaUNIX { - // Monitor signal of current process and unlink unix socket file. - go func(sock string) { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - _ = p.Close() - }(p.addr) - } - - stop := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - go func(ctx context.Context, stop chan struct{}) { - defer func() { - _ = p.Close() - close(stop) - }() - <-ctx.Done() - }(ctx, stop) - - // Start loop of accepting connections. - var c net.Conn - for { - c, err = p.listener.Accept() - if err == io.EOF || isClosedErr(err) { - err = nil - break - } - if err != nil { - err = errors.Wrap(err, "accept next conn failed") - break - } - // Dispatch raw conn. - go func(ctx context.Context, rawConn net.Conn) { - conn := newTCPRConnection(rawConn) - tp := newTransportClient(conn) - p.acceptor(ctx, tp) - }(ctx, c) - } - cancel() - <-stop - return -} - -func newTCPServerTransport(network, addr string, c *tls.Config) *tcpServerTransport { - return &tcpServerTransport{ - network: network, - addr: addr, - tls: c, - } -} - -func newTCPClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { - var rawConn net.Conn - if tlsConfig == nil { - rawConn, err = net.Dial(network, addr) - } else { - rawConn, err = tls.Dial(network, addr, tlsConfig) - } - if err != nil { - return - } - tp = newTransportClient(newTCPRConnection(rawConn)) - return -} diff --git a/internal/transport/uri.go b/internal/transport/uri.go deleted file mode 100644 index d7dd366..0000000 --- a/internal/transport/uri.go +++ /dev/null @@ -1,95 +0,0 @@ -package transport - -import ( - "crypto/tls" - "net/url" - "strings" - - "github.com/pkg/errors" -) - -const ( - schemaUNIX = "unix" - schemaTCP = "tcp" - schemaWebsocket = "ws" - schemaWebsocketSecure = "wss" -) - -// URI represents a URI of RSocket transport. -type URI url.URL - -var tlsInsecure = &tls.Config{ - InsecureSkipVerify: true, -} - -// IsWebsocket returns true if current uri is websocket. -func (p *URI) IsWebsocket() bool { - switch strings.ToLower(p.Scheme) { - case schemaWebsocket, schemaWebsocketSecure: - return true - default: - return false - } -} - -// MakeClientTransport creates a new client-side transport. -func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) (*Transport, error) { - switch strings.ToLower(p.Scheme) { - case schemaTCP: - return newTCPClientTransport(schemaTCP, p.Host, tc) - case schemaWebsocket: - if tc == nil { - return newWebsocketClientTransport(p.pp().String(), nil, headers) - } - var clone = (url.URL)(*p) - clone.Scheme = "wss" - return newWebsocketClientTransport(clone.String(), tc, headers) - case schemaWebsocketSecure: - if tc == nil { - tc = tlsInsecure - } - return newWebsocketClientTransport(p.pp().String(), tc, headers) - case schemaUNIX: - return newTCPClientTransport(schemaUNIX, p.Path, tc) - default: - return nil, errors.Errorf("unsupported transport url: %s", p.pp().String()) - } -} - -// MakeServerTransport creates a new server-side transport. -func (p *URI) MakeServerTransport(c *tls.Config) (tp ServerTransport, err error) { - switch strings.ToLower(p.Scheme) { - case schemaTCP: - tp = newTCPServerTransport(schemaTCP, p.Host, c) - case schemaWebsocket: - tp = newWebsocketServerTransport(p.Host, p.Path, c) - case schemaWebsocketSecure: - if c == nil { - err = errors.Errorf("missing TLS Config for proto %s", schemaWebsocketSecure) - return - } - tp = newWebsocketServerTransport(p.Host, p.Path, c) - case schemaUNIX: - tp = newTCPServerTransport(schemaUNIX, p.Path, c) - default: - err = errors.Errorf("unsupported transport url: %s", p.pp().String()) - } - return -} - -func (p *URI) String() string { - return p.pp().String() -} - -func (p *URI) pp() *url.URL { - return (*url.URL)(p) -} - -// ParseURI parse URI string and returns a URI. -func ParseURI(rawUrl string) (*URI, error) { - u, err := url.Parse(rawUrl) - if err != nil { - return nil, errors.Wrapf(err, "parse url failed: %s", rawUrl) - } - return (*URI)(u), nil -} diff --git a/internal/transport/uri_test.go b/internal/transport/uri_test.go deleted file mode 100644 index 3f0c345..0000000 --- a/internal/transport/uri_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package transport - -import ( - "log" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseUTI(t *testing.T) { - uri, err := ParseURI("tcp://127.0.0.1:8080") - assert.NoError(t, err, "bad URI") - log.Println(uri) -} - -func TestName(t *testing.T) { - //u, err := url.Parse("unix:///tmp/rsocket.sock") - u, err := url.Parse("tcp://127.0.0.1:8080") - require.NoError(t, err, "bad parse") - log.Println("schema:", u.Scheme) - log.Println("host:", u.Host) - log.Println("path:", u.Path) -} diff --git a/justfile b/justfile new file mode 100644 index 0000000..52d01f0 --- /dev/null +++ b/justfile @@ -0,0 +1,12 @@ +default: + echo 'Hello, world!' +lint: + golangci-lint run ./... +test: + go test -count=1 -race -coverprofile=coverage.out ./... +test-no-cover: + go test -count=1 -race ./... +fmt: + @go fmt ./... +cover: + @go tool cover -html=coverage.out diff --git a/lease/lease_test.go b/lease/lease_test.go new file mode 100644 index 0000000..fd2ccf7 --- /dev/null +++ b/lease/lease_test.go @@ -0,0 +1,21 @@ +package lease_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/rsocket/rsocket-go/lease" + "github.com/stretchr/testify/assert" +) + +func TestSimpleLease_Next(t *testing.T) { + l, err := lease.NewSimpleLease(3*time.Second, 1*time.Second, 1*time.Second, 1) + assert.NoError(t, err, "create simple lease failed") + lease, ok := l.Next(context.Background()) + assert.True(t, ok, "get next lease chan failed") + next, ok := <-lease + assert.True(t, ok, "get lease failed") + fmt.Println(next) +} diff --git a/logger/logger.go b/logger/logger.go index 24d486e..b88c7e0 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,30 +1,11 @@ package logger -import ( - "fmt" - "log" -) - -// Func is alias of logger function. -type Func = func(string, ...interface{}) - -// Level is level of logger. -type Level int8 +import "log" -func (s Level) String() string { - switch s { - case LevelDebug: - return "DEBUG" - case LevelInfo: - return "INFO" - case LevelWarn: - return "WARN" - case LevelError: - return "ERROR" - default: - return "UNKNOWN" - } -} +var ( + _level = LevelInfo + _logger Logger = simpleLogger{} +) const ( // LevelDebug is DEBUG level. @@ -37,93 +18,90 @@ const ( LevelError ) -var ( - lvl = LevelInfo - i, d, w, e = log.Printf, log.Printf, log.Printf, log.Printf - prefix = true -) +// Logger is used to print logs. +type Logger interface { + // Debugf print to the debug level logs. + Debugf(format string, args ...interface{}) + // Infof print to the info level logs. + Infof(format string, args ...interface{}) + // Warnf print to the info level logs. + Warnf(format string, args ...interface{}) + // Errorf print to the info level logs. + Errorf(format string, args ...interface{}) +} + +// Level is level of logger. +type Level int8 // SetLevel set global RSocket log level. -// Available levels are `LogLevelDebug`, `LogLevelInfo`, `LogLevelWarn` and `LogLevelError`. +// Available levels are `LevelDebug`, `LevelInfo`, `LevelWarn` and `LevelError`. func SetLevel(level Level) { - lvl = level + _level = level } -// DisablePrefix disable print level prefix. -func DisablePrefix() { - prefix = false +// SetLogger customize the global logger. +// A standard log implementation will be used by default. +func SetLogger(logger Logger) { + _logger = logger } // GetLevel returns current logger level. func GetLevel() Level { - return lvl -} - -// SetFunc set logger func for custom level. -func SetFunc(level Level, fn Func) { - if fn == nil { - return - } - if level&LevelDebug != 0 { - d = fn - } - if level&LevelInfo != 0 { - i = fn - } - if level&LevelWarn != 0 { - w = fn - } - if level&LevelError != 0 { - e = fn - } + return _level } // IsDebugEnabled returns true if debug level is open. func IsDebugEnabled() bool { - return lvl <= LevelDebug + return _level <= LevelDebug } // Debugf prints debug level log. -func Debugf(format string, v ...interface{}) { - if lvl > LevelDebug { +func Debugf(format string, args ...interface{}) { + if _logger == nil || _level > LevelDebug { return } - if prefix { - d(fmt.Sprintf("[%s] %s", LevelDebug, format), v...) - } else { - d(format, v...) - } + _logger.Debugf(format, args...) } // Infof prints info level log. -func Infof(format string, v ...interface{}) { - if lvl > LevelInfo { +func Infof(format string, args ...interface{}) { + if _logger == nil || _level > LevelInfo { return } - if prefix { - i(fmt.Sprintf("[%s] %s", LevelInfo, format), v...) - } else { - i(format, v...) - } + _logger.Infof(format, args...) } // Warnf prints warn level log. -func Warnf(format string, v ...interface{}) { - if lvl > LevelWarn { +func Warnf(format string, args ...interface{}) { + if _logger == nil || _level > LevelWarn { return } - if prefix { - w(fmt.Sprintf("[%s] %s", LevelWarn, format), v...) - } else { - w(format, v...) - } + _logger.Warnf(format, args...) } // Errorf prints error level log. -func Errorf(format string, v ...interface{}) { - if prefix { - e(fmt.Sprintf("[%s] %s", LevelError, format), v...) - } else { - e(format, v...) +func Errorf(format string, args ...interface{}) { + if _logger == nil || _level > LevelError { + return } + _logger.Errorf(format, args...) +} + +type simpleLogger struct { +} + +func (s simpleLogger) Debugf(format string, args ...interface{}) { + log.Printf("[DEBUG] "+format, args...) +} + +func (s simpleLogger) Infof(format string, args ...interface{}) { + log.Printf("[INFO] "+format, args...) +} + +func (s simpleLogger) Warnf(format string, args ...interface{}) { + log.Printf("[WARN] "+format, args...) +} + +func (s simpleLogger) Errorf(format string, args ...interface{}) { + log.Printf("[ERROR] "+format, args...) } diff --git a/logger/logger_mock_test.go b/logger/logger_mock_test.go new file mode 100644 index 0000000..05b4380 --- /dev/null +++ b/logger/logger_mock_test.go @@ -0,0 +1,101 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: logger/logger.go + +// Package logger_test is a generated GoMock package. +package logger_test + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockLogger is a mock of Logger interface +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debugf mocks base method +func (m *MockLogger) Debugf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf +func (mr *MockLoggerMockRecorder) Debugf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Infof mocks base method +func (m *MockLogger) Infof(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Infof", varargs...) +} + +// Infof indicates an expected call of Infof +func (mr *MockLoggerMockRecorder) Infof(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...) +} + +// Warnf mocks base method +func (m *MockLogger) Warnf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf +func (mr *MockLoggerMockRecorder) Warnf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...) +} + +// Errorf mocks base method +func (m *MockLogger) Errorf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf +func (mr *MockLoggerMockRecorder) Errorf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} diff --git a/logger/logger_test.go b/logger/logger_test.go new file mode 100644 index 0000000..154ee96 --- /dev/null +++ b/logger/logger_test.go @@ -0,0 +1,49 @@ +package logger_test + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +var ( + fakeFormat = "fake format: %v" + fakeArgs = []interface{}{"fake args"} +) + +func TestSetLogger(t *testing.T) { + logger.SetLevel(logger.LevelDebug) + + call := func() { + logger.Debugf(fakeFormat, fakeArgs...) + logger.Infof(fakeFormat, fakeArgs...) + logger.Warnf(fakeFormat, fakeArgs...) + logger.Errorf(fakeFormat, fakeArgs...) + } + + call() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := NewMockLogger(ctrl) + l.EXPECT().Debugf(gomock.Any(), gomock.Any()).Times(1) + l.EXPECT().Infof(gomock.Any(), gomock.Any()).Times(2) + l.EXPECT().Warnf(gomock.Any(), gomock.Any()).Times(2) + l.EXPECT().Errorf(gomock.Any(), gomock.Any()).Times(2) + + logger.SetLogger(l) + assert.Equal(t, logger.LevelDebug, logger.GetLevel(), "wrong logger level") + assert.True(t, logger.IsDebugEnabled(), "should be enabled") + + call() + + logger.SetLevel(logger.LevelInfo) + call() + + logger.SetLevel(logger.LevelDebug) + logger.SetLogger(nil) + call() +} diff --git a/payload/payload.go b/payload/payload.go index 53812c9..4a60531 100644 --- a/payload/payload.go +++ b/payload/payload.go @@ -1,10 +1,11 @@ package payload import ( + "bytes" "io/ioutil" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" ) type ( @@ -36,24 +37,45 @@ type ( // MaxLifetime returns max lifetime of RSocket connection. MaxLifetime() time.Duration // Version return RSocket protocol version. - Version() common.Version + Version() core.Version } ) // Clone create a copy of original payload. func Clone(payload Payload) Payload { - ret := &rawPayload{} - if d := payload.Data(); len(d) > 0 { - clone := make([]byte, len(d)) - copy(clone, d) - ret.data = clone + if payload == nil { + return nil } - if m, ok := payload.Metadata(); ok && len(m) > 0 { - clone := make([]byte, len(m)) - copy(clone, m) - ret.metadata = clone + switch v := payload.(type) { + case *rawPayload: + var data []byte + if v.data != nil { + data = make([]byte, len(v.data)) + copy(data, v.data) + } + var metadata []byte + if v.metadata != nil { + metadata = make([]byte, len(v.metadata)) + copy(metadata, v.metadata) + } + return &rawPayload{data: data, metadata: metadata} + case *strPayload: + return &strPayload{data: v.data, metadata: v.metadata} + default: + ret := &rawPayload{} + if d := payload.Data(); len(d) > 0 { + clone := make([]byte, len(d)) + copy(clone, d) + ret.data = clone + } + if m, ok := payload.Metadata(); ok && len(m) > 0 { + clone := make([]byte, len(m)) + copy(clone, m) + ret.metadata = clone + } + return ret } - return ret + } // New create a new payload with bytes. @@ -89,3 +111,20 @@ func MustNewFile(filename string, metadata []byte) Payload { } return foo } + +func Equal(a Payload, b Payload) bool { + if a == b { + return true + } + if !bytes.Equal(a.Data(), b.Data()) { + return false + } + + m1, ok1 := a.Metadata() + m2, ok2 := b.Metadata() + if ok1 != ok2 { + return false + } + + return bytes.Equal(m1, m2) +} diff --git a/payload/payload_raw.go b/payload/payload_raw.go index 5092a2d..e92e82c 100644 --- a/payload/payload_raw.go +++ b/payload/payload_raw.go @@ -1,19 +1,10 @@ package payload -import ( - "fmt" -) - type rawPayload struct { data []byte metadata []byte } -func (p *rawPayload) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("Payload{data=%s,metadata=%s}", p.DataUTF8(), m) -} - func (p *rawPayload) Metadata() (metadata []byte, ok bool) { return p.metadata, len(p.metadata) > 0 } diff --git a/payload/payload_str.go b/payload/payload_str.go index 128f423..4cb0fad 100644 --- a/payload/payload_str.go +++ b/payload/payload_str.go @@ -1,16 +1,10 @@ package payload -import "fmt" - type strPayload struct { data string metadata string } -func (p *strPayload) String() string { - return fmt.Sprintf("Payload{data=%s,metadata=%s}", p.data, p.metadata) -} - func (p *strPayload) Metadata() (metadata []byte, ok bool) { ok = len(p.metadata) > 0 if ok { diff --git a/payload/payload_test.go b/payload/payload_test.go index 179ba3f..e65664e 100644 --- a/payload/payload_test.go +++ b/payload/payload_test.go @@ -1,29 +1,94 @@ -package payload +package payload_test import ( "fmt" "testing" + "unicode/utf8" + "github.com/rsocket/rsocket-go/payload" "github.com/stretchr/testify/assert" ) -func TestPayload_new(t *testing.T) { +type customPayload [2][]byte + +func (c customPayload) Metadata() (metadata []byte, ok bool) { + return c[1], len(c[1]) > 0 +} + +func (c customPayload) MetadataUTF8() (metadata string, ok bool) { + return string(c[1]), len(c[1]) > 0 +} + +func (c customPayload) Data() []byte { + return c[0] +} + +func (c customPayload) DataUTF8() string { + return string(c[0]) +} + +func TestRawPayload(t *testing.T) { + data, metadata := []byte("hello"), []byte("world") + p := payload.New(data, metadata) + fmt.Println("new binary payload:", p) + assert.Equal(t, data, p.Data(), "wrong data") + m, ok := p.Metadata() + assert.True(t, ok, "ok should be true") + assert.Equal(t, metadata, m, "wrong metadata") + assert.Equal(t, "hello", p.DataUTF8(), "wrong data string") + mu, _ := p.MetadataUTF8() + assert.Equal(t, "world", mu, "wrong metadata string") + + invalid := []byte{0xff, 0xfe, 0xfd} + badPayload := payload.New(invalid, invalid) + s := badPayload.DataUTF8() + assert.False(t, utf8.Valid([]byte(s))) +} + +func TestStrPayload(t *testing.T) { data, metadata := "hello", "world" - p1 := New([]byte(data), []byte(metadata)) + p := payload.NewString(data, metadata) + fmt.Println("new string payload:", p) + assert.Equal(t, []byte(data), p.Data(), "wrong data") + m, ok := p.Metadata() + assert.True(t, ok, "ok should be true") + assert.Equal(t, []byte(metadata), m, "wrong metadata") + assert.Equal(t, data, p.DataUTF8(), "wrong data string") + mu, _ := p.MetadataUTF8() + assert.Equal(t, metadata, mu, "wrong metadata string") +} + +func TestClone(t *testing.T) { + check := func(p1 payload.Payload) { + p2 := payload.Clone(p1) + assert.Equal(t, p1.Data(), p2.Data(), "bad data") + m1, _ := p1.Metadata() + m2, _ := p2.Metadata() + assert.Equal(t, m1, m2, "bad metadata") + } - assert.Equal(t, data, p1.DataUTF8(), "bad data") - metadata2, ok := p1.MetadataUTF8() - assert.True(t, ok, "bad metadata") - assert.Equal(t, metadata, metadata2, "bad metadata") + check(payload.NewString("hello", "world")) + check(payload.New([]byte("hello"), []byte("world"))) + // Check custom payload + custom := customPayload([2][]byte{[]byte("hello"), []byte("world")}) + check(custom) - p1 = New([]byte(data), nil) - metadata2, ok = p1.MetadataUTF8() - assert.False(t, ok) - assert.Equal(t, "", metadata2) + // Check clone nil + nilCloned := payload.Clone(nil) + assert.Nil(t, nilCloned, "should return nil") } func TestNewFile(t *testing.T) { - pl, err := NewFile("/etc/hosts", nil) - assert.NoError(t, err, "bad file") - fmt.Print(pl.DataUTF8()) + p := payload.MustNewFile("/etc/hosts", nil) + assert.NotEmpty(t, p.Data(), "empty data") + _, err := payload.NewFile("/not/existing", nil) + assert.Error(t, err, "should return error") + + func() { + defer func() { + e := recover() + assert.Error(t, e.(error), "should panic error") + }() + payload.MustNewFile("/not/existing", nil) + }() } diff --git a/rsocket.go b/rsocket.go index 093b61b..c9b5a7f 100644 --- a/rsocket.go +++ b/rsocket.go @@ -1,7 +1,7 @@ package rsocket import ( - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/socket" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -11,33 +11,33 @@ import ( const ( // ErrorCodeInvalidSetup means the setup frame is invalid for the server. - ErrorCodeInvalidSetup = common.ErrorCodeInvalidSetup + ErrorCodeInvalidSetup = core.ErrorCodeInvalidSetup // ErrorCodeUnsupportedSetup means some (or all) of the parameters specified by the client are unsupported by the server. - ErrorCodeUnsupportedSetup = common.ErrorCodeUnsupportedSetup + ErrorCodeUnsupportedSetup = core.ErrorCodeUnsupportedSetup // ErrorCodeRejectedSetup means server rejected the setup, it can specify the reason in the payload. - ErrorCodeRejectedSetup = common.ErrorCodeRejectedSetup + ErrorCodeRejectedSetup = core.ErrorCodeRejectedSetup // ErrorCodeRejectedResume means server rejected the resume, it can specify the reason in the payload. - ErrorCodeRejectedResume = common.ErrorCodeRejectedResume + ErrorCodeRejectedResume = core.ErrorCodeRejectedResume // ErrorCodeConnectionError means the connection is being terminated. - ErrorCodeConnectionError = common.ErrorCodeConnectionError + ErrorCodeConnectionError = core.ErrorCodeConnectionError // ErrorCodeConnectionClose means the connection is being terminated. - ErrorCodeConnectionClose = common.ErrorCodeConnectionClose + ErrorCodeConnectionClose = core.ErrorCodeConnectionClose // ErrorCodeApplicationError means application layer logic generating a Reactive Streams onError event. - ErrorCodeApplicationError = common.ErrorCodeApplicationError + ErrorCodeApplicationError = core.ErrorCodeApplicationError // ErrorCodeRejected means Responder reject it. - ErrorCodeRejected = common.ErrorCodeRejected + ErrorCodeRejected = core.ErrorCodeRejected // ErrorCodeCanceled means the Responder canceled the request but may have started processing it (similar to REJECTED but doesn't guarantee lack of side-effects). - ErrorCodeCanceled = common.ErrorCodeCanceled + ErrorCodeCanceled = core.ErrorCodeCanceled // ErrorCodeInvalid means the request is invalid. - ErrorCodeInvalid = common.ErrorCodeInvalid + ErrorCodeInvalid = core.ErrorCodeInvalid ) // Aliases for Error defines. type ( // ErrorCode is code for RSocket error. - ErrorCode = common.ErrorCode + ErrorCode = core.ErrorCode // Error provides a method of accessing code and data. - Error = common.CustomError + Error = core.CustomError ) type ( diff --git a/rsocket_example_test.go b/rsocket_example_test.go index 9d279c3..4b3bbae 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -17,8 +17,6 @@ import ( func Example() { // Serve a server err := rsocket.Receive(). - Resume(). // Enable RESUME - //Lease(). Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { return rsocket.NewAbstractSocket( rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { @@ -27,7 +25,7 @@ func Example() { }), ), nil }). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). Serve(context.Background()) if err != nil { panic(err) @@ -36,17 +34,16 @@ func Example() { // Connect to a server. cli, err := rsocket.Connect(). SetupPayload(payload.NewString("Hello World", "From Golang")). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()). Start(context.Background()) if err != nil { panic(err) } - defer func() { - _ = cli.Close() - }() + defer cli.Close() cli.RequestResponse(payload.NewString("Ping", time.Now().String())). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("incoming response:", elem) + return nil }). Subscribe(context.Background()) } @@ -66,10 +63,11 @@ func ExampleReceive() { // Request to client. sendingSocket.RequestResponse(payload.NewString("Ping", time.Now().String())). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("response of Ping from client:", elem) + return nil }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) // Return responser which just echo. return rsocket.NewAbstractSocket( @@ -92,7 +90,7 @@ func ExampleReceive() { }), ), nil }). - Transport("tcp://0.0.0.0:7878"). + Transport(rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).Build()). Serve(context.Background()) panic(err) } @@ -100,7 +98,7 @@ func ExampleReceive() { func ExampleConnect() { cli, err := rsocket.Connect(). Resume(). // Enable RESUME. - Lease(). // Enable LEASE. + Lease(). // Enable LEASE. Fragment(4096). SetupPayload(payload.NewString("Hello", "World")). Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { @@ -110,7 +108,7 @@ func ExampleConnect() { }), ) }). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.TcpClient().SetAddr("127.0.0.1:7878").Build()). Start(context.Background()) if err != nil { panic(err) @@ -122,16 +120,18 @@ func ExampleConnect() { cli.FireAndForget(payload.NewString("This is a FNF message.", "")) // Simple RequestResponse. cli.RequestResponse(payload.NewString("This is a RequestResponse message.", "")). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("response:", elem) + return nil }). Subscribe(context.Background()) var s rx.Subscription // RequestStream with backpressure. (one by one) cli.RequestStream(payload.NewString("This is a RequestStream message.", "")). - DoOnNext(func(elem payload.Payload) { + DoOnNext(func(elem payload.Payload) error { log.Println("next element in stream:", elem) s.Request(1) + return nil }). Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { s.Request(1) @@ -144,8 +144,9 @@ func ExampleConnect() { s.Complete() }) cli.RequestChannel(sendFlux). - DoOnNext(func(elem payload.Payload) { + DoOnNext(func(elem payload.Payload) error { log.Println("next element in channel:", elem) + return nil }). Subscribe(context.Background()) } diff --git a/rsocket_test.go b/rsocket_test.go index aa1ff6c..a8cb397 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -7,7 +7,7 @@ import ( "testing" . "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/logger" + "github.com/rsocket/rsocket-go/core/transport" . "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" @@ -21,36 +21,29 @@ const ( channelElements = int32(2) ) -func init() { - //logger.SetLevel(logger.LevelDebug) - logger.SetFunc(logger.LevelInfo, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelDebug, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelWarn, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelError, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) -} - var testData = "Hello World!" func TestSuite(t *testing.T) { - addresses := map[string]string{ - //"unix": "unix:///tmp/rsocket.test.sock", - "tcp": "tcp://localhost:7878", - "websocket": "ws://localhost:8080/test", + m := []string{ + "tcp", + "websocket", + } + c := []transport.ClientTransportFunc{ + TcpClient().SetHostAndPort("127.0.0.1", 7878).Build(), + WebsocketClient().SetUrl("ws://127.0.0.1:8080/test").Build(), } - for k, v := range addresses { - testAll(k, v, t) + s := []transport.ServerTransportFunc{ + TcpServer().SetAddr(":7878").Build(), + WebsocketServer().SetAddr("127.0.0.1:8080").SetPath("/test").Build(), } + + for i := 0; i < len(m); i++ { + testAll(t, m[i], c[i], s[i]) + } + } -func testAll(proto string, addr string, t *testing.T) { +func testAll(t *testing.T, proto string, clientTp transport.ClientTransportFunc, serverTp transport.ServerTransportFunc) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -98,9 +91,10 @@ func testAll(proto string, addr string, t *testing.T) { inputs.(flux.Flux).DoFinally(func(s rx.SignalType) { close(receives) - }).Subscribe(context.Background(), rx.OnNext(func(input Payload) { + }).Subscribe(context.Background(), rx.OnNext(func(input Payload) error { //fmt.Println("rcv from channel:", input) receives <- input + return nil })) return flux.Create(func(ctx context.Context, s flux.Sink) { @@ -112,7 +106,7 @@ func testAll(proto string, addr string, t *testing.T) { }), ), nil }). - Transport(addr). + Transport(serverTp). Serve(ctx) fmt.Println("SERVER STOPPED!!!!!") if err != nil { @@ -126,7 +120,7 @@ func testAll(proto string, addr string, t *testing.T) { cli, err := Connect(). Fragment(192). SetupPayload(NewString(setupData, setupMetadata)). - Transport(addr). + Transport(clientTp). Start(context.Background()) assert.NoError(t, err, "connect failed") defer func() { @@ -167,11 +161,12 @@ func testRequestStream(ctx context.Context, cli Client, t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d", atomic.LoadInt32(&seq)), m, "bad stream metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad stream data") atomic.AddInt32(&seq, 1) + return nil }). BlockLast(ctx) <-done @@ -187,12 +182,13 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d", atomic.LoadInt32(&seq)), m, "bad stream metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad stream data") atomic.AddInt32(&seq, 1) su.Request(1) + return nil }). Subscribe(ctx, rx.OnSubscribe(func(s rx.Subscription) { su = s @@ -214,12 +210,13 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) { var seq int _, err := cli.RequestChannel(send). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { //fmt.Println(elem) m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d_from_server", seq), m, "bad channel metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad channel data") seq++ + return nil }). BlockLast(ctx) assert.NoError(t, err, "block last failed") @@ -245,15 +242,17 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) { assert.Equal(t, rx.SignalComplete, s, "bad signal type") close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { fmt.Println(elem) m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d_from_server", seq), m, "bad channel metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad channel data") seq++ + return nil }). - Subscribe(ctx, rx.OnNext(func(elem Payload) { + Subscribe(ctx, rx.OnNext(func(elem Payload) error { su.Request(1) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) diff --git a/rx/flux/flux.go b/rx/flux/flux.go index 91607e6..9f5a10f 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -44,12 +44,14 @@ type Flux interface { // DoOnSubscribe add behavior triggered when the Flux is done being subscribed. DoOnSubscribe(rx.FnOnSubscribe) Flux // Map transform the items emitted by this Flux by applying a synchronous function to each item. - Map(func(payload.Payload) payload.Payload) Flux + Map(func(payload.Payload) (payload.Payload, error)) Flux // SwitchOnFirst transform the current Flux once it emits its first element, making a conditional transformation possible. SwitchOnFirst(FnSwitchOnFirst) Flux // SubscribeOn run subscribe, onSubscribe and request on a specified scheduler. SubscribeOn(scheduler.Scheduler) Flux - // Raw returns Native Flux in reactor-go. + // SubscribeWithChan subscribe to this Flux and puts items/error into a chan. + SubscribeWithChan(ctx context.Context, values chan<- payload.Payload, err chan<- error) + // Raw returns low-level reactor.Flux which defined in reactor-go library. Raw() flux.Flux // BlockFirst subscribe to this Flux and block indefinitely until the upstream signals its first value or completes. // Returns that value, error if Flux completes error, or nil if the Flux completes empty. @@ -60,6 +62,8 @@ type Flux interface { // ToChan subscribe to this Flux and puts items into a chan. // It also puts errors into another chan. ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e <-chan error) + // BlockSlice subscribe to this Flux and convert to payload slice. + BlockSlice(context.Context) ([]payload.Payload, error) } // Processor represent a base processor that exposes Flux API for Processor. diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 37bfe4c..c3eb818 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -2,24 +2,126 @@ package flux_test import ( "context" - "errors" "fmt" - "log" "strconv" "testing" "time" + "github.com/jjeffcaii/reactor-go" + reactorFlux "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func TestEmpty(t *testing.T) { + last, err := flux.Empty(). + DoOnNext(func(input payload.Payload) error { + assert.FailNow(t, "unreachable") + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Nil(t, last) + first, err := flux.Empty().BlockFirst(context.Background()) + assert.NoError(t, err) + assert.Nil(t, first) +} + +func TestError(t *testing.T) { + err := errors.New("boom") + _, _ = flux.Error(err). + DoOnNext(func(input payload.Payload) error { + assert.FailNow(t, "unreachable") + return nil + }). + DoOnError(func(e error) { + assert.Equal(t, err, e) + }). + BlockLast(context.Background()) +} + +func TestClone(t *testing.T) { + const total = 10 + source := flux.Create(func(ctx context.Context, s flux.Sink) { + for i := 0; i < total; i++ { + s.Next(payload.NewString(fmt.Sprintf("data_%d", i), "")) + } + s.Complete() + }) + clone := flux.Clone(source) + + c := atomic.NewInt32(0) + last, err := clone. + DoOnNext(func(input payload.Payload) error { + c.Inc() + return nil + }). + DoOnError(func(e error) { + assert.FailNow(t, "unreachable") + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("data_%d", total-1), last.DataUTF8()) + assert.Equal(t, int32(total), c.Load()) +} + +func TestRaw(t *testing.T) { + const total = 10 + c := atomic.NewInt32(0) + f := flux. + Raw(reactorFlux.Range(0, total).Map(func(v reactor.Any) (reactor.Any, error) { + return payload.NewString(fmt.Sprintf("data_%d", v.(int)), ""), nil + })) + last, err := f. + DoOnNext(func(input payload.Payload) error { + c.Inc() + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(total), c.Load()) + assert.Equal(t, fmt.Sprintf("data_%d", total-1), last.DataUTF8()) + + c.Store(0) + const take = 3 + last, err = f.Take(take). + DoOnNext(func(input payload.Payload) error { + c.Inc() + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "data_2", last.DataUTF8()) + assert.Equal(t, int32(take), c.Load()) +} + func TestJust(t *testing.T) { - done := make(chan struct{}) + c := atomic.NewInt32(0) + last, err := flux. + Just( + payload.NewString("foo", ""), + payload.NewString("bar", ""), + payload.NewString("qux", ""), + ). + DoOnNext(func(input payload.Payload) error { + c.Inc() + return nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(3), c.Load()) + assert.Equal(t, "qux", last.DataUTF8()) +} + +func TestCreate(t *testing.T) { + const total = 10 f := flux.Create(func(i context.Context, sink flux.Sink) { - for i := 0; i < 10; i++ { + for i := 0; i < total; i++ { sink.Next(payload.NewString(fmt.Sprintf("foo_%04d", i), fmt.Sprintf("bar_%04d", i))) } sink.Complete() @@ -27,43 +129,61 @@ func TestJust(t *testing.T) { var su rx.Subscription + done := make(chan struct{}) + nextRequests := atomic.NewInt32(0) + f. - DoOnNext(func(input payload.Payload) { - log.Println("next:", input) + DoOnNext(func(input payload.Payload) error { + fmt.Println("next:", input) su.Request(1) + return nil }). DoOnRequest(func(n int) { - log.Println("request:", n) + fmt.Println("request:", n) + nextRequests.Add(int32(n)) }). DoFinally(func(s rx.SignalType) { - log.Println("finally") + fmt.Println("finally") close(done) }). DoOnComplete(func() { - log.Println("complete") + fmt.Println("complete") }). Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) })) <-done + assert.Equal(t, int32(total+1), nextRequests.Load()) +} + +func TestMap(t *testing.T) { + last, err := flux. + Just(payload.NewString("hello", "")). + Map(func(p payload.Payload) (payload.Payload, error) { + return payload.NewString(p.DataUTF8()+" world", ""), nil + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "hello world", last.DataUTF8()) } func TestProcessor(t *testing.T) { - proc := flux.CreateProcessor() + processor := flux.CreateProcessor() time.AfterFunc(1*time.Second, func() { - proc.Next(payload.NewString("111", "")) + processor.Next(payload.NewString("111", "")) }) time.AfterFunc(2*time.Second, func() { - proc.Next(payload.NewString("222", "")) - proc.Complete() + processor.Next(payload.NewString("222", "")) + processor.Complete() }) done := make(chan struct{}) - proc. - DoOnNext(func(input payload.Payload) { - log.Println("next:", input) + processor. + DoOnNext(func(input payload.Payload) error { + fmt.Println("next:", input) + return nil }). DoFinally(func(s rx.SignalType) { close(done) @@ -89,8 +209,9 @@ func TestSwitchOnFirst(t *testing.T) { n, _ := strconv.Atoi(input.DataUTF8()) return n > first }) - }).Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) { + }).Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) error { fmt.Println("next:", input.DataUTF8()) + return nil })) } @@ -105,17 +226,18 @@ func TestFluxRequest(t *testing.T) { var su rx.Subscription sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { - log.Println("onNext:", input) + rx.OnNext(func(input payload.Payload) error { + fmt.Println("onNext:", input) su.Request(1) + return nil }), rx.OnComplete(func() { - log.Println("complete") + fmt.Println("complete") }), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) - log.Println("request:", 1) + fmt.Println("request:", 1) }), ) @@ -131,7 +253,7 @@ func TestProxy_BlockLast(t *testing.T) { s.Complete() }).BlockLast(context.Background()) assert.NoError(t, err, "err occurred") - log.Println(last) + fmt.Println(last) } func TestFluxProcessorWithRequest(t *testing.T) { @@ -145,8 +267,9 @@ func TestFluxProcessorWithRequest(t *testing.T) { var su rx.Subscription sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { su.Request(1) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { su = s @@ -160,7 +283,7 @@ func TestFluxProcessorWithRequest(t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) <-done } @@ -245,26 +368,21 @@ func TestToChannel(t *testing.T) { f := flux.CreateFromChannel(payloads, err) - channel, chanerrors := f.ToChan(context.Background(), 0) + valueChan, errChan := f.ToChan(context.Background(), 0) var count int loop: for { select { - case _, o := <-channel: - if o { - count++ - } else { - break loop - } - case err := <-chanerrors: - if err != nil { - t.Error(err) + case _, ok := <-valueChan: + if !ok { break loop } + count++ + case err := <-errChan: + assert.NoError(t, err) } } - assert.Equal(t, 10, count) } @@ -277,30 +395,72 @@ func TestToChannelEmitError(t *testing.T) { defer close(err) for i := 1; i <= 10; i++ { - err <- errors.New("boom!") + err <- errors.New("boom") } }() f := flux.CreateFromChannel(payloads, err) - channel, chanerrors := f.ToChan(context.Background(), 0) + valChan, errChan := f.ToChan(context.Background(), 0) loop: for { select { - case _, o := <-channel: - if o { - t.Fail() - } else { - break loop - } - case err := <-chanerrors: - if err != nil { + case _, ok := <-valChan: + if !ok { break loop - } else { - t.Fail() } + assert.Fail(t, "should be unreachable") + case err := <-errChan: + assert.Error(t, err, "should return error") + break loop } } } + +func TestFlux_BlockSlice(t *testing.T) { + const n = 10 + arr, err := genRandomFlux(n).BlockSlice(context.Background()) + assert.NoError(t, err) + assert.Len(t, arr, n) +} + +func TestFlux_SubscribeWithChan(t *testing.T) { + ch := make(chan payload.Payload) + err := make(chan error) + done := make(chan struct{}) + + const n = 10 + genRandomFlux(n). + DoFinally(func(s rx.SignalType) { + close(done) + }). + SubscribeOn(scheduler.Parallel()). + SubscribeWithChan(context.Background(), ch, err) + + var results []payload.Payload + +L: + for { + select { + case v := <-ch: + results = append(results, v) + case e := <-err: + assert.NoError(t, e) + case <-done: + break L + } + } + assert.Len(t, results, n) +} + +func genRandomFlux(n int) flux.Flux { + return flux. + Create(func(ctx context.Context, s flux.Sink) { + for i := 0; i < n; i++ { + s.Next(payload.NewString("hello", strconv.Itoa(i))) + } + s.Complete() + }) +} diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index e7b3553..b05ed09 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -3,11 +3,10 @@ package flux import ( "context" - reactor "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" ) @@ -20,20 +19,12 @@ func (p proxy) Raw() flux.Flux { return p.Flux } -func (p proxy) mustProcessor() flux.Processor { - processor, ok := p.Flux.(flux.Processor) - if !ok { - panic(errors.New("require flux.Processor")) - } - return processor -} - func (p proxy) Next(v payload.Payload) { p.mustProcessor().Next(v) } -func (p proxy) Map(fn func(in payload.Payload) payload.Payload) Flux { - return newProxy(p.Flux.Map(func(i interface{}) interface{} { +func (p proxy) Map(fn func(in payload.Payload) (payload.Payload, error)) Flux { + return newProxy(p.Flux.Map(func(i reactor.Any) (reactor.Any, error) { return fn(i.(payload.Payload)) })) } @@ -65,37 +56,28 @@ func (p proxy) DoOnError(fn rx.FnOnError) Flux { } func (p proxy) DoOnNext(fn rx.FnOnNext) Flux { - return newProxy(p.Flux.DoOnNext(func(v interface{}) { - fn(v.(payload.Payload)) + return newProxy(p.Flux.DoOnNext(func(v reactor.Any) error { + return fn(v.(payload.Payload)) })) } -func (p proxy) ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e <-chan error) { +func (p proxy) ToChan(ctx context.Context, cap int) (<-chan payload.Payload, <-chan error) { if cap < 1 { cap = 1 } ch := make(chan payload.Payload, cap) err := make(chan error, 1) - p. - DoFinally(func(s rx.SignalType) { - if s == rx.SignalCancel { + p.Flux. + DoFinally(func(s reactor.SignalType) { + defer func() { + close(ch) + close(err) + }() + if s == reactor.SignalTypeCancel { err <- reactor.ErrSubscribeCancelled } - close(ch) - close(err) }). - Subscribe(ctx, - rx.OnNext(func(v payload.Payload) { - if _, ok := v.(framing.Frame); ok { - ch <- payload.Clone(v) - } else { - ch <- v - } - }), - rx.OnError(func(e error) { - err <- e - }), - ) + SubscribeWithChan(ctx, ch, err) return ch, err } @@ -115,9 +97,37 @@ func (p proxy) BlockLast(ctx context.Context) (last payload.Payload, err error) if err != nil { return } - if v != nil { - last = v.(payload.Payload) + if v == nil { + return } + last = v.(payload.Payload) + return +} + +func (p proxy) SubscribeWithChan(ctx context.Context, payloads chan<- payload.Payload, err chan<- error) { + p.Flux.SubscribeWithChan(ctx, payloads, err) +} + +func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err error) { + done := make(chan struct{}) + p.Flux. + DoFinally(func(s reactor.SignalType) { + defer close(done) + if s == reactor.SignalTypeCancel { + err = reactor.ErrSubscribeCancelled + } + }). + Subscribe( + ctx, + reactor.OnNext(func(v reactor.Any) error { + results = append(results, v.(payload.Payload)) + return nil + }), + reactor.OnError(func(e error) { + err = e + }), + ) + <-done return } @@ -157,8 +167,8 @@ func (p proxy) SubscribeWith(ctx context.Context, s rx.Subscriber) { sub = rx.EmptyRawSubscriber } else { sub = reactor.NewSubscriber( - reactor.OnNext(func(v interface{}) { - s.OnNext(v.(payload.Payload)) + reactor.OnNext(func(v reactor.Any) error { + return s.OnNext(v.(payload.Payload)) }), reactor.OnError(func(e error) { s.OnError(e) @@ -174,6 +184,14 @@ func (p proxy) SubscribeWith(ctx context.Context, s rx.Subscriber) { p.Flux.SubscribeWith(ctx, sub) } +func (p proxy) mustProcessor() flux.Processor { + processor, ok := p.Flux.(flux.Processor) + if !ok { + panic(errors.New("require flux.Processor")) + } + return processor +} + type sinkProxy struct { flux.Sink } diff --git a/rx/flux/utils.go b/rx/flux/utils.go index a0ea099..a5fed25 100644 --- a/rx/flux/utils.go +++ b/rx/flux/utils.go @@ -53,20 +53,25 @@ func Create(gen func(ctx context.Context, s Sink)) Flux { // CreateProcessor creates a new Processor. func CreateProcessor() Processor { - proc := flux.NewUnicastProcessor() - return newProxy(proc) + p := flux.NewUnicastProcessor() + return newProxy(p) } // Clone clones a Publisher to a Flux. func Clone(source rx.Publisher) Flux { return Create(func(ctx context.Context, s Sink) { - source.Subscribe(ctx, rx.OnNext(func(input payload.Payload) { - s.Next(input) - }), rx.OnComplete(func() { - s.Complete() - }), rx.OnError(func(e error) { - s.Error(e) - })) + source.Subscribe(ctx, + rx.OnNext(func(input payload.Payload) error { + s.Next(input) + return nil + }), + rx.OnComplete(func() { + s.Complete() + }), + rx.OnError(func(e error) { + s.Error(e) + }), + ) }) } diff --git a/rx/mono/mono.go b/rx/mono/mono.go index 6c2860c..a6b9abf 100644 --- a/rx/mono/mono.go +++ b/rx/mono/mono.go @@ -9,28 +9,46 @@ import ( "github.com/rsocket/rsocket-go/rx" ) +// Mono is a Reactive Streams Publisher with basic rx operators that completes successfully by emitting an element, or with an error. type Mono interface { rx.Publisher + // Filter evaluate each source value against the given Predicate. + // If the predicate test succeeds, the value is emitted. Filter(rx.FnPredicate) Mono + // DoFinally adds behavior (side-effect) triggered after the Mono terminates for any reason, including cancellation. DoFinally(rx.FnFinally) Mono + // DoOnError adds behavior (side-effect) triggered when the Mono completes with an error. DoOnError(rx.FnOnError) Mono + // DoOnSuccess adds behavior (side-effect) triggered when the Mono completes with an success. DoOnSuccess(rx.FnOnNext) Mono + // DoOnCancel add behavior (side-effect) triggered when the Mono is cancelled. DoOnCancel(rx.FnOnCancel) Mono + // DoOnSubscribe add behavior (side-effect) triggered when the Mono is done being subscribed. DoOnSubscribe(rx.FnOnSubscribe) Mono + // SubscribeOn customize a Scheduler running Subscribe, OnSubscribe and Request. SubscribeOn(scheduler.Scheduler) Mono + // SubscribeWithChan subscribe to this Mono and puts item/error into channels. + SubscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error) + // Block blocks Mono and returns data and error. Block(context.Context) (payload.Payload, error) + //SwitchIfEmpty switch to an alternative Publisher if this Mono is completed without any data. SwitchIfEmpty(alternative Mono) Mono + // Raw returns low-level reactor.Mono which defined in reactor-go library. Raw() mono.Mono // ToChan subscribe Mono and puts items into a chan. // It also puts errors into another chan. ToChan(ctx context.Context) (c <-chan payload.Payload, e <-chan error) } +// Sink is a wrapper API around an actual downstream Subscriber for emitting nothing, a single value or an error (mutually exclusive). type Sink interface { + // Success emits a single value then complete current Sink. Success(payload.Payload) + // Error emits an error then complete current Sink. Error(error) } +// Processor combine Sink and Mono. type Processor interface { Sink Mono diff --git a/rx/mono/mono_test.go b/rx/mono/mono_test.go index f6a79ff..bd910a0 100644 --- a/rx/mono/mono_test.go +++ b/rx/mono/mono_test.go @@ -13,24 +13,64 @@ import ( "github.com/rsocket/rsocket-go/rx" . "github.com/rsocket/rsocket-go/rx/mono" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func TestProxy_Error(t *testing.T) { + originErr := errors.New("error testing") + errCount := atomic.NewInt32(0) + _, err := Error(originErr). + DoOnError(func(e error) { + assert.Equal(t, originErr, e, "bad error") + errCount.Inc() + }). + Block(context.Background()) + assert.Error(t, err, "should got error") + assert.Equal(t, originErr, err, "bad blocked error") + assert.Equal(t, int32(1), errCount.Load(), "error count should be 1") +} + +func TestEmpty(t *testing.T) { + res, err := Empty().Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.Nil(t, res, "result should be nil") +} + +func TestJustOrEmpty(t *testing.T) { + // Give normal payload + res, err := JustOrEmpty(payload.NewString("hello", "world")).Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.NotNil(t, res, "result should not be nil") + // Give nil payload + res, err = JustOrEmpty(nil).Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.Nil(t, res, "result should be nil") +} + func TestJust(t *testing.T) { Just(payload.NewString("hello", "world")). - Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) { + Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) error { log.Println("next:", i) + return nil })) } +func TestMono_Raw(t *testing.T) { + + Just(payload.NewString("hello", "world")).Raw() + +} + func TestProxy_SubscribeOn(t *testing.T) { v, err := Create(func(i context.Context, sink Sink) { time.AfterFunc(time.Second, func() { sink.Success(payload.NewString("foo", "bar")) }) }). - SubscribeOn(scheduler.Elastic()). - DoOnSuccess(func(i payload.Payload) { + SubscribeOn(scheduler.Parallel()). + DoOnSuccess(func(i payload.Payload) error { log.Println("success:", i) + return nil }). Block(context.Background()) assert.NoError(t, err) @@ -63,8 +103,9 @@ func TestProxy_Filter(t *testing.T) { Filter(func(i payload.Payload) bool { return strings.EqualFold("hello_no", i.DataUTF8()) }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { assert.Fail(t, "should never run here") + return nil }). DoFinally(func(i rx.SignalType) { log.Println("finally:", i) @@ -76,14 +117,16 @@ func TestCreate(t *testing.T) { Create(func(i context.Context, sink Sink) { sink.Success(payload.NewString("hello", "world")) }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { log.Println("doOnNext:", i) + return nil }). DoFinally(func(s rx.SignalType) { log.Println("doFinally:", s) }). - Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) { + Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) error { log.Println("next:", i) + return nil })) Create(func(i context.Context, sink Sink) { @@ -92,8 +135,9 @@ func TestCreate(t *testing.T) { DoOnError(func(e error) { assert.Equal(t, "foobar", e.Error(), "bad error") }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { assert.Fail(t, "should never run here") + return nil }). Subscribe(context.Background()) } diff --git a/rx/mono/proxy.go b/rx/mono/proxy.go index 10aafe2..3cdafd2 100644 --- a/rx/mono/proxy.go +++ b/rx/mono/proxy.go @@ -35,28 +35,44 @@ func (p proxy) Error(e error) { p.mustProcessor().Error(e) } -func (p proxy) ToChan(ctx context.Context) (c <-chan payload.Payload, e <-chan error) { - errorChannel := make(chan error, 1) - payloadChannel := make(chan payload.Payload, 1) - p. - DoOnSuccess(func(input payload.Payload) { - payloadChannel <- input - }). - DoOnError(func(e error) { - errorChannel <- e - }). - DoFinally(func(s rx.SignalType) { - close(payloadChannel) - close(errorChannel) - }). - Subscribe(ctx) - return payloadChannel, errorChannel +func (p proxy) ToChan(ctx context.Context) (<-chan payload.Payload, <-chan error) { + value := make(chan payload.Payload, 1) + err := make(chan error, 1) + p.subscribeWithChan(ctx, value, err, true) + return value, err } func (p proxy) SubscribeOn(sc scheduler.Scheduler) Mono { return newProxy(p.Mono.SubscribeOn(sc)) } +func (p proxy) subscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error, autoClose bool) { + p.Mono. + DoFinally(func(s reactor.SignalType) { + if autoClose { + defer close(valueChan) + defer close(errChan) + } + if s == reactor.SignalTypeCancel { + errChan <- reactor.ErrSubscribeCancelled + } + }). + Subscribe( + ctx, + reactor.OnNext(func(v reactor.Any) error { + valueChan <- v.(payload.Payload) + return nil + }), + reactor.OnError(func(e error) { + errChan <- e + }), + ) +} + +func (p proxy) SubscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error) { + p.subscribeWithChan(ctx, valueChan, errChan, false) +} + func (p proxy) Block(ctx context.Context) (pa payload.Payload, err error) { v, err := p.Mono.Block(ctx) if err != nil { @@ -75,7 +91,7 @@ func (p proxy) Filter(fn rx.FnPredicate) Mono { } func (p proxy) DoFinally(fn rx.FnFinally) Mono { - return newProxy(p.Mono.DoFinally(func(signal rs.SignalType) { + return newProxy(p.Mono.DoFinally(func(signal reactor.SignalType) { fn(rx.SignalType(signal)) })) } @@ -86,13 +102,13 @@ func (p proxy) DoOnError(fn rx.FnOnError) Mono { })) } func (p proxy) DoOnSuccess(next rx.FnOnNext) Mono { - return newProxy(p.Mono.DoOnNext(func(v interface{}) { - next(v.(payload.Payload)) + return newProxy(p.Mono.DoOnNext(func(v reactor.Any) error { + return next(v.(payload.Payload)) })) } func (p proxy) DoOnSubscribe(fn rx.FnOnSubscribe) Mono { - return newProxy(p.Mono.DoOnSubscribe(func(su rs.Subscription) { + return newProxy(p.Mono.DoOnSubscribe(func(su reactor.Subscription) { fn(su) })) } @@ -110,21 +126,21 @@ func (p proxy) Subscribe(ctx context.Context, options ...rx.SubscriberOption) { } func (p proxy) SubscribeWith(ctx context.Context, actual rx.Subscriber) { - var sub rs.Subscriber + var sub reactor.Subscriber if actual == rx.EmptySubscriber { sub = rx.EmptyRawSubscriber } else { - sub = rs.NewSubscriber( - rs.OnNext(func(v interface{}) { - actual.OnNext(v.(payload.Payload)) + sub = reactor.NewSubscriber( + reactor.OnNext(func(v reactor.Any) error { + return actual.OnNext(v.(payload.Payload)) }), - rs.OnComplete(func() { + reactor.OnComplete(func() { actual.OnComplete() }), - rs.OnSubscribe(func(su rs.Subscription) { + reactor.OnSubscribe(func(su reactor.Subscription) { actual.OnSubscribe(su) }), - rs.OnError(func(e error) { + reactor.OnError(func(e error) { actual.OnError(e) }), ) diff --git a/rx/mono/utils.go b/rx/mono/utils.go index ec571eb..2b87669 100644 --- a/rx/mono/utils.go +++ b/rx/mono/utils.go @@ -9,32 +9,40 @@ import ( var empty = newProxy(mono.Empty()) +// Raw wrap a low-level Mono. func Raw(input mono.Mono) Mono { return newProxy(input) } +// Just wrap an exist Payload to a Mono. func Just(input payload.Payload) Mono { return newProxy(mono.Just(input)) } +// JustOrEmpty wrap an exist Payload to Mono. +// Payload could be nil here. func JustOrEmpty(input payload.Payload) Mono { return newProxy(mono.JustOrEmpty(input)) } +// Empty returns an empty Mono. func Empty() Mono { return empty } +// Error wrap an error to a Mono. func Error(err error) Mono { return newProxy(mono.Error(err)) } +// Create wrap a generator function to a Mono. func Create(gen func(context.Context, Sink)) Mono { return newProxy(mono.Create(func(i context.Context, sink mono.Sink) { gen(i, sinkProxy{sink}) })) } +// CreateProcessor creates a Processor. func CreateProcessor() Processor { return newProxy(mono.CreateProcessor()) } diff --git a/rx/mono/utils_test.go b/rx/mono/utils_test.go index 04127ce..5f0d326 100644 --- a/rx/mono/utils_test.go +++ b/rx/mono/utils_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + rsMono "github.com/jjeffcaii/reactor-go/mono" + "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx/mono" @@ -72,19 +75,19 @@ func TestToChannel(t *testing.T) { payloads <- p }() - channel, chanerrors := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) + valueChan, errChan := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) loop: for { select { - case p, ok := <-channel: + case p, ok := <-valueChan: if !ok { break loop } assert.Equal(t, "data", p.DataUTF8()) md, _ := p.MetadataUTF8() assert.Equal(t, "metadata", md) - case err := <-chanerrors: + case err := <-errChan: if err != nil { assert.NoError(t, err) } @@ -107,17 +110,17 @@ func TestToChannelEmitError(t *testing.T) { } }() - channel, chanerrors := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) + valueChan, errChan := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) loop: for { select { - case _, ok := <-channel: + case _, ok := <-valueChan: if !ok { break loop } assert.Fail(t, "should never receive anything") - case err := <-chanerrors: + case err := <-errChan: if err != nil { break loop } @@ -126,3 +129,30 @@ loop: } } + +func TestRaw(t *testing.T) { + fakePayload := payload.NewString("fake", "payload") + res, err := mono.Raw(rsMono.Just(fakePayload)).Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fakePayload, res) +} + +func TestSubscribeWithChan(t *testing.T) { + valueChan := make(chan payload.Payload) + errChan := make(chan error) + + fakePayload := payload.NewString("fake data", "fake metadata") + + mono. + Create(func(ctx context.Context, sink mono.Sink) { + sink.Success(fakePayload) + }). + SubscribeOn(scheduler.Parallel()). + SubscribeWithChan(context.Background(), valueChan, errChan) + select { + case next := <-valueChan: + assert.True(t, payload.Equal(fakePayload, next), "result doesn't match") + case err := <-errChan: + assert.NoError(t, err, "should not return error") + } +} diff --git a/rx/rx.go b/rx/rx.go index 6ed6c57..f06f490 100644 --- a/rx/rx.go +++ b/rx/rx.go @@ -23,7 +23,7 @@ type ( // FnOnComplete is alias of function for signal when no more elements are available FnOnComplete = func() // FnOnNext is alias of function for signal when next element arrived. - FnOnNext = func(input payload.Payload) + FnOnNext = func(input payload.Payload) error // FnOnSubscribe is alias of function for signal when subscribe begin. FnOnSubscribe = func(s Subscription) // FnOnError is alias of function for signal when an error occurred. @@ -49,7 +49,7 @@ type RawPublisher interface { type Publisher interface { RawPublisher // Subscribe subscribe elements from a publisher, returns a Disposable. - // You can add some custome options. + // You can add some custom options. // Using `OnSubscribe`, `OnNext`, `OnComplete` and `OnError` as handler wrapper. Subscribe(ctx context.Context, options ...SubscriberOption) } diff --git a/rx/subscriber.go b/rx/subscriber.go index c805c18..4a2ec85 100644 --- a/rx/subscriber.go +++ b/rx/subscriber.go @@ -1,7 +1,7 @@ package rx import ( - reactor "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go" "github.com/rsocket/rsocket-go/payload" ) @@ -9,7 +9,8 @@ var ( // EmptySubscriber is a blank Subscriber. EmptySubscriber Subscriber = &subscriber{} // EmptyRawSubscriber is a blank native Subscriber in reactor-go. - EmptyRawSubscriber = reactor.NewSubscriber(reactor.OnNext(func(v interface{}) { + EmptyRawSubscriber = reactor.NewSubscriber(reactor.OnNext(func(v reactor.Any) error { + return nil })) ) @@ -19,7 +20,7 @@ type Subscription reactor.Subscription // Subscriber will receive call to OnSubscribe(Subscription) once after passing an instance of Subscriber to Publisher#SubscribeWith type Subscriber interface { // OnNext represents data notification sent by the Publisher in response to requests to Subscription#Request. - OnNext(payload payload.Payload) + OnNext(payload payload.Payload) error // OnError represents failed terminal state. OnError(error) // OnComplete represents successful terminal state. @@ -36,10 +37,11 @@ type subscriber struct { fnOnError FnOnError } -func (s *subscriber) OnNext(payload payload.Payload) { - if s != nil && s.fnOnNext != nil { - s.fnOnNext(payload) +func (s *subscriber) OnNext(payload payload.Payload) error { + if s == nil || s.fnOnNext == nil { + return nil } + return s.fnOnNext(payload) } func (s *subscriber) OnError(err error) { diff --git a/server.go b/server.go index 1a05b5a..fc0e65b 100644 --- a/server.go +++ b/server.go @@ -2,15 +2,14 @@ package rsocket import ( "context" - "crypto/tls" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/fragmentation" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/internal/session" "github.com/rsocket/rsocket-go/internal/socket" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/lease" "github.com/rsocket/rsocket-go/logger" ) @@ -38,46 +37,21 @@ type ( // Resume enable resume for current server. Resume(opts ...OpServerResume) ServerBuilder // Acceptor register server acceptor which is used to handle incoming RSockets. - Acceptor(acceptor ServerAcceptor) ServerTransportBuilder + Acceptor(acceptor ServerAcceptor) ToServerStarter // OnStart register a handler when serve success. OnStart(onStart func()) ServerBuilder } - // ServerTransportBuilder is used to build a RSocket server with custom Transport string. - ServerTransportBuilder interface { - // Transport specify transport string. - Transport(transport string) Start + // ToServerStarter is used to build a RSocket server with custom Transport string. + ToServerStarter interface { + // Transport specify transport generator func. + Transport(t transport.ServerTransportFunc) Start } // Start start a RSocket server. Start interface { // Serve serve RSocket server. Serve(ctx context.Context) error - // Serve serve RSocket server with TLS. - // - // You can generate cert.pem and key.pem for local testing: - // - // go run $GOROOT/src/crypto/tls/generate_cert.go --host localhost - // - // Load X509 - // cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") - // if err != nil { - // panic(err) - // } - // // Init TLS configuration. - // tc := &tls.Config{ - // MinVersion: tls.VersionTLS12, - // CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - // PreferServerCipherSuites: true, - // CipherSuites: []uint16{ - // tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_RSA_WITH_AES_256_CBC_SHA, - // }, - // Certificates: []tls.Certificate{cert}, - // } - ServeTLS(ctx context.Context, c *tls.Config) error } ) @@ -99,9 +73,9 @@ type serverResumeOptions struct { } type server struct { + tp transport.ServerTransportFunc resumeOpts *serverResumeOptions fragment int - addr string acc ServerAcceptor sm *session.Manager done chan struct{} @@ -134,34 +108,23 @@ func (p *server) Fragment(mtu int) ServerBuilder { return p } -func (p *server) Acceptor(acceptor ServerAcceptor) ServerTransportBuilder { +func (p *server) Acceptor(acceptor ServerAcceptor) ToServerStarter { p.acc = acceptor return p } -func (p *server) Transport(transport string) Start { - p.addr = transport +func (p *server) Transport(t transport.ServerTransportFunc) Start { + p.tp = t return p } -func (p *server) ServeTLS(ctx context.Context, c *tls.Config) error { - return p.serve(ctx, c) -} - func (p *server) Serve(ctx context.Context) error { - return p.serve(ctx, nil) -} - -func (p *server) serve(ctx context.Context, tc *tls.Config) error { - u, err := transport.ParseURI(p.addr) - if err != nil { - return err - } - err = fragmentation.IsValidFragment(p.fragment) + err := fragmentation.IsValidFragment(p.fragment) if err != nil { return err } - t, err := u.MakeServerTransport(tc) + + t, err := p.tp(ctx) if err != nil { return err } @@ -173,8 +136,8 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { go func(ctx context.Context) { _ = p.loopCleanSession(ctx) }(ctx) - - t.Accept(func(ctx context.Context, tp *transport.Transport) { + t.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) socketChan := make(chan socket.ServerSocket, 1) defer func() { select { @@ -207,9 +170,9 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { } switch frame := first.(type) { - case *framing.FrameResume: + case *framing.ResumeFrame: p.doResume(frame, tp, socketChan) - case *framing.FrameSetup: + case *framing.SetupFrame: sendingSocket, err := p.doSetup(frame, tp, socketChan) if err != nil { _ = tp.Send(err, true) @@ -222,13 +185,13 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { } }(ctx, sendingSocket) default: - err := framing.NewFrameError(0, common.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) + err := framing.NewWriteableErrorFrame(0, core.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) _ = tp.Send(err, true) _ = tp.Close() return } if err := tp.Start(ctx); err != nil { - logger.Warnf("transport exit: %s\n", err.Error()) + logger.Warnf("transport exit: %+v\n", err) } }) @@ -242,31 +205,28 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { return t.Listen(ctx, serveNotifier) } -func (p *server) doSetup( - frame *framing.FrameSetup, - tp *transport.Transport, - socketChan chan<- socket.ServerSocket, -) (sendingSocket socket.ServerSocket, err *framing.FrameError) { - if frame.Header().Flag().Check(framing.FlagLease) && p.leases == nil { - err = framing.NewFrameError(0, common.ErrorCodeUnsupportedSetup, errUnavailableLease) +func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) (sendingSocket socket.ServerSocket, err *framing.WriteableErrorFrame) { + + if frame.Header().Flag().Check(core.FlagLease) && p.leases == nil { + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeUnsupportedSetup, errUnavailableLease) return } - isResume := frame.Header().Flag().Check(framing.FlagResume) + isResume := frame.Header().Flag().Check(core.FlagResume) // 1. receive a token but server doesn't support resume. if isResume && !p.resumeOpts.enable { - err = framing.NewFrameError(0, common.ErrorCodeUnsupportedSetup, errUnavailableResume) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeUnsupportedSetup, errUnavailableResume) return } - rawSocket := socket.NewServerDuplexRSocket(p.fragment, p.leases) + rawSocket := socket.NewServerDuplexConnection(p.fragment, p.leases) // 2. no resume if !isResume { - sendingSocket = socket.NewServer(rawSocket) + sendingSocket = socket.NewSimpleServerSocket(rawSocket) if responder, e := p.acc(frame, sendingSocket); e != nil { - err = framing.NewFrameError(0, common.ErrorCodeRejectedSetup, []byte(e.Error())) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedSetup, []byte(e.Error())) } else { sendingSocket.SetResponder(responder) sendingSocket.SetTransport(tp) @@ -279,7 +239,7 @@ func (p *server) doSetup( // 3. resume reject because of duplicated token. if _, ok := p.sm.Load(token); ok { - err = framing.NewFrameError(0, common.ErrorCodeRejectedSetup, errDuplicatedSetupToken) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedSetup, errDuplicatedSetupToken) return } @@ -288,10 +248,10 @@ func (p *server) doSetup( sendingSocket = socket.NewServerResume(rawSocket, token) if responder, e := p.acc(frame, sendingSocket); e != nil { switch vv := e.(type) { - case *framing.FrameError: - err = framing.NewFrameError(0, vv.ErrorCode(), vv.ErrorData()) + case *framing.ErrorFrame: + err = framing.NewWriteableErrorFrame(0, vv.ErrorCode(), vv.ErrorData()) default: - err = framing.NewFrameError(0, common.ErrorCodeInvalidSetup, []byte(e.Error())) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeInvalidSetup, []byte(e.Error())) } } else { sendingSocket.SetResponder(responder) @@ -301,21 +261,21 @@ func (p *server) doSetup( return } -func (p *server) doResume(frame *framing.FrameResume, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { - var sending framing.Frame +func (p *server) doResume(frame *framing.ResumeFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { + var sending core.WriteableFrame if !p.resumeOpts.enable { - sending = framing.NewFrameError(0, common.ErrorCodeRejectedResume, errUnavailableResume) + sending = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedResume, errUnavailableResume) } else if s, ok := p.sm.Load(frame.Token()); ok { - sending = framing.NewResumeOK(0) + sending = framing.NewWriteableResumeOKFrame(0) s.Socket().SetTransport(tp) socketChan <- s.Socket() if logger.IsDebugEnabled() { logger.Debugf("recover session: %s\n", s) } } else { - sending = framing.NewFrameError( + sending = framing.NewWriteableErrorFrame( 0, - common.ErrorCodeRejectedResume, + core.ErrorCodeRejectedResume, []byte("no such session"), ) } @@ -358,16 +318,16 @@ func (p *server) destroySessions() { } func (p *server) doCleanSession() { - deads := make(chan *session.Session) - go func(deads chan *session.Session) { - for it := range deads { + deadSessions := make(chan *session.Session) + go func(deadSessions chan *session.Session) { + for it := range deadSessions { if err := it.Close(); err != nil { logger.Warnf("close dead session failed: %s\n", err) } else if logger.IsDebugEnabled() { logger.Debugf("close dead session success: %s\n", it) } } - }(deads) + }(deadSessions) var cur *session.Session for p.sm.Len() > 0 { cur = p.sm.Pop() @@ -376,9 +336,9 @@ func (p *server) doCleanSession() { p.sm.Push(cur) break } - deads <- cur + deadSessions <- cur } - close(deads) + close(deadSessions) } // WithServerResumeSessionDuration sets resume session duration for RSocket server. diff --git a/transporter.go b/transporter.go new file mode 100644 index 0000000..9d1510b --- /dev/null +++ b/transporter.go @@ -0,0 +1,190 @@ +package rsocket + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "os" + + "github.com/rsocket/rsocket-go/core/transport" +) + +const DefaultUnixSockPath = "/var/run/rsocket.sock" +const DefaultPort = 7878 + +type TcpClientBuilder struct { + addr string + tlsCfg *tls.Config +} + +type TcpServerBuilder struct { + addr string + tlsCfg *tls.Config +} + +type WebsocketClientBuilder struct { + url string + tlsCfg *tls.Config + header http.Header +} + +type WebsocketServerBuilder struct { + addr string + path string + tlsConfig *tls.Config +} + +type UnixClientBuilder struct { + path string +} + +type UnixServerBuilder struct { + path string +} + +func (us *UnixServerBuilder) SetPath(path string) *UnixServerBuilder { + us.path = path + return us +} + +func (us *UnixServerBuilder) Build() transport.ServerTransportFunc { + return func(ctx context.Context) (transport.ServerTransport, error) { + if _, err := os.Stat(us.path); !os.IsNotExist(err) { + return nil, err + } + return transport.NewTcpServerTransportWithAddr("unix", us.path, nil), nil + } +} + +func (uc *UnixClientBuilder) SetPath(path string) *UnixClientBuilder { + uc.path = path + return uc +} + +func (uc UnixClientBuilder) Build() transport.ClientTransportFunc { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewTcpClientTransportWithAddr("unix", uc.path, nil) + } +} + +func (ws *WebsocketServerBuilder) SetAddr(addr string) *WebsocketServerBuilder { + ws.addr = addr + return ws +} + +func (ws *WebsocketServerBuilder) SetPath(path string) *WebsocketServerBuilder { + ws.path = path + return ws +} + +func (ws *WebsocketServerBuilder) SetTlsConfig(c *tls.Config) *WebsocketServerBuilder { + ws.tlsConfig = c + return ws +} + +func (ws *WebsocketServerBuilder) Build() transport.ServerTransportFunc { + return func(ctx context.Context) (transport.ServerTransport, error) { + return transport.NewWebsocketServerTransportWithAddr(ws.addr, ws.path, ws.tlsConfig), nil + } +} + +func (wc *WebsocketClientBuilder) SetTlsConfig(c *tls.Config) *WebsocketClientBuilder { + wc.tlsCfg = c + return wc +} + +func (wc *WebsocketClientBuilder) SetUrl(url string) *WebsocketClientBuilder { + wc.url = url + return wc +} + +func (wc *WebsocketClientBuilder) SetHeader(h http.Header) *WebsocketClientBuilder { + wc.header = h + return wc +} + +func (wc *WebsocketClientBuilder) Build() transport.ClientTransportFunc { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewWebsocketClientTransport(wc.url, wc.tlsCfg, wc.header) + } +} + +func (ts *TcpServerBuilder) SetHostAndPort(host string, port int) *TcpServerBuilder { + ts.addr = fmt.Sprintf("%s:%d", host, port) + return ts +} + +func (ts *TcpServerBuilder) SetAddr(addr string) *TcpServerBuilder { + ts.addr = addr + return ts +} + +func (ts *TcpServerBuilder) SetTlsConfig(c *tls.Config) *TcpServerBuilder { + ts.tlsCfg = c + return ts +} + +func (ts *TcpServerBuilder) Build() transport.ServerTransportFunc { + return func(ctx context.Context) (transport.ServerTransport, error) { + return transport.NewTcpServerTransportWithAddr("tcp", ts.addr, ts.tlsCfg), nil + } +} + +func (tc *TcpClientBuilder) SetHostAndPort(host string, port int) *TcpClientBuilder { + tc.addr = fmt.Sprintf("%s:%d", host, port) + return tc +} + +func (tc *TcpClientBuilder) SetAddr(addr string) *TcpClientBuilder { + tc.addr = addr + return tc +} + +func (tc *TcpClientBuilder) SetTlsConfig(c *tls.Config) *TcpClientBuilder { + tc.tlsCfg = c + return tc +} + +func (tc *TcpClientBuilder) Build() transport.ClientTransportFunc { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewTcpClientTransportWithAddr("tcp", tc.addr, tc.tlsCfg) + } +} + +func TcpClient() *TcpClientBuilder { + return &TcpClientBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), + } +} + +func TcpServer() *TcpServerBuilder { + return &TcpServerBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), + } +} + +func WebsocketClient() *WebsocketClientBuilder { + return &WebsocketClientBuilder{ + url: fmt.Sprintf("ws://127.0.0.1:%d", DefaultPort), + } +} + +func WebsocketServer() *WebsocketServerBuilder { + return &WebsocketServerBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), + path: "/", + } +} + +func UnixClient() *UnixClientBuilder { + return &UnixClientBuilder{ + path: DefaultUnixSockPath, + } +} + +func UnixServer() *UnixServerBuilder { + return &UnixServerBuilder{ + path: DefaultUnixSockPath, + } +} diff --git a/transporter_test.go b/transporter_test.go new file mode 100644 index 0000000..8e39299 --- /dev/null +++ b/transporter_test.go @@ -0,0 +1,71 @@ +package rsocket_test + +import ( + "context" + "fmt" + "net/http" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/rsocket/rsocket-go" + "github.com/stretchr/testify/assert" +) + +var fakeSockFile string + +func init() { + fmt.Println(os.TempDir()) + fakeSockFile = fmt.Sprintf("%s/test-rsocket-%s.sock", strings.TrimRight(os.TempDir(), "/"), uuid.New().String()) +} + +func TestUnixServer(t *testing.T) { + defer os.Remove(fakeSockFile) + u := rsocket.UnixServer().SetPath(fakeSockFile).Build() + assert.NotNil(t, u) + _, err := u(context.Background()) + assert.NoError(t, err) +} + +func TestUnixClient(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.UnixClient().SetPath(fakeSockFile).Build() + }) +} + +func TestTcpClient(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.TcpClient(). + SetAddr(":7878"). + SetHostAndPort("127.0.0.1", 7878). + Build() + }) +} + +func TestTcpServerBuilder(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.TcpServer().SetAddr(":7878").Build() + }) +} + +func TestWebsocketClient(t *testing.T) { + assert.NotPanics(t, func() { + h := make(http.Header) + h.Set("x-foo-bar", "qux") + rsocket.WebsocketClient(). + SetUrl("ws://127.0.0.1:8080/fake/path"). + SetHeader(h). + Build() + }) +} + +func TestWebsocketServer(t *testing.T) { + assert.NotPanics(t, func() { + tp := rsocket.WebsocketServer(). + SetAddr(":7878"). + SetPath("/fake"). + Build() + assert.NotNil(t, tp) + }) +}