From b6cdbbf4ccf102cb54dd2c268df7331ff5673847 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Thu, 28 May 2020 22:57:23 +0800 Subject: [PATCH 01/26] Setup coveralls. --- .travis.yml | 13 +++--- README.md | 3 +- extension/composite_metadata.go | 2 +- internal/common/bytebuffer.go | 41 ++++++++++------- internal/common/bytebuffer_test.go | 52 ++++++++++++++++++++++ internal/common/errors.go | 4 +- internal/common/u24.go | 41 ++++++++++++++--- internal/common/u24_benchmark_test.go | 17 -------- internal/common/u24_test.go | 63 +++++++++++++++++++++++++-- internal/fragmentation/splitter.go | 2 +- internal/transport/connection_tcp.go | 2 +- 11 files changed, 184 insertions(+), 56 deletions(-) create mode 100644 internal/common/bytebuffer_test.go delete mode 100644 internal/common/u24_benchmark_test.go diff --git a/.travis.yml b/.travis.yml index 4af9046..57a5dde 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,14 +3,13 @@ language: go go: - 1.x -env: - - GO111MODULE=on - -before_install: - - go get -u golang.org/x/lint/golint +install: + - go get golang.org/x/lint/golint - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.27.0 + - go get golang.org/x/tools/cmd/cover + - go get github.com/mattn/goveralls script: -# - golint ./... - golangci-lint run ./... - - go test -race -count=1 . -v + - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 . + - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/README.md b/README.md index 5822ad6..821ae84 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,10 @@ ![logo](./logo.jpg) [![Build Status](https://travis-ci.com/rsocket/rsocket-go.svg?branch=master)](https://travis-ci.com/rsocket/rsocket-go) +[![Coverage Status](https://coveralls.io/repos/github/rsocket/rsocket-go/badge.svg?branch=master)](https://coveralls.io/github/rsocket/rsocket-go?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/rsocket/rsocket-go)](https://goreportcard.com/report/github.com/rsocket/rsocket-go) [![Slack](https://img.shields.io/badge/slack-rsocket--go-blue.svg)](https://rsocket.slack.com/messages/C9VGZ5MV3) [![GoDoc](https://godoc.org/github.com/rsocket/rsocket-go?status.svg)](https://godoc.org/github.com/rsocket/rsocket-go) -[![Go Report Card](https://goreportcard.com/badge/github.com/rsocket/rsocket-go)](https://goreportcard.com/report/github.com/rsocket/rsocket-go) [![License](https://img.shields.io/github/license/rsocket/rsocket-go.svg)](https://github.com/rsocket/rsocket-go/blob/master/LICENSE) [![GitHub Release](https://img.shields.io/github/release-pre/rsocket/rsocket-go.svg)](https://github.com/rsocket/rsocket-go/releases) diff --git a/extension/composite_metadata.go b/extension/composite_metadata.go index acbbd7c..99cb059 100644 --- a/extension/composite_metadata.go +++ b/extension/composite_metadata.go @@ -120,7 +120,7 @@ func (c *CompositeMetadataBuilder) Build() (CompositeMetadata, error) { } metadata := c.v[i] metadataLen := len(metadata) - bf.Write(common.NewUint24(metadataLen).Bytes()) + bf.Write(common.MustNewUint24(metadataLen).Bytes()) if metadataLen > 0 { bf.Write(metadata) } diff --git a/internal/common/bytebuffer.go b/internal/common/bytebuffer.go index dcf9500..a328f18 100644 --- a/internal/common/bytebuffer.go +++ b/internal/common/bytebuffer.go @@ -8,40 +8,49 @@ import ( // ByteBuff provides byte buffer, which can be used for minimizing. type ByteBuff bytes.Buffer -func (p *ByteBuff) pp() *bytes.Buffer { - return (*bytes.Buffer)(p) +func (b *ByteBuff) pp() *bytes.Buffer { + return (*bytes.Buffer)(b) } // Len returns size of ByteBuff. -func (p *ByteBuff) Len() (n int) { - return p.pp().Len() +func (b *ByteBuff) Len() (n int) { + return b.pp().Len() } // WriteTo write bytes to writer. -func (p *ByteBuff) WriteTo(w io.Writer) (int64, error) { - return p.pp().WriteTo(w) +func (b *ByteBuff) WriteTo(w io.Writer) (int64, error) { + return b.pp().WriteTo(w) } // Writer write bytes to current ByteBuff. -func (p *ByteBuff) Write(bs []byte) (int, error) { - return p.pp().Write(bs) +func (b *ByteBuff) Write(bs []byte) (int, error) { + return b.pp().Write(bs) } // WriteUint24 encode and write Uint24 to current ByteBuff. -func (p *ByteBuff) WriteUint24(n int) (err error) { - v := NewUint24(n) - _, err = p.Write(v[:]) +func (b *ByteBuff) WriteUint24(n int) (err error) { + if n > MaxUint24 { + return errExceedMaxUint24 + } + v := MustNewUint24(n) + _, err = b.Write(v[:]) return } -// WriteByte write a byte to current ByteBuff. -func (p *ByteBuff) WriteByte(b byte) error { - return p.pp().WriteByte(b) +// WriteByte writes a byte to current ByteBuff. +func (b *ByteBuff) WriteByte(c byte) error { + return b.pp().WriteByte(c) +} + +// WriteString writes a string to current ByteBuff. +func (b *ByteBuff) WriteString(s string) (err error) { + _, err = b.pp().Write([]byte(s)) + return } // Bytes returns all bytes in ByteBuff. -func (p *ByteBuff) Bytes() []byte { - return p.pp().Bytes() +func (b *ByteBuff) Bytes() []byte { + return b.pp().Bytes() } // NewByteBuff creates a new ByteBuff. diff --git a/internal/common/bytebuffer_test.go b/internal/common/bytebuffer_test.go new file mode 100644 index 0000000..ea0ce94 --- /dev/null +++ b/internal/common/bytebuffer_test.go @@ -0,0 +1,52 @@ +package common_test + +import ( + "os" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestByteBuff_Bytes(t *testing.T) { + data := []byte("foobar") + b := common.NewByteBuff() + wrote, err := b.Write(data) + assert.NoError(t, err, "write failed") + assert.Equal(t, len(data), wrote, "wrong wrote size") + assert.Equal(t, data, b.Bytes(), "wrong data") +} + +func TestByteBuff_WriteUint24(t *testing.T) { + b := common.NewByteBuff() + var err error + err = b.WriteUint24(0) + assert.NoError(t, err, "write uint24 failed") + err = b.WriteUint24(common.MaxUint24) + assert.NoError(t, err, "write maximum uint24 failed") + err = b.WriteUint24(0x01FFFFFF) + assert.Error(t, err, "should write failed") +} + +func TestByteBuff_Len(t *testing.T) { + b := common.NewByteBuff() + // 3+1+6 + _ = b.WriteUint24(1) + _ = b.WriteByte('c') + _, _ = b.Write([]byte("foobar")) + assert.Equal(t, 10, b.Len(), "wrong length") +} + +func TestByteBuff_WriteTo(t *testing.T) { + b := common.NewByteBuff() + f, err := os.OpenFile("/dev/null", os.O_WRONLY, os.ModeAppend) + assert.NoError(t, err, "open /dev/null failed") + defer f.Close() + // 16MB + s := common.RandAlphanumeric(16 * 1024 * 1024) + err = b.WriteString(s) + assert.NoError(t, err) + n, err := b.WriteTo(f) + assert.NoError(t, err, "WriteTo failed") + assert.Equal(t, len(s), int(n), "wrong length") +} diff --git a/internal/common/errors.go b/internal/common/errors.go index e84487b..976e9dc 100644 --- a/internal/common/errors.go +++ b/internal/common/errors.go @@ -14,8 +14,8 @@ type CustomError interface { ErrorData() []byte } -func (p ErrorCode) String() string { - switch p { +func (e ErrorCode) String() string { + switch e { case ErrorCodeInvalidSetup: return "INVALID_SETUP" case ErrorCodeUnsupportedSetup: diff --git a/internal/common/u24.go b/internal/common/u24.go index f2ca3a3..daf2860 100644 --- a/internal/common/u24.go +++ b/internal/common/u24.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "io" ) @@ -8,7 +9,20 @@ import ( // MaxUint24 is the max value of Uint24. const MaxUint24 = 16777215 -var errMaxUint24 = fmt.Errorf("uint24 exceed max value: %d", MaxUint24) +var ( + errExceedMaxUint24 = fmt.Errorf("uint24 exceed max value: %d", MaxUint24) + errNegativeNumber = errors.New("negative number is illegal") +) + +// IsExceedMaximumUint24Error returns true if exceed maximum Uint24. (16777215) +func IsExceedMaximumUint24Error(err error) bool { + return err == errExceedMaxUint24 +} + +// IsNegativeUint24Error returns true if number is negative. +func IsNegativeUint24Error(err error) bool { + return err == errNegativeNumber +} // Uint24 is 3 bytes unsigned integer. type Uint24 [3]byte @@ -29,14 +43,27 @@ func (p Uint24) AsInt() int { return int(p[0])<<16 + int(p[1])<<8 + int(p[2]) } +// MustNewUint24 returns a new uint24. +func MustNewUint24(n int) Uint24 { + v, err := NewUint24(n) + if err != nil { + panic(err) + } + return v +} + // NewUint24 returns a new uint24. -func NewUint24(n int) (v Uint24) { - if n > MaxUint24 { - panic(errMaxUint24) +func NewUint24(v int) (n Uint24, err error) { + if v < 0 { + err = errNegativeNumber + return + } + if v > MaxUint24 { + err = errExceedMaxUint24 } - v[0] = byte(n >> 16) - v[1] = byte(n >> 8) - v[2] = byte(n) + n[0] = byte(v >> 16) + n[1] = byte(v >> 8) + n[2] = byte(v) return } diff --git a/internal/common/u24_benchmark_test.go b/internal/common/u24_benchmark_test.go deleted file mode 100644 index 39df2fd..0000000 --- a/internal/common/u24_benchmark_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package common - -import "testing" - -func BenchmarkNewUint24(b *testing.B) { - n := RandIntn(MaxUint24) - for i := 0; i < b.N; i++ { - _ = NewUint24(n) - } -} - -func BenchmarkUint24_AsInt(b *testing.B) { - n := NewUint24(RandIntn(MaxUint24)) - for i := 0; i < b.N; i++ { - _ = n.AsInt() - } -} diff --git a/internal/common/u24_test.go b/internal/common/u24_test.go index ff59106..fdfb9fb 100644 --- a/internal/common/u24_test.go +++ b/internal/common/u24_test.go @@ -1,14 +1,71 @@ -package common +package common_test import ( "testing" + . "github.com/rsocket/rsocket-go/internal/common" "github.com/stretchr/testify/assert" ) -func TestUint24(t *testing.T) { +func BenchmarkNewUint24(b *testing.B) { n := RandIntn(MaxUint24) - x := NewUint24(n) + for i := 0; i < b.N; i++ { + _ = MustNewUint24(n) + } +} + +func BenchmarkUint24_AsInt(b *testing.B) { + n := MustNewUint24(RandIntn(MaxUint24)) + for i := 0; i < b.N; i++ { + _ = n.AsInt() + } +} + +func TestMustNewUint24(t *testing.T) { + func() { + defer func() { + e := recover() + assert.True(t, IsExceedMaximumUint24Error(e.(error)), "should failed") + }() + _ = MustNewUint24(MaxUint24 + 1) + }() + func() { + defer func() { + e := recover() + assert.True(t, IsNegativeUint24Error(e.(error)), "should failed") + }() + _ = MustNewUint24(-1) + }() +} + +func TestUint24(t *testing.T) { + testSingle(t, 0) + for range [1_000_000]struct{}{} { + testSingle(t, RandIntn(MaxUint24)) + } + testSingle(t, MaxUint24) + // negative + _, err := NewUint24(-1) + assert.Error(t, err, "negative number should failed") + + // over maximum number + _, err = NewUint24(MaxUint24 + 1) + assert.Error(t, err, "over maximum number should failed") +} + +func TestUint24_WriteTo(t *testing.T) { + for _, n := range []int{0, 1, RandIntn(MaxUint24), MaxUint24} { + v := MustNewUint24(n) + b := NewByteBuff() + wrote, err := v.WriteTo(b) + assert.NoError(t, err, "write uint24 failed") + assert.Equal(t, int64(3), wrote, "wrote bytes length should be 3") + assert.Equal(t, n, NewUint24Bytes(b.Bytes()).AsInt(), "bad uint24 result") + } +} + +func testSingle(t *testing.T, n int) { + x := MustNewUint24(n) assert.Equal(t, n, x.AsInt(), "bad new from int") y := NewUint24Bytes(x.Bytes()) assert.Equal(t, n, y.AsInt(), "bad new from bytes") diff --git a/internal/fragmentation/splitter.go b/internal/fragmentation/splitter.go index 299612e..124736e 100644 --- a/internal/fragmentation/splitter.go +++ b/internal/fragmentation/splitter.go @@ -70,7 +70,7 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx } if wroteM > 0 { // set metadata length - x := common.NewUint24(wroteM) + x := common.MustNewUint24(wroteM) for i := 0; i < len(x); i++ { if idx == 0 { bf.Bytes()[i+skip] = x[i] diff --git a/internal/transport/connection_tcp.go b/internal/transport/connection_tcp.go index e045278..4aeda3b 100644 --- a/internal/transport/connection_tcp.go +++ b/internal/transport/connection_tcp.go @@ -76,7 +76,7 @@ func (p *tcpConn) Write(frame framing.Frame) (err error) { if p.counter != nil && frame.CanResume() { p.counter.incrWriteBytes(size) } - _, err = common.NewUint24(size).WriteTo(p.writer) + _, err = common.MustNewUint24(size).WriteTo(p.writer) if err != nil { err = errors.Wrap(err, "write frame failed") return From cec8eb2139ad5d64b0df5234b6f000a86de970ab Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Mon, 1 Jun 2020 22:41:13 +0800 Subject: [PATCH 02/26] Add some unit tests. --- extension/authentication.go | 48 ++++++++++++---- extension/authentication_test.go | 38 ++++++++++++- internal/common/version.go | 28 ++++++++-- internal/common/version_test.go | 43 ++++++++++++++ internal/socket/duplex.go | 13 +++-- internal/socket/stream_id.go | 12 ++-- internal/socket/stream_id_test.go | 7 +++ payload/payload.go | 41 ++++++++++---- payload/payload_raw.go | 31 ++++++++++- payload/payload_str.go | 12 +++- payload/payload_test.go | 93 ++++++++++++++++++++++++++----- 11 files changed, 308 insertions(+), 58 deletions(-) create mode 100644 internal/common/version_test.go create mode 100644 internal/socket/stream_id_test.go diff --git a/extension/authentication.go b/extension/authentication.go index e8dfa0a..98fae6a 100644 --- a/extension/authentication.go +++ b/extension/authentication.go @@ -15,7 +15,10 @@ const ( _authenticationBearer wellKnownAuthenticationType = 0x01 ) -var errInvalidAuthBytes = errors.New("invalid authentication bytes") +var ( + _errInvalidAuthBytes = errors.New("invalid authentication bytes") + _errAuthTypeLengthExceed = errors.New("invalid authType length: exceed 127 bytes") +) type wellKnownAuthenticationType uint8 @@ -55,14 +58,23 @@ func (a Authentication) IsWellKnown() (ok bool) { } // NewAuthentication creates a new Authentication -func NewAuthentication(authType string, payload []byte) Authentication { +func NewAuthentication(authType string, payload []byte) (*Authentication, error) { if len(authType) > 0x7F { - panic("illegal authType length: exceed 127 bytes") + return nil, _errAuthTypeLengthExceed } - return Authentication{ + return &Authentication{ typ: authType, payload: payload, + }, nil +} + +// MustNewAuthentication creates a new Authentication +func MustNewAuthentication(authType string, payload []byte) *Authentication { + auth, err := NewAuthentication(authType, payload) + if err != nil { + panic(err) } + return auth } // Bytes encodes current Authentication to byte slice. @@ -78,28 +90,42 @@ func (a Authentication) Bytes() (raw []byte) { } // ParseAuthentication parse Authentication from raw bytes. -func ParseAuthentication(raw []byte) (auth Authentication, err error) { +func ParseAuthentication(raw []byte) (auth *Authentication, err error) { totals := len(raw) if totals < 1 { - err = errInvalidAuthBytes + err = _errInvalidAuthBytes return } first := raw[0] n := 0x7F & first if first&0x80 != 0 { - auth.typ = wellKnownAuthenticationType(n).String() - auth.payload = raw[1:] + auth = &Authentication{ + typ: wellKnownAuthenticationType(n).String(), + payload: raw[1:], + } return } if totals < int(n+1) { - err = errInvalidAuthBytes + err = _errInvalidAuthBytes return } - auth.typ = string(raw[1 : 1+n]) - auth.payload = raw[n+1:] + auth = &Authentication{ + typ: string(raw[1 : 1+n]), + payload: raw[n+1:], + } return } +// IsInvalidAuthenticationBytes returns true if input error is for invalid bytes. +func IsInvalidAuthenticationBytes(err error) bool { + return err == _errInvalidAuthBytes +} + +// IsAuthTypeLengthExceed returns true if input error is for AuthType length exceed. +func IsAuthTypeLengthExceed(err error) bool { + return err == _errAuthTypeLengthExceed +} + func parseWellKnownAuthenticateType(typ string) (au wellKnownAuthenticationType, ok bool) { switch typ { case _simpleAuth: diff --git a/extension/authentication_test.go b/extension/authentication_test.go index 34d65c4..35151e1 100644 --- a/extension/authentication_test.go +++ b/extension/authentication_test.go @@ -1,14 +1,46 @@ package extension_test import ( + "math/rand" + "strings" "testing" + "time" "github.com/rsocket/rsocket-go/extension" "github.com/stretchr/testify/assert" ) +func TestNewAuthentication(t *testing.T) { + payload := []byte("foobar") + + for _, authType := range []string{"bearer", "simple"} { + au, err := extension.NewAuthentication(authType, payload) + assert.NoError(t, err, "create Authentication failed!") + assert.Equal(t, authType, au.Type(), "wrong type") + assert.Equal(t, payload, au.Payload(), "wrong payload") + assert.Equal(t, true, au.IsWellKnown(), "well-known should be true") + b := au.Bytes() + au2, err := extension.ParseAuthentication(b) + assert.NoError(t, err, "parse Authentication failed!") + assert.Equal(t, au.Type(), au2.Type(), "authType doesn't match") + assert.Equal(t, au.Payload(), au2.Payload(), "payload doesn't match") + } + + _, err := extension.NewAuthentication(strings.Repeat("0", 128), payload) + assert.True(t, extension.IsAuthTypeLengthExceed(err), "should error") +} + +func TestParseAuthentication(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + input := make([]byte, 2) + rand.Read(input) + _, err := extension.ParseAuthentication(input) + assert.True(t, extension.IsInvalidAuthenticationBytes(err), "should error") + +} + func TestAuthentication(t *testing.T) { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) raw := au.Bytes() au2, err := extension.ParseAuthentication(raw) assert.NoError(t, err, "bad authentication bytes") @@ -17,13 +49,13 @@ func TestAuthentication(t *testing.T) { func BenchmarkAuthentication_Bytes(b *testing.B) { for i := 0; i < b.N; i++ { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) _ = au.Bytes() } } func BenchmarkParseAuthentication(b *testing.B) { - au := extension.NewAuthentication("simple", []byte("foobar")) + au := extension.MustNewAuthentication("simple", []byte("foobar")) raw := au.Bytes() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/internal/common/version.go b/internal/common/version.go index 9f83d81..52365e3 100644 --- a/internal/common/version.go +++ b/internal/common/version.go @@ -2,8 +2,9 @@ package common import ( "encoding/binary" - "fmt" "io" + "strconv" + "strings" ) // DefaultVersion is default protocol version. @@ -33,14 +34,29 @@ func (p Version) Minor() uint16 { // WriteTo write raw version bytes to a writer. func (p Version) WriteTo(w io.Writer) (n int64, err error) { - var wrote int - wrote, err = w.Write(p.Bytes()) - if err == nil { - n += int64(wrote) + err = binary.Write(w, binary.BigEndian, p[0]) + if err != nil { + return } + err = binary.Write(w, binary.BigEndian, p[1]) + if err != nil { + return + } + n = 4 return } func (p Version) String() string { - return fmt.Sprintf("%d.%d", p[0], p[1]) + b := strings.Builder{} + b.WriteString(strconv.Itoa(int(p[0]))) + b.WriteByte('.') + b.WriteString(strconv.Itoa(int(p[1]))) + return b.String() +} + +// NewVersion creates a new Version from major and minor. +func NewVersion(major, minor uint16) Version { + return Version{ + major, minor, + } } diff --git a/internal/common/version_test.go b/internal/common/version_test.go new file mode 100644 index 0000000..f9ad917 --- /dev/null +++ b/internal/common/version_test.go @@ -0,0 +1,43 @@ +package common_test + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func BenchmarkVersion_String(b *testing.B) { + v := common.NewVersion(2, 3) + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.String() + } +} + +func TestVersion(t *testing.T) { + var ( + major uint16 = 2 + minor uint16 = 1 + ) + v := common.NewVersion(major, minor) + assert.Equal(t, "2.1", v.String()) + assert.Equal(t, uint16(2), v.Major(), "wrong major version") + assert.Equal(t, uint16(1), v.Minor(), "wrong minor version") + checkBytes(t, v.Bytes(), 2, 1) + w := bytes.Buffer{} + n, err := v.WriteTo(&w) + assert.NoError(t, err, "write version failed") + assert.Equal(t, int64(4), n, "wrong wrote bytes length") + checkBytes(t, v.Bytes(), 2, 1) +} + +func checkBytes(t *testing.T, b []byte, expectMajor, expectMinor uint16) { + assert.Equal(t, 4, len(b), "wrong version bytes") + major := binary.BigEndian.Uint16(b[:2]) + minor := binary.BigEndian.Uint16(b[2:]) + assert.Equal(t, expectMajor, major, "wrong major version") + assert.Equal(t, expectMinor, minor, "wrong minor version") +} diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index c738482..d9ca4d9 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -32,6 +32,11 @@ var ( unsupportedRequestChannel = []byte("Request-Channel not implemented.") ) +// IsSocketClosedError returns true if input error is for socket closed. +func IsSocketClosedError(err error) bool { + return err == errSocketClosed +} + // DuplexRSocket represents a socket of RSocket which can be a requester or a responder. type DuplexRSocket struct { counter *transport.Counter @@ -40,7 +45,7 @@ type DuplexRSocket struct { outsPriority []framing.Frame responder Responder messages *u32map - sids streamIDs + sids StreamID mtu int fragments *u32map // key=streamID, value=Joiner closed *atomic.Bool @@ -60,8 +65,8 @@ func (p *DuplexRSocket) SetError(e error) { func (p *DuplexRSocket) nextStreamID() (sid uint32) { var lap1st bool for { - // There's no necessery to check StreamID conflicts. - sid, lap1st = p.sids.next() + // There's no required to check StreamID conflicts. + sid, lap1st = p.sids.Next() if lap1st { return } @@ -611,7 +616,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa func (p *DuplexRSocket) writeError(sid uint32, e error) { // ignore sending error because current socket has been closed. - if e == errSocketClosed { + if IsSocketClosedError(e) { return } switch err := e.(type) { diff --git a/internal/socket/stream_id.go b/internal/socket/stream_id.go index 3bec2da..0282cc5 100644 --- a/internal/socket/stream_id.go +++ b/internal/socket/stream_id.go @@ -9,34 +9,34 @@ const ( halfSeed uint64 = 0x40000000 ) -type streamIDs interface { - next() (id uint32, lap1st bool) +type StreamID interface { + Next() (id uint32, lap1st bool) } type serverStreamIDs struct { cur uint64 } -func (p *serverStreamIDs) next() (uint32, bool) { +func (p *serverStreamIDs) Next() (uint32, bool) { // 2,4,6,8... seed := atomic.AddUint64(&p.cur, 1) v := 2 * seed if v != 0 { return uint32(maskStreamID & v), seed <= halfSeed } - return p.next() + return p.Next() } type clientStreamIDs struct { cur uint64 } -func (p *clientStreamIDs) next() (uint32, bool) { +func (p *clientStreamIDs) Next() (uint32, bool) { // 1,3,5,7 seed := atomic.AddUint64(&p.cur, 1) v := 2*(seed-1) + 1 if v != 0 { return uint32(maskStreamID & v), seed <= halfSeed } - return p.next() + return p.Next() } diff --git a/internal/socket/stream_id_test.go b/internal/socket/stream_id_test.go new file mode 100644 index 0000000..7f44e08 --- /dev/null +++ b/internal/socket/stream_id_test.go @@ -0,0 +1,7 @@ +package socket + +import "testing" + +func TestSt(t *testing.T) { + +} diff --git a/payload/payload.go b/payload/payload.go index 53812c9..e91259b 100644 --- a/payload/payload.go +++ b/payload/payload.go @@ -42,18 +42,39 @@ type ( // Clone create a copy of original payload. func Clone(payload Payload) Payload { - ret := &rawPayload{} - if d := payload.Data(); len(d) > 0 { - clone := make([]byte, len(d)) - copy(clone, d) - ret.data = clone + if payload == nil { + return nil } - if m, ok := payload.Metadata(); ok && len(m) > 0 { - clone := make([]byte, len(m)) - copy(clone, m) - ret.metadata = clone + switch v := payload.(type) { + case *rawPayload: + var data []byte + if v.data != nil { + data = make([]byte, len(v.data)) + copy(data, v.data) + } + var metadata []byte + if v.metadata != nil { + metadata = make([]byte, len(v.metadata)) + copy(metadata, v.metadata) + } + return &rawPayload{data: data, metadata: metadata} + case *strPayload: + return &strPayload{data: v.data, metadata: v.metadata} + default: + ret := &rawPayload{} + if d := payload.Data(); len(d) > 0 { + clone := make([]byte, len(d)) + copy(clone, d) + ret.data = clone + } + if m, ok := payload.Metadata(); ok && len(m) > 0 { + clone := make([]byte, len(m)) + copy(clone, m) + ret.metadata = clone + } + return ret } - return ret + } // New create a new payload with bytes. diff --git a/payload/payload_raw.go b/payload/payload_raw.go index 5092a2d..23c563c 100644 --- a/payload/payload_raw.go +++ b/payload/payload_raw.go @@ -2,6 +2,8 @@ package payload import ( "fmt" + "strings" + "unicode/utf8" ) type rawPayload struct { @@ -10,8 +12,33 @@ type rawPayload struct { } func (p *rawPayload) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("Payload{data=%s,metadata=%s}", p.DataUTF8(), m) + bu := strings.Builder{} + bu.WriteString("Payload{data=") + if utf8.Valid(p.data) { + bu.Write(p.data) + } else { + bu.WriteByte('[') + for _, b := range p.data { + bu.WriteString(fmt.Sprintf(" 0x%x", b)) + } + bu.WriteByte(' ') + bu.WriteByte(']') + } + bu.WriteString(",metadata=") + if len(p.metadata) > 0 { + if utf8.Valid(p.metadata) { + bu.Write(p.metadata) + } else { + bu.WriteByte('[') + for _, b := range p.metadata { + bu.WriteString(fmt.Sprintf(" 0x%x", b)) + } + bu.WriteByte(' ') + bu.WriteByte(']') + } + } + bu.WriteByte('}') + return bu.String() } func (p *rawPayload) Metadata() (metadata []byte, ok bool) { diff --git a/payload/payload_str.go b/payload/payload_str.go index 128f423..502c21b 100644 --- a/payload/payload_str.go +++ b/payload/payload_str.go @@ -1,6 +1,8 @@ package payload -import "fmt" +import ( + "strings" +) type strPayload struct { data string @@ -8,7 +10,13 @@ type strPayload struct { } func (p *strPayload) String() string { - return fmt.Sprintf("Payload{data=%s,metadata=%s}", p.data, p.metadata) + bu := strings.Builder{} + bu.WriteString("Payload{data=") + bu.WriteString(p.data) + bu.WriteString("metadata=") + bu.WriteString(p.metadata) + bu.WriteByte('}') + return bu.String() } func (p *strPayload) Metadata() (metadata []byte, ok bool) { diff --git a/payload/payload_test.go b/payload/payload_test.go index 179ba3f..771b5ef 100644 --- a/payload/payload_test.go +++ b/payload/payload_test.go @@ -1,29 +1,94 @@ -package payload +package payload_test import ( "fmt" "testing" + "github.com/rsocket/rsocket-go/payload" "github.com/stretchr/testify/assert" ) -func TestPayload_new(t *testing.T) { +type customPayload [2][]byte + +func (c customPayload) Metadata() (metadata []byte, ok bool) { + return c[1], len(c[1]) > 0 +} + +func (c customPayload) MetadataUTF8() (metadata string, ok bool) { + return string(c[1]), len(c[1]) > 0 +} + +func (c customPayload) Data() []byte { + return c[0] +} + +func (c customPayload) DataUTF8() string { + return string(c[0]) +} + +func TestRawPayload(t *testing.T) { + data, metadata := []byte("hello"), []byte("world") + p := payload.New(data, metadata) + fmt.Println("new binary payload:", p) + assert.Equal(t, data, p.Data(), "wrong data") + m, ok := p.Metadata() + assert.True(t, ok, "ok should be true") + assert.Equal(t, metadata, m, "wrong metadata") + assert.Equal(t, "hello", p.DataUTF8(), "wrong data string") + mu, _ := p.MetadataUTF8() + assert.Equal(t, "world", mu, "wrong metadata string") + + invalid := []byte{0xff, 0xfe, 0xfd} + badPayload := payload.New(invalid, invalid) + s := badPayload.(fmt.Stringer).String() + fmt.Println("no utf8 payload:", s) + assert.NotEmpty(t, s) +} + +func TestStrPayload(t *testing.T) { data, metadata := "hello", "world" - p1 := New([]byte(data), []byte(metadata)) + p := payload.NewString(data, metadata) + fmt.Println("new string payload:", p) + assert.Equal(t, []byte(data), p.Data(), "wrong data") + m, ok := p.Metadata() + assert.True(t, ok, "ok should be true") + assert.Equal(t, []byte(metadata), m, "wrong metadata") + assert.Equal(t, data, p.DataUTF8(), "wrong data string") + mu, _ := p.MetadataUTF8() + assert.Equal(t, metadata, mu, "wrong metadata string") +} + +func TestClone(t *testing.T) { + check := func(p1 payload.Payload) { + p2 := payload.Clone(p1) + assert.Equal(t, p1.Data(), p2.Data(), "bad data") + m1, _ := p1.Metadata() + m2, _ := p2.Metadata() + assert.Equal(t, m1, m2, "bad metadata") + } - assert.Equal(t, data, p1.DataUTF8(), "bad data") - metadata2, ok := p1.MetadataUTF8() - assert.True(t, ok, "bad metadata") - assert.Equal(t, metadata, metadata2, "bad metadata") + check(payload.NewString("hello", "world")) + check(payload.New([]byte("hello"), []byte("world"))) + // Check custom payload + custom := customPayload([2][]byte{[]byte("hello"), []byte("world")}) + check(custom) - p1 = New([]byte(data), nil) - metadata2, ok = p1.MetadataUTF8() - assert.False(t, ok) - assert.Equal(t, "", metadata2) + // Check clone nil + nilCloned := payload.Clone(nil) + assert.Nil(t, nilCloned, "should return nil") } func TestNewFile(t *testing.T) { - pl, err := NewFile("/etc/hosts", nil) - assert.NoError(t, err, "bad file") - fmt.Print(pl.DataUTF8()) + p := payload.MustNewFile("/etc/hosts", nil) + assert.NotEmpty(t, p.Data(), "empty data") + _, err := payload.NewFile("/not/existing", nil) + assert.Error(t, err, "should return error") + + func() { + defer func() { + e := recover() + assert.Error(t, e.(error), "should panic error") + }() + payload.MustNewFile("/not/existing", nil) + }() } From eea0365ff4f46bd058050f1017ee74432ac66b23 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 9 Jun 2020 22:54:41 +0800 Subject: [PATCH 03/26] fix lint. fix fuzz foobar foobar foobar fix fuzz --- fuzz.go | 27 +-- internal/common/{misc.go => common.go} | 0 internal/common/errors_test.go | 28 ++++ internal/common/rand_test.go | 35 ++++ internal/common/u32map.go | 120 ++++++++++++++ internal/common/u32map_test.go | 63 +++++++ internal/common/version_test.go | 2 +- internal/framing/frame_cancel.go | 4 + internal/framing/frame_error.go | 10 +- internal/framing/frame_fnf.go | 3 + internal/framing/frame_lease.go | 3 + internal/framing/frame_payload.go | 4 + internal/framing/frame_request_channel.go | 12 +- internal/framing/frame_request_n.go | 7 +- internal/framing/frame_request_response.go | 3 + internal/framing/frame_request_stream.go | 12 +- internal/framing/frame_resume.go | 11 ++ internal/framing/frame_resume_ok.go | 4 + internal/framing/frame_setup.go | 14 +- internal/framing/frame_test.go | 181 ++++++++++++++++++++- internal/framing/header.go | 13 +- internal/framing/header_test.go | 29 ++-- internal/socket/duplex.go | 18 +- internal/socket/misc.go | 50 ------ internal/socket/misc_test.go | 1 + internal/socket/stream_id_test.go | 38 ++++- rx/mono/mono_test.go | 38 +++++ 27 files changed, 612 insertions(+), 118 deletions(-) rename internal/common/{misc.go => common.go} (100%) create mode 100644 internal/common/errors_test.go create mode 100644 internal/common/rand_test.go create mode 100644 internal/common/u32map.go create mode 100644 internal/common/u32map_test.go create mode 100644 internal/socket/misc_test.go diff --git a/fuzz.go b/fuzz.go index 9e70722..f7ad744 100644 --- a/fuzz.go +++ b/fuzz.go @@ -6,7 +6,6 @@ package rsocket import ( "bytes" "errors" - "fmt" "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/internal/framing" @@ -20,12 +19,16 @@ func Fuzz(data []byte) int { if err == nil { err = handleRaw(raw) } - if err == nil || err == common.ErrInvalidFrame || err == transport.ErrIncompleteHeader { + if err == nil || isExpectedError(err) { return 0 } return 1 } +func isExpectedError(err error) bool { + return err == common.ErrInvalidFrame || err == transport.ErrIncompleteHeader +} + func handleRaw(raw []byte) (err error) { h := framing.ParseFrameHeader(raw) bf := common.NewByteBuff() @@ -38,21 +41,9 @@ func handleRaw(raw []byte) (err error) { if err != nil { return } - - switch f := frame.(type) { - case fmt.Stringer: - s := f.String() - if len(s) > 0 { - return - } - case error: - e := f.Error() - if len(e) > 0 { - return - } - default: - panic("unreachable") + if frame.Len() >= framing.HeaderLen { + return } - - return errors.New("???") + err = errors.New("broken frame") + return } diff --git a/internal/common/misc.go b/internal/common/common.go similarity index 100% rename from internal/common/misc.go rename to internal/common/common.go diff --git a/internal/common/errors_test.go b/internal/common/errors_test.go new file mode 100644 index 0000000..a29ea86 --- /dev/null +++ b/internal/common/errors_test.go @@ -0,0 +1,28 @@ +package common_test + +import ( + "math" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestErrorCode_String(t *testing.T) { + all := []common.ErrorCode{ + common.ErrorCodeInvalidSetup, + common.ErrorCodeUnsupportedSetup, + common.ErrorCodeRejectedSetup, + common.ErrorCodeRejectedResume, + common.ErrorCodeConnectionError, + common.ErrorCodeConnectionClose, + common.ErrorCodeApplicationError, + common.ErrorCodeRejected, + common.ErrorCodeCanceled, + common.ErrorCodeInvalid, + } + for _, code := range all { + assert.NotEqual(t, "UNKNOWN", code.String()) + } + assert.Equal(t, "UNKNOWN", common.ErrorCode(math.MaxUint32).String()) +} diff --git a/internal/common/rand_test.go b/internal/common/rand_test.go new file mode 100644 index 0000000..402ea60 --- /dev/null +++ b/internal/common/rand_test.go @@ -0,0 +1,35 @@ +package common_test + +import ( + "regexp" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestRandAlphanumeric(t *testing.T) { + s := common.RandAlphanumeric(10) + r := regexp.MustCompile("^[a-zA-Z0-9]{10}$") + assert.True(t, r.MatchString(s)) + s = common.RandAlphanumeric(0) + assert.Empty(t, s) +} + +func TestRandAlphabetic(t *testing.T) { + s := common.RandAlphabetic(10) + r := regexp.MustCompile("^[a-zA-Z]{10}$") + assert.True(t, r.MatchString(s)) + s = common.RandAlphabetic(0) + assert.Empty(t, s) +} + +func TestRandFloat64(t *testing.T) { + f := common.RandFloat64() + assert.True(t, f < 1 && f > 0) +} + +func TestRandIntn(t *testing.T) { + n := common.RandIntn(10) + assert.True(t, n >= 0 && n < 10) +} diff --git a/internal/common/u32map.go b/internal/common/u32map.go new file mode 100644 index 0000000..b35bd75 --- /dev/null +++ b/internal/common/u32map.go @@ -0,0 +1,120 @@ +package common + +import ( + "sync" +) + +const _slots = 2 * 2 * 2 * 2 + +type U32Map interface { + Clear() + Range(fn func(k uint32, v interface{}) bool) + Load(key uint32) (v interface{}, ok bool) + Store(key uint32, value interface{}) + Delete(key uint32) +} + +type u32map struct { + slots [_slots]*u32slot +} + +func (u *u32map) Clear() { + for _, slot := range u.slots { + slot.Clear() + } +} + +func (u *u32map) Range(fn func(k uint32, v interface{}) bool) { + for _, slot := range u.slots { + if !slot.innerRange(fn) { + return + } + } +} + +func (u *u32map) Load(key uint32) (v interface{}, ok bool) { + return u.seek(key).Load(key) +} + +func (u *u32map) Store(key uint32, value interface{}) { + u.seek(key).Store(key, value) +} + +func (u *u32map) Delete(key uint32) { + u.seek(key).Delete(key) +} + +func (u *u32map) seek(key uint32) *u32slot { + k := key & (_slots - 1) + return u.slots[k] +} + +type u32slot struct { + k sync.RWMutex + m map[uint32]interface{} +} + +func (u *u32slot) Clear() { + if u == nil || u.m == nil { + return + } + u.k.Lock() + u.m = nil + u.k.Unlock() +} + +func (u *u32slot) Range(fn func(k uint32, v interface{}) bool) { + u.innerRange(fn) +} + +func (u *u32slot) Load(key uint32) (v interface{}, ok bool) { + if u == nil || u.m == nil { + return + } + u.k.RLock() + v, ok = u.m[key] + u.k.RUnlock() + return +} + +func (u *u32slot) Store(key uint32, value interface{}) { + if u == nil || u.m == nil { + return + } + u.k.Lock() + u.m[key] = value + u.k.Unlock() +} + +func (u *u32slot) Delete(key uint32) { + if u == nil || u.m == nil { + return + } + u.k.Lock() + delete(u.m, key) + u.k.Unlock() +} + +func (u *u32slot) innerRange(fn func(k uint32, v interface{}) bool) bool { + if u == nil || u.m == nil { + return false + } + u.k.RLock() + defer u.k.RUnlock() + for key, value := range u.m { + if !fn(key, value) { + return false + } + } + return true +} + +func NewU32Map() U32Map { + var slots [_slots]*u32slot + for i := 0; i < len(slots); i++ { + slots[i] = &u32slot{ + m: make(map[uint32]interface{}), + } + } + return &u32map{slots: slots} +} diff --git a/internal/common/u32map_test.go b/internal/common/u32map_test.go new file mode 100644 index 0000000..a49a458 --- /dev/null +++ b/internal/common/u32map_test.go @@ -0,0 +1,63 @@ +package common_test + +import ( + "sort" + "sync/atomic" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestU32map(t *testing.T) { + var keys []int + value := common.RandAlphanumeric(10) + m := common.NewU32Map() + for i := uint32(0); i < 10; i++ { + m.Store(i, value) + keys = append(keys, int(i)) + } + v, ok := m.Load(1) + assert.True(t, ok, "key not found") + assert.Equal(t, value, v, "value doesn't match") + + _, ok = m.Load(10) + assert.False(t, ok, "key should not exist") + + var keys2 []int + m.Range(func(k uint32, _ interface{}) bool { + keys2 = append(keys2, int(k)) + return true + }) + sort.Ints(keys) + sort.Ints(keys2) + assert.Equal(t, keys, keys2, "keys doesn't match") + + m.Delete(1) + _, ok = m.Load(1) + assert.False(t, ok, "key should be deleted") + + var c int + m.Range(func(k uint32, v interface{}) bool { + c++ + return false + }) + assert.Equal(t, 1, c, "should be 1") + + m.Clear() + _, ok = m.Load(2) + assert.False(t, ok, "should be closed already") +} + +func BenchmarkU32Map(b *testing.B) { + const value = "foobar" + m := common.NewU32Map() + next := uint32(0) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Store(atomic.AddUint32(&next, 1)-1, value) + } + }) +} diff --git a/internal/common/version_test.go b/internal/common/version_test.go index f9ad917..9bde3ef 100644 --- a/internal/common/version_test.go +++ b/internal/common/version_test.go @@ -13,7 +13,7 @@ func BenchmarkVersion_String(b *testing.B) { v := common.NewVersion(2, 3) b.ResetTimer() for i := 0; i < b.N; i++ { - v.String() + _ = v.String() } } diff --git a/internal/framing/frame_cancel.go b/internal/framing/frame_cancel.go index 1849440..6d2eb0c 100644 --- a/internal/framing/frame_cancel.go +++ b/internal/framing/frame_cancel.go @@ -13,6 +13,10 @@ type FrameCancel struct { // Validate returns error if frame is invalid. func (p *FrameCancel) Validate() (err error) { + // Cancel frame doesn't need any binary body. + if p.body != nil && p.body.Len() > 0 { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_error.go b/internal/framing/frame_error.go index 19812bc..494e838 100644 --- a/internal/framing/frame_error.go +++ b/internal/framing/frame_error.go @@ -3,6 +3,7 @@ package framing import ( "encoding/binary" "fmt" + "strings" "github.com/rsocket/rsocket-go/internal/common" ) @@ -24,14 +25,19 @@ func (p *FrameError) String() string { // Validate returns error if frame is invalid. func (p *FrameError) Validate() (err error) { - if p.Len() < minErrorFrameLen { + if p.body.Len() < minErrorFrameLen { err = errIncompleteFrame } return } func (p *FrameError) Error() string { - return fmt.Sprintf("%s: %s", p.ErrorCode(), string(p.ErrorData())) + bu := strings.Builder{} + bu.WriteString(p.ErrorCode().String()) + bu.WriteByte(':') + bu.WriteByte(' ') + bu.Write(p.ErrorData()) + return bu.String() } // ErrorCode returns error code. diff --git a/internal/framing/frame_fnf.go b/internal/framing/frame_fnf.go index 9b68584..fad43ad 100644 --- a/internal/framing/frame_fnf.go +++ b/internal/framing/frame_fnf.go @@ -13,6 +13,9 @@ type FrameFNF struct { // Validate returns error if frame is invalid. func (p *FrameFNF) Validate() (err error) { + if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_lease.go b/internal/framing/frame_lease.go index 9b1c240..9dd8d54 100644 --- a/internal/framing/frame_lease.go +++ b/internal/framing/frame_lease.go @@ -63,6 +63,9 @@ func NewFrameLease(ttl time.Duration, n uint32, metadata []byte) *FrameLease { var fg FrameFlag if len(metadata) > 0 { fg |= FlagMetadata + if _, err := bf.Write(metadata); err != nil { + panic(err) + } } return &FrameLease{NewBaseFrame(NewFrameHeader(0, FrameTypeLease, fg), bf)} } diff --git a/internal/framing/frame_payload.go b/internal/framing/frame_payload.go index 531b0c9..11ab00d 100644 --- a/internal/framing/frame_payload.go +++ b/internal/framing/frame_payload.go @@ -13,6 +13,10 @@ type FramePayload struct { // Validate returns error if frame is invalid. func (p *FramePayload) Validate() (err error) { + // Minimal length should be 3 if metadata exists. + if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_request_channel.go b/internal/framing/frame_request_channel.go index 9e25cea..7a8449b 100644 --- a/internal/framing/frame_request_channel.go +++ b/internal/framing/frame_request_channel.go @@ -18,11 +18,15 @@ type FrameRequestChannel struct { } // Validate returns error if frame is invalid. -func (p *FrameRequestChannel) Validate() (err error) { - if p.body.Len() < minRequestChannelFrameLen { - err = errIncompleteFrame +func (p *FrameRequestChannel) Validate() error { + l := p.body.Len() + if l < minRequestChannelFrameLen { + return errIncompleteFrame } - return + if p.header.Flag().Check(FlagMetadata) && l < minRequestChannelFrameLen+3 { + return errIncompleteFrame + } + return nil } func (p *FrameRequestChannel) String() string { diff --git a/internal/framing/frame_request_n.go b/internal/framing/frame_request_n.go index 819c2b8..713f95b 100644 --- a/internal/framing/frame_request_n.go +++ b/internal/framing/frame_request_n.go @@ -7,11 +7,6 @@ import ( "github.com/rsocket/rsocket-go/internal/common" ) -const ( - reqNLen = 4 - minRequestNFrameLen = reqNLen -) - // FrameRequestN is RequestN frame. type FrameRequestN struct { *BaseFrame @@ -19,7 +14,7 @@ type FrameRequestN struct { // Validate returns error if frame is invalid. func (p *FrameRequestN) Validate() (err error) { - if p.body.Len() < minRequestNFrameLen { + if p.body.Len() != 4 { err = errIncompleteFrame } return diff --git a/internal/framing/frame_request_response.go b/internal/framing/frame_request_response.go index fbfb508..1d1738f 100644 --- a/internal/framing/frame_request_response.go +++ b/internal/framing/frame_request_response.go @@ -13,6 +13,9 @@ type FrameRequestResponse struct { // Validate returns error if frame is invalid. func (p *FrameRequestResponse) Validate() (err error) { + if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_request_stream.go b/internal/framing/frame_request_stream.go index a274d81..3a8d8bb 100644 --- a/internal/framing/frame_request_stream.go +++ b/internal/framing/frame_request_stream.go @@ -17,11 +17,15 @@ type FrameRequestStream struct { } // Validate returns error if frame is invalid. -func (p *FrameRequestStream) Validate() (err error) { - if p.body.Len() < minRequestStreamFrameLen { - err = errIncompleteFrame +func (p *FrameRequestStream) Validate() error { + l := p.body.Len() + if l < minRequestStreamFrameLen { + return errIncompleteFrame } - return + if p.header.Flag().Check(FlagMetadata) && l < minRequestStreamFrameLen+3 { + return errIncompleteFrame + } + return nil } func (p *FrameRequestStream) String() string { diff --git a/internal/framing/frame_resume.go b/internal/framing/frame_resume.go index 1b2fe77..766fcb0 100644 --- a/internal/framing/frame_resume.go +++ b/internal/framing/frame_resume.go @@ -11,6 +11,14 @@ import ( var errResumeTokenTooLarge = errors.New("max length of resume token is 65535") +const ( + _lenVersion = 4 + _lenTokenLength = 2 + _lenLastRecvPos = 8 + _lenFirstPos = 8 + _minResumeLength = _lenVersion + _lenTokenLength + _lenLastRecvPos + _lenFirstPos +) + // FrameResume represents a frame of Resume. type FrameResume struct { *BaseFrame @@ -25,6 +33,9 @@ func (p *FrameResume) String() string { // Validate validate current frame. func (p *FrameResume) Validate() (err error) { + if p.body.Len() < _minResumeLength { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_resume_ok.go b/internal/framing/frame_resume_ok.go index 32125d6..b1bb845 100644 --- a/internal/framing/frame_resume_ok.go +++ b/internal/framing/frame_resume_ok.go @@ -18,6 +18,10 @@ func (p *FrameResumeOK) String() string { // Validate validate current frame. func (p *FrameResumeOK) Validate() (err error) { + // Length of frame body should be 8 + if p.body.Len() != 8 { + err = errIncompleteFrame + } return } diff --git a/internal/framing/frame_setup.go b/internal/framing/frame_setup.go index bcac659..6244c09 100644 --- a/internal/framing/frame_setup.go +++ b/internal/framing/frame_setup.go @@ -9,12 +9,12 @@ import ( ) const ( - versionLen = 4 - timeLen = 4 - tokenLen = 2 - metadataLen = 1 - dataLen = 1 - minSetupFrameLen = versionLen + timeLen*2 + tokenLen + metadataLen + dataLen + _versionLen = 4 + _timeLen = 4 + _tokenLen = 2 + _metadataLen = 1 + _dataLen = 1 + _minSetupFrameLen = _versionLen + _timeLen*2 + _tokenLen + _metadataLen + _dataLen ) // FrameSetup is sent by client to initiate protocol processing. @@ -24,7 +24,7 @@ type FrameSetup struct { // Validate returns error if frame is invalid. func (p *FrameSetup) Validate() (err error) { - if p.Len() < minSetupFrameLen { + if p.Len() < _minSetupFrameLen { err = errIncompleteFrame } return diff --git a/internal/framing/frame_test.go b/internal/framing/frame_test.go index c5806ac..d0af3f4 100644 --- a/internal/framing/frame_test.go +++ b/internal/framing/frame_test.go @@ -1,15 +1,182 @@ -package framing +package framing_test import ( "encoding/hex" "log" + "math" "testing" "time" "github.com/rsocket/rsocket-go/internal/common" + . "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) +const _sid uint32 = 1 + +func TestFrameCancel(t *testing.T) { + f := NewFrameCancel(_sid) + basicCheck(t, f, FrameTypeCancel) +} + +func TestFrameError(t *testing.T) { + errData := []byte(common.RandAlphanumeric(100)) + f := NewFrameError(_sid, common.ErrorCodeApplicationError, errData) + basicCheck(t, f, FrameTypeError) + assert.Equal(t, common.ErrorCodeApplicationError, f.ErrorCode()) + assert.Equal(t, errData, f.ErrorData()) + assert.NotEmpty(t, f.Error()) +} + +func TestFrameFNF(t *testing.T) { + b := []byte(common.RandAlphanumeric(100)) + // Without Metadata + f := NewFrameFNF(_sid, b, nil, FlagNext) + basicCheck(t, f, FrameTypeRequestFNF) + assert.Equal(t, b, f.Data()) + metadata, ok := f.Metadata() + assert.False(t, ok) + assert.Nil(t, metadata) + assert.True(t, f.Header().Flag().Check(FlagNext)) + assert.False(t, f.Header().Flag().Check(FlagMetadata)) + // With Metadata + f = NewFrameFNF(_sid, nil, b, FlagNext) + basicCheck(t, f, FrameTypeRequestFNF) + assert.Empty(t, f.Data()) + metadata, ok = f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, metadata) + assert.True(t, f.Header().Flag().Check(FlagNext)) + assert.True(t, f.Header().Flag().Check(FlagMetadata)) +} + +func TestFrameKeepalive(t *testing.T) { + pos := uint64(common.RandIntn(math.MaxInt32)) + d := []byte(common.RandAlphanumeric(100)) + f := NewFrameKeepalive(pos, d, true) + basicCheck(t, f, FrameTypeKeepalive) + assert.Equal(t, d, f.Data()) + assert.Equal(t, pos, f.LastReceivedPosition()) + assert.True(t, f.Header().Flag().Check(FlagRespond)) +} + +func TestFrameLease(t *testing.T) { + metadata := []byte("foobar") + n := uint32(4444) + f := NewFrameLease(time.Second, n, metadata) + basicCheck(t, f, FrameTypeLease) + assert.Equal(t, time.Second, f.TimeToLive()) + assert.Equal(t, n, f.NumberOfRequests()) + assert.Equal(t, metadata, f.Metadata()) +} + +func TestFrameMetadataPush(t *testing.T) { + metadata := []byte("foobar") + f := NewFrameMetadataPush(metadata) + basicCheck(t, f, FrameTypeMetadataPush) + metadata2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, metadata, metadata2) +} + +func TestFramePayload(t *testing.T) { + b := []byte("foobar") + f := NewFramePayload(_sid, b, b, FlagNext) + basicCheck(t, f, FrameTypePayload) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, f.Data()) + assert.Equal(t, b, m) + assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) +} + +func TestFrameRequestChannel(t *testing.T) { + b := []byte("foobar") + n := uint32(1) + f := NewFrameRequestChannel(_sid, n, b, b, FlagNext) + basicCheck(t, f, FrameTypeRequestChannel) + assert.Equal(t, n, f.InitialRequestN()) + assert.Equal(t, b, f.Data()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) +} + +func TestFrameRequestN(t *testing.T) { + n := uint32(1234) + f := NewFrameRequestN(_sid, n) + basicCheck(t, f, FrameTypeRequestN) + assert.Equal(t, n, f.N()) +} + +func TestFrameRequestResponse(t *testing.T) { + b := []byte("foobar") + f := NewFrameRequestResponse(_sid, b, b, FlagNext) + basicCheck(t, f, FrameTypeRequestResponse) + assert.Equal(t, b, f.Data()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) + assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) +} + +func TestFrameRequestStream(t *testing.T) { + b := []byte("foobar") + n := uint32(1234) + f := NewFrameRequestStream(_sid, n, b, b, FlagNext) + basicCheck(t, f, FrameTypeRequestStream) + assert.Equal(t, b, f.Data()) + assert.Equal(t, n, f.InitialRequestN()) + m, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m) +} + +func TestFrameResume(t *testing.T) { + v := common.NewVersion(3, 1) + token := []byte("hello") + p1 := uint64(333) + p2 := uint64(444) + f := NewFrameResume(v, token, p1, p2) + basicCheck(t, f, FrameTypeResume) + assert.Equal(t, token, f.Token()) + assert.Equal(t, p1, f.FirstAvailableClientPosition()) + assert.Equal(t, p2, f.LastReceivedServerPosition()) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) +} + +func TestFrameResumeOK(t *testing.T) { + pos := uint64(1234) + f := NewResumeOK(pos) + basicCheck(t, f, FrameTypeResumeOK) + assert.Equal(t, pos, f.LastReceivedClientPosition()) +} + +func TestFrameSetup(t *testing.T) { + v := common.NewVersion(3, 1) + timeKeepalive := 30 * time.Second + maxLifetime := 3 * timeKeepalive + token := []byte("hello") + mimeData := []byte("application/json") + mimeMetadata := []byte("text/plain") + d := []byte(`{"hello":"world"}`) + m := []byte("foobar") + f := NewFrameSetup(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + basicCheck(t, f, FrameTypeSetup) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) + assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) + assert.Equal(t, maxLifetime, f.MaxLifetime()) + assert.Equal(t, token, f.Token()) + assert.Equal(t, string(mimeData), f.DataMimeType()) + assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) + assert.Equal(t, d, f.Data()) + m2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, m, m2) +} + func TestDecode_Payload(t *testing.T) { //s := "000000012940000005776f726c6468656c6c6f" // go //s := "00000001296000000966726f6d5f6a617661706f6e67" //java @@ -37,3 +204,15 @@ func TestDecode_Payload(t *testing.T) { log.Println("actual:", hex.EncodeToString(lease.Bytes())) log.Println("should: 00000000090000000bb800000005") } + +func basicCheck(t *testing.T, f Frame, typ FrameType) { + sid := _sid + switch typ { + case FrameTypeKeepalive, FrameTypeSetup, FrameTypeLease, FrameTypeResume, FrameTypeResumeOK, FrameTypeMetadataPush: + sid = 0 + } + assert.Equal(t, sid, f.Header().StreamID(), "wrong frame stream id") + assert.NoError(t, f.Validate(), "validate frame type failed") + assert.Equal(t, typ, f.Header().Type(), "frame type doesn't match") + assert.NotEmpty(t, f.String(), "empty frame string") +} diff --git a/internal/framing/header.go b/internal/framing/header.go index 925f0e1..2b9a519 100644 --- a/internal/framing/header.go +++ b/internal/framing/header.go @@ -2,8 +2,9 @@ package framing import ( "encoding/binary" - "fmt" "io" + "strconv" + "strings" ) const ( @@ -17,7 +18,15 @@ const ( type FrameHeader [HeaderLen]byte func (p FrameHeader) String() string { - return fmt.Sprintf("FrameHeader{id=%d,type=%s,flag=%s}", p.StreamID(), p.Type(), p.Flag()) + bu := strings.Builder{} + bu.WriteString("FrameHeader{id=") + bu.WriteString(strconv.FormatUint(uint64(p.StreamID()), 10)) + bu.WriteString(",type=") + bu.WriteString(p.Type().String()) + bu.WriteString(",flag=") + bu.WriteString(p.Flag().String()) + bu.WriteByte('}') + return bu.String() } // WriteTo writes frame header to a writer. diff --git a/internal/framing/header_test.go b/internal/framing/header_test.go index 44ee867..6032965 100644 --- a/internal/framing/header_test.go +++ b/internal/framing/header_test.go @@ -1,20 +1,27 @@ -package framing +package framing_test import ( - "fmt" + "bytes" + "math" "testing" + "github.com/rsocket/rsocket-go/internal/common" + . "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) func TestHeader_All(t *testing.T) { - h1 := NewFrameHeader(134, FrameTypePayload, FlagMetadata|FlagComplete|FlagNext) - h := ParseFrameHeader(h1[:]) - assert.Equal(t, h1.StreamID(), h.StreamID()) - assert.Equal(t, h1.Type(), h.Type()) - assert.Equal(t, h1.Flag(), h.Flag()) - - fmt.Println("streamID:", h.StreamID()) - fmt.Println("type:", h.Type()) - fmt.Println("flag:", h.Flag()) + id := uint32(common.RandIntn(math.MaxUint32)) + h1 := NewFrameHeader(id, FrameTypePayload, FlagMetadata|FlagComplete|FlagNext) + assert.NotEmpty(t, h1.String()) + h2 := ParseFrameHeader(h1[:]) + assert.Equal(t, h1.StreamID(), h2.StreamID()) + assert.Equal(t, h1.Type(), h2.Type()) + assert.Equal(t, h1.Flag(), h2.Flag()) + assert.Equal(t, FrameTypePayload, h1.Type()) + assert.Equal(t, FlagMetadata|FlagComplete|FlagNext, h1.Flag()) + bf := &bytes.Buffer{} + n, err := h2.WriteTo(bf) + assert.NoError(t, err) + assert.Equal(t, int64(HeaderLen), n) } diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index d9ca4d9..60a9ebc 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -44,10 +44,10 @@ type DuplexRSocket struct { outs chan framing.Frame outsPriority []framing.Frame responder Responder - messages *u32map + messages common.U32Map sids StreamID mtu int - fragments *u32map // key=streamID, value=Joiner + fragments common.U32Map // key=streamID, value=Joiner closed *atomic.Bool done chan struct{} keepaliver *keepaliver @@ -91,7 +91,7 @@ func (p *DuplexRSocket) Close() error { p.cond.Broadcast() p.cond.L.Unlock() - _ = p.fragments.Close() + p.fragments.Clear() <-p.done if p.tp != nil { @@ -105,7 +105,7 @@ func (p *DuplexRSocket) Close() error { p.fragments.Range(func(key uint32, value interface{}) bool { return true }) - _ = p.fragments.Close() + p.fragments.Clear() p.messages.Range(func(key uint32, value interface{}) bool { if cc, ok := value.(closerWithError); ok { @@ -121,7 +121,7 @@ func (p *DuplexRSocket) Close() error { } return true }) - _ = p.messages.Close() + p.messages.Clear() return p.e } @@ -1116,9 +1116,9 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { leases: leases, outs: make(chan framing.Frame, outsSize), mtu: mtu, - messages: newU32Map(), + messages: common.NewU32Map(), sids: &serverStreamIDs{}, - fragments: newU32Map(), + fragments: common.NewU32Map(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: transport.NewCounter(), @@ -1136,9 +1136,9 @@ func NewClientDuplexRSocket( closed: atomic.NewBool(false), outs: make(chan framing.Frame, outsSize), mtu: mtu, - messages: newU32Map(), + messages: common.NewU32Map(), sids: &clientStreamIDs{}, - fragments: newU32Map(), + fragments: common.NewU32Map(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: transport.NewCounter(), diff --git a/internal/socket/misc.go b/internal/socket/misc.go index 1d634ef..6bcdb8a 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -1,7 +1,6 @@ package socket import ( - "sync" "time" "github.com/pkg/errors" @@ -10,55 +9,6 @@ import ( "github.com/rsocket/rsocket-go/rx" ) -type u32map struct { - k sync.RWMutex - m map[uint32]interface{} -} - -func (p *u32map) Close() error { - p.k.Lock() - p.m = nil - p.k.Unlock() - return nil -} - -func (p *u32map) Range(fn func(uint32, interface{}) bool) { - p.k.RLock() - for key, value := range p.m { - if !fn(key, value) { - break - } - } - p.k.RUnlock() -} - -func (p *u32map) Load(key uint32) (v interface{}, ok bool) { - p.k.RLock() - v, ok = p.m[key] - p.k.RUnlock() - return -} - -func (p *u32map) Store(key uint32, value interface{}) { - p.k.Lock() - if p.m != nil { - p.m[key] = value - } - p.k.Unlock() -} - -func (p *u32map) Delete(key uint32) { - p.k.Lock() - delete(p.m, key) - p.k.Unlock() -} - -func newU32Map() *u32map { - return &u32map{ - m: make(map[uint32]interface{}), - } -} - // SetupInfo represents basic info of setup. type SetupInfo struct { Lease bool diff --git a/internal/socket/misc_test.go b/internal/socket/misc_test.go new file mode 100644 index 0000000..69513dc --- /dev/null +++ b/internal/socket/misc_test.go @@ -0,0 +1 @@ +package socket diff --git a/internal/socket/stream_id_test.go b/internal/socket/stream_id_test.go index 7f44e08..75f455a 100644 --- a/internal/socket/stream_id_test.go +++ b/internal/socket/stream_id_test.go @@ -1,7 +1,39 @@ package socket -import "testing" +import ( + "testing" -func TestSt(t *testing.T) { - + "github.com/stretchr/testify/assert" +) + +func TestClientStreamIDs_Next(t *testing.T) { + ids := clientStreamIDs{} + id, firstLap := ids.Next() + assert.Equal(t, uint32(1), id) + assert.True(t, firstLap) +} + +func TestServerStreamIDs_Next(t *testing.T) { + ids := serverStreamIDs{} + id, firstLap := ids.Next() + assert.Equal(t, uint32(2), id) + assert.True(t, firstLap) +} + +func BenchmarkServerStreamIDs_Next(b *testing.B) { + ids := serverStreamIDs{} + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ids.Next() + } + }) +} + +func BenchmarkClientStreamIDs_Next(b *testing.B) { + ids := clientStreamIDs{} + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + ids.Next() + } + }) } diff --git a/rx/mono/mono_test.go b/rx/mono/mono_test.go index f6a79ff..008d383 100644 --- a/rx/mono/mono_test.go +++ b/rx/mono/mono_test.go @@ -13,8 +13,40 @@ import ( "github.com/rsocket/rsocket-go/rx" . "github.com/rsocket/rsocket-go/rx/mono" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func TestProxy_Error(t *testing.T) { + originErr := errors.New("error testing") + errCount := atomic.NewInt32(0) + _, err := Error(originErr). + DoOnError(func(e error) { + assert.Equal(t, originErr, e, "bad error") + errCount.Inc() + }). + Block(context.Background()) + assert.Error(t, err, "should got error") + assert.Equal(t, originErr, err, "bad blocked error") + assert.Equal(t, int32(1), errCount.Load(), "error count should be 1") +} + +func TestEmpty(t *testing.T) { + res, err := Empty().Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.Nil(t, res, "result should be nil") +} + +func TestJustOrEmpty(t *testing.T) { + // Give normal payload + res, err := JustOrEmpty(payload.NewString("hello", "world")).Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.NotNil(t, res, "result should not be nil") + // Give nil payload + res, err = JustOrEmpty(nil).Block(context.Background()) + assert.NoError(t, err, "an error occurred") + assert.Nil(t, res, "result should be nil") +} + func TestJust(t *testing.T) { Just(payload.NewString("hello", "world")). Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) { @@ -22,6 +54,12 @@ func TestJust(t *testing.T) { })) } +func TestMono_Raw(t *testing.T) { + + Just(payload.NewString("hello", "world")).Raw() + +} + func TestProxy_SubscribeOn(t *testing.T) { v, err := Create(func(i context.Context, sink Sink) { time.AfterFunc(time.Second, func() { From d92b1b4b600739ec099a921bc61ce156347f046c Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Fri, 19 Jun 2020 23:28:47 +0800 Subject: [PATCH 04/26] Tuning for reducing memory cost. --- fuzz.go | 2 +- internal/fragmentation/fragmentation.go | 4 +- internal/fragmentation/joiner.go | 2 +- internal/fragmentation/joiner_test.go | 10 +- internal/fragmentation/splitter_test.go | 8 +- internal/framing/frame.go | 144 +++++++-------- internal/framing/frame_cancel.go | 47 +++-- internal/framing/frame_error.go | 82 ++++++--- internal/framing/frame_fnf.go | 83 ++++++--- internal/framing/frame_keepalive.go | 81 +++++++-- internal/framing/frame_lease.go | 110 +++++++++--- internal/framing/frame_metadata_push.go | 70 +++++--- internal/framing/frame_payload.go | 80 ++++++--- internal/framing/frame_payload_test.go | 34 ---- internal/framing/frame_request_channel.go | 102 ++++++++--- internal/framing/frame_request_n.go | 60 +++++-- internal/framing/frame_request_response.go | 78 +++++--- internal/framing/frame_request_stream.go | 105 +++++++---- internal/framing/frame_resume.go | 116 +++++++++--- internal/framing/frame_resume_ok.go | 60 +++++-- internal/framing/frame_setup.go | 197 ++++++++++++++++----- internal/framing/frame_setup_test.go | 2 +- internal/framing/frame_test.go | 120 +++++++++---- internal/framing/header.go | 53 +++--- internal/framing/misc.go | 63 +++++-- internal/socket/client_default.go | 4 +- internal/socket/client_resume.go | 8 +- internal/socket/duplex.go | 141 +++++++-------- internal/socket/misc.go | 4 +- internal/transport/connection.go | 2 +- internal/transport/connection_tcp.go | 12 +- internal/transport/connection_ws.go | 39 ++-- internal/transport/decoder_test.go | 2 +- internal/transport/transport.go | 16 +- server.go | 37 ++-- 35 files changed, 1332 insertions(+), 646 deletions(-) delete mode 100644 internal/framing/frame_payload_test.go diff --git a/fuzz.go b/fuzz.go index f7ad744..a62f23c 100644 --- a/fuzz.go +++ b/fuzz.go @@ -33,7 +33,7 @@ func handleRaw(raw []byte) (err error) { h := framing.ParseFrameHeader(raw) bf := common.NewByteBuff() var frame framing.Frame - frame, err = framing.NewFromBase(framing.NewBaseFrame(h, bf)) + frame, err = framing.FromRawFrame(framing.NewRawFrame(h, bf)) if err != nil { return } diff --git a/internal/fragmentation/fragmentation.go b/internal/fragmentation/fragmentation.go index 058a1b7..bcd88dd 100644 --- a/internal/fragmentation/fragmentation.go +++ b/internal/fragmentation/fragmentation.go @@ -18,11 +18,11 @@ const ( var errInvalidFragmentLen = fmt.Errorf("invalid fragment: [%d,%d]", MinFragment, MaxFragment) -// HeaderAndPayload is Payload which having a FrameHeader. +// HeaderAndPayload is Payload which having a Header. type HeaderAndPayload interface { payload.Payload // Header returns a header of frame. - Header() framing.FrameHeader + Header() framing.Header } // Joiner is used to join frames to a payload. diff --git a/internal/fragmentation/joiner.go b/internal/fragmentation/joiner.go index 92cf778..cd208e1 100644 --- a/internal/fragmentation/joiner.go +++ b/internal/fragmentation/joiner.go @@ -22,7 +22,7 @@ func (p *implJoiner) First() framing.Frame { return first.Value.(framing.Frame) } -func (p *implJoiner) Header() framing.FrameHeader { +func (p *implJoiner) Header() framing.Header { return p.First().Header() } diff --git a/internal/fragmentation/joiner_test.go b/internal/fragmentation/joiner_test.go index ffb121f..005217f 100644 --- a/internal/fragmentation/joiner_test.go +++ b/internal/fragmentation/joiner_test.go @@ -11,17 +11,17 @@ import ( func TestFragmentPayload(t *testing.T) { const totals = 10 const sid = uint32(1) - fr := NewJoiner(framing.NewFramePayload(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow, framing.FlagMetadata)) + fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow, framing.FlagMetadata)) for i := 0; i < totals; i++ { data := fmt.Sprintf("(data%04d)", i) - var frame *framing.FramePayload + var frame *framing.PayloadFrame if i < 3 { meta := fmt.Sprintf("(meta%04d)", i) - frame = framing.NewFramePayload(sid, []byte(data), []byte(meta), framing.FlagFollow, framing.FlagMetadata) + frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), framing.FlagFollow, framing.FlagMetadata) } else if i != totals-1 { - frame = framing.NewFramePayload(sid, []byte(data), nil, framing.FlagFollow) + frame = framing.NewPayloadFrame(sid, []byte(data), nil, framing.FlagFollow) } else { - frame = framing.NewFramePayload(sid, []byte(data), nil) + frame = framing.NewPayloadFrame(sid, []byte(data), nil) } fr.Push(frame) } diff --git a/internal/fragmentation/splitter_test.go b/internal/fragmentation/splitter_test.go index e45725b..22f2ff1 100644 --- a/internal/fragmentation/splitter_test.go +++ b/internal/fragmentation/splitter_test.go @@ -26,13 +26,13 @@ func split2joiner(mtu int, data, metadata []byte) (joiner Joiner, err error) { fn := func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { if idx == 0 { h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, framing.FlagComplete|fg) - joiner = NewJoiner(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + joiner = NewJoiner(&framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), }) } else { h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, fg) - joiner.Push(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + joiner.Push(&framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), }) } } diff --git a/internal/framing/frame.go b/internal/framing/frame.go index 3cb297e..7079eb0 100644 --- a/internal/framing/frame.go +++ b/internal/framing/frame.go @@ -2,7 +2,6 @@ package framing import ( "errors" - "fmt" "io" "strings" @@ -122,99 +121,79 @@ func newFlags(flags ...FrameFlag) FrameFlag { return fg } -// Frame is a single message containing a request, response, or protocol processing. -type Frame interface { - fmt.Stringer +type FrameSupport interface { io.WriterTo - // Header returns frame FrameHeader. - Header() FrameHeader - // Body returns body of frame. - Body() *common.ByteBuff + // Header returns frame Header. + Header() Header // Len returns length of frame. Len() int - // Validate returns error if frame is invalid. - Validate() error - // SetHeader set frame header. - SetHeader(h FrameHeader) - // SetBody set frame body. - SetBody(body *common.ByteBuff) - // Bytes encodes and returns frame in bytes. - Bytes() []byte - // CanResume returns true if frame supports resume. - CanResume() bool // Done marks current frame has been sent. Done() (closed bool) // DoneNotify notifies when frame done. DoneNotify() <-chan struct{} } -// BaseFrame is basic frame implementation. -type BaseFrame struct { - header FrameHeader - body *common.ByteBuff +func PrintFrame(f FrameSupport) string { + return "// TODO: print frame" +} + +// Frame is a single message containing a request, response, or protocol processing. +type Frame interface { + FrameSupport + // Validate returns error if frame is invalid. + Validate() error +} + +type tinyFrame struct { + header Header done chan struct{} } +func (t *tinyFrame) Header() Header { + return t.header +} + // Done can be invoked when a frame has been been processed. -func (p *BaseFrame) Done() (closed bool) { +func (t *tinyFrame) Done() (closed bool) { defer func() { if e := recover(); e != nil { closed = true } }() - close(p.done) + close(t.done) return } // DoneNotify notify when frame has been done. -func (p *BaseFrame) DoneNotify() <-chan struct{} { - return p.done -} - -// CanResume returns true if frame supports resume. -func (p *BaseFrame) CanResume() bool { - switch p.header.Type() { - case FrameTypeRequestChannel, FrameTypeRequestStream, FrameTypeRequestResponse, FrameTypeRequestFNF, FrameTypeRequestN, FrameTypeCancel, FrameTypeError, FrameTypePayload: - return true - default: - return false - } +func (t *tinyFrame) DoneNotify() <-chan struct{} { + return t.done } -// SetBody set frame body. -func (p *BaseFrame) SetBody(body *common.ByteBuff) { - p.body = body +// RawFrame is basic frame implementation. +type RawFrame struct { + *tinyFrame + body *common.ByteBuff } // Body returns frame body. -func (p *BaseFrame) Body() *common.ByteBuff { - return p.body -} - -// Header returns frame header. -func (p *BaseFrame) Header() FrameHeader { - return p.header -} - -// SetHeader set frame header. -func (p *BaseFrame) SetHeader(h FrameHeader) { - p.header = h +func (f *RawFrame) Body() *common.ByteBuff { + return f.body } // Len returns length of frame. -func (p *BaseFrame) Len() int { - return HeaderLen + p.body.Len() +func (f *RawFrame) Len() int { + return HeaderLen + f.body.Len() } // WriteTo write frame to writer. -func (p *BaseFrame) WriteTo(w io.Writer) (n int64, err error) { +func (f *RawFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 - wrote, err = p.header.WriteTo(w) + wrote, err = f.header.WriteTo(w) if err != nil { return } n += wrote - wrote, err = p.body.WriteTo(w) + wrote, err = f.body.WriteTo(w) if err != nil { return } @@ -223,29 +202,19 @@ func (p *BaseFrame) WriteTo(w io.Writer) (n int64, err error) { } // Bytes returns frame in bytes. -func (p *BaseFrame) Bytes() []byte { - ret := make([]byte, HeaderLen+p.body.Len()) - copy(ret[:HeaderLen], p.header[:]) - copy(ret[HeaderLen:], p.body.Bytes()) +func (f *RawFrame) Bytes() []byte { + ret := make([]byte, HeaderLen+f.body.Len()) + copy(ret[:HeaderLen], f.header.Bytes()) + copy(ret[HeaderLen:], f.body.Bytes()) return ret } -// NewBaseFrame returns a new BaseFrame. -func NewBaseFrame(h FrameHeader, body *common.ByteBuff) (f *BaseFrame) { - f = &BaseFrame{ - header: h, - body: body, - done: make(chan struct{}), - } - return -} - -func (p *BaseFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { - raw := p.body.Bytes() +func (f *RawFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { + raw := f.body.Bytes() if offset > 0 { raw = raw[offset:] } - hasMetadata = p.header.Flag().Check(FlagMetadata) + hasMetadata = f.header.Flag().Check(FlagMetadata) if !hasMetadata { return } @@ -257,21 +226,36 @@ func (p *BaseFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { return } -func (p *BaseFrame) trySliceMetadata(offset int) ([]byte, bool) { - n, ok := p.trySeekMetadataLen(offset) +func (f *RawFrame) trySliceMetadata(offset int) ([]byte, bool) { + n, ok := f.trySeekMetadataLen(offset) if !ok || n < 0 { return nil, false } - return p.body.Bytes()[offset+3 : offset+3+n], true + return f.body.Bytes()[offset+3 : offset+3+n], true } -func (p *BaseFrame) trySliceData(offset int) []byte { - n, ok := p.trySeekMetadataLen(offset) +func (f *RawFrame) trySliceData(offset int) []byte { + n, ok := f.trySeekMetadataLen(offset) if !ok { - return p.body.Bytes()[offset:] + return f.body.Bytes()[offset:] } if n < 0 { return nil } - return p.body.Bytes()[offset+n+3:] + return f.body.Bytes()[offset+n+3:] +} + +func newTinyFrame(header Header) *tinyFrame { + return &tinyFrame{ + header: header, + done: make(chan struct{}), + } +} + +// NewRawFrame returns a new RawFrame. +func NewRawFrame(header Header, body *common.ByteBuff) *RawFrame { + return &RawFrame{ + tinyFrame: newTinyFrame(header), + body: body, + } } diff --git a/internal/framing/frame_cancel.go b/internal/framing/frame_cancel.go index 6d2eb0c..d431d6d 100644 --- a/internal/framing/frame_cancel.go +++ b/internal/framing/frame_cancel.go @@ -1,32 +1,51 @@ package framing import ( - "fmt" - - "github.com/rsocket/rsocket-go/internal/common" + "io" ) -// FrameCancel is frame of cancel. -type FrameCancel struct { - *BaseFrame +// CancelFrame is frame of cancel. +type CancelFrame struct { + *RawFrame +} + +type CancelFrameSupport struct { + *tinyFrame +} + +func (c CancelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = c.header.WriteTo(w) + if err != nil { + return + } + n += wrote + return +} + +func (c CancelFrameSupport) Len() int { + return HeaderLen } // Validate returns error if frame is invalid. -func (p *FrameCancel) Validate() (err error) { +func (f *CancelFrame) Validate() (err error) { // Cancel frame doesn't need any binary body. - if p.body != nil && p.body.Len() > 0 { + if f.body != nil && f.body.Len() > 0 { err = errIncompleteFrame } return } -func (p *FrameCancel) String() string { - return fmt.Sprintf("FrameCancel{%s}", p.header) +func NewCancelFrameSupport(id uint32) *CancelFrameSupport { + h := NewFrameHeader(id, FrameTypeCancel, 0) + return &CancelFrameSupport{ + tinyFrame: newTinyFrame(h), + } } -// NewFrameCancel returns a new cancel frame. -func NewFrameCancel(sid uint32) *FrameCancel { - return &FrameCancel{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeCancel), common.NewByteBuff()), +// NewCancelFrame creates cancel frame. +func NewCancelFrame(sid uint32) *CancelFrame { + return &CancelFrame{ + NewRawFrame(NewFrameHeader(sid, FrameTypeCancel, 0), nil), } } diff --git a/internal/framing/frame_error.go b/internal/framing/frame_error.go index 494e838..b21bbb6 100644 --- a/internal/framing/frame_error.go +++ b/internal/framing/frame_error.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "strings" "github.com/rsocket/rsocket-go/internal/common" @@ -14,57 +14,95 @@ const ( minErrorFrameLen = errCodeLen ) -// FrameError is error frame. -type FrameError struct { - *BaseFrame +// ErrorFrame is error frame. +type ErrorFrame struct { + *RawFrame } -func (p *FrameError) String() string { - return fmt.Sprintf("FrameError{%s,code=%s,data=%s}", p.header, p.ErrorCode(), p.ErrorData()) +type ErrorFrameSupport struct { + *tinyFrame + code common.ErrorCode + data []byte +} + +func (e ErrorFrameSupport) Error() string { + return makeErrorString(e.code, e.data) +} + +func (e ErrorFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = e.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + err = binary.Write(w, binary.BigEndian, uint32(e.code)) + if err != nil { + return + } + n += 4 + return +} + +func (e ErrorFrameSupport) Len() int { + return HeaderLen + 4 + len(e.data) } // Validate returns error if frame is invalid. -func (p *FrameError) Validate() (err error) { +func (p *ErrorFrame) Validate() (err error) { if p.body.Len() < minErrorFrameLen { err = errIncompleteFrame } return } -func (p *FrameError) Error() string { - bu := strings.Builder{} - bu.WriteString(p.ErrorCode().String()) - bu.WriteByte(':') - bu.WriteByte(' ') - bu.Write(p.ErrorData()) - return bu.String() +func (p *ErrorFrame) Error() string { + return makeErrorString(p.ErrorCode(), p.ErrorData()) } // ErrorCode returns error code. -func (p *FrameError) ErrorCode() common.ErrorCode { +func (p *ErrorFrame) ErrorCode() common.ErrorCode { v := binary.BigEndian.Uint32(p.body.Bytes()) return common.ErrorCode(v) } // ErrorData returns error data bytes. -func (p *FrameError) ErrorData() []byte { +func (p *ErrorFrame) ErrorData() []byte { return p.body.Bytes()[errDataOff:] } -// NewFrameError returns a new error frame. -func NewFrameError(streamID uint32, code common.ErrorCode, data []byte) *FrameError { +func NewErrorFrameSupport(id uint32, code common.ErrorCode, data []byte) *ErrorFrameSupport { + h := NewFrameHeader(id, FrameTypeError, 0) + t := newTinyFrame(h) + return &ErrorFrameSupport{ + tinyFrame: t, + code: code, + data: data, + } +} + +// NewErrorFrame returns a new error frame. +func NewErrorFrame(streamID uint32, code common.ErrorCode, data []byte) *ErrorFrame { bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], uint32(code)) if _, err := bf.Write(b4[:]); err != nil { - panic(err) } if _, err := bf.Write(data); err != nil { - panic(err) } - return &FrameError{ - NewBaseFrame(NewFrameHeader(streamID, FrameTypeError), bf), + return &ErrorFrame{ + NewRawFrame(NewFrameHeader(streamID, FrameTypeError, 0), bf), } } + +func makeErrorString(code common.ErrorCode, data []byte) string { + bu := strings.Builder{} + bu.WriteString(code.String()) + bu.WriteByte(':') + bu.WriteByte(' ') + bu.Write(data) + return bu.String() +} diff --git a/internal/framing/frame_fnf.go b/internal/framing/frame_fnf.go index fad43ad..582362e 100644 --- a/internal/framing/frame_fnf.go +++ b/internal/framing/frame_fnf.go @@ -1,42 +1,63 @@ package framing import ( - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -// FrameFNF is fire and forget frame. -type FrameFNF struct { - *BaseFrame +// FireAndForgetFrame is fire and forget frame. +type FireAndForgetFrame struct { + *RawFrame } -// Validate returns error if frame is invalid. -func (p *FrameFNF) Validate() (err error) { - if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { - err = errIncompleteFrame +type FireAndForgetFrameSupport struct { + *tinyFrame + metadata []byte + data []byte +} + +func (f FireAndForgetFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = f.header.WriteTo(w) + if err != nil { + return } + n += wrote + + wrote, err = writePayload(w, f.data, f.metadata) + if err != nil { + return + } + n += wrote return } -func (p *FrameFNF) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameFNF{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) +func (f FireAndForgetFrameSupport) Len() int { + return CalcPayloadFrameSize(f.data, f.metadata) +} + +// Validate returns error if frame is invalid. +func (f *FireAndForgetFrame) Validate() (err error) { + if f.header.Flag().Check(FlagMetadata) && f.body.Len() < 3 { + err = errIncompleteFrame + } + return } // Metadata returns metadata bytes. -func (p *FrameFNF) Metadata() ([]byte, bool) { - return p.trySliceMetadata(0) +func (f *FireAndForgetFrame) Metadata() ([]byte, bool) { + return f.trySliceMetadata(0) } // Data returns data bytes. -func (p *FrameFNF) Data() []byte { - return p.trySliceData(0) +func (f *FireAndForgetFrame) Data() []byte { + return f.trySliceData(0) } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameFNF) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() +func (f *FireAndForgetFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := f.Metadata() if ok { metadata = string(raw) } @@ -44,16 +65,28 @@ func (p *FrameFNF) MetadataUTF8() (metadata string, ok bool) { } // DataUTF8 returns data as UTF8 string. -func (p *FrameFNF) DataUTF8() string { - return string(p.Data()) +func (f *FireAndForgetFrame) DataUTF8() string { + return string(f.Data()) +} + +func NewFireAndForgetFrameSupport(sid uint32, data, metadata []byte, flag FrameFlag) *FireAndForgetFrameSupport { + if len(metadata) > 0 { + flag |= FlagMetadata + } + h := NewFrameHeader(sid, FrameTypeRequestFNF, flag) + t := newTinyFrame(h) + return &FireAndForgetFrameSupport{ + tinyFrame: t, + metadata: metadata, + data: data, + } } -// NewFrameFNF returns a new fire and forget frame. -func NewFrameFNF(sid uint32, data, metadata []byte, flags ...FrameFlag) *FrameFNF { - fg := newFlags(flags...) +// NewFireAndForgetFrame returns a new fire and forget frame. +func NewFireAndForgetFrame(sid uint32, data, metadata []byte, flag FrameFlag) *FireAndForgetFrame { bf := common.NewByteBuff() if len(metadata) > 0 { - fg |= FlagMetadata + flag |= FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -64,7 +97,7 @@ func NewFrameFNF(sid uint32, data, metadata []byte, flags ...FrameFlag) *FrameFN if _, err := bf.Write(data); err != nil { panic(err) } - return &FrameFNF{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestFNF, fg), bf), + return &FireAndForgetFrame{ + NewRawFrame(NewFrameHeader(sid, FrameTypeRequestFNF, flag), bf), } } diff --git a/internal/framing/frame_keepalive.go b/internal/framing/frame_keepalive.go index ab7192c..def5db9 100644 --- a/internal/framing/frame_keepalive.go +++ b/internal/framing/frame_keepalive.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) @@ -12,35 +12,84 @@ const ( minKeepaliveFrameLen = lastRecvPosLen ) -// FrameKeepalive is keepalive frame. -type FrameKeepalive struct { - *BaseFrame +// KeepaliveFrame is keepalive frame. +type KeepaliveFrame struct { + *RawFrame } -func (p *FrameKeepalive) String() string { - return fmt.Sprintf("FrameKeepalive{%s,lastReceivedPosition=%d,data=%s}", p.header, p.LastReceivedPosition(), string(p.Data())) +type KeepaliveFrameSupport struct { + *tinyFrame + pos [8]byte + data []byte +} + +func (k KeepaliveFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = k.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(k.pos[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(k.data) + if err != nil { + return + } + n += int64(v) + + return +} + +func (k KeepaliveFrameSupport) Len() int { + return HeaderLen + 8 + len(k.data) } // Validate returns error if frame is invalid. -func (p *FrameKeepalive) Validate() (err error) { - if p.body.Len() < minKeepaliveFrameLen { +func (k *KeepaliveFrame) Validate() (err error) { + if k.body.Len() < minKeepaliveFrameLen { err = errIncompleteFrame } return } // LastReceivedPosition returns last received position. -func (p *FrameKeepalive) LastReceivedPosition() uint64 { - return binary.BigEndian.Uint64(p.body.Bytes()) +func (k *KeepaliveFrame) LastReceivedPosition() uint64 { + return binary.BigEndian.Uint64(k.body.Bytes()) } // Data returns data bytes. -func (p *FrameKeepalive) Data() []byte { - return p.body.Bytes()[lastRecvPosLen:] +func (k *KeepaliveFrame) Data() []byte { + return k.body.Bytes()[lastRecvPosLen:] +} + +func NewKeepaliveFrameSupport(position uint64, data []byte, respond bool) *KeepaliveFrameSupport { + var flag FrameFlag + if respond { + flag |= FlagRespond + } + + var b [8]byte + binary.BigEndian.PutUint64(b[:], position) + + h := NewFrameHeader(0, FrameTypeKeepalive, flag) + t := newTinyFrame(h) + + return &KeepaliveFrameSupport{ + tinyFrame: t, + pos: b, + data: data, + } } -// NewFrameKeepalive returns a new keepalive frame. -func NewFrameKeepalive(position uint64, data []byte, respond bool) *FrameKeepalive { +// NewKeepaliveFrame returns a new keepalive frame. +func NewKeepaliveFrame(position uint64, data []byte, respond bool) *KeepaliveFrame { var fg FrameFlag if respond { fg |= FlagRespond @@ -56,7 +105,7 @@ func NewFrameKeepalive(position uint64, data []byte, respond bool) *FrameKeepali panic(err) } } - return &FrameKeepalive{ - NewBaseFrame(NewFrameHeader(0, FrameTypeKeepalive, fg), bf), + return &KeepaliveFrame{ + NewRawFrame(NewFrameHeader(0, FrameTypeKeepalive, fg), bf), } } diff --git a/internal/framing/frame_lease.go b/internal/framing/frame_lease.go index 9dd8d54..87c66b9 100644 --- a/internal/framing/frame_lease.go +++ b/internal/framing/frame_lease.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "time" "github.com/rsocket/rsocket-go/internal/common" @@ -15,44 +15,114 @@ const ( minLeaseFrame = ttlLen + reqLen ) -// FrameLease is lease frame. -type FrameLease struct { - *BaseFrame +// LeaseFrame is lease frame. +type LeaseFrame struct { + *RawFrame } // Validate returns error if frame is invalid. -func (p *FrameLease) Validate() (err error) { - if p.body.Len() < minLeaseFrame { +func (l *LeaseFrame) Validate() (err error) { + if l.body.Len() < minLeaseFrame { err = errIncompleteFrame } return } -func (p *FrameLease) String() string { - return fmt.Sprintf("FrameLease{%s,timeToLive=%s,numberOfRequests=%d,metadata=%s}", - p.header, p.TimeToLive(), p.NumberOfRequests(), string(p.Metadata())) -} - // TimeToLive returns time to live duration. -func (p *FrameLease) TimeToLive() time.Duration { - v := binary.BigEndian.Uint32(p.body.Bytes()) +func (l *LeaseFrame) TimeToLive() time.Duration { + v := binary.BigEndian.Uint32(l.body.Bytes()) return time.Millisecond * time.Duration(v) } // NumberOfRequests returns number of requests. -func (p *FrameLease) NumberOfRequests() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()[reqOff:]) +func (l *LeaseFrame) NumberOfRequests() uint32 { + return binary.BigEndian.Uint32(l.body.Bytes()[reqOff:]) } // Metadata returns metadata bytes. -func (p *FrameLease) Metadata() []byte { - if !p.header.Flag().Check(FlagMetadata) { +func (l *LeaseFrame) Metadata() []byte { + if !l.header.Flag().Check(FlagMetadata) { return nil } - return p.body.Bytes()[8:] + return l.body.Bytes()[8:] +} + +type LeaseFrameSupport struct { + *tinyFrame + ttl [4]byte + n [4]byte + metadata []byte +} + +func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = l.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(l.ttl[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(l.n[:]) + if err != nil { + return + } + n += int64(v) + + if !l.header.Flag().Check(FlagMetadata) { + return + } + + u := common.MustNewUint24(len(l.metadata)) + wrote, err = u.WriteTo(w) + if err != nil { + return + } + n += wrote + + v, err = w.Write(l.metadata) + if err != nil { + return + } + n += int64(v) + return +} + +func (l LeaseFrameSupport) Len() int { + n := HeaderLen + 8 + if l.header.Flag().Check(FlagMetadata) { + n += 3 + n += len(l.metadata) + } + return n +} + +func NewLeaseFrameSupport(ttl time.Duration, n uint32, metadata []byte) *LeaseFrameSupport { + var a, b [4]byte + binary.BigEndian.PutUint32(a[:], uint32(ttl.Milliseconds())) + binary.BigEndian.PutUint32(b[:], n) + + var flag FrameFlag + if len(metadata) > 0 { + flag |= FlagMetadata + } + h := NewFrameHeader(0, FrameTypeLease, flag) + t := newTinyFrame(h) + return &LeaseFrameSupport{ + tinyFrame: t, + ttl: a, + n: b, + metadata: metadata, + } } -func NewFrameLease(ttl time.Duration, n uint32, metadata []byte) *FrameLease { +func NewLeaseFrame(ttl time.Duration, n uint32, metadata []byte) *LeaseFrame { bf := common.NewByteBuff() if err := binary.Write(bf, binary.BigEndian, uint32(ttl.Milliseconds())); err != nil { panic(err) @@ -67,5 +137,5 @@ func NewFrameLease(ttl time.Duration, n uint32, metadata []byte) *FrameLease { panic(err) } } - return &FrameLease{NewBaseFrame(NewFrameHeader(0, FrameTypeLease, fg), bf)} + return &LeaseFrame{NewRawFrame(NewFrameHeader(0, FrameTypeLease, fg), bf)} } diff --git a/internal/framing/frame_metadata_push.go b/internal/framing/frame_metadata_push.go index 495cf6c..e2c9e3f 100644 --- a/internal/framing/frame_metadata_push.go +++ b/internal/framing/frame_metadata_push.go @@ -1,59 +1,87 @@ package framing import ( - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -var defaultFrameMetadataPushHeader = NewFrameHeader(0, FrameTypeMetadataPush, FlagMetadata) +var _metadataPushHeader = NewFrameHeader(0, FrameTypeMetadataPush, FlagMetadata) -// FrameMetadataPush is metadata push frame. -type FrameMetadataPush struct { - *BaseFrame +// MetadataPushFrame is metadata push frame. +type MetadataPushFrame struct { + *RawFrame +} +type MetadataPushFrameSupport struct { + *tinyFrame + metadata []byte } // Validate returns error if frame is invalid. -func (p *FrameMetadataPush) Validate() (err error) { +func (m *MetadataPushFrame) Validate() (err error) { return } -func (p *FrameMetadataPush) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameMetadataPush{%s,metadata=%s}", p.header, m) -} - // Metadata returns metadata bytes. -func (p *FrameMetadataPush) Metadata() ([]byte, bool) { - return p.body.Bytes(), true +func (m *MetadataPushFrame) Metadata() ([]byte, bool) { + return m.body.Bytes(), true } // Data returns data bytes. -func (p *FrameMetadataPush) Data() []byte { +func (m *MetadataPushFrame) Data() []byte { return nil } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameMetadataPush) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() +func (m *MetadataPushFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := m.Metadata() if ok { metadata = string(raw) } return } +func (m MetadataPushFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = m.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(m.metadata) + if err != nil { + return + } + n += int64(v) + return +} + +func (m MetadataPushFrameSupport) Len() int { + return HeaderLen + len(m.metadata) +} + // DataUTF8 returns data as UTF8 string. -func (p *FrameMetadataPush) DataUTF8() (data string) { +func (m *MetadataPushFrame) DataUTF8() (data string) { return } -// NewFrameMetadataPush returns a new metadata push frame. -func NewFrameMetadataPush(metadata []byte) *FrameMetadataPush { +func NewMetadataPushFrameSupport(metadata []byte) *MetadataPushFrameSupport { + t := newTinyFrame(_metadataPushHeader) + return &MetadataPushFrameSupport{ + tinyFrame: t, + metadata: metadata, + } +} + +// NewMetadataPushFrame returns a new metadata push frame. +func NewMetadataPushFrame(metadata []byte) *MetadataPushFrame { bf := common.NewByteBuff() if _, err := bf.Write(metadata); err != nil { panic(err) } - return &FrameMetadataPush{ - NewBaseFrame(defaultFrameMetadataPushHeader, bf), + return &MetadataPushFrame{ + NewRawFrame(_metadataPushHeader, bf), } } diff --git a/internal/framing/frame_payload.go b/internal/framing/frame_payload.go index 11ab00d..6d7315c 100644 --- a/internal/framing/frame_payload.go +++ b/internal/framing/frame_payload.go @@ -1,18 +1,24 @@ package framing import ( - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -// FramePayload is payload frame. -type FramePayload struct { - *BaseFrame +// PayloadFrame is payload frame. +type PayloadFrame struct { + *RawFrame +} + +type PayloadFrameSupport struct { + *tinyFrame + metadata []byte + data []byte } // Validate returns error if frame is invalid. -func (p *FramePayload) Validate() (err error) { +func (p *PayloadFrame) Validate() (err error) { // Minimal length should be 3 if metadata exists. if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { err = errIncompleteFrame @@ -20,23 +26,18 @@ func (p *FramePayload) Validate() (err error) { return } -func (p *FramePayload) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FramePayload{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) -} - // Metadata returns metadata bytes. -func (p *FramePayload) Metadata() ([]byte, bool) { +func (p *PayloadFrame) Metadata() ([]byte, bool) { return p.trySliceMetadata(0) } // Data returns data bytes. -func (p *FramePayload) Data() []byte { +func (p *PayloadFrame) Data() []byte { return p.trySliceData(0) } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FramePayload) MetadataUTF8() (metadata string, ok bool) { +func (p *PayloadFrame) MetadataUTF8() (metadata string, ok bool) { raw, ok := p.Metadata() if ok { metadata = string(raw) @@ -44,17 +45,56 @@ func (p *FramePayload) MetadataUTF8() (metadata string, ok bool) { return } +func (p *PayloadFrame) MustMetadataUTF8() string { + s, ok := p.MetadataUTF8() + if !ok { + panic("cannot convert metadata to utf8") + } + return s +} + // DataUTF8 returns data as UTF8 string. -func (p *FramePayload) DataUTF8() string { +func (p *PayloadFrame) DataUTF8() string { return string(p.Data()) } -// NewFramePayload returns a new payload frame. -func NewFramePayload(id uint32, data, metadata []byte, flags ...FrameFlag) *FramePayload { - fg := newFlags(flags...) +func (p PayloadFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = p.header.WriteTo(w) + if err != nil { + return + } + n += wrote + wrote, err = writePayload(w, p.data, p.metadata) + if err == nil { + n += wrote + } + return +} + +func (p PayloadFrameSupport) Len() int { + return CalcPayloadFrameSize(p.data, p.metadata) +} + +// NewPayloadFrameSupport returns a new payload frame. +func NewPayloadFrameSupport(id uint32, data, metadata []byte, flag FrameFlag) *PayloadFrameSupport { + if len(metadata) > 0 { + flag |= FlagMetadata + } + h := NewFrameHeader(id, FrameTypePayload, flag) + t := newTinyFrame(h) + return &PayloadFrameSupport{ + tinyFrame: t, + metadata: metadata, + data: data, + } +} + +// NewPayloadFrame returns a new payload frame. +func NewPayloadFrame(id uint32, data, metadata []byte, flag FrameFlag) *PayloadFrame { bf := common.NewByteBuff() if len(metadata) > 0 { - fg |= FlagMetadata + flag |= FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -67,7 +107,7 @@ func NewFramePayload(id uint32, data, metadata []byte, flags ...FrameFlag) *Fram panic(err) } } - return &FramePayload{ - NewBaseFrame(NewFrameHeader(id, FrameTypePayload, fg), bf), + return &PayloadFrame{ + NewRawFrame(NewFrameHeader(id, FrameTypePayload, flag), bf), } } diff --git a/internal/framing/frame_payload_test.go b/internal/framing/frame_payload_test.go deleted file mode 100644 index 4346a34..0000000 --- a/internal/framing/frame_payload_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package framing - -import ( - "bytes" - "testing" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestFramePayload_Basic(t *testing.T) { - metadata := []byte("hello") - data := []byte("world") - f := NewFramePayload(123, data, metadata) - - metadata1, _ := f.Metadata() - assert.Equal(t, metadata, metadata1, "metadata failed") - assert.Equal(t, data, f.Data(), "data failed") - - bf := &bytes.Buffer{} - _, _ = f.WriteTo(bf) - bs := bf.Bytes() - - bb := common.NewByteBuff() - _, _ = bb.Write(bs[HeaderLen:]) - f2 := &FramePayload{ - NewBaseFrame(ParseFrameHeader(bs[:HeaderLen]), bb), - } - - metadata2, _ := f2.Metadata() - assert.Equal(t, metadata, metadata2, "metadata failed 2") - assert.Equal(t, data, f2.Data(), "data failed 2") - assert.Equal(t, f2.header[:], f2.header[:]) -} diff --git a/internal/framing/frame_request_channel.go b/internal/framing/frame_request_channel.go index 7a8449b..93bd8ff 100644 --- a/internal/framing/frame_request_channel.go +++ b/internal/framing/frame_request_channel.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) @@ -12,47 +12,48 @@ const ( minRequestChannelFrameLen = initReqLen ) -// FrameRequestChannel is frame for RequestChannel. -type FrameRequestChannel struct { - *BaseFrame +// RequestChannelFrame is frame for RequestChannel. +type RequestChannelFrame struct { + *RawFrame +} + +type RequestChannelFrameSupport struct { + *tinyFrame + n [4]byte + metadata []byte + data []byte } // Validate returns error if frame is invalid. -func (p *FrameRequestChannel) Validate() error { - l := p.body.Len() +func (r *RequestChannelFrame) Validate() error { + l := r.body.Len() if l < minRequestChannelFrameLen { return errIncompleteFrame } - if p.header.Flag().Check(FlagMetadata) && l < minRequestChannelFrameLen+3 { + if r.header.Flag().Check(FlagMetadata) && l < minRequestChannelFrameLen+3 { return errIncompleteFrame } return nil } -func (p *FrameRequestChannel) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestChannel{%s,data=%s,metadata=%s,initialRequestN=%d}", - p.header, p.DataUTF8(), m, p.InitialRequestN()) -} - // InitialRequestN returns initial N. -func (p *FrameRequestChannel) InitialRequestN() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) +func (r *RequestChannelFrame) InitialRequestN() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) } // Metadata returns metadata bytes. -func (p *FrameRequestChannel) Metadata() ([]byte, bool) { - return p.trySliceMetadata(initReqLen) +func (r *RequestChannelFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(initReqLen) } // Data returns data bytes. -func (p *FrameRequestChannel) Data() []byte { - return p.trySliceData(initReqLen) +func (r *RequestChannelFrame) Data() []byte { + return r.trySliceData(initReqLen) } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestChannel) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() +func (r *RequestChannelFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() if ok { metadata = string(raw) } @@ -60,13 +61,56 @@ func (p *FrameRequestChannel) MetadataUTF8() (metadata string, ok bool) { } // DataUTF8 returns data as UTF8 string. -func (p *FrameRequestChannel) DataUTF8() string { - return string(p.Data()) +func (r *RequestChannelFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r RequestChannelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(r.n[:]) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, r.data, r.metadata) + if err != nil { + return + } + n += wrote + + return +} + +func (r RequestChannelFrameSupport) Len() int { + return CalcPayloadFrameSize(r.data, r.metadata) + 4 +} + +func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestChannelFrameSupport { + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + if len(metadata) > 0 { + flag |= FlagMetadata + } + h := NewFrameHeader(sid, FrameTypeRequestChannel, flag) + t := newTinyFrame(h) + return &RequestChannelFrameSupport{ + tinyFrame: t, + n: b, + metadata: metadata, + data: data, + } } -// NewFrameRequestChannel returns a new RequestChannel frame. -func NewFrameRequestChannel(sid uint32, n uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestChannel { - fg := newFlags(flags...) +// NewRequestChannelFrame returns a new RequestChannel frame. +func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestChannelFrame { bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) @@ -74,7 +118,7 @@ func NewFrameRequestChannel(sid uint32, n uint32, data, metadata []byte, flags . panic(err) } if len(metadata) > 0 { - fg |= FlagMetadata + flag |= FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -87,7 +131,7 @@ func NewFrameRequestChannel(sid uint32, n uint32, data, metadata []byte, flags . panic(err) } } - return &FrameRequestChannel{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestChannel, fg), bf), + return &RequestChannelFrame{ + NewRawFrame(NewFrameHeader(sid, FrameTypeRequestChannel, flag), bf), } } diff --git a/internal/framing/frame_request_n.go b/internal/framing/frame_request_n.go index 713f95b..540b884 100644 --- a/internal/framing/frame_request_n.go +++ b/internal/framing/frame_request_n.go @@ -2,44 +2,70 @@ package framing import ( "encoding/binary" - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -// FrameRequestN is RequestN frame. -type FrameRequestN struct { - *BaseFrame +// RequestNFrame is RequestN frame. +type RequestNFrame struct { + *RawFrame +} + +type RequestNFrameSupport struct { + *tinyFrame + n [4]byte } // Validate returns error if frame is invalid. -func (p *FrameRequestN) Validate() (err error) { - if p.body.Len() != 4 { +func (r *RequestNFrame) Validate() (err error) { + if r.body.Len() != 4 { err = errIncompleteFrame } return } -func (p *FrameRequestN) String() string { - return fmt.Sprintf("FrameRequestN{%s,n=%d}", p.header, p.N()) +// N returns N in RequestN. +func (r *RequestNFrame) N() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) } -// N returns N in RequestN. -func (p *FrameRequestN) N() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) +func (r RequestNFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + v, err := w.Write(r.n[:]) + if err == nil { + n += int64(v) + } + return } -// NewFrameRequestN returns a new RequestN frame. -func NewFrameRequestN(sid, n uint32, flags ...FrameFlag) *FrameRequestN { - fg := newFlags(flags...) - bf := common.NewByteBuff() +func (r RequestNFrameSupport) Len() int { + return HeaderLen + 4 +} +func NewRequestNFrameSupport(id uint32, n uint32, fg FrameFlag) *RequestNFrameSupport { + var b4 [4]byte + binary.BigEndian.PutUint32(b4[:], n) + return &RequestNFrameSupport{ + tinyFrame: newTinyFrame(NewFrameHeader(id, FrameTypeRequestN, fg)), + n: b4, + } +} + +// NewRequestNFrame returns a new RequestN frame. +func NewRequestNFrame(sid, n uint32, fg FrameFlag) *RequestNFrame { + bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) if _, err := bf.Write(b4[:]); err != nil { panic(err) } - return &FrameRequestN{ - NewBaseFrame(NewFrameHeader(sid, FrameTypeRequestN, fg), bf), + return &RequestNFrame{ + NewRawFrame(NewFrameHeader(sid, FrameTypeRequestN, fg), bf), } } diff --git a/internal/framing/frame_request_response.go b/internal/framing/frame_request_response.go index 1d1738f..326d00c 100644 --- a/internal/framing/frame_request_response.go +++ b/internal/framing/frame_request_response.go @@ -1,42 +1,43 @@ package framing import ( - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -// FrameRequestResponse is frame for requesting single response. -type FrameRequestResponse struct { - *BaseFrame +// RequestResponseFrame is frame for requesting single response. +type RequestResponseFrame struct { + *RawFrame +} + +type RequestResponseFrameSupport struct { + *tinyFrame + metadata []byte + data []byte } // Validate returns error if frame is invalid. -func (p *FrameRequestResponse) Validate() (err error) { - if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { +func (r *RequestResponseFrame) Validate() (err error) { + if r.header.Flag().Check(FlagMetadata) && r.body.Len() < 3 { err = errIncompleteFrame } return } -func (p *FrameRequestResponse) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestResponse{%s,data=%s,metadata=%s}", p.header, p.DataUTF8(), m) -} - // Metadata returns metadata bytes. -func (p *FrameRequestResponse) Metadata() ([]byte, bool) { - return p.trySliceMetadata(0) +func (r *RequestResponseFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(0) } // Data returns data bytes. -func (p *FrameRequestResponse) Data() []byte { - return p.trySliceData(0) +func (r *RequestResponseFrame) Data() []byte { + return r.trySliceData(0) } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestResponse) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() +func (r *RequestResponseFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() if ok { metadata = string(raw) } @@ -44,13 +45,42 @@ func (p *FrameRequestResponse) MetadataUTF8() (metadata string, ok bool) { } // DataUTF8 returns data as UTF8 string. -func (p *FrameRequestResponse) DataUTF8() string { - return string(p.Data()) +func (r *RequestResponseFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r RequestResponseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + wrote, err = writePayload(w, r.data, r.metadata) + if err == nil { + n += wrote + } + return +} + +func (r RequestResponseFrameSupport) Len() int { + return CalcPayloadFrameSize(r.data, r.metadata) +} + +// NewRequestResponseFrameSupport returns a new RequestResponse frame support. +func NewRequestResponseFrameSupport(id uint32, data, metadata []byte, fg FrameFlag) FrameSupport { + if len(metadata) > 0 { + fg |= FlagMetadata + } + return &RequestResponseFrameSupport{ + tinyFrame: newTinyFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg)), + metadata: metadata, + data: data, + } } -// NewFrameRequestResponse returns a new RequestResponse frame. -func NewFrameRequestResponse(id uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestResponse { - fg := newFlags(flags...) +// NewRequestResponseFrame returns a new RequestResponse frame. +func NewRequestResponseFrame(id uint32, data, metadata []byte, fg FrameFlag) *RequestResponseFrame { bf := common.NewByteBuff() if len(metadata) > 0 { fg |= FlagMetadata @@ -66,7 +96,7 @@ func NewFrameRequestResponse(id uint32, data, metadata []byte, flags ...FrameFla panic(err) } } - return &FrameRequestResponse{ - NewBaseFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg), bf), + return &RequestResponseFrame{ + NewRawFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg), bf), } } diff --git a/internal/framing/frame_request_stream.go b/internal/framing/frame_request_stream.go index 3a8d8bb..0a2f469 100644 --- a/internal/framing/frame_request_stream.go +++ b/internal/framing/frame_request_stream.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) @@ -11,47 +11,48 @@ const ( minRequestStreamFrameLen = initReqLen ) -// FrameRequestStream is frame for requesting a completable stream. -type FrameRequestStream struct { - *BaseFrame +// RequestStreamFrame is frame for requesting a completable stream. +type RequestStreamFrame struct { + *RawFrame +} + +type RequestStreamFrameSupport struct { + *tinyFrame + n [4]byte + metadata []byte + data []byte } // Validate returns error if frame is invalid. -func (p *FrameRequestStream) Validate() error { - l := p.body.Len() +func (r *RequestStreamFrame) Validate() error { + l := r.body.Len() if l < minRequestStreamFrameLen { return errIncompleteFrame } - if p.header.Flag().Check(FlagMetadata) && l < minRequestStreamFrameLen+3 { + if r.header.Flag().Check(FlagMetadata) && l < minRequestStreamFrameLen+3 { return errIncompleteFrame } return nil } -func (p *FrameRequestStream) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf("FrameRequestStream{%s,data=%s,metadata=%s,initialRequestN=%d}", - p.header, p.DataUTF8(), m, p.InitialRequestN()) -} - // InitialRequestN returns initial request N. -func (p *FrameRequestStream) InitialRequestN() uint32 { - return binary.BigEndian.Uint32(p.body.Bytes()) +func (r *RequestStreamFrame) InitialRequestN() uint32 { + return binary.BigEndian.Uint32(r.body.Bytes()) } // Metadata returns metadata bytes. -func (p *FrameRequestStream) Metadata() ([]byte, bool) { - return p.trySliceMetadata(4) +func (r *RequestStreamFrame) Metadata() ([]byte, bool) { + return r.trySliceMetadata(4) } // Data returns data bytes. -func (p *FrameRequestStream) Data() []byte { - return p.trySliceData(4) +func (r *RequestStreamFrame) Data() []byte { + return r.trySliceData(4) } // MetadataUTF8 returns metadata as UTF8 string. -func (p *FrameRequestStream) MetadataUTF8() (metadata string, ok bool) { - raw, ok := p.Metadata() +func (r *RequestStreamFrame) MetadataUTF8() (metadata string, ok bool) { + raw, ok := r.Metadata() if ok { metadata = string(raw) } @@ -59,21 +60,61 @@ func (p *FrameRequestStream) MetadataUTF8() (metadata string, ok bool) { } // DataUTF8 returns data as UTF8 string. -func (p *FrameRequestStream) DataUTF8() string { - return string(p.Data()) +func (r *RequestStreamFrame) DataUTF8() string { + return string(r.Data()) +} + +func (r RequestStreamFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(r.n[:]) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, r.data, r.metadata) + if err != nil { + return + } + n += wrote + return +} + +func (r RequestStreamFrameSupport) Len() int { + return 4 + CalcPayloadFrameSize(r.data, r.metadata) +} + +func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, flag FrameFlag) FrameSupport { + if len(metadata) > 0 { + flag |= FlagMetadata + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], n) + h := NewFrameHeader(id, FrameTypeRequestStream, flag) + t := newTinyFrame(h) + return &RequestStreamFrameSupport{ + tinyFrame: t, + n: b, + metadata: metadata, + data: data, + } } -// NewFrameRequestStream returns a new request stream frame. -func NewFrameRequestStream(id uint32, n uint32, data, metadata []byte, flags ...FrameFlag) *FrameRequestStream { - fg := newFlags(flags...) +// NewRequestStreamFrame returns a new request stream frame. +func NewRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestStreamFrame { bf := common.NewByteBuff() - var b4 [4]byte - binary.BigEndian.PutUint32(b4[:], n) - if _, err := bf.Write(b4[:]); err != nil { + if err := binary.Write(bf, binary.BigEndian, n); err != nil { panic(err) } if len(metadata) > 0 { - fg |= FlagMetadata + flag |= FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -86,7 +127,7 @@ func NewFrameRequestStream(id uint32, n uint32, data, metadata []byte, flags ... panic(err) } } - return &FrameRequestStream{ - NewBaseFrame(NewFrameHeader(id, FrameTypeRequestStream, fg), bf), + return &RequestStreamFrame{ + NewRawFrame(NewFrameHeader(id, FrameTypeRequestStream, flag), bf), } } diff --git a/internal/framing/frame_resume.go b/internal/framing/frame_resume.go index 766fcb0..5e1d6ac 100644 --- a/internal/framing/frame_resume.go +++ b/internal/framing/frame_resume.go @@ -3,7 +3,7 @@ package framing import ( "encoding/binary" "errors" - "fmt" + "io" "math" "github.com/rsocket/rsocket-go/internal/common" @@ -19,57 +19,123 @@ const ( _minResumeLength = _lenVersion + _lenTokenLength + _lenLastRecvPos + _lenFirstPos ) -// FrameResume represents a frame of Resume. -type FrameResume struct { - *BaseFrame -} - -func (p *FrameResume) String() string { - return fmt.Sprintf( - "FrameResume{%s,version=%s,token=0x%02x,lastReceivedServerPosition=%d,firstAvailableClientPosition=%d}", - p.header, p.Version(), p.Token(), p.LastReceivedServerPosition(), p.FirstAvailableClientPosition(), - ) +// ResumeFrame represents a frame of Resume. +type ResumeFrame struct { + *RawFrame } // Validate validate current frame. -func (p *FrameResume) Validate() (err error) { - if p.body.Len() < _minResumeLength { +func (r *ResumeFrame) Validate() (err error) { + if r.body.Len() < _minResumeLength { err = errIncompleteFrame } return } // Version returns version. -func (p *FrameResume) Version() common.Version { - raw := p.body.Bytes() +func (r *ResumeFrame) Version() common.Version { + raw := r.body.Bytes() major := binary.BigEndian.Uint16(raw) minor := binary.BigEndian.Uint16(raw[2:]) return [2]uint16{major, minor} } // Token returns resume token in bytes. -func (p *FrameResume) Token() []byte { - raw := p.body.Bytes() +func (r *ResumeFrame) Token() []byte { + raw := r.body.Bytes() tokenLen := binary.BigEndian.Uint16(raw[4:6]) return raw[6 : 6+tokenLen] } // LastReceivedServerPosition returns last received server position. -func (p *FrameResume) LastReceivedServerPosition() uint64 { - raw := p.body.Bytes() +func (r *ResumeFrame) LastReceivedServerPosition() uint64 { + raw := r.body.Bytes() offset := 6 + binary.BigEndian.Uint16(raw[4:6]) return binary.BigEndian.Uint64(raw[offset:]) } // FirstAvailableClientPosition returns first available client position. -func (p *FrameResume) FirstAvailableClientPosition() uint64 { - raw := p.body.Bytes() +func (r *ResumeFrame) FirstAvailableClientPosition() uint64 { + raw := r.body.Bytes() offset := 6 + binary.BigEndian.Uint16(raw[4:6]) + 8 return binary.BigEndian.Uint64(raw[offset:]) } -// NewFrameResume creates a new frame of Resume. -func NewFrameResume(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *FrameResume { +type ResumeFrameSupport struct { + *tinyFrame + version common.Version + token []byte + posFirst [8]byte + posLast [8]byte +} + +func (r ResumeFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + + v, err = w.Write(r.version.Bytes()) + if err != nil { + return + } + n += int64(v) + + lenToken := uint16(len(r.token)) + err = binary.Write(w, binary.BigEndian, lenToken) + if err != nil { + return + } + n += 2 + + v, err = w.Write(r.token) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(r.posLast[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(r.posFirst[:]) + if err != nil { + return + } + n += int64(v) + + return +} + +func (r ResumeFrameSupport) Len() int { + return HeaderLen + _lenTokenLength + _lenFirstPos + _lenLastRecvPos + _lenVersion + len(r.token) +} + +// NewResumeFrameSupport creates a new frame support of Resume. +func NewResumeFrameSupport(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrameSupport { + h := NewFrameHeader(0, FrameTypeResume, 0) + t := newTinyFrame(h) + var a, b [8]byte + binary.BigEndian.PutUint64(a[:], firstAvailableClientPosition) + binary.BigEndian.PutUint64(b[:], lastReceivedServerPosition) + + return &ResumeFrameSupport{ + tinyFrame: t, + version: version, + token: token, + posFirst: a, + posLast: b, + } +} + +// NewResumeFrame creates a new frame of Resume. +func NewResumeFrame(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrame { n := len(token) if n > math.MaxUint16 { panic(errResumeTokenTooLarge) @@ -92,7 +158,7 @@ func NewFrameResume(version common.Version, token []byte, firstAvailableClientPo if err := binary.Write(bf, binary.BigEndian, firstAvailableClientPosition); err != nil { panic(err) } - return &FrameResume{ - NewBaseFrame(NewFrameHeader(0, FrameTypeResume), bf), + return &ResumeFrame{ + NewRawFrame(NewFrameHeader(0, FrameTypeResume, 0), bf), } } diff --git a/internal/framing/frame_resume_ok.go b/internal/framing/frame_resume_ok.go index b1bb845..2026d49 100644 --- a/internal/framing/frame_resume_ok.go +++ b/internal/framing/frame_resume_ok.go @@ -2,37 +2,69 @@ package framing import ( "encoding/binary" - "fmt" + "io" "github.com/rsocket/rsocket-go/internal/common" ) -// FrameResumeOK represents a frame of ResumeOK. -type FrameResumeOK struct { - *BaseFrame +// ResumeOKFrame represents a frame of ResumeOK. +type ResumeOKFrame struct { + *RawFrame } -func (p *FrameResumeOK) String() string { - return fmt.Sprintf("FrameResumeOK{%s,lastReceivedClientPosition=%d}", p.header, p.LastReceivedClientPosition()) +type ResumeOKFrameSupport struct { + *tinyFrame + pos [8]byte } // Validate validate current frame. -func (p *FrameResumeOK) Validate() (err error) { +func (r *ResumeOKFrame) Validate() (err error) { // Length of frame body should be 8 - if p.body.Len() != 8 { + if r.body.Len() != 8 { err = errIncompleteFrame } return } // LastReceivedClientPosition returns last received client position. -func (p *FrameResumeOK) LastReceivedClientPosition() uint64 { - raw := p.body.Bytes() +func (r *ResumeOKFrame) LastReceivedClientPosition() uint64 { + raw := r.body.Bytes() return binary.BigEndian.Uint64(raw) } -// NewResumeOK creates a new frame of ResumeOK. -func NewResumeOK(position uint64) *FrameResumeOK { +func (r ResumeOKFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = r.header.WriteTo(w) + if err != nil { + return + } + n += wrote + var v int + v, err = w.Write(r.pos[:]) + if err != nil { + return + } + n += int64(v) + return +} + +func (r ResumeOKFrameSupport) Len() int { + return HeaderLen + 8 +} + +func NewResumeOKFrameSupport(position uint64) *ResumeOKFrameSupport { + h := NewFrameHeader(0, FrameTypeResumeOK, 0) + t := newTinyFrame(h) + var b [8]byte + binary.BigEndian.PutUint64(b[:], position) + return &ResumeOKFrameSupport{ + tinyFrame: t, + pos: b, + } +} + +// NewResumeOKFrame creates a new frame of ResumeOK. +func NewResumeOKFrame(position uint64) *ResumeOKFrame { var b8 [8]byte binary.BigEndian.PutUint64(b8[:], position) bf := common.NewByteBuff() @@ -40,7 +72,7 @@ func NewResumeOK(position uint64) *FrameResumeOK { if err != nil { panic(err) } - return &FrameResumeOK{ - NewBaseFrame(NewFrameHeader(0, FrameTypeResumeOK), bf), + return &ResumeOKFrame{ + NewRawFrame(NewFrameHeader(0, FrameTypeResumeOK, 0), bf), } } diff --git a/internal/framing/frame_setup.go b/internal/framing/frame_setup.go index 6244c09..e75a303 100644 --- a/internal/framing/frame_setup.go +++ b/internal/framing/frame_setup.go @@ -2,7 +2,7 @@ package framing import ( "encoding/binary" - "fmt" + "io" "time" "github.com/rsocket/rsocket-go/internal/common" @@ -11,60 +11,43 @@ import ( const ( _versionLen = 4 _timeLen = 4 - _tokenLen = 2 _metadataLen = 1 _dataLen = 1 - _minSetupFrameLen = _versionLen + _timeLen*2 + _tokenLen + _metadataLen + _dataLen + _minSetupFrameLen = _versionLen + _timeLen*2 + _metadataLen + _dataLen ) -// FrameSetup is sent by client to initiate protocol processing. -type FrameSetup struct { - *BaseFrame +// SetupFrame is sent by client to initiate protocol processing. +type SetupFrame struct { + *RawFrame } // Validate returns error if frame is invalid. -func (p *FrameSetup) Validate() (err error) { +func (p *SetupFrame) Validate() (err error) { if p.Len() < _minSetupFrameLen { err = errIncompleteFrame } return } -func (p *FrameSetup) String() string { - m, _ := p.MetadataUTF8() - return fmt.Sprintf( - "FrameSetup{%s,version=%s,keepaliveInterval=%s,keepaliveMaxLifetime=%s,token=0x%02x,dataMimeType=%s,metadataMimeType=%s,data=%s,metadata=%s}", - p.header, - p.Version(), - p.TimeBetweenKeepalive(), - p.MaxLifetime(), - p.Token(), - p.DataMimeType(), - p.MetadataMimeType(), - p.DataUTF8(), - m, - ) -} - // Version returns version. -func (p *FrameSetup) Version() common.Version { +func (p *SetupFrame) Version() common.Version { major := binary.BigEndian.Uint16(p.body.Bytes()) minor := binary.BigEndian.Uint16(p.body.Bytes()[2:]) return [2]uint16{major, minor} } // TimeBetweenKeepalive returns keepalive interval duration. -func (p *FrameSetup) TimeBetweenKeepalive() time.Duration { +func (p *SetupFrame) TimeBetweenKeepalive() time.Duration { return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[4:])) } // MaxLifetime returns keepalive max lifetime. -func (p *FrameSetup) MaxLifetime() time.Duration { +func (p *SetupFrame) MaxLifetime() time.Duration { return time.Millisecond * time.Duration(binary.BigEndian.Uint32(p.body.Bytes()[8:])) } // Token returns token of setup. -func (p *FrameSetup) Token() []byte { +func (p *SetupFrame) Token() []byte { if !p.header.Flag().Check(FlagResume) { return nil } @@ -74,19 +57,19 @@ func (p *FrameSetup) Token() []byte { } // DataMimeType returns MIME of data. -func (p *FrameSetup) DataMimeType() (mime string) { +func (p *SetupFrame) DataMimeType() (mime string) { _, b := p.mime() return string(b) } // MetadataMimeType returns MIME of metadata. -func (p *FrameSetup) MetadataMimeType() string { +func (p *SetupFrame) MetadataMimeType() string { a, _ := p.mime() return string(a) } // Metadata returns metadata bytes. -func (p *FrameSetup) Metadata() ([]byte, bool) { +func (p *SetupFrame) Metadata() ([]byte, bool) { if !p.header.Flag().Check(FlagMetadata) { return nil, false } @@ -97,7 +80,7 @@ func (p *FrameSetup) Metadata() ([]byte, bool) { } // Data returns data bytes. -func (p *FrameSetup) Data() []byte { +func (p *SetupFrame) Data() []byte { offset := p.seekMIME() m1, m2 := p.mime() offset += 2 + len(m1) + len(m2) @@ -108,7 +91,7 @@ func (p *FrameSetup) Data() []byte { } // MetadataUTF8 returns metadata as UTF8 string -func (p *FrameSetup) MetadataUTF8() (metadata string, ok bool) { +func (p *SetupFrame) MetadataUTF8() (metadata string, ok bool) { raw, ok := p.Metadata() if ok { metadata = string(raw) @@ -117,11 +100,11 @@ func (p *FrameSetup) MetadataUTF8() (metadata string, ok bool) { } // DataUTF8 returns data as UTF8 string. -func (p *FrameSetup) DataUTF8() string { +func (p *SetupFrame) DataUTF8() string { return string(p.Data()) } -func (p *FrameSetup) mime() (metadata []byte, data []byte) { +func (p *SetupFrame) mime() (metadata []byte, data []byte) { offset := p.seekMIME() raw := p.body.Bytes() l1 := int(raw[offset]) @@ -134,7 +117,7 @@ func (p *FrameSetup) mime() (metadata []byte, data []byte) { return m1, m2 } -func (p *FrameSetup) seekMIME() int { +func (p *SetupFrame) seekMIME() int { if !p.header.Flag().Check(FlagResume) { return 12 } @@ -142,8 +125,142 @@ func (p *FrameSetup) seekMIME() int { return 14 + int(l) } -// NewFrameSetup returns a new setup frame. -func NewFrameSetup( +type SetupFrameSupport struct { + *tinyFrame + version common.Version + keepalive [4]byte + lifetime [4]byte + token []byte + mimeMetadata []byte + mimeData []byte + metadata []byte + data []byte +} + +func (s SetupFrameSupport) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = s.header.WriteTo(w) + if err != nil { + return + } + n += wrote + + wrote, err = s.version.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(s.keepalive[:]) + if err != nil { + return + } + n += int64(v) + + v, err = w.Write(s.lifetime[:]) + if err != nil { + return + } + n += int64(v) + + if s.header.Flag().Check(FlagResume) { + tokenLen := len(s.token) + err = binary.Write(w, binary.BigEndian, uint16(tokenLen)) + if err != nil { + return + } + n += 2 + v, err = w.Write(s.token) + if err != nil { + return + } + n += int64(v) + } + + lenMimeMetadata := len(s.mimeMetadata) + v, err = w.Write([]byte{byte(lenMimeMetadata)}) + if err != nil { + return + } + n += int64(v) + v, err = w.Write(s.mimeMetadata) + if err != nil { + return + } + n += int64(v) + + lenMimeData := len(s.mimeData) + v, err = w.Write([]byte{byte(lenMimeData)}) + if err != nil { + return + } + n += int64(v) + v, err = w.Write(s.mimeData) + if err != nil { + return + } + n += int64(v) + + wrote, err = writePayload(w, s.data, s.metadata) + if err != nil { + return + } + n += wrote + return +} + +func (s SetupFrameSupport) Len() int { + n := _minSetupFrameLen + CalcPayloadFrameSize(s.data, s.metadata) + n += len(s.mimeData) + len(s.mimeMetadata) + if l := len(s.token); l > 0 { + n += 2 + len(s.token) + } + return n +} + +func NewSetupFrameSupport( + version common.Version, + timeBetweenKeepalive, + maxLifetime time.Duration, + token []byte, + mimeMetadata []byte, + mimeData []byte, + data []byte, + metadata []byte, + lease bool, +) *SetupFrameSupport { + var flag FrameFlag + if l := len(token); l > 0 { + flag |= FlagResume + } + if lease { + flag |= FlagLease + } + if l := len(metadata); l > 0 { + flag |= FlagMetadata + } + h := NewFrameHeader(0, FrameTypeSetup, flag) + t := newTinyFrame(h) + + var a, b [4]byte + binary.BigEndian.PutUint32(a[:], uint32(timeBetweenKeepalive.Nanoseconds()/1e6)) + binary.BigEndian.PutUint32(b[:], uint32(maxLifetime.Nanoseconds()/1e6)) + return &SetupFrameSupport{ + tinyFrame: t, + version: version, + keepalive: a, + lifetime: b, + token: token, + mimeMetadata: mimeMetadata, + mimeData: mimeData, + metadata: metadata, + data: data, + } +} + +// NewSetupFrame returns a new setup frame. +func NewSetupFrame( version common.Version, timeBetweenKeepalive, maxLifetime time.Duration, @@ -153,7 +270,7 @@ func NewFrameSetup( data []byte, metadata []byte, lease bool, -) *FrameSetup { +) *SetupFrame { var fg FrameFlag bf := common.NewByteBuff() if _, err := bf.Write(version.Bytes()); err != nil { @@ -207,7 +324,7 @@ func NewFrameSetup( panic(err) } } - return &FrameSetup{ - NewBaseFrame(NewFrameHeader(0, FrameTypeSetup, fg), bf), + return &SetupFrame{ + NewRawFrame(NewFrameHeader(0, FrameTypeSetup, fg), bf), } } diff --git a/internal/framing/frame_setup_test.go b/internal/framing/frame_setup_test.go index 3c553d0..d6db7a7 100644 --- a/internal/framing/frame_setup_test.go +++ b/internal/framing/frame_setup_test.go @@ -15,7 +15,7 @@ func TestDecodeFrameSetup(t *testing.T) { data := []byte("world") mimeMetadata, mimeData := []byte("text/plain"), []byte("application/json") token := []byte(common.RandAlphanumeric(16)) - setup := NewFrameSetup( + setup := NewSetupFrame( common.DefaultVersion, 30*time.Second, 90*time.Second, diff --git a/internal/framing/frame_test.go b/internal/framing/frame_test.go index d0af3f4..f676b1e 100644 --- a/internal/framing/frame_test.go +++ b/internal/framing/frame_test.go @@ -1,7 +1,9 @@ package framing_test import ( + "bytes" "encoding/hex" + "fmt" "log" "math" "testing" @@ -15,13 +17,13 @@ import ( const _sid uint32 = 1 func TestFrameCancel(t *testing.T) { - f := NewFrameCancel(_sid) + f := NewCancelFrame(_sid) basicCheck(t, f, FrameTypeCancel) } func TestFrameError(t *testing.T) { errData := []byte(common.RandAlphanumeric(100)) - f := NewFrameError(_sid, common.ErrorCodeApplicationError, errData) + f := NewErrorFrame(_sid, common.ErrorCodeApplicationError, errData) basicCheck(t, f, FrameTypeError) assert.Equal(t, common.ErrorCodeApplicationError, f.ErrorCode()) assert.Equal(t, errData, f.ErrorData()) @@ -31,7 +33,7 @@ func TestFrameError(t *testing.T) { func TestFrameFNF(t *testing.T) { b := []byte(common.RandAlphanumeric(100)) // Without Metadata - f := NewFrameFNF(_sid, b, nil, FlagNext) + f := NewFireAndForgetFrame(_sid, b, nil, FlagNext) basicCheck(t, f, FrameTypeRequestFNF) assert.Equal(t, b, f.Data()) metadata, ok := f.Metadata() @@ -40,7 +42,7 @@ func TestFrameFNF(t *testing.T) { assert.True(t, f.Header().Flag().Check(FlagNext)) assert.False(t, f.Header().Flag().Check(FlagMetadata)) // With Metadata - f = NewFrameFNF(_sid, nil, b, FlagNext) + f = NewFireAndForgetFrame(_sid, nil, b, FlagNext) basicCheck(t, f, FrameTypeRequestFNF) assert.Empty(t, f.Data()) metadata, ok = f.Metadata() @@ -53,7 +55,7 @@ func TestFrameFNF(t *testing.T) { func TestFrameKeepalive(t *testing.T) { pos := uint64(common.RandIntn(math.MaxInt32)) d := []byte(common.RandAlphanumeric(100)) - f := NewFrameKeepalive(pos, d, true) + f := NewKeepaliveFrame(pos, d, true) basicCheck(t, f, FrameTypeKeepalive) assert.Equal(t, d, f.Data()) assert.Equal(t, pos, f.LastReceivedPosition()) @@ -63,7 +65,7 @@ func TestFrameKeepalive(t *testing.T) { func TestFrameLease(t *testing.T) { metadata := []byte("foobar") n := uint32(4444) - f := NewFrameLease(time.Second, n, metadata) + f := NewLeaseFrame(time.Second, n, metadata) basicCheck(t, f, FrameTypeLease) assert.Equal(t, time.Second, f.TimeToLive()) assert.Equal(t, n, f.NumberOfRequests()) @@ -72,16 +74,16 @@ func TestFrameLease(t *testing.T) { func TestFrameMetadataPush(t *testing.T) { metadata := []byte("foobar") - f := NewFrameMetadataPush(metadata) + f := NewMetadataPushFrame(metadata) basicCheck(t, f, FrameTypeMetadataPush) metadata2, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, metadata, metadata2) } -func TestFramePayload(t *testing.T) { +func TestPayloadFrame(t *testing.T) { b := []byte("foobar") - f := NewFramePayload(_sid, b, b, FlagNext) + f := NewPayloadFrame(_sid, b, b, FlagNext) basicCheck(t, f, FrameTypePayload) m, ok := f.Metadata() assert.True(t, ok) @@ -90,10 +92,29 @@ func TestFramePayload(t *testing.T) { assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) } +func TestPayloadFrameSupport(t *testing.T) { + b := []byte("foobar") + f := NewPayloadFrameSupport(_sid, b, b, FlagNext) + fmt.Println("len:", f.Len()) + bf := &bytes.Buffer{} + _, err := f.WriteTo(bf) + assert.NoError(t, err, "write failed") + raw := bf.Bytes() + bb := common.NewByteBuff() + bb.Write(raw[6:]) + f2, err := FromRawFrame(NewRawFrame(ParseFrameHeader(raw[0:6]), bb)) + assert.NoError(t, err, "new frame failed") + f3 := f2.(*PayloadFrame) + fmt.Println("streamID:", f3.Header().StreamID()) + fmt.Println("data:", f3.DataUTF8()) + fmt.Println("metadata:", f3.MustMetadataUTF8()) + fmt.Println("flags:", f3.Header().Flag()) +} + func TestFrameRequestChannel(t *testing.T) { b := []byte("foobar") n := uint32(1) - f := NewFrameRequestChannel(_sid, n, b, b, FlagNext) + f := NewRequestChannelFrame(_sid, n, b, b, FlagNext) basicCheck(t, f, FrameTypeRequestChannel) assert.Equal(t, n, f.InitialRequestN()) assert.Equal(t, b, f.Data()) @@ -104,14 +125,14 @@ func TestFrameRequestChannel(t *testing.T) { func TestFrameRequestN(t *testing.T) { n := uint32(1234) - f := NewFrameRequestN(_sid, n) + f := NewRequestNFrame(_sid, n, 0) basicCheck(t, f, FrameTypeRequestN) assert.Equal(t, n, f.N()) } func TestFrameRequestResponse(t *testing.T) { b := []byte("foobar") - f := NewFrameRequestResponse(_sid, b, b, FlagNext) + f := NewRequestResponseFrame(_sid, b, b, FlagNext) basicCheck(t, f, FrameTypeRequestResponse) assert.Equal(t, b, f.Data()) m, ok := f.Metadata() @@ -123,7 +144,7 @@ func TestFrameRequestResponse(t *testing.T) { func TestFrameRequestStream(t *testing.T) { b := []byte("foobar") n := uint32(1234) - f := NewFrameRequestStream(_sid, n, b, b, FlagNext) + f := NewRequestStreamFrame(_sid, n, b, b, FlagNext) basicCheck(t, f, FrameTypeRequestStream) assert.Equal(t, b, f.Data()) assert.Equal(t, n, f.InitialRequestN()) @@ -137,7 +158,7 @@ func TestFrameResume(t *testing.T) { token := []byte("hello") p1 := uint64(333) p2 := uint64(444) - f := NewFrameResume(v, token, p1, p2) + f := NewResumeFrame(v, token, p1, p2) basicCheck(t, f, FrameTypeResume) assert.Equal(t, token, f.Token()) assert.Equal(t, p1, f.FirstAvailableClientPosition()) @@ -148,33 +169,57 @@ func TestFrameResume(t *testing.T) { func TestFrameResumeOK(t *testing.T) { pos := uint64(1234) - f := NewResumeOK(pos) + f := NewResumeOKFrame(pos) basicCheck(t, f, FrameTypeResumeOK) assert.Equal(t, pos, f.LastReceivedClientPosition()) } func TestFrameSetup(t *testing.T) { v := common.NewVersion(3, 1) - timeKeepalive := 30 * time.Second - maxLifetime := 3 * timeKeepalive - token := []byte("hello") - mimeData := []byte("application/json") - mimeMetadata := []byte("text/plain") - d := []byte(`{"hello":"world"}`) - m := []byte("foobar") - f := NewFrameSetup(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) - basicCheck(t, f, FrameTypeSetup) - assert.Equal(t, v.Major(), f.Version().Major()) - assert.Equal(t, v.Minor(), f.Version().Minor()) - assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) - assert.Equal(t, maxLifetime, f.MaxLifetime()) - assert.Equal(t, token, f.Token()) - assert.Equal(t, string(mimeData), f.DataMimeType()) - assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) - assert.Equal(t, d, f.Data()) - m2, ok := f.Metadata() - assert.True(t, ok) - assert.Equal(t, m, m2) + timeKeepalive := 20 * time.Second + maxLifetime := time.Minute + 30*time.Second + var token []byte + mimeData := []byte("application/binary") + mimeMetadata := []byte("application/binary") + d := []byte("你好") + m := []byte("世界") + f := NewSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + + doCheck := func(f *SetupFrame) { + fmt.Println("length:", f.Len()) + basicCheck(t, f, FrameTypeSetup) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) + assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) + assert.Equal(t, maxLifetime, f.MaxLifetime()) + assert.Equal(t, token, f.Token()) + assert.Equal(t, string(mimeData), f.DataMimeType()) + assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) + assert.Equal(t, d, f.Data()) + m2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, m, m2) + } + + doCheck(f) + + su := NewSetupFrameSupport(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + bf := &bytes.Buffer{} + + _, err := su.WriteTo(bf) + assert.NoError(t, err, "write failed") + + raw := bf.Bytes() + + assert.Equal(t, f.Len(), su.Len(), "wrong length") + + h := ParseFrameHeader(raw[:6]) + + bb := common.NewByteBuff() + _, _ = bb.Write(raw[6:]) + f2, err := FromRawFrame(NewRawFrame(h, bb)) + assert.NoError(t, err, "recreate setup frame failed") + doCheck(f2.(*SetupFrame)) } func TestDecode_Payload(t *testing.T) { @@ -195,12 +240,12 @@ func TestDecode_Payload(t *testing.T) { //log.Println(h) bf := common.NewByteBuff() _, _ = bf.Write(bs[HeaderLen:]) - f, err := NewFromBase(NewBaseFrame(h, bf)) + f, err := FromRawFrame(NewRawFrame(h, bf)) assert.NoError(t, err, "decode failed") log.Println(f) } - lease := NewFrameLease(3*time.Second, 5, nil) + lease := NewLeaseFrame(3*time.Second, 5, nil) log.Println("actual:", hex.EncodeToString(lease.Bytes())) log.Println("should: 00000000090000000bb800000005") } @@ -214,5 +259,4 @@ func basicCheck(t *testing.T, f Frame, typ FrameType) { assert.Equal(t, sid, f.Header().StreamID(), "wrong frame stream id") assert.NoError(t, f.Validate(), "validate frame type failed") assert.Equal(t, typ, f.Header().Type(), "frame type doesn't match") - assert.NotEmpty(t, f.String(), "empty frame string") } diff --git a/internal/framing/header.go b/internal/framing/header.go index 2b9a519..f4e5409 100644 --- a/internal/framing/header.go +++ b/internal/framing/header.go @@ -12,51 +12,64 @@ const ( HeaderLen = 6 ) -// FrameHeader is the header fo a RSocket frame. +// Header is the header fo a RSocket frame. // RSocket frames begin with a RSocket Frame Header. // It includes StreamID, FrameType and Flags. -type FrameHeader [HeaderLen]byte +type Header [HeaderLen]byte -func (p FrameHeader) String() string { +func (h Header) String() string { bu := strings.Builder{} - bu.WriteString("FrameHeader{id=") - bu.WriteString(strconv.FormatUint(uint64(p.StreamID()), 10)) + bu.WriteString("Header{id=") + bu.WriteString(strconv.FormatUint(uint64(h.StreamID()), 10)) bu.WriteString(",type=") - bu.WriteString(p.Type().String()) + bu.WriteString(h.Type().String()) bu.WriteString(",flag=") - bu.WriteString(p.Flag().String()) + bu.WriteString(h.Flag().String()) bu.WriteByte('}') return bu.String() } +// Resumable returns true if frame supports resume. +func (h Header) Resumable() bool { + switch h.Type() { + case FrameTypeRequestChannel, FrameTypeRequestStream, FrameTypeRequestResponse, FrameTypeRequestFNF, FrameTypeRequestN, FrameTypeCancel, FrameTypeError, FrameTypePayload: + return true + default: + return false + } +} + // WriteTo writes frame header to a writer. -func (p FrameHeader) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(p[:]) +func (h Header) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(h[:]) return int64(n), err } // StreamID returns StreamID. -func (p FrameHeader) StreamID() uint32 { - return binary.BigEndian.Uint32(p[:4]) +func (h Header) StreamID() uint32 { + return binary.BigEndian.Uint32(h[:4]) } // Type returns frame type. -func (p FrameHeader) Type() FrameType { - return FrameType((p.n() & 0xFC00) >> 10) +func (h Header) Type() FrameType { + return FrameType((h.n() & 0xFC00) >> 10) } // Flag returns flag of a frame. -func (p FrameHeader) Flag() FrameFlag { - return FrameFlag(p.n() & 0x03FF) +func (h Header) Flag() FrameFlag { + return FrameFlag(h.n() & 0x03FF) +} + +func (h Header) Bytes() []byte { + return h[:] } -func (p FrameHeader) n() uint16 { - return binary.BigEndian.Uint16(p[4:]) +func (h Header) n() uint16 { + return binary.BigEndian.Uint16(h[4:]) } // NewFrameHeader returns a new frame header. -func NewFrameHeader(streamID uint32, frameType FrameType, flags ...FrameFlag) FrameHeader { - fg := newFlags(flags...) +func NewFrameHeader(streamID uint32, frameType FrameType, fg FrameFlag) Header { var h [HeaderLen]byte binary.BigEndian.PutUint32(h[:], streamID) binary.BigEndian.PutUint16(h[4:], uint16(frameType)<<10|uint16(fg)) @@ -65,7 +78,7 @@ func NewFrameHeader(streamID uint32, frameType FrameType, flags ...FrameFlag) Fr } // ParseFrameHeader parse a header from bytes. -func ParseFrameHeader(bs []byte) FrameHeader { +func ParseFrameHeader(bs []byte) Header { _ = bs[HeaderLen-1] var bb [HeaderLen]byte copy(bb[:], bs[:HeaderLen]) diff --git a/internal/framing/misc.go b/internal/framing/misc.go index 3a53c71..1cddce1 100644 --- a/internal/framing/misc.go +++ b/internal/framing/misc.go @@ -1,6 +1,8 @@ package framing import ( + "io" + "github.com/rsocket/rsocket-go/internal/common" ) @@ -13,39 +15,68 @@ func CalcPayloadFrameSize(data, metadata []byte) int { return size } -// NewFromBase creates a frame from a BaseFrame. -func NewFromBase(f *BaseFrame) (frame Frame, err error) { +// FromRawFrame creates a frame from a RawFrame. +func FromRawFrame(f *RawFrame) (frame Frame, err error) { switch f.header.Type() { case FrameTypeSetup: - frame = &FrameSetup{BaseFrame: f} + frame = &SetupFrame{RawFrame: f} case FrameTypeKeepalive: - frame = &FrameKeepalive{BaseFrame: f} + frame = &KeepaliveFrame{RawFrame: f} case FrameTypeRequestResponse: - frame = &FrameRequestResponse{BaseFrame: f} + frame = &RequestResponseFrame{RawFrame: f} case FrameTypeRequestFNF: - frame = &FrameFNF{BaseFrame: f} + frame = &FireAndForgetFrame{RawFrame: f} case FrameTypeRequestStream: - frame = &FrameRequestStream{BaseFrame: f} + frame = &RequestStreamFrame{RawFrame: f} case FrameTypeRequestChannel: - frame = &FrameRequestChannel{BaseFrame: f} + frame = &RequestChannelFrame{RawFrame: f} case FrameTypeCancel: - frame = &FrameCancel{BaseFrame: f} + frame = &CancelFrame{RawFrame: f} case FrameTypePayload: - frame = &FramePayload{BaseFrame: f} + frame = &PayloadFrame{RawFrame: f} case FrameTypeMetadataPush: - frame = &FrameMetadataPush{BaseFrame: f} + frame = &MetadataPushFrame{RawFrame: f} case FrameTypeError: - frame = &FrameError{BaseFrame: f} + frame = &ErrorFrame{RawFrame: f} case FrameTypeRequestN: - frame = &FrameRequestN{BaseFrame: f} + frame = &RequestNFrame{RawFrame: f} case FrameTypeLease: - frame = &FrameLease{BaseFrame: f} + frame = &LeaseFrame{RawFrame: f} case FrameTypeResume: - frame = &FrameResume{BaseFrame: f} + frame = &ResumeFrame{RawFrame: f} case FrameTypeResumeOK: - frame = &FrameResumeOK{BaseFrame: f} + frame = &ResumeOKFrame{RawFrame: f} default: err = common.ErrInvalidFrame } return } + +func writePayload(w io.Writer, data []byte, metadata []byte) (n int64, err error) { + if l := len(metadata); l > 0 { + var wrote int64 + u := common.MustNewUint24(l) + wrote, err = u.WriteTo(w) + if err != nil { + return + } + n += wrote + + var v int + v, err = w.Write(metadata) + if err != nil { + return + } + n += int64(v) + } + + if l := len(data); l > 0 { + var v int + v, err = w.Write(data) + if err != nil { + return + } + n += int64(v) + } + return +} diff --git a/internal/socket/client_default.go b/internal/socket/client_default.go index 7cb05bb..31ffefd 100644 --- a/internal/socket/client_default.go +++ b/internal/socket/client_default.go @@ -29,7 +29,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err if setup.Lease { p.refreshLease(0, 0) tp.HandleLease(func(frame framing.Frame) (err error) { - lease := frame.(*framing.FrameLease) + lease := frame.(*framing.LeaseFrame) p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) logger.Infof(">>>>> refresh lease: %v\n", lease) return @@ -37,7 +37,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err } tp.HandleDisaster(func(frame framing.Frame) (err error) { - p.socket.SetError(frame.(*framing.FrameError)) + p.socket.SetError(frame.(*framing.ErrorFrame)) return }) diff --git a/internal/socket/client_resume.go b/internal/socket/client_resume.go index cac1323..3ac8c0f 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/client_resume.go @@ -78,12 +78,12 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { } }(ctx, tp) - var f framing.Frame + var f framing.FrameSupport // connect first time. if len(p.setup.Token) < 1 || connects == 1 { tp.HandleDisaster(func(frame framing.Frame) (err error) { - p.socket.SetError(frame.(*framing.FrameError)) + p.socket.SetError(frame.(*framing.ErrorFrame)) p.markClosing() return }) @@ -94,7 +94,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { return } - f = framing.NewFrameResume( + f = framing.NewResumeFrameSupport( common.DefaultVersion, p.setup.Token, p.socket.counter.WriteBytes(), @@ -110,7 +110,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { tp.HandleDisaster(func(frame framing.Frame) (err error) { // TODO: process other error with zero StreamID - f := frame.(*framing.FrameError) + f := frame.(*framing.ErrorFrame) if f.ErrorCode() == common.ErrorCodeRejectedResume { resumeErr <- f.Error() close(resumeErr) diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 60a9ebc..6a0873f 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -23,7 +23,7 @@ import ( "go.uber.org/atomic" ) -const outsSize = 64 +const _outChanSize = 64 var ( errSocketClosed = errors.New("socket closed already") @@ -41,8 +41,8 @@ func IsSocketClosedError(err error) bool { type DuplexRSocket struct { counter *transport.Counter tp *transport.Transport - outs chan framing.Frame - outsPriority []framing.Frame + outs chan framing.FrameSupport + outsPriority []framing.FrameSupport responder Responder messages common.U32Map sids StreamID @@ -135,20 +135,20 @@ func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { } sid := p.nextStreamID() if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameFNF(sid, data, m)) + p.sendFrame(framing.NewFireAndForgetFrameSupport(sid, data, m, 0)) return } p.doSplit(data, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { var f framing.Frame if idx == 0 { h := framing.NewFrameHeader(sid, framing.FrameTypeRequestFNF, fg) - f = &framing.FrameFNF{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.FireAndForgetFrame{ + RawFrame: framing.NewRawFrame(h, body), } } else { h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), } } p.sendFrame(f) @@ -158,7 +158,7 @@ func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { // MetadataPush start a request of MetadataPush. func (p *DuplexRSocket) MetadataPush(payload payload.Payload) { metadata, _ := payload.Metadata() - p.sendFrame(framing.NewFrameMetadataPush(metadata)) + p.sendFrame(framing.NewMetadataPushFrameSupport(metadata)) } // RequestResponse start a request of RequestResponse. @@ -173,7 +173,7 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { mo = resp. DoFinally(func(s rx.SignalType) { if s == rx.SignalCancel { - p.sendFrame(framing.NewFrameCancel(sid)) + p.sendFrame(framing.NewCancelFrameSupport(sid)) } p.unregister(sid) }) @@ -182,20 +182,21 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { // sending... size := framing.CalcPayloadFrameSize(data, metadata) if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameRequestResponse(sid, data, metadata)) + p.sendFrame(framing.NewRequestResponseFrameSupport(sid, data, metadata, 0)) + //p.sendFrame(framing.NewRequestResponseFrame(sid, data, metadata, 0)) return } p.doSplit(data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { var f framing.Frame if idx == 0 { h := framing.NewFrameHeader(sid, framing.FrameTypeRequestResponse, fg) - f = &framing.FrameRequestResponse{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.RequestResponseFrame{ + RawFrame: framing.NewRawFrame(h, body), } } else { h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), } } p.sendFrame(f) @@ -216,7 +217,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { ret = pc. DoFinally(func(sig rx.SignalType) { if sig == rx.SignalCancel { - p.sendFrame(framing.NewFrameCancel(sid)) + p.sendFrame(framing.NewCancelFrameSupport(sid)) } p.unregister(sid) }). @@ -232,7 +233,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { } if !newborn { - frameN := framing.NewFrameRequestN(sid, n32) + frameN := framing.NewRequestNFrameSupport(sid, n32, 0) p.sendFrame(frameN) <-frameN.DoneNotify() return @@ -243,7 +244,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { size := framing.CalcPayloadFrameSize(data, metadata) + 4 if !p.shouldSplit(size) { - p.sendFrame(framing.NewFrameRequestStream(sid, n32, data, metadata)) + p.sendFrame(framing.NewRequestStreamFrameSupport(sid, n32, data, metadata, 0)) return } p.doSplitSkip(4, data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { @@ -252,13 +253,13 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { h := framing.NewFrameHeader(sid, framing.FrameTypeRequestStream, fg) // write init RequestN binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.FrameRequestStream{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.RequestStreamFrame{ + RawFrame: framing.NewRawFrame(h, body), } } else { h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), } } p.sendFrame(f) @@ -290,7 +291,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { close(rcvRequested) } if !newborn { - frameN := framing.NewFrameRequestN(sid, n32) + frameN := framing.NewRequestNFrameSupport(sid, n32, 0) p.sendFrame(frameN) <-frameN.DoneNotify() return @@ -316,7 +317,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { size := framing.CalcPayloadFrameSize(d, m) + 4 if !p.shouldSplit(size) { metadata, _ := item.Metadata() - p.sendFrame(framing.NewFrameRequestChannel(sid, n32, item.Data(), metadata, framing.FlagNext)) + p.sendFrame(framing.NewRequestChannelFrameSupport(sid, n32, item.Data(), metadata, framing.FlagNext)) return } p.doSplitSkip(4, d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { @@ -325,13 +326,13 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { h := framing.NewFrameHeader(sid, framing.FrameTypeRequestChannel, fg|framing.FlagNext) // write init RequestN binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.FrameRequestChannel{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.RequestChannelFrame{ + RawFrame: framing.NewRawFrame(h, body), } } else { h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + f = &framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), } } p.sendFrame(f) @@ -347,7 +348,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { // TODO: handle cancel or error switch sig { case rx.SignalComplete: - complete := framing.NewFramePayload(sid, nil, nil, framing.FlagComplete) + complete := framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete) p.sendFrame(complete) <-complete.DoneNotify() default: @@ -362,7 +363,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { func (p *DuplexRSocket) onFrameRequestResponse(frame framing.Frame) error { // fragment - receiving, ok := p.doFragment(frame.(*framing.FrameRequestResponse)) + receiving, ok := p.doFragment(frame.(*framing.RequestResponseFrame)) if !ok { return nil } @@ -387,7 +388,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd } // 3. sending error with unsupported handler if sending == nil { - p.writeError(sid, framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestResponse)) + p.writeError(sid, framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestResponse)) return nil } @@ -414,7 +415,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd } func (p *DuplexRSocket) onFrameRequestChannel(input framing.Frame) error { - receiving, ok := p.doFragment(input.(*framing.FrameRequestChannel)) + receiving, ok := p.doFragment(input.(*framing.RequestChannelFrame)) if !ok { return nil } @@ -425,10 +426,10 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) // seek initRequestN var initRequestN int switch v := pl.(type) { - case *framing.FrameRequestChannel: + case *framing.RequestChannelFrame: initRequestN = toIntN(v.InitialRequestN()) case fragmentation.Joiner: - initRequestN = toIntN(v.First().(*framing.FrameRequestChannel).InitialRequestN()) + initRequestN = toIntN(v.First().(*framing.RequestChannelFrame).InitialRequestN()) default: panic("unreachable") } @@ -452,7 +453,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) } }). DoOnRequest(func(n int) { - frameN := framing.NewFrameRequestN(sid, toU32N(n)) + frameN := framing.NewRequestNFrameSupport(sid, toU32N(n), 0) p.sendFrame(frameN) <-frameN.DoneNotify() }) @@ -468,7 +469,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) }() flux = p.responder.RequestChannel(receiving) if flux == nil { - err = framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestChannel) + err = framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestChannel) } return }() @@ -486,7 +487,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) p.writeError(sid, e) }), rx.OnComplete(func() { - complete := framing.NewFramePayload(sid, nil, nil, framing.FlagComplete) + complete := framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete) p.sendFrame(complete) <-complete.DoneNotify() }), @@ -526,12 +527,12 @@ func (p *DuplexRSocket) respondMetadataPush(input framing.Frame) (err error) { logger.Errorf("respond METADATA_PUSH failed: %s\n", e) } }() - p.responder.MetadataPush(input.(*framing.FrameMetadataPush)) + p.responder.MetadataPush(input.(*framing.MetadataPushFrame)) return } func (p *DuplexRSocket) onFrameFNF(frame framing.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.FrameFNF)) + receiving, ok := p.doFragment(frame.(*framing.FireAndForgetFrame)) if !ok { return nil } @@ -549,7 +550,7 @@ func (p *DuplexRSocket) respondFNF(receiving fragmentation.HeaderAndPayload) (er } func (p *DuplexRSocket) onFrameRequestStream(frame framing.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.FrameRequestStream)) + receiving, ok := p.doFragment(frame.(*framing.RequestStreamFrame)) if !ok { return nil } @@ -566,7 +567,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa }() resp = p.responder.RequestStream(receiving) if resp == nil { - err = framing.NewFrameError(sid, common.ErrorCodeApplicationError, unsupportedRequestStream) + err = framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestStream) } return }() @@ -580,10 +581,10 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa // seek n32 var n32 int switch v := receiving.(type) { - case *framing.FrameRequestStream: + case *framing.RequestStreamFrame: n32 = int(v.InitialRequestN()) case fragmentation.Joiner: - n32 = int(v.First().(*framing.FrameRequestStream).InitialRequestN()) + n32 = int(v.First().(*framing.RequestStreamFrame).InitialRequestN()) default: panic("unreachable") } @@ -600,7 +601,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa p.writeError(sid, e) }), rx.OnComplete(func() { - p.sendFrame(framing.NewFramePayload(sid, nil, nil, framing.FlagComplete)) + p.sendFrame(framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete)) }), ) @@ -620,12 +621,12 @@ func (p *DuplexRSocket) writeError(sid uint32, e error) { return } switch err := e.(type) { - case *framing.FrameError: + case *framing.ErrorFrame: p.sendFrame(err) case common.CustomError: - p.sendFrame(framing.NewFrameError(sid, err.ErrorCode(), err.ErrorData())) + p.sendFrame(framing.NewErrorFrameSupport(sid, err.ErrorCode(), err.ErrorData())) default: - p.sendFrame(framing.NewFrameError(sid, common.ErrorCodeApplicationError, []byte(e.Error()))) + p.sendFrame(framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, []byte(e.Error()))) } } @@ -635,10 +636,11 @@ func (p *DuplexRSocket) SetResponder(responder Responder) { } func (p *DuplexRSocket) onFrameKeepalive(frame framing.Frame) (err error) { - f := frame.(*framing.FrameKeepalive) + f := frame.(*framing.KeepaliveFrame) if f.Header().Flag().Check(framing.FlagRespond) { - f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) - p.sendFrame(f) + k := framing.NewKeepaliveFrame(f.LastReceivedPosition(), f.Data(), false) + //f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) + p.sendFrame(k) } return } @@ -668,7 +670,7 @@ func (p *DuplexRSocket) onFrameCancel(frame framing.Frame) (err error) { } func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { - f := input.(*framing.FrameError) + f := input.(*framing.ErrorFrame) logger.Errorf("handle error frame: %s\n", f) sid := f.Header().StreamID() @@ -692,7 +694,7 @@ func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { } func (p *DuplexRSocket) onFrameRequestN(input framing.Frame) (err error) { - f := input.(*framing.FrameRequestN) + f := input.(*framing.RequestNFrame) sid := f.Header().StreamID() v, ok := p.messages.Load(sid) if !ok { @@ -738,7 +740,7 @@ func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fr } func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { - pl, ok := p.doFragment(frame.(*framing.FramePayload)) + pl, ok := p.doFragment(frame.(*framing.PayloadFrame)) if !ok { return nil } @@ -829,7 +831,7 @@ func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { p.cond.L.Unlock() } -func (p *DuplexRSocket) sendFrame(f framing.Frame) { +func (p *DuplexRSocket) sendFrame(f framing.FrameSupport) { defer func() { if e := recover(); e != nil { logger.Warnf("send frame failed: %s\n", e) @@ -848,18 +850,19 @@ func (p *DuplexRSocket) sendPayload( size := framing.CalcPayloadFrameSize(d, m) if !p.shouldSplit(size) { - p.sendFrame(framing.NewFramePayload(sid, d, m, frameFlag)) + p.sendFrame(framing.NewPayloadFrameSupport(sid, d, m, frameFlag)) + //p.sendFrame(framing.NewPayloadFrame(sid, d, m, frameFlag)) return } p.doSplit(d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var h framing.FrameHeader + var h framing.Header if idx == 0 { h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|frameFlag) } else { h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) } - p.sendFrame(&framing.FramePayload{ - BaseFrame: framing.NewBaseFrame(h, body), + p.sendFrame(&framing.PayloadFrame{ + RawFrame: framing.NewRawFrame(h, body), }) }) } @@ -868,11 +871,11 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) if len(p.outs) > 0 { p.drain(nil) } - var out framing.Frame + var out framing.FrameSupport select { case <-p.keepaliver.C(): ok = true - out = framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) + out = framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) if p.tp != nil { err := p.tp.Send(out, true) if err != nil { @@ -884,7 +887,7 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) if !ok { return } - out = framing.NewFrameLease(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) + out = framing.NewLeaseFrameSupport(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) if p.tp == nil { p.outsPriority = append(p.outsPriority, out) } else if err := p.tp.Send(out, true); err != nil { @@ -909,12 +912,12 @@ func (p *DuplexRSocket) drainWithKeepalive() (ok bool) { if len(p.outs) > 0 { p.drain(nil) } - var out framing.Frame + var out framing.FrameSupport select { case <-p.keepaliver.C(): ok = true - out = framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) + out = framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) if p.tp != nil { err := p.tp.Send(out, true) if err != nil { @@ -947,7 +950,7 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { if !ok { return false } - if p.drainOne(framing.NewFrameLease(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { + if p.drainOne(framing.NewLeaseFrameSupport(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { flush = true } case out, ok := <-p.outs: @@ -967,7 +970,7 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { return true } -func (p *DuplexRSocket) drainOne(out framing.Frame) (wrote bool) { +func (p *DuplexRSocket) drainOne(out framing.FrameSupport) (wrote bool) { if p.tp == nil { p.outsPriority = append(p.outsPriority, out) return @@ -992,7 +995,7 @@ func (p *DuplexRSocket) drainOutBack() { if p.tp == nil { return } - var out framing.Frame + var out framing.FrameSupport for i := range p.outsPriority { out = p.outsPriority[i] if err := p.tp.Send(out, false); err != nil { @@ -1023,7 +1026,7 @@ func (p *DuplexRSocket) loopWriteWithKeepaliver(ctx context.Context, leaseChan < select { case <-p.keepaliver.C(): - kf := framing.NewFrameKeepalive(p.counter.ReadBytes(), nil, true) + kf := framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) if p.tp != nil { err := p.tp.Send(kf, true) if err != nil { @@ -1114,7 +1117,7 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { return &DuplexRSocket{ closed: atomic.NewBool(false), leases: leases, - outs: make(chan framing.Frame, outsSize), + outs: make(chan framing.FrameSupport, _outChanSize), mtu: mtu, messages: common.NewU32Map(), sids: &serverStreamIDs{}, @@ -1134,7 +1137,7 @@ func NewClientDuplexRSocket( ka := newKeepaliver(keepaliveInterval) s = &DuplexRSocket{ closed: atomic.NewBool(false), - outs: make(chan framing.Frame, outsSize), + outs: make(chan framing.FrameSupport, _outChanSize), mtu: mtu, messages: common.NewU32Map(), sids: &clientStreamIDs{}, diff --git a/internal/socket/misc.go b/internal/socket/misc.go index 6bcdb8a..3c258b1 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -22,8 +22,8 @@ type SetupInfo struct { Metadata []byte } -func (p *SetupInfo) toFrame() *framing.FrameSetup { - return framing.NewFrameSetup( +func (p *SetupInfo) toFrame() framing.FrameSupport { + return framing.NewSetupFrameSupport( p.Version, p.KeepaliveInterval, p.KeepaliveLifetime, diff --git a/internal/transport/connection.go b/internal/transport/connection.go index 767cda0..d4badc4 100644 --- a/internal/transport/connection.go +++ b/internal/transport/connection.go @@ -18,7 +18,7 @@ type Conn interface { // Read reads next frame from Conn. Read() (framing.Frame, error) // Write writes a frame to Conn. - Write(frames framing.Frame) error + Write(frames framing.FrameSupport) error // Flush. Flush() error } diff --git a/internal/transport/connection_tcp.go b/internal/transport/connection_tcp.go index 4aeda3b..2998bbe 100644 --- a/internal/transport/connection_tcp.go +++ b/internal/transport/connection_tcp.go @@ -43,11 +43,11 @@ func (p *tcpConn) Read() (f framing.Frame, err error) { err = errors.Wrap(err, "read frame failed") return } - base := framing.NewBaseFrame(h, bf) - if p.counter != nil && base.CanResume() { + base := framing.NewRawFrame(h, bf) + if p.counter != nil && base.Header().Resumable() { p.counter.incrReadBytes(base.Len()) } - f, err = framing.NewFromBase(base) + f, err = framing.FromRawFrame(base) if err != nil { err = errors.Wrap(err, "read frame failed") return @@ -71,9 +71,9 @@ func (p *tcpConn) Flush() (err error) { return } -func (p *tcpConn) Write(frame framing.Frame) (err error) { +func (p *tcpConn) Write(frame framing.FrameSupport) (err error) { size := frame.Len() - if p.counter != nil && frame.CanResume() { + if p.counter != nil && frame.Header().Resumable() { p.counter.incrWriteBytes(size) } _, err = common.MustNewUint24(size).WriteTo(p.writer) @@ -83,7 +83,7 @@ func (p *tcpConn) Write(frame framing.Frame) (err error) { } var debugStr string if logger.IsDebugEnabled() { - debugStr = frame.String() + debugStr = framing.PrintFrame(frame) } _, err = frame.WriteTo(p.writer) if err != nil { diff --git a/internal/transport/connection_ws.go b/internal/transport/connection_ws.go index 1bbd2d5..9fa6f37 100644 --- a/internal/transport/connection_ws.go +++ b/internal/transport/connection_ws.go @@ -1,7 +1,9 @@ package transport import ( + "bytes" "io" + "sync" "time" "github.com/gorilla/websocket" @@ -11,20 +13,24 @@ import ( "github.com/rsocket/rsocket-go/logger" ) -type wsConnection struct { +var _buffPool = sync.Pool{ + New: func() interface{} { return &bytes.Buffer{} }, +} + +type wsConn struct { c *websocket.Conn counter *Counter } -func (p *wsConnection) SetCounter(c *Counter) { +func (p *wsConn) SetCounter(c *Counter) { p.counter = c } -func (p *wsConnection) SetDeadline(deadline time.Time) error { +func (p *wsConn) SetDeadline(deadline time.Time) error { return p.c.SetReadDeadline(deadline) } -func (p *wsConnection) Read() (f framing.Frame, err error) { +func (p *wsConn) Read() (f framing.Frame, err error) { t, raw, err := p.c.ReadMessage() if err != nil { err = errors.Wrap(err, "read frame failed") @@ -46,8 +52,8 @@ func (p *wsConnection) Read() (f framing.Frame, err error) { err = errors.Wrap(err, "read frame failed") return } - base := framing.NewBaseFrame(header, bf) - f, err = framing.NewFromBase(base) + base := framing.NewRawFrame(header, bf) + f, err = framing.FromRawFrame(base) if err != nil { err = errors.Wrap(err, "read frame failed") return @@ -63,12 +69,21 @@ func (p *wsConnection) Read() (f framing.Frame, err error) { return } -func (p *wsConnection) Flush() (err error) { +func (p *wsConn) Flush() (err error) { return } -func (p *wsConnection) Write(frame framing.Frame) (err error) { - err = p.c.WriteMessage(websocket.BinaryMessage, frame.Bytes()) +func (p *wsConn) Write(frame framing.FrameSupport) (err error) { + bf := _buffPool.Get().(*bytes.Buffer) + defer func() { + bf.Reset() + _buffPool.Put(bf) + }() + _, err = frame.WriteTo(bf) + if err != nil { + return + } + err = p.c.WriteMessage(websocket.BinaryMessage, bf.Bytes()) if err == io.EOF { return } @@ -82,12 +97,12 @@ func (p *wsConnection) Write(frame framing.Frame) (err error) { return } -func (p *wsConnection) Close() error { +func (p *wsConn) Close() error { return p.c.Close() } -func newWebsocketConnection(rawConn *websocket.Conn) *wsConnection { - return &wsConnection{ +func newWebsocketConnection(rawConn *websocket.Conn) *wsConn { + return &wsConn{ c: rawConn, } } diff --git a/internal/transport/decoder_test.go b/internal/transport/decoder_test.go index 39c362a..9728c55 100644 --- a/internal/transport/decoder_test.go +++ b/internal/transport/decoder_test.go @@ -24,7 +24,7 @@ func TestDecoder(t *testing.T) { h := framing.ParseFrameHeader(raw) bf := common.NewByteBuff() _, _ = bf.Write(raw[framing.HeaderLen:]) - f, err := framing.NewFromBase(framing.NewBaseFrame(h, bf)) + f, err := framing.FromRawFrame(framing.NewRawFrame(h, bf)) if err != nil { panic(err) } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 58690f2..d420ce7 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -75,7 +75,7 @@ func (p *Transport) SetLifetime(lifetime time.Duration) { } // Send send a frame. -func (p *Transport) Send(frame framing.Frame, flush bool) (err error) { +func (p *Transport) Send(frame framing.FrameSupport, flush bool) (err error) { defer func() { // ensure frame done when send success. if err == nil { @@ -148,7 +148,7 @@ L: if err != nil { break L } - err = p.DeliveryFrame(ctx, f) + err = p.DispatchFrame(ctx, f) if err != nil { break L } @@ -233,8 +233,8 @@ func (p *Transport) HandleKeepalive(handler FrameHandler) { p.hKeepalive = handler } -// DeliveryFrame delivery incoming frames. -func (p *Transport) DeliveryFrame(_ context.Context, frame framing.Frame) (err error) { +// DispatchFrame delivery incoming frames. +func (p *Transport) DispatchFrame(_ context.Context, frame framing.Frame) (err error) { header := frame.Header() t := header.Type() sid := header.StreamID() @@ -243,12 +243,12 @@ func (p *Transport) DeliveryFrame(_ context.Context, frame framing.Frame) (err e switch t { case framing.FrameTypeSetup: - p.maxLifetime = frame.(*framing.FrameSetup).MaxLifetime() + p.maxLifetime = frame.(*framing.SetupFrame).MaxLifetime() handler = p.hSetup case framing.FrameTypeResume: handler = p.hResume case framing.FrameTypeResumeOK: - p.lastRcvPos = frame.(*framing.FrameResumeOK).LastReceivedClientPosition() + p.lastRcvPos = frame.(*framing.ResumeOKFrame).LastReceivedClientPosition() handler = p.hResumeOK case framing.FrameTypeRequestFNF: handler = p.hFireAndForget @@ -271,7 +271,7 @@ func (p *Transport) DeliveryFrame(_ context.Context, frame framing.Frame) (err e handler = p.hRequestN case framing.FrameTypeError: if sid == 0 { - err = errors.New(frame.(*framing.FrameError).Error()) + err = errors.New(frame.(*framing.ErrorFrame).Error()) if p.hError0 != nil { _ = p.hError0(frame) } @@ -281,7 +281,7 @@ func (p *Transport) DeliveryFrame(_ context.Context, frame framing.Frame) (err e case framing.FrameTypeCancel: handler = p.hCancel case framing.FrameTypeKeepalive: - ka := frame.(*framing.FrameKeepalive) + ka := frame.(*framing.KeepaliveFrame) p.lastRcvPos = ka.LastReceivedPosition() handler = p.hKeepalive case framing.FrameTypeLease: diff --git a/server.go b/server.go index 1a05b5a..3d06fb7 100644 --- a/server.go +++ b/server.go @@ -207,9 +207,9 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { } switch frame := first.(type) { - case *framing.FrameResume: + case *framing.ResumeFrame: p.doResume(frame, tp, socketChan) - case *framing.FrameSetup: + case *framing.SetupFrame: sendingSocket, err := p.doSetup(frame, tp, socketChan) if err != nil { _ = tp.Send(err, true) @@ -222,7 +222,7 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { } }(ctx, sendingSocket) default: - err := framing.NewFrameError(0, common.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) + err := framing.NewErrorFrameSupport(0, common.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) _ = tp.Send(err, true) _ = tp.Close() return @@ -242,13 +242,10 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { return t.Listen(ctx, serveNotifier) } -func (p *server) doSetup( - frame *framing.FrameSetup, - tp *transport.Transport, - socketChan chan<- socket.ServerSocket, -) (sendingSocket socket.ServerSocket, err *framing.FrameError) { +func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) (sendingSocket socket.ServerSocket, err *framing.ErrorFrameSupport) { + if frame.Header().Flag().Check(framing.FlagLease) && p.leases == nil { - err = framing.NewFrameError(0, common.ErrorCodeUnsupportedSetup, errUnavailableLease) + err = framing.NewErrorFrameSupport(0, common.ErrorCodeUnsupportedSetup, errUnavailableLease) return } @@ -256,7 +253,7 @@ func (p *server) doSetup( // 1. receive a token but server doesn't support resume. if isResume && !p.resumeOpts.enable { - err = framing.NewFrameError(0, common.ErrorCodeUnsupportedSetup, errUnavailableResume) + err = framing.NewErrorFrameSupport(0, common.ErrorCodeUnsupportedSetup, errUnavailableResume) return } @@ -266,7 +263,7 @@ func (p *server) doSetup( if !isResume { sendingSocket = socket.NewServer(rawSocket) if responder, e := p.acc(frame, sendingSocket); e != nil { - err = framing.NewFrameError(0, common.ErrorCodeRejectedSetup, []byte(e.Error())) + err = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedSetup, []byte(e.Error())) } else { sendingSocket.SetResponder(responder) sendingSocket.SetTransport(tp) @@ -279,7 +276,7 @@ func (p *server) doSetup( // 3. resume reject because of duplicated token. if _, ok := p.sm.Load(token); ok { - err = framing.NewFrameError(0, common.ErrorCodeRejectedSetup, errDuplicatedSetupToken) + err = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedSetup, errDuplicatedSetupToken) return } @@ -288,10 +285,10 @@ func (p *server) doSetup( sendingSocket = socket.NewServerResume(rawSocket, token) if responder, e := p.acc(frame, sendingSocket); e != nil { switch vv := e.(type) { - case *framing.FrameError: - err = framing.NewFrameError(0, vv.ErrorCode(), vv.ErrorData()) + case *framing.ErrorFrame: + err = framing.NewErrorFrameSupport(0, vv.ErrorCode(), vv.ErrorData()) default: - err = framing.NewFrameError(0, common.ErrorCodeInvalidSetup, []byte(e.Error())) + err = framing.NewErrorFrameSupport(0, common.ErrorCodeInvalidSetup, []byte(e.Error())) } } else { sendingSocket.SetResponder(responder) @@ -301,19 +298,19 @@ func (p *server) doSetup( return } -func (p *server) doResume(frame *framing.FrameResume, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { - var sending framing.Frame +func (p *server) doResume(frame *framing.ResumeFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { + var sending framing.FrameSupport if !p.resumeOpts.enable { - sending = framing.NewFrameError(0, common.ErrorCodeRejectedResume, errUnavailableResume) + sending = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedResume, errUnavailableResume) } else if s, ok := p.sm.Load(frame.Token()); ok { - sending = framing.NewResumeOK(0) + sending = framing.NewResumeOKFrameSupport(0) s.Socket().SetTransport(tp) socketChan <- s.Socket() if logger.IsDebugEnabled() { logger.Debugf("recover session: %s\n", s) } } else { - sending = framing.NewFrameError( + sending = framing.NewErrorFrameSupport( 0, common.ErrorCodeRejectedResume, []byte("no such session"), From ae38b8940de6530767b8beb7c0e12a03bd6a1170 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 21 Jun 2020 00:02:17 +0800 Subject: [PATCH 05/26] Fix test compile error. --- internal/fragmentation/joiner_test.go | 6 +++--- internal/framing/frame.go | 16 +++++----------- internal/framing/frame_test.go | 2 +- justfile | 8 ++++++++ 4 files changed, 17 insertions(+), 15 deletions(-) create mode 100644 justfile diff --git a/internal/fragmentation/joiner_test.go b/internal/fragmentation/joiner_test.go index 005217f..eac91cc 100644 --- a/internal/fragmentation/joiner_test.go +++ b/internal/fragmentation/joiner_test.go @@ -11,17 +11,17 @@ import ( func TestFragmentPayload(t *testing.T) { const totals = 10 const sid = uint32(1) - fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow, framing.FlagMetadata)) + fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow|framing.FlagMetadata)) for i := 0; i < totals; i++ { data := fmt.Sprintf("(data%04d)", i) var frame *framing.PayloadFrame if i < 3 { meta := fmt.Sprintf("(meta%04d)", i) - frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), framing.FlagFollow, framing.FlagMetadata) + frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), framing.FlagFollow|framing.FlagMetadata) } else if i != totals-1 { frame = framing.NewPayloadFrame(sid, []byte(data), nil, framing.FlagFollow) } else { - frame = framing.NewPayloadFrame(sid, []byte(data), nil) + frame = framing.NewPayloadFrame(sid, []byte(data), nil, 0) } fr.Push(frame) } diff --git a/internal/framing/frame.go b/internal/framing/frame.go index 7079eb0..32b7d95 100644 --- a/internal/framing/frame.go +++ b/internal/framing/frame.go @@ -2,6 +2,7 @@ package framing import ( "errors" + "fmt" "io" "strings" @@ -109,16 +110,8 @@ const ( ) // Check returns true if mask exists. -func (f FrameFlag) Check(mask FrameFlag) bool { - return mask&f == mask -} - -func newFlags(flags ...FrameFlag) FrameFlag { - var fg FrameFlag - for _, it := range flags { - fg |= it - } - return fg +func (f FrameFlag) Check(flag FrameFlag) bool { + return flag&f == flag } type FrameSupport interface { @@ -134,7 +127,8 @@ type FrameSupport interface { } func PrintFrame(f FrameSupport) string { - return "// TODO: print frame" + // TODO: print frame + return fmt.Sprintf("%+v", f) } // Frame is a single message containing a request, response, or protocol processing. diff --git a/internal/framing/frame_test.go b/internal/framing/frame_test.go index f676b1e..0f4fff5 100644 --- a/internal/framing/frame_test.go +++ b/internal/framing/frame_test.go @@ -101,7 +101,7 @@ func TestPayloadFrameSupport(t *testing.T) { assert.NoError(t, err, "write failed") raw := bf.Bytes() bb := common.NewByteBuff() - bb.Write(raw[6:]) + _, _ = bb.Write(raw[6:]) f2, err := FromRawFrame(NewRawFrame(ParseFrameHeader(raw[0:6]), bb)) assert.NoError(t, err, "new frame failed") f3 := f2.(*PayloadFrame) diff --git a/justfile b/justfile new file mode 100644 index 0000000..f3ec1e7 --- /dev/null +++ b/justfile @@ -0,0 +1,8 @@ +default: + echo 'Hello, world!' +lint: + golangci-lint run ./... +test: + go test -race -count=1 . -v +fmt: + @go fmt ./... From 1ec6a14376a289fc5fd4e00987ad593359003e84 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 21 Jun 2020 16:35:42 +0800 Subject: [PATCH 06/26] Add unit tests and fix some docs. --- extension/authentication_test.go | 2 +- internal/common/u32map.go | 6 + internal/fragmentation/splitter.go | 64 +++---- .../fragmentation/splitter_benchmark_test.go | 3 +- internal/fragmentation/splitter_test.go | 15 +- .../{fragmentation.go => types.go} | 0 internal/framing/frame.go | 21 +-- internal/framing/frame_lease.go | 21 +-- internal/framing/frame_payload.go | 32 +++- internal/framing/frame_setup_test.go | 37 ---- internal/framing/frame_test.go | 178 ++++++++---------- internal/framing/header_test.go | 4 +- internal/session/manager.go | 8 +- internal/session/session_test.go | 38 ++-- internal/socket/duplex.go | 133 +++++-------- internal/socket/msg.go | 26 +-- internal/socket/smap_test.go | 4 +- rx/flux/flux_test.go | 145 ++++++++++++-- rx/flux/utils.go | 22 ++- rx/mono/mono.go | 16 ++ rx/mono/utils.go | 8 + rx/rx.go | 2 +- server.go | 12 +- 23 files changed, 430 insertions(+), 367 deletions(-) rename internal/fragmentation/{fragmentation.go => types.go} (100%) delete mode 100644 internal/framing/frame_setup_test.go diff --git a/extension/authentication_test.go b/extension/authentication_test.go index 35151e1..99a77c6 100644 --- a/extension/authentication_test.go +++ b/extension/authentication_test.go @@ -34,9 +34,9 @@ func TestParseAuthentication(t *testing.T) { rand.Seed(time.Now().UnixNano()) input := make([]byte, 2) rand.Read(input) + input[0] &= ^uint8(0x80) _, err := extension.ParseAuthentication(input) assert.True(t, extension.IsInvalidAuthenticationBytes(err), "should error") - } func TestAuthentication(t *testing.T) { diff --git a/internal/common/u32map.go b/internal/common/u32map.go index b35bd75..a4ea297 100644 --- a/internal/common/u32map.go +++ b/internal/common/u32map.go @@ -109,6 +109,12 @@ func (u *u32slot) innerRange(fn func(k uint32, v interface{}) bool) bool { return true } +func NewU32MapLite() U32Map { + return &u32slot{ + m: make(map[uint32]interface{}), + } +} + func NewU32Map() U32Map { var slots [_slots]*u32slot for i := 0; i < len(slots); i++ { diff --git a/internal/fragmentation/splitter.go b/internal/fragmentation/splitter.go index 124736e..b5fc72d 100644 --- a/internal/fragmentation/splitter.go +++ b/internal/fragmentation/splitter.go @@ -5,20 +5,25 @@ import ( "github.com/rsocket/rsocket-go/internal/framing" ) -type splitResult struct { - f framing.FrameFlag - b *common.ByteBuff +// HandleSplitResult is callback for fragmentation result. +type HandleSplitResult = func(index int, result SplitResult) + +// SplitResult defines fragmentation result struct. +type SplitResult struct { + Flag framing.FrameFlag + Metadata []byte + Data []byte } // Split split data and metadata in frame. -func Split(mtu int, data []byte, metadata []byte, onFrame func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { +func Split(mtu int, data []byte, metadata []byte, onFrame HandleSplitResult) { SplitSkip(mtu, 0, data, metadata, onFrame) } // SplitSkip skip some bytes and split data and metadata in frame. -func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { - ch := make(chan splitResult, 3) - go func(mtu int, skip int, data []byte, metadata []byte, ch chan splitResult) { +func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame HandleSplitResult) { + ch := make(chan SplitResult, 3) + go func(mtu int, skip int, data []byte, metadata []byte, ch chan SplitResult) { defer func() { close(ch) }() @@ -28,7 +33,6 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx var follow bool for { bf = common.NewByteBuff() - var wroteM int left := mtu - framing.HeaderLen if idx == 0 && skip > 0 { left -= skip @@ -41,51 +45,35 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx hasMetadata := cursor1 < lenM if hasMetadata { left -= 3 - // write metadata length placeholder - if err := bf.WriteUint24(0); err != nil { - panic(err) - } } begin1, begin2 := cursor1, cursor2 for wrote := 0; wrote < left; wrote++ { if cursor1 < lenM { - wroteM++ cursor1++ } else if cursor2 < lenD { cursor2++ } } - if _, err := bf.Write(metadata[begin1:cursor1]); err != nil { - panic(err) - } - if _, err := bf.Write(data[begin2:cursor2]); err != nil { - panic(err) - } + curMetadata := metadata[begin1:cursor1] + curData := data[begin2:cursor2] follow = cursor1+cursor2 < lenM+lenD - var fg framing.FrameFlag + var flag framing.FrameFlag if follow { - fg |= framing.FlagFollow + flag |= framing.FlagFollow } else { - fg &= ^framing.FlagFollow + flag &= ^framing.FlagFollow } - if wroteM > 0 { - // set metadata length - x := common.MustNewUint24(wroteM) - for i := 0; i < len(x); i++ { - if idx == 0 { - bf.Bytes()[i+skip] = x[i] - } else { - bf.Bytes()[i] = x[i] - } - } - fg |= framing.FlagMetadata + if hasMetadata { + // metadata + flag |= framing.FlagMetadata } else { // non-metadata - fg &= ^framing.FlagMetadata + flag &= ^framing.FlagMetadata } - ch <- splitResult{ - f: fg, - b: bf, + ch <- SplitResult{ + Flag: flag, + Metadata: curMetadata, + Data: curData, } if !follow { break @@ -96,7 +84,7 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame func(idx var idx int for v := range ch { if onFrame != nil { - onFrame(idx, v.f, v.b) + onFrame(idx, v) } idx++ } diff --git a/internal/fragmentation/splitter_benchmark_test.go b/internal/fragmentation/splitter_benchmark_test.go index 8d1d67d..2828d98 100644 --- a/internal/fragmentation/splitter_benchmark_test.go +++ b/internal/fragmentation/splitter_benchmark_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) func BenchmarkToFragments(b *testing.B) { @@ -13,7 +12,7 @@ func BenchmarkToFragments(b *testing.B) { data := []byte(common.RandAlphanumeric(4 * 1024 * 1024)) metadata := []byte(common.RandAlphanumeric(1024 * 1024)) - fn := func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { + fn := func(idx int, result SplitResult) { } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { diff --git a/internal/fragmentation/splitter_test.go b/internal/fragmentation/splitter_test.go index 22f2ff1..7f99319 100644 --- a/internal/fragmentation/splitter_test.go +++ b/internal/fragmentation/splitter_test.go @@ -23,17 +23,14 @@ func TestSplitter_Split(t *testing.T) { } func split2joiner(mtu int, data, metadata []byte) (joiner Joiner, err error) { - fn := func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { + fn := func(idx int, result SplitResult) { + sid := uint32(77778888) if idx == 0 { - h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, framing.FlagComplete|fg) - joiner = NewJoiner(&framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - }) + f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, framing.FlagComplete|result.Flag) + joiner = NewJoiner(f) } else { - h := framing.NewFrameHeader(77778888, framing.FrameTypePayload, fg) - joiner.Push(&framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - }) + f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag) + joiner.Push(f) } } Split(mtu, data, metadata, fn) diff --git a/internal/fragmentation/fragmentation.go b/internal/fragmentation/types.go similarity index 100% rename from internal/fragmentation/fragmentation.go rename to internal/fragmentation/types.go diff --git a/internal/framing/frame.go b/internal/framing/frame.go index 32b7d95..5fb5845 100644 --- a/internal/framing/frame.go +++ b/internal/framing/frame.go @@ -176,6 +176,9 @@ func (f *RawFrame) Body() *common.ByteBuff { // Len returns length of frame. func (f *RawFrame) Len() int { + if f.body == nil { + return HeaderLen + } return HeaderLen + f.body.Len() } @@ -187,22 +190,16 @@ func (f *RawFrame) WriteTo(w io.Writer) (n int64, err error) { return } n += wrote - wrote, err = f.body.WriteTo(w) - if err != nil { - return + if f.body != nil { + wrote, err = f.body.WriteTo(w) + if err != nil { + return + } + n += wrote } - n += wrote return } -// Bytes returns frame in bytes. -func (f *RawFrame) Bytes() []byte { - ret := make([]byte, HeaderLen+f.body.Len()) - copy(ret[:HeaderLen], f.header.Bytes()) - copy(ret[HeaderLen:], f.body.Bytes()) - return ret -} - func (f *RawFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { raw := f.body.Bytes() if offset > 0 { diff --git a/internal/framing/frame_lease.go b/internal/framing/frame_lease.go index 87c66b9..b4b7621 100644 --- a/internal/framing/frame_lease.go +++ b/internal/framing/frame_lease.go @@ -75,29 +75,20 @@ func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } n += int64(v) - if !l.header.Flag().Check(FlagMetadata) { - return - } - - u := common.MustNewUint24(len(l.metadata)) - wrote, err = u.WriteTo(w) - if err != nil { - return + if l.header.Flag().Check(FlagMetadata) { + v, err = w.Write(l.metadata) + if err != nil { + return + } + n += int64(v) } - n += wrote - v, err = w.Write(l.metadata) - if err != nil { - return - } - n += int64(v) return } func (l LeaseFrameSupport) Len() int { n := HeaderLen + 8 if l.header.Flag().Check(FlagMetadata) { - n += 3 n += len(l.metadata) } return n diff --git a/internal/framing/frame_payload.go b/internal/framing/frame_payload.go index 6d7315c..d97c377 100644 --- a/internal/framing/frame_payload.go +++ b/internal/framing/frame_payload.go @@ -11,12 +11,6 @@ type PayloadFrame struct { *RawFrame } -type PayloadFrameSupport struct { - *tinyFrame - metadata []byte - data []byte -} - // Validate returns error if frame is invalid. func (p *PayloadFrame) Validate() (err error) { // Minimal length should be 3 if metadata exists. @@ -58,6 +52,32 @@ func (p *PayloadFrame) DataUTF8() string { return string(p.Data()) } +type PayloadFrameSupport struct { + *tinyFrame + metadata []byte + data []byte +} + +func (p PayloadFrameSupport) DataUTF8() string { + return string(p.data) +} + +func (p PayloadFrameSupport) MetadataUTF8() (metadata string, ok bool) { + if p.header.Flag().Check(FlagMetadata) { + metadata = string(p.metadata) + ok = true + } + return +} + +func (p PayloadFrameSupport) Data() []byte { + return p.data +} + +func (p PayloadFrameSupport) Metadata() ([]byte, bool) { + return p.metadata, p.header.Flag().Check(FlagMetadata) +} + func (p PayloadFrameSupport) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = p.header.WriteTo(w) diff --git a/internal/framing/frame_setup_test.go b/internal/framing/frame_setup_test.go deleted file mode 100644 index d6db7a7..0000000 --- a/internal/framing/frame_setup_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package framing - -import ( - "bytes" - "log" - "testing" - "time" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestDecodeFrameSetup(t *testing.T) { - metadata := []byte("hello") - data := []byte("world") - mimeMetadata, mimeData := []byte("text/plain"), []byte("application/json") - token := []byte(common.RandAlphanumeric(16)) - setup := NewSetupFrame( - common.DefaultVersion, - 30*time.Second, - 90*time.Second, - token, - mimeMetadata, - mimeData, - data, - metadata, - false, - ) - log.Println(setup) - assert.Equal(t, "1.0", setup.Version().String()) - assert.True(t, bytes.Equal(token, setup.Token()), "bad token") - assert.Equal(t, string(mimeMetadata), setup.MetadataMimeType(), "bad mime metadata") - assert.Equal(t, string(mimeData), setup.DataMimeType(), "bad mime data") - m, _ := setup.Metadata() - assert.Equal(t, metadata, m, "bad metadata") - assert.Equal(t, data, setup.Data(), "bad data") -} diff --git a/internal/framing/frame_test.go b/internal/framing/frame_test.go index 0f4fff5..e9ab7f2 100644 --- a/internal/framing/frame_test.go +++ b/internal/framing/frame_test.go @@ -2,9 +2,6 @@ package framing_test import ( "bytes" - "encoding/hex" - "fmt" - "log" "math" "testing" "time" @@ -18,139 +15,145 @@ const _sid uint32 = 1 func TestFrameCancel(t *testing.T) { f := NewCancelFrame(_sid) - basicCheck(t, f, FrameTypeCancel) + checkBasic(t, f, FrameTypeCancel) + f2 := NewCancelFrameSupport(_sid) + checkBytes(t, f, f2) } func TestFrameError(t *testing.T) { errData := []byte(common.RandAlphanumeric(100)) f := NewErrorFrame(_sid, common.ErrorCodeApplicationError, errData) - basicCheck(t, f, FrameTypeError) + checkBasic(t, f, FrameTypeError) assert.Equal(t, common.ErrorCodeApplicationError, f.ErrorCode()) assert.Equal(t, errData, f.ErrorData()) assert.NotEmpty(t, f.Error()) + f2 := NewErrorFrame(_sid, common.ErrorCodeApplicationError, errData) + checkBytes(t, f, f2) } func TestFrameFNF(t *testing.T) { b := []byte(common.RandAlphanumeric(100)) // Without Metadata f := NewFireAndForgetFrame(_sid, b, nil, FlagNext) - basicCheck(t, f, FrameTypeRequestFNF) + checkBasic(t, f, FrameTypeRequestFNF) assert.Equal(t, b, f.Data()) metadata, ok := f.Metadata() assert.False(t, ok) assert.Nil(t, metadata) assert.True(t, f.Header().Flag().Check(FlagNext)) assert.False(t, f.Header().Flag().Check(FlagMetadata)) + f2 := NewFireAndForgetFrameSupport(_sid, b, nil, FlagNext) + checkBytes(t, f, f2) + // With Metadata f = NewFireAndForgetFrame(_sid, nil, b, FlagNext) - basicCheck(t, f, FrameTypeRequestFNF) + checkBasic(t, f, FrameTypeRequestFNF) assert.Empty(t, f.Data()) metadata, ok = f.Metadata() assert.True(t, ok) assert.Equal(t, b, metadata) assert.True(t, f.Header().Flag().Check(FlagNext)) assert.True(t, f.Header().Flag().Check(FlagMetadata)) + f2 = NewFireAndForgetFrameSupport(_sid, nil, b, FlagNext) + checkBytes(t, f, f2) } func TestFrameKeepalive(t *testing.T) { pos := uint64(common.RandIntn(math.MaxInt32)) d := []byte(common.RandAlphanumeric(100)) f := NewKeepaliveFrame(pos, d, true) - basicCheck(t, f, FrameTypeKeepalive) + checkBasic(t, f, FrameTypeKeepalive) assert.Equal(t, d, f.Data()) assert.Equal(t, pos, f.LastReceivedPosition()) assert.True(t, f.Header().Flag().Check(FlagRespond)) + f2 := NewKeepaliveFrameSupport(pos, d, true) + checkBytes(t, f, f2) } func TestFrameLease(t *testing.T) { metadata := []byte("foobar") n := uint32(4444) f := NewLeaseFrame(time.Second, n, metadata) - basicCheck(t, f, FrameTypeLease) + checkBasic(t, f, FrameTypeLease) assert.Equal(t, time.Second, f.TimeToLive()) assert.Equal(t, n, f.NumberOfRequests()) assert.Equal(t, metadata, f.Metadata()) + f2 := NewLeaseFrameSupport(time.Second, n, metadata) + checkBytes(t, f, f2) } func TestFrameMetadataPush(t *testing.T) { metadata := []byte("foobar") f := NewMetadataPushFrame(metadata) - basicCheck(t, f, FrameTypeMetadataPush) + checkBasic(t, f, FrameTypeMetadataPush) metadata2, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, metadata, metadata2) + f2 := NewMetadataPushFrameSupport(metadata) + checkBytes(t, f, f2) } func TestPayloadFrame(t *testing.T) { b := []byte("foobar") f := NewPayloadFrame(_sid, b, b, FlagNext) - basicCheck(t, f, FrameTypePayload) + checkBasic(t, f, FrameTypePayload) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, f.Data()) assert.Equal(t, b, m) assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) -} - -func TestPayloadFrameSupport(t *testing.T) { - b := []byte("foobar") - f := NewPayloadFrameSupport(_sid, b, b, FlagNext) - fmt.Println("len:", f.Len()) - bf := &bytes.Buffer{} - _, err := f.WriteTo(bf) - assert.NoError(t, err, "write failed") - raw := bf.Bytes() - bb := common.NewByteBuff() - _, _ = bb.Write(raw[6:]) - f2, err := FromRawFrame(NewRawFrame(ParseFrameHeader(raw[0:6]), bb)) - assert.NoError(t, err, "new frame failed") - f3 := f2.(*PayloadFrame) - fmt.Println("streamID:", f3.Header().StreamID()) - fmt.Println("data:", f3.DataUTF8()) - fmt.Println("metadata:", f3.MustMetadataUTF8()) - fmt.Println("flags:", f3.Header().Flag()) + f2 := NewPayloadFrameSupport(_sid, b, b, FlagNext) + checkBytes(t, f, f2) } func TestFrameRequestChannel(t *testing.T) { b := []byte("foobar") n := uint32(1) f := NewRequestChannelFrame(_sid, n, b, b, FlagNext) - basicCheck(t, f, FrameTypeRequestChannel) + checkBasic(t, f, FrameTypeRequestChannel) assert.Equal(t, n, f.InitialRequestN()) assert.Equal(t, b, f.Data()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) + f2 := NewRequestChannelFrameSupport(_sid, n, b, b, FlagNext) + checkBytes(t, f, f2) } func TestFrameRequestN(t *testing.T) { n := uint32(1234) f := NewRequestNFrame(_sid, n, 0) - basicCheck(t, f, FrameTypeRequestN) + checkBasic(t, f, FrameTypeRequestN) assert.Equal(t, n, f.N()) + f2 := NewRequestNFrameSupport(_sid, n, 0) + checkBytes(t, f, f2) } func TestFrameRequestResponse(t *testing.T) { b := []byte("foobar") f := NewRequestResponseFrame(_sid, b, b, FlagNext) - basicCheck(t, f, FrameTypeRequestResponse) + checkBasic(t, f, FrameTypeRequestResponse) assert.Equal(t, b, f.Data()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) + f2 := NewRequestResponseFrameSupport(_sid, b, b, FlagNext) + checkBytes(t, f, f2) } func TestFrameRequestStream(t *testing.T) { b := []byte("foobar") n := uint32(1234) f := NewRequestStreamFrame(_sid, n, b, b, FlagNext) - basicCheck(t, f, FrameTypeRequestStream) + checkBasic(t, f, FrameTypeRequestStream) assert.Equal(t, b, f.Data()) assert.Equal(t, n, f.InitialRequestN()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) + f2 := NewRequestStreamFrameSupport(_sid, n, b, b, FlagNext) + checkBytes(t, f, f2) } func TestFrameResume(t *testing.T) { @@ -159,19 +162,23 @@ func TestFrameResume(t *testing.T) { p1 := uint64(333) p2 := uint64(444) f := NewResumeFrame(v, token, p1, p2) - basicCheck(t, f, FrameTypeResume) + checkBasic(t, f, FrameTypeResume) assert.Equal(t, token, f.Token()) assert.Equal(t, p1, f.FirstAvailableClientPosition()) assert.Equal(t, p2, f.LastReceivedServerPosition()) assert.Equal(t, v.Major(), f.Version().Major()) assert.Equal(t, v.Minor(), f.Version().Minor()) + f2 := NewResumeFrameSupport(v, token, p1, p2) + checkBytes(t, f, f2) } func TestFrameResumeOK(t *testing.T) { pos := uint64(1234) f := NewResumeOKFrame(pos) - basicCheck(t, f, FrameTypeResumeOK) + checkBasic(t, f, FrameTypeResumeOK) assert.Equal(t, pos, f.LastReceivedClientPosition()) + f2 := NewResumeOKFrameSupport(pos) + checkBytes(t, f, f2) } func TestFrameSetup(t *testing.T) { @@ -184,73 +191,25 @@ func TestFrameSetup(t *testing.T) { d := []byte("你好") m := []byte("世界") f := NewSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + checkBasic(t, f, FrameTypeSetup) + assert.Equal(t, v.Major(), f.Version().Major()) + assert.Equal(t, v.Minor(), f.Version().Minor()) + assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) + assert.Equal(t, maxLifetime, f.MaxLifetime()) + assert.Equal(t, token, f.Token()) + assert.Equal(t, string(mimeData), f.DataMimeType()) + assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) + assert.Equal(t, d, f.Data()) + m2, ok := f.Metadata() + assert.True(t, ok) + assert.Equal(t, m, m2) - doCheck := func(f *SetupFrame) { - fmt.Println("length:", f.Len()) - basicCheck(t, f, FrameTypeSetup) - assert.Equal(t, v.Major(), f.Version().Major()) - assert.Equal(t, v.Minor(), f.Version().Minor()) - assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) - assert.Equal(t, maxLifetime, f.MaxLifetime()) - assert.Equal(t, token, f.Token()) - assert.Equal(t, string(mimeData), f.DataMimeType()) - assert.Equal(t, string(mimeMetadata), f.MetadataMimeType()) - assert.Equal(t, d, f.Data()) - m2, ok := f.Metadata() - assert.True(t, ok) - assert.Equal(t, m, m2) - } - - doCheck(f) - - su := NewSetupFrameSupport(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) - bf := &bytes.Buffer{} - - _, err := su.WriteTo(bf) - assert.NoError(t, err, "write failed") - - raw := bf.Bytes() - - assert.Equal(t, f.Len(), su.Len(), "wrong length") - - h := ParseFrameHeader(raw[:6]) - - bb := common.NewByteBuff() - _, _ = bb.Write(raw[6:]) - f2, err := FromRawFrame(NewRawFrame(h, bb)) - assert.NoError(t, err, "recreate setup frame failed") - doCheck(f2.(*SetupFrame)) -} - -func TestDecode_Payload(t *testing.T) { - //s := "000000012940000005776f726c6468656c6c6f" // go - //s := "00000001296000000966726f6d5f6a617661706f6e67" //java - - var all []string - all = append(all, "0000000004400001000000004e2000015f90126170706c69636174696f6e2f62696e617279126170706c69636174696f6e2f62696e617279") - all = append(all, "00000000090000000bb800000005") - all = append(all, "00000000090000001b5800000005") - all = append(all, "000000011100000000436c69656e74207265717565737420547565204f63742032322032303a31373a3333204353542032303139") - all = append(all, "00000001286053657276657220526573706f6e736520547565204f63742032322032303a31373a3333204353542032303139") - - for _, s := range all { - bs, err := hex.DecodeString(s) - assert.NoError(t, err, "bad bytes") - h := ParseFrameHeader(bs[:HeaderLen]) - //log.Println(h) - bf := common.NewByteBuff() - _, _ = bf.Write(bs[HeaderLen:]) - f, err := FromRawFrame(NewRawFrame(h, bf)) - assert.NoError(t, err, "decode failed") - log.Println(f) - } + fs := NewSetupFrameSupport(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) - lease := NewLeaseFrame(3*time.Second, 5, nil) - log.Println("actual:", hex.EncodeToString(lease.Bytes())) - log.Println("should: 00000000090000000bb800000005") + checkBytes(t, f, fs) } -func basicCheck(t *testing.T, f Frame, typ FrameType) { +func checkBasic(t *testing.T, f Frame, typ FrameType) { sid := _sid switch typ { case FrameTypeKeepalive, FrameTypeSetup, FrameTypeLease, FrameTypeResume, FrameTypeResumeOK, FrameTypeMetadataPush: @@ -259,4 +218,25 @@ func basicCheck(t *testing.T, f Frame, typ FrameType) { assert.Equal(t, sid, f.Header().StreamID(), "wrong frame stream id") assert.NoError(t, f.Validate(), "validate frame type failed") assert.Equal(t, typ, f.Header().Type(), "frame type doesn't match") + assert.True(t, f.Header().Type().String() != "UNKNOWN") + go func() { + f.Done() + }() + <-f.DoneNotify() +} + +func checkBytes(t *testing.T, a Frame, b FrameSupport) { + assert.Equal(t, a.Len(), b.Len()) + bf1, bf2 := &bytes.Buffer{}, &bytes.Buffer{} + _, err := a.WriteTo(bf1) + assert.NoError(t, err, "write failed") + _, err = b.WriteTo(bf2) + assert.NoError(t, err, "write failed") + b1, b2 := bf1.Bytes(), bf2.Bytes() + assert.Equal(t, b1, b2, "bytes doesn't match") + bf := common.NewByteBuff() + _, _ = bf.Write(b1[HeaderLen:]) + raw := NewRawFrame(ParseFrameHeader(b1[:HeaderLen]), bf) + _, err = FromRawFrame(raw) + assert.NoError(t, err, "create from raw failed") } diff --git a/internal/framing/header_test.go b/internal/framing/header_test.go index 6032965..6cc88f0 100644 --- a/internal/framing/header_test.go +++ b/internal/framing/header_test.go @@ -13,8 +13,10 @@ import ( func TestHeader_All(t *testing.T) { id := uint32(common.RandIntn(math.MaxUint32)) h1 := NewFrameHeader(id, FrameTypePayload, FlagMetadata|FlagComplete|FlagNext) - assert.NotEmpty(t, h1.String()) + assert.NotEmpty(t, h1.String(), "header string is blank") + assert.True(t, h1.Resumable()) h2 := ParseFrameHeader(h1[:]) + assert.Equal(t, h1[:], h2.Bytes()) assert.Equal(t, h1.StreamID(), h2.StreamID()) assert.Equal(t, h1.Type(), h2.Type()) assert.Equal(t, h1.Flag(), h2.Flag()) diff --git a/internal/session/manager.go b/internal/session/manager.go index efcfa6d..74575d5 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -15,8 +15,8 @@ type Manager struct { // Len returns size of session in current manager. func (p *Manager) Len() (n int) { p.locker.RLock() + defer p.locker.RUnlock() n = len(*p.h) - p.locker.RUnlock() return } @@ -31,32 +31,32 @@ func (p *Manager) Push(session *Session) { // Load returns session with custom token. func (p *Manager) Load(token []byte) (session *Session, ok bool) { p.locker.RLock() + defer p.locker.RUnlock() session, ok = p.m[(string)(token)] - p.locker.RUnlock() return } // Remove remove a session with custom token. func (p *Manager) Remove(token []byte) (session *Session, ok bool) { p.locker.Lock() + defer p.locker.Unlock() session, ok = p.m[(string)(token)] if ok && session.index > -1 { heap.Remove(p.h, session.index) delete(p.m, (string)(token)) session.index = -1 } - p.locker.Unlock() return } // Pop pop earliest session. func (p *Manager) Pop() (session *Session) { p.locker.Lock() + defer p.locker.Unlock() session = heap.Pop(p.h).(*Session) if session != nil { delete(p.m, (string)(session.Token())) } - p.locker.Unlock() return } diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 371bcba..300471e 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -1,28 +1,40 @@ -package session +package session_test import ( - "log" + "fmt" "testing" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/internal/session" "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" ) func TestSession(t *testing.T) { - manager := NewManager() - for i := 0; i < 3; i++ { - deadline := time.Now().Add(time.Duration(common.RandIntn(30)) * time.Second) - token := common.RandAlphanumeric(32) - manager.Push(NewSession(deadline, socket.NewServerResume(nil, []byte(token)))) + const total = 100 + var tokens []string + manager := session.NewManager() + for i := 0; i < total; i++ { + deadline := time.Now().Add(time.Duration(i+1) * time.Second) + token := fmt.Sprintf("token_%d", i) + tokens = append(tokens, token) + manager.Push(session.NewSession(deadline, socket.NewServerResume(nil, []byte(token)))) } - for _, value := range *(manager.h) { - session, ok := manager.Load(value.Token()) - log.Printf("session=%s,ok=%t\n", session, ok) + for _, token := range tokens { + s, ok := manager.Load([]byte(token)) + assert.True(t, ok) + assert.NotNil(t, s, "session is nil") } - for manager.Len() > 0 { - log.Println("session:", manager.Pop()) + firstToken := []byte(tokens[0]) + s, ok := manager.Remove(firstToken) + assert.True(t, ok) + assert.NotNil(t, s) + assert.Equal(t, firstToken, s.Token()) + assert.Equal(t, len(tokens)-1, manager.Len()) + for i := 1; i < len(tokens); i++ { + manager.Pop() } + assert.Equal(t, 0, manager.Len()) } diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 6a0873f..c08484e 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -2,7 +2,6 @@ package socket import ( "context" - "encoding/binary" "errors" "fmt" "io" @@ -108,7 +107,7 @@ func (p *DuplexRSocket) Close() error { p.fragments.Clear() p.messages.Range(func(key uint32, value interface{}) bool { - if cc, ok := value.(closerWithError); ok { + if cc, ok := value.(callback); ok { if p.e == nil { go func() { cc.Close(errSocketClosed) @@ -138,18 +137,12 @@ func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { p.sendFrame(framing.NewFireAndForgetFrameSupport(sid, data, m, 0)) return } - p.doSplit(data, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestFNF, fg) - f = &framing.FireAndForgetFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + p.doSplit(data, m, func(index int, result fragmentation.SplitResult) { + var f framing.FrameSupport + if index == 0 { + f = framing.NewFireAndForgetFrameSupport(sid, result.Data, result.Metadata, result.Flag) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) } p.sendFrame(f) }) @@ -166,7 +159,7 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { sid := p.nextStreamID() resp := mono.CreateProcessor() - p.register(sid, reqRR{pc: resp}) + p.register(sid, requestResponseCallback{pc: resp}) data := pl.Data() metadata, _ := pl.Metadata() @@ -183,21 +176,14 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { size := framing.CalcPayloadFrameSize(data, metadata) if !p.shouldSplit(size) { p.sendFrame(framing.NewRequestResponseFrameSupport(sid, data, metadata, 0)) - //p.sendFrame(framing.NewRequestResponseFrame(sid, data, metadata, 0)) return } - p.doSplit(data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestResponse, fg) - f = &framing.RequestResponseFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + p.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { + var f framing.FrameSupport + if index == 0 { + f = framing.NewRequestResponseFrameSupport(sid, result.Data, result.Metadata, result.Flag) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) } p.sendFrame(f) }) @@ -210,7 +196,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { sid := p.nextStreamID() pc := flux.CreateProcessor() - p.register(sid, reqRS{pc: pc}) + p.register(sid, requestStreamCallback{pc: pc}) requested := make(chan struct{}) @@ -247,20 +233,13 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { p.sendFrame(framing.NewRequestStreamFrameSupport(sid, n32, data, metadata, 0)) return } - p.doSplitSkip(4, data, metadata, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestStream, fg) - // write init RequestN - binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.RequestStreamFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + + p.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { + var f framing.FrameSupport + if index == 0 { + f = framing.NewRequestStreamFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) } p.sendFrame(f) }) @@ -320,26 +299,19 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { p.sendFrame(framing.NewRequestChannelFrameSupport(sid, n32, item.Data(), metadata, framing.FlagNext)) return } - p.doSplitSkip(4, d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var f framing.Frame - if idx == 0 { - h := framing.NewFrameHeader(sid, framing.FrameTypeRequestChannel, fg|framing.FlagNext) - // write init RequestN - binary.BigEndian.PutUint32(body.Bytes(), n32) - f = &framing.RequestChannelFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + + p.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { + var f framing.FrameSupport + if index == 0 { + f = framing.NewRequestChannelFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag|framing.FlagNext) } else { - h := framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) - f = &framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - } + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) } p.sendFrame(f) }) }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, reqRC{rcv: receiving, snd: s}) + p.register(sid, requestChannelCallback{rcv: receiving, snd: s}) s.Request(1) }), ) @@ -401,7 +373,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd p.writeError(sid, e) }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRR{su: s}) + p.register(sid, requestResponseCallbackReverse{su: s}) s.Request(rx.RequestMax) }), ) @@ -492,7 +464,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) <-complete.DoneNotify() }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRC{rcv: receivingProcessor, snd: s}) + p.register(sid, requestChannelCallbackReverse{rcv: receivingProcessor, snd: s}) close(mustSub) s.Request(initRequestN) }), @@ -594,7 +566,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa p.sendPayload(sid, elem, framing.FlagNext) }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, resRS{su: s}) + p.register(sid, requestStreamCallbackReverse{su: s}) s.Request(n32) }), rx.OnError(func(e error) { @@ -655,9 +627,9 @@ func (p *DuplexRSocket) onFrameCancel(frame framing.Frame) (err error) { } switch vv := v.(type) { - case resRR: + case requestResponseCallbackReverse: vv.su.Cancel() - case resRS: + case requestStreamCallbackReverse: vv.su.Cancel() default: panic(fmt.Errorf("illegal cancel target: %v", vv)) @@ -681,11 +653,11 @@ func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { } switch vv := v.(type) { - case reqRR: + case requestResponseCallback: vv.pc.Error(f) - case reqRS: + case requestStreamCallback: vv.pc.Error(f) - case reqRC: + case requestChannelCallback: vv.rcv.Error(f) default: panic(fmt.Errorf("illegal value for error: %v", vv)) @@ -705,11 +677,11 @@ func (p *DuplexRSocket) onFrameRequestN(input framing.Frame) (err error) { } n := toIntN(f.N()) switch vv := v.(type) { - case resRS: + case requestStreamCallbackReverse: vv.su.Request(n) - case reqRC: + case requestChannelCallback: vv.snd.Request(n) - case resRC: + case requestChannelCallbackReverse: vv.snd.Request(n) default: panic(fmt.Errorf("illegal requestN for %+v", vv)) @@ -767,9 +739,9 @@ func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { } switch vv := v.(type) { - case reqRR: + case requestResponseCallback: vv.pc.Success(pl) - case reqRS: + case requestStreamCallback: fg := h.Flag() isNext := fg.Check(framing.FlagNext) if isNext { @@ -779,7 +751,7 @@ func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { // Release pure complete payload vv.pc.Complete() } - case reqRC: + case requestChannelCallback: fg := h.Flag() isNext := fg.Check(framing.FlagNext) if isNext { @@ -788,7 +760,7 @@ func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { if fg.Check(framing.FlagComplete) { vv.rcv.Complete() } - case resRC: + case requestChannelCallbackReverse: fg := h.Flag() isNext := fg.Check(framing.FlagNext) if isNext { @@ -851,19 +823,16 @@ func (p *DuplexRSocket) sendPayload( if !p.shouldSplit(size) { p.sendFrame(framing.NewPayloadFrameSupport(sid, d, m, frameFlag)) - //p.sendFrame(framing.NewPayloadFrame(sid, d, m, frameFlag)) return } - p.doSplit(d, m, func(idx int, fg framing.FrameFlag, body *common.ByteBuff) { - var h framing.Header - if idx == 0 { - h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|frameFlag) + p.doSplit(d, m, func(index int, result fragmentation.SplitResult) { + flag := result.Flag + if index == 0 { + flag |= frameFlag } else { - h = framing.NewFrameHeader(sid, framing.FrameTypePayload, fg|framing.FlagNext) + flag |= framing.FlagNext } - p.sendFrame(&framing.PayloadFrame{ - RawFrame: framing.NewRawFrame(h, body), - }) + p.sendFrame(framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, flag)) }) } @@ -1091,11 +1060,11 @@ func (p *DuplexRSocket) loopWrite(ctx context.Context) error { return nil } -func (p *DuplexRSocket) doSplit(data, metadata []byte, handler func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { +func (p *DuplexRSocket) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { fragmentation.Split(p.mtu, data, metadata, handler) } -func (p *DuplexRSocket) doSplitSkip(skip int, data, metadata []byte, handler func(idx int, fg framing.FrameFlag, body *common.ByteBuff)) { +func (p *DuplexRSocket) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { fragmentation.SplitSkip(p.mtu, skip, data, metadata, handler) } @@ -1121,7 +1090,7 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { mtu: mtu, messages: common.NewU32Map(), sids: &serverStreamIDs{}, - fragments: common.NewU32Map(), + fragments: common.NewU32MapLite(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: transport.NewCounter(), @@ -1141,7 +1110,7 @@ func NewClientDuplexRSocket( mtu: mtu, messages: common.NewU32Map(), sids: &clientStreamIDs{}, - fragments: common.NewU32Map(), + fragments: common.NewU32MapLite(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: transport.NewCounter(), diff --git a/internal/socket/msg.go b/internal/socket/msg.go index d088e97..6277cd2 100644 --- a/internal/socket/msg.go +++ b/internal/socket/msg.go @@ -7,60 +7,60 @@ import ( "github.com/rsocket/rsocket-go/rx/mono" ) -type closerWithError interface { +type callback interface { Close(error) } -type reqRS struct { +type requestStreamCallback struct { pc flux.Processor } -func (s reqRS) Close(err error) { +func (s requestStreamCallback) Close(err error) { s.pc.Error(err) } -type reqRR struct { +type requestResponseCallback struct { pc mono.Processor } -func (s reqRR) Close(err error) { +func (s requestResponseCallback) Close(err error) { s.pc.Error(err) } -type reqRC struct { +type requestChannelCallback struct { snd rx.Subscription rcv flux.Processor } -func (s reqRC) Close(err error) { +func (s requestChannelCallback) Close(err error) { s.snd.Cancel() s.rcv.Error(err) } -type resRR struct { +type requestResponseCallbackReverse struct { su rs.Subscription } -func (s resRR) Close(err error) { +func (s requestResponseCallbackReverse) Close(err error) { s.su.Cancel() // TODO: fill err } -type resRS struct { +type requestStreamCallbackReverse struct { su rx.Subscription } -func (s resRS) Close(err error) { +func (s requestStreamCallbackReverse) Close(err error) { s.su.Cancel() // TODO: fill error } -type resRC struct { +type requestChannelCallbackReverse struct { snd rx.Subscription rcv flux.Processor } -func (s resRC) Close(err error) { +func (s requestChannelCallbackReverse) Close(err error) { s.rcv.Error(err) s.snd.Cancel() } diff --git a/internal/socket/smap_test.go b/internal/socket/smap_test.go index 6a698b0..3160c79 100644 --- a/internal/socket/smap_test.go +++ b/internal/socket/smap_test.go @@ -14,7 +14,7 @@ func nextID() uint32 { } func BenchmarkLock(b *testing.B) { - var v reqRC + var v requestChannelCallback m := make(map[uint32]interface{}) var lk sync.RWMutex b.ResetTimer() @@ -40,7 +40,7 @@ func BenchmarkLock(b *testing.B) { } func BenchmarkSync(b *testing.B) { - var v reqRC + var v requestChannelCallback m := &sync.Map{} b.ResetTimer() b.RunParallel(func(pb *testing.PB) { diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 37bfe4c..8921ed7 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -4,22 +4,117 @@ import ( "context" "errors" "fmt" - "log" "strconv" "testing" "time" + nativeFlux "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func TestEmpty(t *testing.T) { + last, err := flux.Empty(). + DoOnNext(func(input payload.Payload) { + assert.FailNow(t, "unreachable") + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Nil(t, last) + first, err := flux.Empty().BlockFirst(context.Background()) + assert.NoError(t, err) + assert.Nil(t, first) +} + +func TestError(t *testing.T) { + err := errors.New("boom") + _, _ = flux.Error(err). + DoOnNext(func(input payload.Payload) { + assert.FailNow(t, "unreachable") + }). + DoOnError(func(e error) { + assert.Equal(t, err, e) + }). + BlockLast(context.Background()) +} + +func TestClone(t *testing.T) { + const total = 10 + source := flux.Create(func(ctx context.Context, s flux.Sink) { + for i := 0; i < total; i++ { + s.Next(payload.NewString(fmt.Sprintf("data_%d", i), "")) + } + s.Complete() + }) + clone := flux.Clone(source) + + c := atomic.NewInt32(0) + last, err := clone. + DoOnNext(func(input payload.Payload) { + c.Inc() + }). + DoOnError(func(e error) { + assert.FailNow(t, "unreachable") + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("data_%d", total-1), last.DataUTF8()) + assert.Equal(t, int32(total), c.Load()) +} + +func TestRaw(t *testing.T) { + const total = 10 + c := atomic.NewInt32(0) + f := flux. + Raw(nativeFlux.Range(0, total).Map(func(v interface{}) interface{} { + return payload.NewString(fmt.Sprintf("data_%d", v.(int)), "") + })) + last, err := f. + DoOnNext(func(input payload.Payload) { + c.Inc() + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(total), c.Load()) + assert.Equal(t, fmt.Sprintf("data_%d", total-1), last.DataUTF8()) + + c.Store(0) + const take = 3 + last, err = f.Take(take). + DoOnNext(func(input payload.Payload) { + c.Inc() + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "data_2", last.DataUTF8()) + assert.Equal(t, int32(take), c.Load()) +} + func TestJust(t *testing.T) { - done := make(chan struct{}) + c := atomic.NewInt32(0) + last, err := flux. + Just( + payload.NewString("foo", ""), + payload.NewString("bar", ""), + payload.NewString("qux", ""), + ). + DoOnNext(func(input payload.Payload) { + c.Inc() + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int32(3), c.Load()) + assert.Equal(t, "qux", last.DataUTF8()) +} + +func TestCreate(t *testing.T) { + const total = 10 f := flux.Create(func(i context.Context, sink flux.Sink) { - for i := 0; i < 10; i++ { + for i := 0; i < total; i++ { sink.Next(payload.NewString(fmt.Sprintf("foo_%04d", i), fmt.Sprintf("bar_%04d", i))) } sink.Complete() @@ -27,43 +122,59 @@ func TestJust(t *testing.T) { var su rx.Subscription + done := make(chan struct{}) + nextRequests := atomic.NewInt32(0) + f. DoOnNext(func(input payload.Payload) { - log.Println("next:", input) + fmt.Println("next:", input) su.Request(1) }). DoOnRequest(func(n int) { - log.Println("request:", n) + fmt.Println("request:", n) + nextRequests.Add(int32(n)) }). DoFinally(func(s rx.SignalType) { - log.Println("finally") + fmt.Println("finally") close(done) }). DoOnComplete(func() { - log.Println("complete") + fmt.Println("complete") }). Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) })) <-done + assert.Equal(t, int32(total+1), nextRequests.Load()) +} + +func TestMap(t *testing.T) { + last, err := flux. + Just(payload.NewString("hello", "")). + Map(func(p payload.Payload) payload.Payload { + return payload.NewString(p.DataUTF8()+" world", "") + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "hello world", last.DataUTF8()) } func TestProcessor(t *testing.T) { - proc := flux.CreateProcessor() + processor := flux.CreateProcessor() time.AfterFunc(1*time.Second, func() { - proc.Next(payload.NewString("111", "")) + processor.Next(payload.NewString("111", "")) }) time.AfterFunc(2*time.Second, func() { - proc.Next(payload.NewString("222", "")) - proc.Complete() + processor.Next(payload.NewString("222", "")) + processor.Complete() }) done := make(chan struct{}) - proc. + processor. DoOnNext(func(input payload.Payload) { - log.Println("next:", input) + fmt.Println("next:", input) }). DoFinally(func(s rx.SignalType) { close(done) @@ -106,16 +217,16 @@ func TestFluxRequest(t *testing.T) { sub := rx.NewSubscriber( rx.OnNext(func(input payload.Payload) { - log.Println("onNext:", input) + fmt.Println("onNext:", input) su.Request(1) }), rx.OnComplete(func() { - log.Println("complete") + fmt.Println("complete") }), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) - log.Println("request:", 1) + fmt.Println("request:", 1) }), ) @@ -131,7 +242,7 @@ func TestProxy_BlockLast(t *testing.T) { s.Complete() }).BlockLast(context.Background()) assert.NoError(t, err, "err occurred") - log.Println(last) + fmt.Println(last) } func TestFluxProcessorWithRequest(t *testing.T) { diff --git a/rx/flux/utils.go b/rx/flux/utils.go index a0ea099..4ad7333 100644 --- a/rx/flux/utils.go +++ b/rx/flux/utils.go @@ -53,20 +53,24 @@ func Create(gen func(ctx context.Context, s Sink)) Flux { // CreateProcessor creates a new Processor. func CreateProcessor() Processor { - proc := flux.NewUnicastProcessor() - return newProxy(proc) + p := flux.NewUnicastProcessor() + return newProxy(p) } // Clone clones a Publisher to a Flux. func Clone(source rx.Publisher) Flux { return Create(func(ctx context.Context, s Sink) { - source.Subscribe(ctx, rx.OnNext(func(input payload.Payload) { - s.Next(input) - }), rx.OnComplete(func() { - s.Complete() - }), rx.OnError(func(e error) { - s.Error(e) - })) + source.Subscribe(ctx, + rx.OnNext(func(input payload.Payload) { + s.Next(input) + }), + rx.OnComplete(func() { + s.Complete() + }), + rx.OnError(func(e error) { + s.Error(e) + }), + ) }) } diff --git a/rx/mono/mono.go b/rx/mono/mono.go index 6c2860c..b4e74f4 100644 --- a/rx/mono/mono.go +++ b/rx/mono/mono.go @@ -9,28 +9,44 @@ import ( "github.com/rsocket/rsocket-go/rx" ) +// Mono is a Reactive Streams Publisher with basic rx operators that completes successfully by emitting an element, or with an error. type Mono interface { rx.Publisher + // Filter evaluate each source value against the given Predicate. + // If the predicate test succeeds, the value is emitted. Filter(rx.FnPredicate) Mono + // DoFinally adds behavior (side-effect) triggered after the Mono terminates for any reason, including cancellation. DoFinally(rx.FnFinally) Mono + // DoOnError adds behavior (side-effect) triggered when the Mono completes with an error. DoOnError(rx.FnOnError) Mono + // DoOnSuccess adds behavior (side-effect) triggered when the Mono completes with an success. DoOnSuccess(rx.FnOnNext) Mono + // DoOnCancel add behavior (side-effect) triggered when the Mono is cancelled. DoOnCancel(rx.FnOnCancel) Mono + // DoOnSubscribe add behavior (side-effect) triggered when the Mono is done being subscribed. DoOnSubscribe(rx.FnOnSubscribe) Mono + // SubscribeOn customize a Scheduler running Subscribe, OnSubscribe and Request. SubscribeOn(scheduler.Scheduler) Mono + // Block blocks Mono and returns data and error. Block(context.Context) (payload.Payload, error) + //SwitchIfEmpty switch to an alternative Publisher if this Mono is completed without any data. SwitchIfEmpty(alternative Mono) Mono + // Raw returns low-level Mono which defined in upstream reactor library. Raw() mono.Mono // ToChan subscribe Mono and puts items into a chan. // It also puts errors into another chan. ToChan(ctx context.Context) (c <-chan payload.Payload, e <-chan error) } +// Sink is a wrapper API around an actual downstream Subscriber for emitting nothing, a single value or an error (mutually exclusive). type Sink interface { + // Success emits a single value then complete current Sink. Success(payload.Payload) + // Error emits an error then complete current Sink. Error(error) } +// Processor combine Sink and Mono. type Processor interface { Sink Mono diff --git a/rx/mono/utils.go b/rx/mono/utils.go index ec571eb..2b87669 100644 --- a/rx/mono/utils.go +++ b/rx/mono/utils.go @@ -9,32 +9,40 @@ import ( var empty = newProxy(mono.Empty()) +// Raw wrap a low-level Mono. func Raw(input mono.Mono) Mono { return newProxy(input) } +// Just wrap an exist Payload to a Mono. func Just(input payload.Payload) Mono { return newProxy(mono.Just(input)) } +// JustOrEmpty wrap an exist Payload to Mono. +// Payload could be nil here. func JustOrEmpty(input payload.Payload) Mono { return newProxy(mono.JustOrEmpty(input)) } +// Empty returns an empty Mono. func Empty() Mono { return empty } +// Error wrap an error to a Mono. func Error(err error) Mono { return newProxy(mono.Error(err)) } +// Create wrap a generator function to a Mono. func Create(gen func(context.Context, Sink)) Mono { return newProxy(mono.Create(func(i context.Context, sink mono.Sink) { gen(i, sinkProxy{sink}) })) } +// CreateProcessor creates a Processor. func CreateProcessor() Processor { return newProxy(mono.CreateProcessor()) } diff --git a/rx/rx.go b/rx/rx.go index 6ed6c57..7173749 100644 --- a/rx/rx.go +++ b/rx/rx.go @@ -49,7 +49,7 @@ type RawPublisher interface { type Publisher interface { RawPublisher // Subscribe subscribe elements from a publisher, returns a Disposable. - // You can add some custome options. + // You can add some custom options. // Using `OnSubscribe`, `OnNext`, `OnComplete` and `OnError` as handler wrapper. Subscribe(ctx context.Context, options ...SubscriberOption) } diff --git a/server.go b/server.go index 3d06fb7..dc18036 100644 --- a/server.go +++ b/server.go @@ -355,16 +355,16 @@ func (p *server) destroySessions() { } func (p *server) doCleanSession() { - deads := make(chan *session.Session) - go func(deads chan *session.Session) { - for it := range deads { + deadSessions := make(chan *session.Session) + go func(deadSessions chan *session.Session) { + for it := range deadSessions { if err := it.Close(); err != nil { logger.Warnf("close dead session failed: %s\n", err) } else if logger.IsDebugEnabled() { logger.Debugf("close dead session success: %s\n", it) } } - }(deads) + }(deadSessions) var cur *session.Session for p.sm.Len() > 0 { cur = p.sm.Pop() @@ -373,9 +373,9 @@ func (p *server) doCleanSession() { p.sm.Push(cur) break } - deads <- cur + deadSessions <- cur } - close(deads) + close(deadSessions) } // WithServerResumeSessionDuration sets resume session duration for RSocket server. From 4e7b71a9db084d74d873d91db2dc299488e5e302 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 21 Jun 2020 22:26:14 +0800 Subject: [PATCH 07/26] change travis tests. --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 57a5dde..7ac7c9f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,5 @@ install: script: - golangci-lint run ./... - - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 . + - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 ./rx/... ./internal/... ./extension/... ./payload/... . - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN From 47e02405d6bd4def5add2981213eb85df7623045 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Thu, 25 Jun 2020 21:41:53 +0800 Subject: [PATCH 08/26] Optimize round-robin balancer and adjust project struct. --- balancer/balancer.go | 5 +- balancer/group_test.go | 14 +- balancer/round_robin.go | 75 ++++-- balancer/round_robin_test.go | 152 +++++++++--- client.go | 225 ++++++++---------- cmd/rsocket-cli/runner.go | 46 +++- .../transport => cmd/rsocket-cli}/uri.go | 25 +- .../transport => cmd/rsocket-cli}/uri_test.go | 2 +- {internal/transport => core}/counter.go | 6 +- core/counter_test.go | 29 +++ {internal/common => core}/errors.go | 20 +- core/errors_test.go | 28 +++ core/framing/frame.go | 130 ++++++++++ {internal => core}/framing/frame_cancel.go | 8 +- {internal => core}/framing/frame_error.go | 19 +- {internal => core}/framing/frame_fnf.go | 15 +- {internal => core}/framing/frame_keepalive.go | 15 +- {internal => core}/framing/frame_lease.go | 21 +- .../framing/frame_metadata_push.go | 5 +- {internal => core}/framing/frame_payload.go | 19 +- .../framing/frame_request_channel.go | 15 +- {internal => core}/framing/frame_request_n.go | 11 +- .../framing/frame_request_response.go | 15 +- .../framing/frame_request_stream.go | 15 +- {internal => core}/framing/frame_resume.go | 15 +- {internal => core}/framing/frame_resume_ok.go | 7 +- {internal => core}/framing/frame_setup.go | 39 +-- {internal => core}/framing/frame_test.go | 91 +++---- {internal => core}/framing/misc.go | 35 +-- {internal/framing => core}/header.go | 42 ++-- {internal/framing => core}/header_test.go | 6 +- .../transport/connection_tcp.go | 19 +- {internal => core}/transport/connection_ws.go | 17 +- {internal => core}/transport/decoder.go | 10 +- {internal => core}/transport/decoder_test.go | 7 +- {internal => core}/transport/misc.go | 4 + {internal => core}/transport/transport.go | 53 +++-- {internal => core}/transport/transport_tcp.go | 71 +++--- {internal => core}/transport/transport_ws.go | 37 +-- internal/framing/frame.go => core/types.go | 143 ++--------- {internal/common => core}/version.go | 20 +- {internal/common => core}/version_test.go | 40 +++- examples/echo/echo.go | 11 +- .../echo_bench.go} | 54 +++-- examples/fibonacci/main.go | 11 +- examples/lease/main.go | 32 --- examples/lease/main_test.go | 48 ---- examples/word_counter/main.go | 11 +- fuzz.go | 13 +- internal/common/errors_test.go | 28 --- internal/fragmentation/joiner.go | 12 +- internal/fragmentation/joiner_test.go | 9 +- internal/fragmentation/splitter.go | 16 +- internal/fragmentation/splitter_test.go | 5 +- internal/fragmentation/types.go | 12 +- internal/socket/{msg.go => callback.go} | 0 internal/socket/client_default.go | 22 +- internal/socket/client_resume.go | 31 +-- internal/socket/duplex.go | 127 +++++----- internal/socket/keepaliver.go | 33 ++- internal/socket/keepaliver_test.go | 33 +++ internal/socket/misc.go | 8 +- internal/socket/server_default.go | 2 +- internal/socket/server_resume.go | 2 +- internal/socket/smap_test.go | 60 ----- internal/socket/socket.go | 48 ---- internal/socket/stream_id.go | 8 +- internal/socket/types.go | 57 +++++ internal/transport/connection.go | 24 -- lease/lease_test.go | 81 +++++++ payload/payload.go | 4 +- payload/payload_raw.go | 36 --- payload/payload_str.go | 14 -- payload/payload_test.go | 6 +- rsocket.go | 26 +- rsocket_example_test.go | 11 +- rsocket_test.go | 17 +- rx/flux/proxy.go | 4 +- server.go | 91 +++---- transporter.go | 181 ++++++++++++++ transporter_test.go | 30 +++ 81 files changed, 1569 insertions(+), 1220 deletions(-) rename {internal/transport => cmd/rsocket-cli}/uri.go (68%) rename {internal/transport => cmd/rsocket-cli}/uri_test.go (96%) rename {internal/transport => core}/counter.go (85%) create mode 100644 core/counter_test.go rename {internal/common => core}/errors.go (99%) create mode 100644 core/errors_test.go create mode 100644 core/framing/frame.go rename {internal => core}/framing/frame_cancel.go (81%) rename {internal => core}/framing/frame_error.go (75%) rename {internal => core}/framing/frame_fnf.go (83%) rename {internal => core}/framing/frame_keepalive.go (85%) rename {internal => core}/framing/frame_lease.go (83%) rename {internal => core}/framing/frame_metadata_push.go (90%) rename {internal => core}/framing/frame_payload.go (81%) rename {internal => core}/framing/frame_request_channel.go (85%) rename {internal => core}/framing/frame_request_n.go (74%) rename {internal => core}/framing/frame_request_response.go (83%) rename {internal => core}/framing/frame_request_stream.go (85%) rename {internal => core}/framing/frame_resume.go (83%) rename {internal => core}/framing/frame_resume_ok.go (88%) rename {internal => core}/framing/frame_setup.go (89%) rename {internal => core}/framing/frame_test.go (67%) rename {internal => core}/framing/misc.go (69%) rename {internal/framing => core}/header.go (63%) rename {internal/framing => core}/header_test.go (87%) rename {internal => core}/transport/connection_tcp.go (81%) rename {internal => core}/transport/connection_ws.go (82%) rename {internal => core}/transport/decoder.go (86%) rename {internal => core}/transport/decoder_test.go (76%) rename {internal => core}/transport/misc.go (64%) rename {internal => core}/transport/transport.go (87%) rename {internal => core}/transport/transport_tcp.go (63%) rename {internal => core}/transport/transport_ws.go (80%) rename internal/framing/frame.go => core/types.go (53%) rename {internal/common => core}/version.go (75%) rename {internal/common => core}/version_test.go (51%) rename examples/{echo/echo_benchmark_test.go => echo_bench/echo_bench.go} (60%) delete mode 100644 examples/lease/main.go delete mode 100644 examples/lease/main_test.go delete mode 100644 internal/common/errors_test.go rename internal/socket/{msg.go => callback.go} (100%) create mode 100644 internal/socket/keepaliver_test.go delete mode 100644 internal/socket/smap_test.go create mode 100644 internal/socket/types.go delete mode 100644 internal/transport/connection.go create mode 100644 lease/lease_test.go create mode 100644 transporter.go create mode 100644 transporter_test.go diff --git a/balancer/balancer.go b/balancer/balancer.go index ba20b6c..fcc7692 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -2,6 +2,7 @@ package balancer import ( + "context" "io" "github.com/rsocket/rsocket-go" @@ -15,7 +16,9 @@ type Balancer interface { // PutLabel puts a new client with a label. PutLabel(label string, client rsocket.Client) // Next returns next balanced RSocket client. - Next() rsocket.Client + Next(context.Context) (rsocket.Client, bool) + // MustNext returns next balanced RSocket client. + MustNext(context.Context) rsocket.Client // OnLeave handle events when a client exit. OnLeave(fn func(label string)) } diff --git a/balancer/group_test.go b/balancer/group_test.go index 9099900..9ec4cca 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -16,7 +16,11 @@ import ( "github.com/stretchr/testify/require" ) -const uri = "tcp://127.0.0.1:7878" +var tp Transporter + +func init() { + tp = Tcp().HostAndPort("127.0.0.1", 7878).Build() +} func ExampleNewGroup() { group := NewGroup(func() Balancer { @@ -40,10 +44,10 @@ func ExampleNewGroup() { panic(errors.New("missing service ID in metadata")) } log.Println("[broker] redirect request to service", requestServiceID) - return group.Get(requestServiceID).Next().RequestResponse(msg) + return group.Get(requestServiceID).MustNext(context.Background()).RequestResponse(msg) })), nil }). - Transport(uri). + Transport(tp). Serve(context.Background()) if err != nil { panic(err) @@ -72,7 +76,7 @@ func TestServiceSubscribe(t *testing.T) { return mono.Just(result) })) }). - Transport(uri). + Transport(tp). Start(context.Background()) if err != nil { panic(err) @@ -86,7 +90,7 @@ func TestServiceSubscribe(t *testing.T) { // Create a client and request md5 service. cli, err := Connect(). SetupPayload(payload.NewString("This is a Subscriber", "")). - Transport(uri). + Transport(tp). Start(context.Background()) require.NoError(t, err, "create client failed") defer func() { diff --git a/balancer/round_robin.go b/balancer/round_robin.go index fd46193..f098755 100644 --- a/balancer/round_robin.go +++ b/balancer/round_robin.go @@ -1,11 +1,13 @@ package balancer import ( + "context" "sync" "github.com/google/uuid" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/logger" + "go.uber.org/atomic" ) type labelClient struct { @@ -14,12 +16,13 @@ type labelClient struct { } type balancerRoundRobin struct { - cond *sync.Cond - seq int + seq *atomic.Uint32 + mutex sync.RWMutex clients []*labelClient done chan struct{} once sync.Once onLeave []func(string) + cond *sync.Cond } func (p *balancerRoundRobin) OnLeave(fn func(label string)) { @@ -34,7 +37,8 @@ func (p *balancerRoundRobin) Put(client rsocket.Client) { } func (p *balancerRoundRobin) PutLabel(label string, client rsocket.Client) { - p.cond.L.Lock() + p.mutex.Lock() + defer p.mutex.Unlock() p.clients = append(p.clients, &labelClient{ l: label, c: client, @@ -45,29 +49,50 @@ func (p *balancerRoundRobin) PutLabel(label string, client rsocket.Client) { if len(p.clients) == 1 { p.cond.Broadcast() } - p.cond.L.Unlock() } -func (p *balancerRoundRobin) Next() (c rsocket.Client) { - p.cond.L.Lock() - for len(p.clients) < 1 { - select { - case <-p.done: - goto L - default: - p.cond.Wait() - } +func (p *balancerRoundRobin) MustNext(ctx context.Context) rsocket.Client { + c, ok := p.Next(ctx) + if !ok { + panic("cannot get next client from current balancer") } - c = p.choose() -L: - p.cond.L.Unlock() - return + return c } -func (p *balancerRoundRobin) choose() (cli rsocket.Client) { - p.seq = (p.seq + 1) % len(p.clients) - cli = p.clients[p.seq].c - return +func (p *balancerRoundRobin) Next(ctx context.Context) (rsocket.Client, bool) { + p.mutex.RLock() + defer p.mutex.RUnlock() + if n := len(p.clients); n > 0 { + idx := int(p.seq.Inc() % uint32(n)) + return p.clients[idx].c, true + } + + ch := make(chan rsocket.Client, 1) + closed := atomic.NewBool(false) + + go func() { + p.cond.L.Lock() + for len(p.clients) < 1 && !closed.Load() { + p.cond.Wait() + } + p.cond.L.Unlock() + if n := len(p.clients); n > 0 { + idx := int(p.seq.Inc() % uint32(n)) + ch <- p.clients[idx].c + } + }() + + select { + case <-ctx.Done(): + closed.Store(true) + p.cond.Broadcast() + return nil, false + case c, ok := <-ch: + if !ok { + return nil, false + } + return c, true + } } func (p *balancerRoundRobin) Close() (err error) { @@ -93,7 +118,7 @@ func (p *balancerRoundRobin) Close() (err error) { } func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok bool) { - p.cond.L.Lock() + p.mutex.Lock() j := -1 for i, l := 0, len(p.clients); i < l; i++ { if p.clients[i].c == client { @@ -106,7 +131,7 @@ func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok boo label = p.clients[j].l p.clients = append(p.clients[:j], p.clients[j+1:]...) } - p.cond.L.Unlock() + p.mutex.Unlock() if ok && len(p.onLeave) > 0 { go func(label string) { for _, fn := range p.onLeave { @@ -120,8 +145,8 @@ func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok boo // NewRoundRobinBalancer returns a new Round-Robin Balancer. func NewRoundRobinBalancer() Balancer { return &balancerRoundRobin{ - cond: sync.NewCond(&sync.Mutex{}), - seq: -1, + cond: sync.NewCond(new(sync.Mutex)), + seq: atomic.NewUint32(0), done: make(chan struct{}), } } diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index 32c9832..d21c8b8 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -9,52 +9,136 @@ import ( "time" "github.com/jjeffcaii/reactor-go/scheduler" - . "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go" . "github.com/rsocket/rsocket-go/balancer" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/mono" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +func startServer(ctx context.Context, port int, counter *sync.Map) { + _ = rsocket.Receive(). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { + return rsocket.NewAbstractSocket( + rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + cur, _ := counter.LoadOrStore(port, atomic.NewInt32(0)) + cur.(*atomic.Int32).Inc() + return mono.Just(msg) + }), + ), nil + }). + Transport(rsocket.Tcp().HostAndPort("127.0.0.1", port).Build()). + Serve(ctx) +} + func TestRoundRobin(t *testing.T) { + ports := [3]int{7000, 7001, 7002} + counter := &sync.Map{} + ctx0, cancel0 := context.WithCancel(context.Background()) + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx0, ports[0]) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx1, ports[1]) + go func(ctx context.Context, port int) { + startServer(ctx, port, counter) + }(ctx2, ports[2]) + + time.Sleep(1 * time.Second) + b := NewRoundRobinBalancer() b.OnLeave(func(label string) { log.Println("client leave:", label) }) - defer func() { - _ = b.Close() - }() - - const x, y = 3, 1000 - wg := &sync.WaitGroup{} - wg.Add(x * y) - for i := 0; i < x; i++ { - go func(n int) { - for j := 0; j < y; j++ { - b.Next().RequestResponse(payload.NewString(fmt.Sprintf("GO_%04d_%04d", n, j), "go")). - DoOnSuccess(func(elem payload.Payload) { - m, _ := elem.MetadataUTF8() - log.Println("elem:", elem.DataUTF8(), m) - }). - DoFinally(func(st rx.SignalType) { - wg.Done() - }). - SubscribeOn(scheduler.Elastic()). - Subscribe(context.Background()) - time.Sleep(1 * time.Second) - } - }(i) + defer b.Close() + + for i := 0; i < len(ports); i++ { + client, err := rsocket.Connect(). + Transport(rsocket.Tcp().HostAndPort("127.0.0.1", ports[i]).Build()). + Start(context.Background()) + assert.NoError(t, err) + b.PutLabel(fmt.Sprintf("test-client-%d", ports[i]), client) } - for _, port := range []int{17878, 8000, 8001, 8002} { - go func(uri string) { - c, err := Connect(). - SetupPayload(payload.NewString(uri, "hello")). - Transport(uri). - Start(context.Background()) - if err == nil { - b.PutLabel(uri, c) - } - }(fmt.Sprintf("tcp://127.0.0.1:%d", port)) + + req := payload.NewString("foo", "bar") + + const n = 3 + wg := sync.WaitGroup{} + wg.Add(n * len(ports)) + for i := 0; i < n*len(ports); i++ { + b.MustNext(context.Background()).RequestResponse(req). + DoFinally(func(s rx.SignalType) { + wg.Done() + }). + DoOnError(func(e error) { + assert.Fail(t, "should never run here") + }). + SubscribeOn(scheduler.Elastic()). + Subscribe(context.Background()) } wg.Wait() + + counter.Range(func(key, value interface{}) bool { + v := int(value.(*atomic.Int32).Load()) + assert.Equal(t, n, v) + return true + }) + + ac0, ok := counter.Load(ports[0]) + assert.True(t, ok) + + amount0 := ac0.(*atomic.Int32).Load() + + // shutdown server 1 + cancel0() + time.Sleep(100 * time.Millisecond) + + // then send a request + _, err := b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + + var total int + counter.Range(func(key, value interface{}) bool { + total += int(value.(*atomic.Int32).Load()) + return true + }) + assert.Equal(t, n*len(ports)+1, total) + assert.Equal(t, int32(0), ac0.(*atomic.Int32).Load()-amount0) + + _, err = b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + total++ + + // shutdown server 2 + cancel1() + time.Sleep(100 * time.Millisecond) + + const extra = 10 + + for i := 0; i < extra; i++ { + _, err = b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + assert.NoError(t, err) + } + total += 10 + + requestedActual := 0 + counter.Range(func(key, value interface{}) bool { + requestedActual += int(value.(*atomic.Int32).Load()) + return true + }) + assert.Equal(t, total, requestedActual) + + cancel2() + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, ok = b.Next(ctx) + assert.False(t, ok) } diff --git a/client.go b/client.go index b9bb5ec..e5d71ae 100644 --- a/client.go +++ b/client.go @@ -2,14 +2,14 @@ package rsocket import ( "context" - "crypto/tls" "time" "github.com/google/uuid" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/internal/fragmentation" "github.com/rsocket/rsocket-go/internal/socket" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/payload" ) @@ -21,113 +21,79 @@ var ( type ( // ClientResumeOptions represents resume options for client. ClientResumeOptions func(opts *resumeOpts) - - // Client is Client Side of a RSocket socket. Sends Frames to a RSocket Server. - Client interface { - CloseableRSocket - } - - // ClientSocketAcceptor is alias for RSocket handler function. - ClientSocketAcceptor = func(socket RSocket) RSocket - - // ClientStarter can be used to start a client. - ClientStarter interface { - // Start start a client socket. - Start(ctx context.Context) (Client, error) - // Start start a client socket with TLS. - // Here's an example: - // tc := &tls.Config { - // InsecureSkipVerify: true, - // } - StartTLS(ctx context.Context, tc *tls.Config) (Client, error) - } - - // ClientBuilder can be used to build a RSocket client. - ClientBuilder interface { - ClientTransportBuilder - // Fragment set fragmentation size which default is 16_777_215(16MB). - Fragment(mtu int) ClientBuilder - // KeepAlive defines current client keepalive settings. - KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder - // Resume enable the functionality of resume. - Resume(opts ...ClientResumeOptions) ClientBuilder - // Lease enable the functionality of lease. - Lease() ClientBuilder - // DataMimeType is used to set payload data MIME type. - // Default MIME type is `application/binary`. - DataMimeType(mime string) ClientBuilder - // MetadataMimeType is used to set payload metadata MIME type. - // Default MIME type is `application/binary`. - MetadataMimeType(mime string) ClientBuilder - // SetupPayload set the setup payload. - SetupPayload(setup payload.Payload) ClientBuilder - // OnClose register handler when client socket closed. - OnClose(fn func(error)) ClientBuilder - // Acceptor set acceptor for RSocket client. - Acceptor(acceptor ClientSocketAcceptor) ClientTransportBuilder - } - - // ClientTransportBuilder is used to build a RSocket client with custom Transport string. - ClientTransportBuilder interface { - // Transport set Transport for current RSocket client. - // URI is used to create RSocket Transport: - // Example: - // "tcp://127.0.0.1:7878" means a TCP RSocket transport. - // "ws://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport. - // "wss://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport with HTTPS. - Transport(uri string, opts ...TransportOpts) ClientStarter - } - - setupClientSocket interface { - Client - Setup(ctx context.Context, setup *socket.SetupInfo) error - } ) -// Connect create a new RSocket client builder with default settings. -func Connect() ClientBuilder { - return &implClientBuilder{ - fragment: fragmentation.MaxFragment, - setup: &socket.SetupInfo{ - Version: common.DefaultVersion, - KeepaliveInterval: common.DefaultKeepaliveInterval, - KeepaliveLifetime: common.DefaultKeepaliveMaxLifetime, - DataMimeType: _defaultMimeType, - MetadataMimeType: _defaultMimeType, - }, - } -} - -type transportOpts struct { - addr string - headers map[string][]string -} - -// WithWebsocketHeaders attach headers for websocket transport. -func WithWebsocketHeaders(headers map[string][]string) TransportOpts { - return func(opts *transportOpts) { - opts.headers = headers - } -} - -// TransportOpts represents options of transport. -type TransportOpts = func(*transportOpts) - -type implClientBuilder struct { +// Client is Client Side of a RSocket socket. Sends Frames to a RSocket Server. +type Client interface { + CloseableRSocket +} + +// ClientSocketAcceptor is alias for RSocket handler function. +type ClientSocketAcceptor = func(socket RSocket) RSocket + +// ClientStarter can be used to start a client. +type ClientStarter interface { + // Start start a client socket. + Start(ctx context.Context) (Client, error) +} + +// ClientBuilder can be used to build a RSocket client. +type ClientBuilder interface { + ToClientStarter + // Fragment set fragmentation size which default is 16_777_215(16MB). + // Also zero mtu means using default fragmentation size. + Fragment(mtu int) ClientBuilder + // KeepAlive defines current client keepalive settings. + KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder + // Resume enable the functionality of resume. + Resume(opts ...ClientResumeOptions) ClientBuilder + // Lease enable the functionality of lease. + Lease() ClientBuilder + // DataMimeType is used to set payload data MIME type. + // Default MIME type is `application/binary`. + DataMimeType(mime string) ClientBuilder + // MetadataMimeType is used to set payload metadata MIME type. + // Default MIME type is `application/binary`. + MetadataMimeType(mime string) ClientBuilder + // SetupPayload set the setup payload. + SetupPayload(setup payload.Payload) ClientBuilder + // OnClose register handler when client socket closed. + OnClose(fn func(error)) ClientBuilder + // Acceptor set acceptor for RSocket client. + Acceptor(acceptor ClientSocketAcceptor) ToClientStarter +} + +type ToClientStarter interface { + // Transport set Transport for current RSocket client. + // URI is used to create RSocket Transport: + // Example: + // "tcp://127.0.0.1:7878" means a TCP RSocket transport. + // "ws://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport. + // "wss://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport with HTTPS. + Transport(Transporter) ClientStarter +} + +// ToClientStarter is used to build a RSocket client with custom Transport string. +type setupClientSocket interface { + Client + Setup(ctx context.Context, setup *socket.SetupInfo) error +} + +type clientBuilder struct { resume *resumeOpts fragment int - tpOpts *transportOpts + tpGen transport.ToClientTransport setup *socket.SetupInfo acceptor ClientSocketAcceptor onCloses []func(error) } -func (p *implClientBuilder) Lease() ClientBuilder { +func (p *clientBuilder) Lease() ClientBuilder { p.setup.Lease = true return p } -func (p *implClientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { +func (p *clientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { if p.resume == nil { p.resume = newResumeOpts() } @@ -137,33 +103,37 @@ func (p *implClientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { return p } -func (p *implClientBuilder) Fragment(mtu int) ClientBuilder { - p.fragment = mtu +func (p *clientBuilder) Fragment(mtu int) ClientBuilder { + if mtu == 0 { + p.fragment = fragmentation.MaxFragment + } else { + p.fragment = mtu + } return p } -func (p *implClientBuilder) OnClose(fn func(error)) ClientBuilder { +func (p *clientBuilder) OnClose(fn func(error)) ClientBuilder { p.onCloses = append(p.onCloses, fn) return p } -func (p *implClientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { +func (p *clientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { p.setup.KeepaliveInterval = tickPeriod p.setup.KeepaliveLifetime = time.Duration(missedAcks) * ackTimeout return p } -func (p *implClientBuilder) DataMimeType(mime string) ClientBuilder { +func (p *clientBuilder) DataMimeType(mime string) ClientBuilder { p.setup.DataMimeType = []byte(mime) return p } -func (p *implClientBuilder) MetadataMimeType(mime string) ClientBuilder { +func (p *clientBuilder) MetadataMimeType(mime string) ClientBuilder { p.setup.MetadataMimeType = []byte(mime) return p } -func (p *implClientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { +func (p *clientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { p.setup.Data = nil p.setup.Metadata = nil @@ -178,36 +148,17 @@ func (p *implClientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { return p } -func (p *implClientBuilder) Acceptor(acceptor ClientSocketAcceptor) ClientTransportBuilder { +func (p *clientBuilder) Acceptor(acceptor ClientSocketAcceptor) ToClientStarter { p.acceptor = acceptor return p } -func (p *implClientBuilder) Transport(transport string, opts ...TransportOpts) ClientStarter { - p.tpOpts = &transportOpts{ - addr: transport, - } - for i := 0; i < len(opts); i++ { - opts[i](p.tpOpts) - } +func (p *clientBuilder) Transport(support Transporter) ClientStarter { + p.tpGen = support.Client() return p } -func (p *implClientBuilder) StartTLS(ctx context.Context, tc *tls.Config) (Client, error) { - return p.start(ctx, tc) -} - -func (p *implClientBuilder) Start(ctx context.Context) (client Client, err error) { - return p.start(ctx, nil) -} - -func (p *implClientBuilder) start(ctx context.Context, tc *tls.Config) (client Client, err error) { - var uri *transport.URI - uri, err = transport.ParseURI(p.tpOpts.addr) - if err != nil { - return - } - +func (p *clientBuilder) Start(ctx context.Context) (client Client, err error) { // create a blank socket. err = fragmentation.IsValidFragment(p.fragment) if err != nil { @@ -218,17 +169,13 @@ func (p *implClientBuilder) start(ctx context.Context, tc *tls.Config) (client C p.fragment, p.setup.KeepaliveInterval, ) - var headers map[string][]string - if uri.IsWebsocket() { - headers = p.tpOpts.headers - } // create a client. var cs setupClientSocket if p.resume != nil { p.setup.Token = p.resume.tokenGen() - cs = socket.NewClientResume(uri, sk, tc, headers) + cs = socket.NewClientResume(p.tpGen, sk) } else { - cs = socket.NewClient(uri, sk, tc, headers) + cs = socket.NewClient(p.tpGen, sk) } if p.acceptor != nil { sk.SetResponder(p.acceptor(cs)) @@ -272,3 +219,17 @@ func WithClientResumeToken(gen func() []byte) ClientResumeOptions { opts.tokenGen = gen } } + +// Connect create a new RSocket client builder with default settings. +func Connect() ClientBuilder { + return &clientBuilder{ + fragment: fragmentation.MaxFragment, + setup: &socket.SetupInfo{ + Version: core.DefaultVersion, + KeepaliveInterval: common.DefaultKeepaliveInterval, + KeepaliveLifetime: common.DefaultKeepaliveMaxLifetime, + DataMimeType: _defaultMimeType, + MetadataMimeType: _defaultMimeType, + }, + } +} diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index eb73e16..2576842 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "io/ioutil" + "net/url" "os" + "strconv" "strings" "time" @@ -112,11 +114,23 @@ func (p *Runner) runClientMode(ctx context.Context) (err error) { } setupPayload := payload.New(setupData, nil) sendingPayloads := p.createPayload() + + tp, err := makeTransport(p.URI) + if err != nil { + return + } + + // TODO: + + //if ws, ok := tp.(*rsocket.wsTransporter); ok { + // ws.Header(p.wsHeaders) + //} + c, err := cb. DataMimeType(p.DataFormat). MetadataMimeType(p.MetadataFormat). SetupPayload(setupPayload). - Transport(p.URI, rsocket.WithWebsocketHeaders(p.wsHeaders)). + Transport(tp). Start(ctx) if err != nil { return @@ -165,6 +179,12 @@ func (p *Runner) runServerMode(ctx context.Context) error { sb = rsocket.Receive() } ch := make(chan error) + + tp, err := makeTransport(p.URI) + if err != nil { + return err + } + go func() { sendingPayloads := p.createPayload() ch <- sb. @@ -200,7 +220,7 @@ func (p *Runner) runServerMode(ctx context.Context) error { })) return rsocket.NewAbstractSocket(options...), nil }). - Transport(p.URI). + Transport(tp). Serve(ctx) close(ch) }() @@ -318,3 +338,25 @@ func (p *Runner) readData(input string) (data []byte, err error) { } return } + +func makeTransport(s string) (rsocket.Transporter, error) { + u, err := url.Parse(s) + if err != nil { + return nil, err + } + switch u.Scheme { + case "tcp": + port, err := strconv.Atoi(u.Port()) + if err != nil { + return nil, err + } + return rsocket.Tcp().HostAndPort(u.Hostname(), port).Build(), nil + case "unix": + return rsocket.Unix().Path(u.Hostname()).Build(), nil + case "ws", "wss": + return rsocket.Websocket().Url(s).Build(), nil + default: + return nil, fmt.Errorf("invalid transport %s", u.Scheme) + } + +} diff --git a/internal/transport/uri.go b/cmd/rsocket-cli/uri.go similarity index 68% rename from internal/transport/uri.go rename to cmd/rsocket-cli/uri.go index d7dd366..b955023 100644 --- a/internal/transport/uri.go +++ b/cmd/rsocket-cli/uri.go @@ -1,4 +1,4 @@ -package transport +package main import ( "crypto/tls" @@ -6,6 +6,7 @@ import ( "strings" "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" ) const ( @@ -33,44 +34,44 @@ func (p *URI) IsWebsocket() bool { } // MakeClientTransport creates a new client-side transport. -func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) (*Transport, error) { +func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) (*transport.Transport, error) { switch strings.ToLower(p.Scheme) { case schemaTCP: - return newTCPClientTransport(schemaTCP, p.Host, tc) + return transport.NewTcpClientTransport(schemaTCP, p.Host, tc) case schemaWebsocket: if tc == nil { - return newWebsocketClientTransport(p.pp().String(), nil, headers) + return transport.NewWebsocketClientTransport(p.pp().String(), nil, headers) } var clone = (url.URL)(*p) clone.Scheme = "wss" - return newWebsocketClientTransport(clone.String(), tc, headers) + return transport.NewWebsocketClientTransport(clone.String(), tc, headers) case schemaWebsocketSecure: if tc == nil { tc = tlsInsecure } - return newWebsocketClientTransport(p.pp().String(), tc, headers) + return transport.NewWebsocketClientTransport(p.pp().String(), tc, headers) case schemaUNIX: - return newTCPClientTransport(schemaUNIX, p.Path, tc) + return transport.NewTcpClientTransport(schemaUNIX, p.Path, tc) default: return nil, errors.Errorf("unsupported transport url: %s", p.pp().String()) } } // MakeServerTransport creates a new server-side transport. -func (p *URI) MakeServerTransport(c *tls.Config) (tp ServerTransport, err error) { +func (p *URI) MakeServerTransport(c *tls.Config) (tp transport.ServerTransport, err error) { switch strings.ToLower(p.Scheme) { case schemaTCP: - tp = newTCPServerTransport(schemaTCP, p.Host, c) + tp = transport.NewTcpServerTransport(schemaTCP, p.Host, c) case schemaWebsocket: - tp = newWebsocketServerTransport(p.Host, p.Path, c) + tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) case schemaWebsocketSecure: if c == nil { err = errors.Errorf("missing TLS Config for proto %s", schemaWebsocketSecure) return } - tp = newWebsocketServerTransport(p.Host, p.Path, c) + tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) case schemaUNIX: - tp = newTCPServerTransport(schemaUNIX, p.Path, c) + tp = transport.NewTcpServerTransport(schemaUNIX, p.Path, c) default: err = errors.Errorf("unsupported transport url: %s", p.pp().String()) } diff --git a/internal/transport/uri_test.go b/cmd/rsocket-cli/uri_test.go similarity index 96% rename from internal/transport/uri_test.go rename to cmd/rsocket-cli/uri_test.go index 3f0c345..3fa9312 100644 --- a/internal/transport/uri_test.go +++ b/cmd/rsocket-cli/uri_test.go @@ -1,4 +1,4 @@ -package transport +package main import ( "log" diff --git a/internal/transport/counter.go b/core/counter.go similarity index 85% rename from internal/transport/counter.go rename to core/counter.go index a8fd46b..4d54e38 100644 --- a/internal/transport/counter.go +++ b/core/counter.go @@ -1,4 +1,4 @@ -package transport +package core import ( "go.uber.org/atomic" @@ -19,11 +19,11 @@ func (p Counter) WriteBytes() uint64 { return p.w.Load() } -func (p Counter) incrWriteBytes(n int) { +func (p Counter) IncWriteBytes(n int) { p.w.Add(uint64(n)) } -func (p Counter) incrReadBytes(n int) { +func (p Counter) IncReadBytes(n int) { p.r.Add(uint64(n)) } diff --git a/core/counter_test.go b/core/counter_test.go new file mode 100644 index 0000000..bdefc3e --- /dev/null +++ b/core/counter_test.go @@ -0,0 +1,29 @@ +package core_test + +import ( + "sync" + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + const cycle = 1000 + const amount = 1000 + wg := sync.WaitGroup{} + wg.Add(amount) + c := core.NewCounter() + for range [amount]struct{}{} { + go func() { + for range [cycle]struct{}{} { + c.IncWriteBytes(1) + c.IncReadBytes(1) + } + wg.Done() + }() + } + wg.Wait() + assert.Equal(t, uint64(cycle*amount), c.WriteBytes()) + assert.Equal(t, uint64(cycle*amount), c.ReadBytes()) +} diff --git a/internal/common/errors.go b/core/errors.go similarity index 99% rename from internal/common/errors.go rename to core/errors.go index 976e9dc..023435e 100644 --- a/internal/common/errors.go +++ b/core/errors.go @@ -1,19 +1,10 @@ -package common +package core import "errors" // ErrorCode is code for RSocket error. type ErrorCode uint32 -// CustomError provides a method of accessing code and data. -type CustomError interface { - error - // ErrorCode returns error code. - ErrorCode() ErrorCode - // ErrorData returns error data bytes. - ErrorData() []byte -} - func (e ErrorCode) String() string { switch e { case ErrorCodeInvalidSetup: @@ -64,6 +55,15 @@ const ( ErrorCodeInvalid ErrorCode = 0x00000204 ) +// CustomError provides a method of accessing code and data. +type CustomError interface { + error + // ErrorCode returns error code. + ErrorCode() ErrorCode + // ErrorData returns error data bytes. + ErrorData() []byte +} + // Error defines. var ( ErrFrameLengthExceed = errors.New("rsocket: frame length is greater than 24bits") diff --git a/core/errors_test.go b/core/errors_test.go new file mode 100644 index 0000000..dc8b11e --- /dev/null +++ b/core/errors_test.go @@ -0,0 +1,28 @@ +package core_test + +import ( + "math" + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestErrorCode_String(t *testing.T) { + all := []core.ErrorCode{ + core.ErrorCodeInvalidSetup, + core.ErrorCodeUnsupportedSetup, + core.ErrorCodeRejectedSetup, + core.ErrorCodeRejectedResume, + core.ErrorCodeConnectionError, + core.ErrorCodeConnectionClose, + core.ErrorCodeApplicationError, + core.ErrorCodeRejected, + core.ErrorCodeCanceled, + core.ErrorCodeInvalid, + } + for _, code := range all { + assert.NotEqual(t, "UNKNOWN", code.String()) + } + assert.Equal(t, "UNKNOWN", core.ErrorCode(math.MaxUint32).String()) +} diff --git a/core/framing/frame.go b/core/framing/frame.go new file mode 100644 index 0000000..4e5af5f --- /dev/null +++ b/core/framing/frame.go @@ -0,0 +1,130 @@ +package framing + +import ( + "errors" + "fmt" + "io" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/internal/common" +) + +var errIncompleteFrame = errors.New("incomplete frame") + +type tinyFrame struct { + header core.FrameHeader + done chan struct{} +} + +func (t *tinyFrame) Header() core.FrameHeader { + return t.header +} + +// Done can be invoked when a frame has been been processed. +func (t *tinyFrame) Done() (closed bool) { + defer func() { + if e := recover(); e != nil { + closed = true + } + }() + close(t.done) + return +} + +// DoneNotify notify when frame has been done. +func (t *tinyFrame) DoneNotify() <-chan struct{} { + return t.done +} + +// RawFrame is basic frame implementation. +type RawFrame struct { + *tinyFrame + body *common.ByteBuff +} + +// Body returns frame body. +func (f *RawFrame) Body() *common.ByteBuff { + return f.body +} + +// Len returns length of frame. +func (f *RawFrame) Len() int { + if f.body == nil { + return core.FrameHeaderLen + } + return core.FrameHeaderLen + f.body.Len() +} + +// WriteTo write frame to writer. +func (f *RawFrame) WriteTo(w io.Writer) (n int64, err error) { + var wrote int64 + wrote, err = f.header.WriteTo(w) + if err != nil { + return + } + n += wrote + if f.body != nil { + wrote, err = f.body.WriteTo(w) + if err != nil { + return + } + n += wrote + } + return +} + +func (f *RawFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { + raw := f.body.Bytes() + if offset > 0 { + raw = raw[offset:] + } + hasMetadata = f.header.Flag().Check(core.FlagMetadata) + if !hasMetadata { + return + } + if len(raw) < 3 { + n = -1 + } else { + n = common.NewUint24Bytes(raw).AsInt() + } + return +} + +func (f *RawFrame) trySliceMetadata(offset int) ([]byte, bool) { + n, ok := f.trySeekMetadataLen(offset) + if !ok || n < 0 { + return nil, false + } + return f.body.Bytes()[offset+3 : offset+3+n], true +} + +func (f *RawFrame) trySliceData(offset int) []byte { + n, ok := f.trySeekMetadataLen(offset) + if !ok { + return f.body.Bytes()[offset:] + } + if n < 0 { + return nil + } + return f.body.Bytes()[offset+n+3:] +} + +func newTinyFrame(header core.FrameHeader) *tinyFrame { + return &tinyFrame{ + header: header, + done: make(chan struct{}), + } +} + +// NewRawFrame returns a new RawFrame. +func NewRawFrame(header core.FrameHeader, body *common.ByteBuff) *RawFrame { + return &RawFrame{ + tinyFrame: newTinyFrame(header), + body: body, + } +} + +func PrintFrame(f core.FrameSupport) string { + // TODO: print frame + return fmt.Sprintf("%+v", f) +} diff --git a/internal/framing/frame_cancel.go b/core/framing/frame_cancel.go similarity index 81% rename from internal/framing/frame_cancel.go rename to core/framing/frame_cancel.go index d431d6d..9dcbf5e 100644 --- a/internal/framing/frame_cancel.go +++ b/core/framing/frame_cancel.go @@ -2,6 +2,8 @@ package framing import ( "io" + + "github.com/rsocket/rsocket-go/core" ) // CancelFrame is frame of cancel. @@ -24,7 +26,7 @@ func (c CancelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (c CancelFrameSupport) Len() int { - return HeaderLen + return core.FrameHeaderLen } // Validate returns error if frame is invalid. @@ -37,7 +39,7 @@ func (f *CancelFrame) Validate() (err error) { } func NewCancelFrameSupport(id uint32) *CancelFrameSupport { - h := NewFrameHeader(id, FrameTypeCancel, 0) + h := core.NewFrameHeader(id, core.FrameTypeCancel, 0) return &CancelFrameSupport{ tinyFrame: newTinyFrame(h), } @@ -46,6 +48,6 @@ func NewCancelFrameSupport(id uint32) *CancelFrameSupport { // NewCancelFrame creates cancel frame. func NewCancelFrame(sid uint32) *CancelFrame { return &CancelFrame{ - NewRawFrame(NewFrameHeader(sid, FrameTypeCancel, 0), nil), + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeCancel, 0), nil), } } diff --git a/internal/framing/frame_error.go b/core/framing/frame_error.go similarity index 75% rename from internal/framing/frame_error.go rename to core/framing/frame_error.go index b21bbb6..5e45b42 100644 --- a/internal/framing/frame_error.go +++ b/core/framing/frame_error.go @@ -5,6 +5,7 @@ import ( "io" "strings" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -21,7 +22,7 @@ type ErrorFrame struct { type ErrorFrameSupport struct { *tinyFrame - code common.ErrorCode + code core.ErrorCode data []byte } @@ -46,7 +47,7 @@ func (e ErrorFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (e ErrorFrameSupport) Len() int { - return HeaderLen + 4 + len(e.data) + return core.FrameHeaderLen + 4 + len(e.data) } // Validate returns error if frame is invalid. @@ -62,9 +63,9 @@ func (p *ErrorFrame) Error() string { } // ErrorCode returns error code. -func (p *ErrorFrame) ErrorCode() common.ErrorCode { +func (p *ErrorFrame) ErrorCode() core.ErrorCode { v := binary.BigEndian.Uint32(p.body.Bytes()) - return common.ErrorCode(v) + return core.ErrorCode(v) } // ErrorData returns error data bytes. @@ -72,8 +73,8 @@ func (p *ErrorFrame) ErrorData() []byte { return p.body.Bytes()[errDataOff:] } -func NewErrorFrameSupport(id uint32, code common.ErrorCode, data []byte) *ErrorFrameSupport { - h := NewFrameHeader(id, FrameTypeError, 0) +func NewErrorFrameSupport(id uint32, code core.ErrorCode, data []byte) *ErrorFrameSupport { + h := core.NewFrameHeader(id, core.FrameTypeError, 0) t := newTinyFrame(h) return &ErrorFrameSupport{ tinyFrame: t, @@ -83,7 +84,7 @@ func NewErrorFrameSupport(id uint32, code common.ErrorCode, data []byte) *ErrorF } // NewErrorFrame returns a new error frame. -func NewErrorFrame(streamID uint32, code common.ErrorCode, data []byte) *ErrorFrame { +func NewErrorFrame(streamID uint32, code core.ErrorCode, data []byte) *ErrorFrame { bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], uint32(code)) @@ -94,11 +95,11 @@ func NewErrorFrame(streamID uint32, code common.ErrorCode, data []byte) *ErrorFr panic(err) } return &ErrorFrame{ - NewRawFrame(NewFrameHeader(streamID, FrameTypeError, 0), bf), + NewRawFrame(core.NewFrameHeader(streamID, core.FrameTypeError, 0), bf), } } -func makeErrorString(code common.ErrorCode, data []byte) string { +func makeErrorString(code core.ErrorCode, data []byte) string { bu := strings.Builder{} bu.WriteString(code.String()) bu.WriteByte(':') diff --git a/internal/framing/frame_fnf.go b/core/framing/frame_fnf.go similarity index 83% rename from internal/framing/frame_fnf.go rename to core/framing/frame_fnf.go index 582362e..e1b748f 100644 --- a/internal/framing/frame_fnf.go +++ b/core/framing/frame_fnf.go @@ -3,6 +3,7 @@ package framing import ( "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -39,7 +40,7 @@ func (f FireAndForgetFrameSupport) Len() int { // Validate returns error if frame is invalid. func (f *FireAndForgetFrame) Validate() (err error) { - if f.header.Flag().Check(FlagMetadata) && f.body.Len() < 3 { + if f.header.Flag().Check(core.FlagMetadata) && f.body.Len() < 3 { err = errIncompleteFrame } return @@ -69,11 +70,11 @@ func (f *FireAndForgetFrame) DataUTF8() string { return string(f.Data()) } -func NewFireAndForgetFrameSupport(sid uint32, data, metadata []byte, flag FrameFlag) *FireAndForgetFrameSupport { +func NewFireAndForgetFrameSupport(sid uint32, data, metadata []byte, flag core.FrameFlag) *FireAndForgetFrameSupport { if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } - h := NewFrameHeader(sid, FrameTypeRequestFNF, flag) + h := core.NewFrameHeader(sid, core.FrameTypeRequestFNF, flag) t := newTinyFrame(h) return &FireAndForgetFrameSupport{ tinyFrame: t, @@ -83,10 +84,10 @@ func NewFireAndForgetFrameSupport(sid uint32, data, metadata []byte, flag FrameF } // NewFireAndForgetFrame returns a new fire and forget frame. -func NewFireAndForgetFrame(sid uint32, data, metadata []byte, flag FrameFlag) *FireAndForgetFrame { +func NewFireAndForgetFrame(sid uint32, data, metadata []byte, flag core.FrameFlag) *FireAndForgetFrame { bf := common.NewByteBuff() if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -98,6 +99,6 @@ func NewFireAndForgetFrame(sid uint32, data, metadata []byte, flag FrameFlag) *F panic(err) } return &FireAndForgetFrame{ - NewRawFrame(NewFrameHeader(sid, FrameTypeRequestFNF, flag), bf), + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestFNF, flag), bf), } } diff --git a/internal/framing/frame_keepalive.go b/core/framing/frame_keepalive.go similarity index 85% rename from internal/framing/frame_keepalive.go rename to core/framing/frame_keepalive.go index def5db9..ec03da5 100644 --- a/internal/framing/frame_keepalive.go +++ b/core/framing/frame_keepalive.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -48,7 +49,7 @@ func (k KeepaliveFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (k KeepaliveFrameSupport) Len() int { - return HeaderLen + 8 + len(k.data) + return core.FrameHeaderLen + 8 + len(k.data) } // Validate returns error if frame is invalid. @@ -70,15 +71,15 @@ func (k *KeepaliveFrame) Data() []byte { } func NewKeepaliveFrameSupport(position uint64, data []byte, respond bool) *KeepaliveFrameSupport { - var flag FrameFlag + var flag core.FrameFlag if respond { - flag |= FlagRespond + flag |= core.FlagRespond } var b [8]byte binary.BigEndian.PutUint64(b[:], position) - h := NewFrameHeader(0, FrameTypeKeepalive, flag) + h := core.NewFrameHeader(0, core.FrameTypeKeepalive, flag) t := newTinyFrame(h) return &KeepaliveFrameSupport{ @@ -90,9 +91,9 @@ func NewKeepaliveFrameSupport(position uint64, data []byte, respond bool) *Keepa // NewKeepaliveFrame returns a new keepalive frame. func NewKeepaliveFrame(position uint64, data []byte, respond bool) *KeepaliveFrame { - var fg FrameFlag + var fg core.FrameFlag if respond { - fg |= FlagRespond + fg |= core.FlagRespond } bf := common.NewByteBuff() var b8 [8]byte @@ -106,6 +107,6 @@ func NewKeepaliveFrame(position uint64, data []byte, respond bool) *KeepaliveFra } } return &KeepaliveFrame{ - NewRawFrame(NewFrameHeader(0, FrameTypeKeepalive, fg), bf), + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeKeepalive, fg), bf), } } diff --git a/internal/framing/frame_lease.go b/core/framing/frame_lease.go similarity index 83% rename from internal/framing/frame_lease.go rename to core/framing/frame_lease.go index b4b7621..b258491 100644 --- a/internal/framing/frame_lease.go +++ b/core/framing/frame_lease.go @@ -5,6 +5,7 @@ import ( "io" "time" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -41,7 +42,7 @@ func (l *LeaseFrame) NumberOfRequests() uint32 { // Metadata returns metadata bytes. func (l *LeaseFrame) Metadata() []byte { - if !l.header.Flag().Check(FlagMetadata) { + if !l.header.Flag().Check(core.FlagMetadata) { return nil } return l.body.Bytes()[8:] @@ -75,7 +76,7 @@ func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } n += int64(v) - if l.header.Flag().Check(FlagMetadata) { + if l.header.Flag().Check(core.FlagMetadata) { v, err = w.Write(l.metadata) if err != nil { return @@ -87,8 +88,8 @@ func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (l LeaseFrameSupport) Len() int { - n := HeaderLen + 8 - if l.header.Flag().Check(FlagMetadata) { + n := core.FrameHeaderLen + 8 + if l.header.Flag().Check(core.FlagMetadata) { n += len(l.metadata) } return n @@ -99,11 +100,11 @@ func NewLeaseFrameSupport(ttl time.Duration, n uint32, metadata []byte) *LeaseFr binary.BigEndian.PutUint32(a[:], uint32(ttl.Milliseconds())) binary.BigEndian.PutUint32(b[:], n) - var flag FrameFlag + var flag core.FrameFlag if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } - h := NewFrameHeader(0, FrameTypeLease, flag) + h := core.NewFrameHeader(0, core.FrameTypeLease, flag) t := newTinyFrame(h) return &LeaseFrameSupport{ tinyFrame: t, @@ -121,12 +122,12 @@ func NewLeaseFrame(ttl time.Duration, n uint32, metadata []byte) *LeaseFrame { if err := binary.Write(bf, binary.BigEndian, n); err != nil { panic(err) } - var fg FrameFlag + var fg core.FrameFlag if len(metadata) > 0 { - fg |= FlagMetadata + fg |= core.FlagMetadata if _, err := bf.Write(metadata); err != nil { panic(err) } } - return &LeaseFrame{NewRawFrame(NewFrameHeader(0, FrameTypeLease, fg), bf)} + return &LeaseFrame{NewRawFrame(core.NewFrameHeader(0, core.FrameTypeLease, fg), bf)} } diff --git a/internal/framing/frame_metadata_push.go b/core/framing/frame_metadata_push.go similarity index 90% rename from internal/framing/frame_metadata_push.go rename to core/framing/frame_metadata_push.go index e2c9e3f..713e8ed 100644 --- a/internal/framing/frame_metadata_push.go +++ b/core/framing/frame_metadata_push.go @@ -3,10 +3,11 @@ package framing import ( "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) -var _metadataPushHeader = NewFrameHeader(0, FrameTypeMetadataPush, FlagMetadata) +var _metadataPushHeader = core.NewFrameHeader(0, core.FrameTypeMetadataPush, core.FlagMetadata) // MetadataPushFrame is metadata push frame. type MetadataPushFrame struct { @@ -59,7 +60,7 @@ func (m MetadataPushFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (m MetadataPushFrameSupport) Len() int { - return HeaderLen + len(m.metadata) + return core.FrameHeaderLen + len(m.metadata) } // DataUTF8 returns data as UTF8 string. diff --git a/internal/framing/frame_payload.go b/core/framing/frame_payload.go similarity index 81% rename from internal/framing/frame_payload.go rename to core/framing/frame_payload.go index d97c377..1b4e92e 100644 --- a/internal/framing/frame_payload.go +++ b/core/framing/frame_payload.go @@ -3,6 +3,7 @@ package framing import ( "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -14,7 +15,7 @@ type PayloadFrame struct { // Validate returns error if frame is invalid. func (p *PayloadFrame) Validate() (err error) { // Minimal length should be 3 if metadata exists. - if p.header.Flag().Check(FlagMetadata) && p.body.Len() < 3 { + if p.header.Flag().Check(core.FlagMetadata) && p.body.Len() < 3 { err = errIncompleteFrame } return @@ -63,7 +64,7 @@ func (p PayloadFrameSupport) DataUTF8() string { } func (p PayloadFrameSupport) MetadataUTF8() (metadata string, ok bool) { - if p.header.Flag().Check(FlagMetadata) { + if p.header.Flag().Check(core.FlagMetadata) { metadata = string(p.metadata) ok = true } @@ -75,7 +76,7 @@ func (p PayloadFrameSupport) Data() []byte { } func (p PayloadFrameSupport) Metadata() ([]byte, bool) { - return p.metadata, p.header.Flag().Check(FlagMetadata) + return p.metadata, p.header.Flag().Check(core.FlagMetadata) } func (p PayloadFrameSupport) WriteTo(w io.Writer) (n int64, err error) { @@ -97,11 +98,11 @@ func (p PayloadFrameSupport) Len() int { } // NewPayloadFrameSupport returns a new payload frame. -func NewPayloadFrameSupport(id uint32, data, metadata []byte, flag FrameFlag) *PayloadFrameSupport { +func NewPayloadFrameSupport(id uint32, data, metadata []byte, flag core.FrameFlag) *PayloadFrameSupport { if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } - h := NewFrameHeader(id, FrameTypePayload, flag) + h := core.NewFrameHeader(id, core.FrameTypePayload, flag) t := newTinyFrame(h) return &PayloadFrameSupport{ tinyFrame: t, @@ -111,10 +112,10 @@ func NewPayloadFrameSupport(id uint32, data, metadata []byte, flag FrameFlag) *P } // NewPayloadFrame returns a new payload frame. -func NewPayloadFrame(id uint32, data, metadata []byte, flag FrameFlag) *PayloadFrame { +func NewPayloadFrame(id uint32, data, metadata []byte, flag core.FrameFlag) *PayloadFrame { bf := common.NewByteBuff() if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -128,6 +129,6 @@ func NewPayloadFrame(id uint32, data, metadata []byte, flag FrameFlag) *PayloadF } } return &PayloadFrame{ - NewRawFrame(NewFrameHeader(id, FrameTypePayload, flag), bf), + NewRawFrame(core.NewFrameHeader(id, core.FrameTypePayload, flag), bf), } } diff --git a/internal/framing/frame_request_channel.go b/core/framing/frame_request_channel.go similarity index 85% rename from internal/framing/frame_request_channel.go rename to core/framing/frame_request_channel.go index 93bd8ff..466844e 100644 --- a/internal/framing/frame_request_channel.go +++ b/core/framing/frame_request_channel.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -30,7 +31,7 @@ func (r *RequestChannelFrame) Validate() error { if l < minRequestChannelFrameLen { return errIncompleteFrame } - if r.header.Flag().Check(FlagMetadata) && l < minRequestChannelFrameLen+3 { + if r.header.Flag().Check(core.FlagMetadata) && l < minRequestChannelFrameLen+3 { return errIncompleteFrame } return nil @@ -93,13 +94,13 @@ func (r RequestChannelFrameSupport) Len() int { return CalcPayloadFrameSize(r.data, r.metadata) + 4 } -func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestChannelFrameSupport { +func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestChannelFrameSupport { var b [4]byte binary.BigEndian.PutUint32(b[:], n) if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } - h := NewFrameHeader(sid, FrameTypeRequestChannel, flag) + h := core.NewFrameHeader(sid, core.FrameTypeRequestChannel, flag) t := newTinyFrame(h) return &RequestChannelFrameSupport{ tinyFrame: t, @@ -110,7 +111,7 @@ func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, } // NewRequestChannelFrame returns a new RequestChannel frame. -func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestChannelFrame { +func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestChannelFrame { bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) @@ -118,7 +119,7 @@ func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag Fr panic(err) } if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -132,6 +133,6 @@ func NewRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag Fr } } return &RequestChannelFrame{ - NewRawFrame(NewFrameHeader(sid, FrameTypeRequestChannel, flag), bf), + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestChannel, flag), bf), } } diff --git a/internal/framing/frame_request_n.go b/core/framing/frame_request_n.go similarity index 74% rename from internal/framing/frame_request_n.go rename to core/framing/frame_request_n.go index 540b884..394ad8f 100644 --- a/internal/framing/frame_request_n.go +++ b/core/framing/frame_request_n.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -45,20 +46,20 @@ func (r RequestNFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (r RequestNFrameSupport) Len() int { - return HeaderLen + 4 + return core.FrameHeaderLen + 4 } -func NewRequestNFrameSupport(id uint32, n uint32, fg FrameFlag) *RequestNFrameSupport { +func NewRequestNFrameSupport(id uint32, n uint32, fg core.FrameFlag) *RequestNFrameSupport { var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) return &RequestNFrameSupport{ - tinyFrame: newTinyFrame(NewFrameHeader(id, FrameTypeRequestN, fg)), + tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestN, fg)), n: b4, } } // NewRequestNFrame returns a new RequestN frame. -func NewRequestNFrame(sid, n uint32, fg FrameFlag) *RequestNFrame { +func NewRequestNFrame(sid, n uint32, fg core.FrameFlag) *RequestNFrame { bf := common.NewByteBuff() var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) @@ -66,6 +67,6 @@ func NewRequestNFrame(sid, n uint32, fg FrameFlag) *RequestNFrame { panic(err) } return &RequestNFrame{ - NewRawFrame(NewFrameHeader(sid, FrameTypeRequestN, fg), bf), + NewRawFrame(core.NewFrameHeader(sid, core.FrameTypeRequestN, fg), bf), } } diff --git a/internal/framing/frame_request_response.go b/core/framing/frame_request_response.go similarity index 83% rename from internal/framing/frame_request_response.go rename to core/framing/frame_request_response.go index 326d00c..1eb716a 100644 --- a/internal/framing/frame_request_response.go +++ b/core/framing/frame_request_response.go @@ -3,6 +3,7 @@ package framing import ( "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -19,7 +20,7 @@ type RequestResponseFrameSupport struct { // Validate returns error if frame is invalid. func (r *RequestResponseFrame) Validate() (err error) { - if r.header.Flag().Check(FlagMetadata) && r.body.Len() < 3 { + if r.header.Flag().Check(core.FlagMetadata) && r.body.Len() < 3 { err = errIncompleteFrame } return @@ -68,22 +69,22 @@ func (r RequestResponseFrameSupport) Len() int { } // NewRequestResponseFrameSupport returns a new RequestResponse frame support. -func NewRequestResponseFrameSupport(id uint32, data, metadata []byte, fg FrameFlag) FrameSupport { +func NewRequestResponseFrameSupport(id uint32, data, metadata []byte, fg core.FrameFlag) core.FrameSupport { if len(metadata) > 0 { - fg |= FlagMetadata + fg |= core.FlagMetadata } return &RequestResponseFrameSupport{ - tinyFrame: newTinyFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg)), + tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestResponse, fg)), metadata: metadata, data: data, } } // NewRequestResponseFrame returns a new RequestResponse frame. -func NewRequestResponseFrame(id uint32, data, metadata []byte, fg FrameFlag) *RequestResponseFrame { +func NewRequestResponseFrame(id uint32, data, metadata []byte, fg core.FrameFlag) *RequestResponseFrame { bf := common.NewByteBuff() if len(metadata) > 0 { - fg |= FlagMetadata + fg |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -97,6 +98,6 @@ func NewRequestResponseFrame(id uint32, data, metadata []byte, fg FrameFlag) *Re } } return &RequestResponseFrame{ - NewRawFrame(NewFrameHeader(id, FrameTypeRequestResponse, fg), bf), + NewRawFrame(core.NewFrameHeader(id, core.FrameTypeRequestResponse, fg), bf), } } diff --git a/internal/framing/frame_request_stream.go b/core/framing/frame_request_stream.go similarity index 85% rename from internal/framing/frame_request_stream.go rename to core/framing/frame_request_stream.go index 0a2f469..50d3db0 100644 --- a/internal/framing/frame_request_stream.go +++ b/core/framing/frame_request_stream.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -29,7 +30,7 @@ func (r *RequestStreamFrame) Validate() error { if l < minRequestStreamFrameLen { return errIncompleteFrame } - if r.header.Flag().Check(FlagMetadata) && l < minRequestStreamFrameLen+3 { + if r.header.Flag().Check(core.FlagMetadata) && l < minRequestStreamFrameLen+3 { return errIncompleteFrame } return nil @@ -91,13 +92,13 @@ func (r RequestStreamFrameSupport) Len() int { return 4 + CalcPayloadFrameSize(r.data, r.metadata) } -func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, flag FrameFlag) FrameSupport { +func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) core.FrameSupport { if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } var b [4]byte binary.BigEndian.PutUint32(b[:], n) - h := NewFrameHeader(id, FrameTypeRequestStream, flag) + h := core.NewFrameHeader(id, core.FrameTypeRequestStream, flag) t := newTinyFrame(h) return &RequestStreamFrameSupport{ tinyFrame: t, @@ -108,13 +109,13 @@ func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, fl } // NewRequestStreamFrame returns a new request stream frame. -func NewRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag FrameFlag) *RequestStreamFrame { +func NewRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestStreamFrame { bf := common.NewByteBuff() if err := binary.Write(bf, binary.BigEndian, n); err != nil { panic(err) } if len(metadata) > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -128,6 +129,6 @@ func NewRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag Fram } } return &RequestStreamFrame{ - NewRawFrame(NewFrameHeader(id, FrameTypeRequestStream, flag), bf), + NewRawFrame(core.NewFrameHeader(id, core.FrameTypeRequestStream, flag), bf), } } diff --git a/internal/framing/frame_resume.go b/core/framing/frame_resume.go similarity index 83% rename from internal/framing/frame_resume.go rename to core/framing/frame_resume.go index 5e1d6ac..c505651 100644 --- a/internal/framing/frame_resume.go +++ b/core/framing/frame_resume.go @@ -6,6 +6,7 @@ import ( "io" "math" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -33,7 +34,7 @@ func (r *ResumeFrame) Validate() (err error) { } // Version returns version. -func (r *ResumeFrame) Version() common.Version { +func (r *ResumeFrame) Version() core.Version { raw := r.body.Bytes() major := binary.BigEndian.Uint16(raw) minor := binary.BigEndian.Uint16(raw[2:]) @@ -63,7 +64,7 @@ func (r *ResumeFrame) FirstAvailableClientPosition() uint64 { type ResumeFrameSupport struct { *tinyFrame - version common.Version + version core.Version token []byte posFirst [8]byte posLast [8]byte @@ -114,12 +115,12 @@ func (r ResumeFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (r ResumeFrameSupport) Len() int { - return HeaderLen + _lenTokenLength + _lenFirstPos + _lenLastRecvPos + _lenVersion + len(r.token) + return core.FrameHeaderLen + _lenTokenLength + _lenFirstPos + _lenLastRecvPos + _lenVersion + len(r.token) } // NewResumeFrameSupport creates a new frame support of Resume. -func NewResumeFrameSupport(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrameSupport { - h := NewFrameHeader(0, FrameTypeResume, 0) +func NewResumeFrameSupport(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrameSupport { + h := core.NewFrameHeader(0, core.FrameTypeResume, 0) t := newTinyFrame(h) var a, b [8]byte binary.BigEndian.PutUint64(a[:], firstAvailableClientPosition) @@ -135,7 +136,7 @@ func NewResumeFrameSupport(version common.Version, token []byte, firstAvailableC } // NewResumeFrame creates a new frame of Resume. -func NewResumeFrame(version common.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrame { +func NewResumeFrame(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrame { n := len(token) if n > math.MaxUint16 { panic(errResumeTokenTooLarge) @@ -159,6 +160,6 @@ func NewResumeFrame(version common.Version, token []byte, firstAvailableClientPo panic(err) } return &ResumeFrame{ - NewRawFrame(NewFrameHeader(0, FrameTypeResume, 0), bf), + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeResume, 0), bf), } } diff --git a/internal/framing/frame_resume_ok.go b/core/framing/frame_resume_ok.go similarity index 88% rename from internal/framing/frame_resume_ok.go rename to core/framing/frame_resume_ok.go index 2026d49..f4d5401 100644 --- a/internal/framing/frame_resume_ok.go +++ b/core/framing/frame_resume_ok.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -49,11 +50,11 @@ func (r ResumeOKFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } func (r ResumeOKFrameSupport) Len() int { - return HeaderLen + 8 + return core.FrameHeaderLen + 8 } func NewResumeOKFrameSupport(position uint64) *ResumeOKFrameSupport { - h := NewFrameHeader(0, FrameTypeResumeOK, 0) + h := core.NewFrameHeader(0, core.FrameTypeResumeOK, 0) t := newTinyFrame(h) var b [8]byte binary.BigEndian.PutUint64(b[:], position) @@ -73,6 +74,6 @@ func NewResumeOKFrame(position uint64) *ResumeOKFrame { panic(err) } return &ResumeOKFrame{ - NewRawFrame(NewFrameHeader(0, FrameTypeResumeOK, 0), bf), + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeResumeOK, 0), bf), } } diff --git a/internal/framing/frame_setup.go b/core/framing/frame_setup.go similarity index 89% rename from internal/framing/frame_setup.go rename to core/framing/frame_setup.go index e75a303..cf9b86d 100644 --- a/internal/framing/frame_setup.go +++ b/core/framing/frame_setup.go @@ -5,6 +5,7 @@ import ( "io" "time" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) @@ -30,7 +31,7 @@ func (p *SetupFrame) Validate() (err error) { } // Version returns version. -func (p *SetupFrame) Version() common.Version { +func (p *SetupFrame) Version() core.Version { major := binary.BigEndian.Uint16(p.body.Bytes()) minor := binary.BigEndian.Uint16(p.body.Bytes()[2:]) return [2]uint16{major, minor} @@ -48,7 +49,7 @@ func (p *SetupFrame) MaxLifetime() time.Duration { // Token returns token of setup. func (p *SetupFrame) Token() []byte { - if !p.header.Flag().Check(FlagResume) { + if !p.header.Flag().Check(core.FlagResume) { return nil } raw := p.body.Bytes() @@ -70,7 +71,7 @@ func (p *SetupFrame) MetadataMimeType() string { // Metadata returns metadata bytes. func (p *SetupFrame) Metadata() ([]byte, bool) { - if !p.header.Flag().Check(FlagMetadata) { + if !p.header.Flag().Check(core.FlagMetadata) { return nil, false } offset := p.seekMIME() @@ -84,7 +85,7 @@ func (p *SetupFrame) Data() []byte { offset := p.seekMIME() m1, m2 := p.mime() offset += 2 + len(m1) + len(m2) - if !p.header.Flag().Check(FlagMetadata) { + if !p.header.Flag().Check(core.FlagMetadata) { return p.Body().Bytes()[offset:] } return p.trySliceData(offset) @@ -118,7 +119,7 @@ func (p *SetupFrame) mime() (metadata []byte, data []byte) { } func (p *SetupFrame) seekMIME() int { - if !p.header.Flag().Check(FlagResume) { + if !p.header.Flag().Check(core.FlagResume) { return 12 } l := binary.BigEndian.Uint16(p.body.Bytes()[12:]) @@ -127,7 +128,7 @@ func (p *SetupFrame) seekMIME() int { type SetupFrameSupport struct { *tinyFrame - version common.Version + version core.Version keepalive [4]byte lifetime [4]byte token []byte @@ -164,7 +165,7 @@ func (s SetupFrameSupport) WriteTo(w io.Writer) (n int64, err error) { } n += int64(v) - if s.header.Flag().Check(FlagResume) { + if s.header.Flag().Check(core.FlagResume) { tokenLen := len(s.token) err = binary.Write(w, binary.BigEndian, uint16(tokenLen)) if err != nil { @@ -220,7 +221,7 @@ func (s SetupFrameSupport) Len() int { } func NewSetupFrameSupport( - version common.Version, + version core.Version, timeBetweenKeepalive, maxLifetime time.Duration, token []byte, @@ -230,17 +231,17 @@ func NewSetupFrameSupport( metadata []byte, lease bool, ) *SetupFrameSupport { - var flag FrameFlag + var flag core.FrameFlag if l := len(token); l > 0 { - flag |= FlagResume + flag |= core.FlagResume } if lease { - flag |= FlagLease + flag |= core.FlagLease } if l := len(metadata); l > 0 { - flag |= FlagMetadata + flag |= core.FlagMetadata } - h := NewFrameHeader(0, FrameTypeSetup, flag) + h := core.NewFrameHeader(0, core.FrameTypeSetup, flag) t := newTinyFrame(h) var a, b [4]byte @@ -261,7 +262,7 @@ func NewSetupFrameSupport( // NewSetupFrame returns a new setup frame. func NewSetupFrame( - version common.Version, + version core.Version, timeBetweenKeepalive, maxLifetime time.Duration, token []byte, @@ -271,7 +272,7 @@ func NewSetupFrame( metadata []byte, lease bool, ) *SetupFrame { - var fg FrameFlag + var fg core.FrameFlag bf := common.NewByteBuff() if _, err := bf.Write(version.Bytes()); err != nil { panic(err) @@ -286,10 +287,10 @@ func NewSetupFrame( panic(err) } if lease { - fg |= FlagLease + fg |= core.FlagLease } if len(token) > 0 { - fg |= FlagResume + fg |= core.FlagResume binary.BigEndian.PutUint16(b4[:2], uint16(len(token))) if _, err := bf.Write(b4[:2]); err != nil { panic(err) @@ -311,7 +312,7 @@ func NewSetupFrame( panic(err) } if len(metadata) > 0 { - fg |= FlagMetadata + fg |= core.FlagMetadata if err := bf.WriteUint24(len(metadata)); err != nil { panic(err) } @@ -325,6 +326,6 @@ func NewSetupFrame( } } return &SetupFrame{ - NewRawFrame(NewFrameHeader(0, FrameTypeSetup, fg), bf), + NewRawFrame(core.NewFrameHeader(0, core.FrameTypeSetup, fg), bf), } } diff --git a/internal/framing/frame_test.go b/core/framing/frame_test.go similarity index 67% rename from internal/framing/frame_test.go rename to core/framing/frame_test.go index e9ab7f2..ba42b60 100644 --- a/internal/framing/frame_test.go +++ b/core/framing/frame_test.go @@ -6,8 +6,9 @@ import ( "testing" "time" + "github.com/rsocket/rsocket-go/core" + . "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - . "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) @@ -15,46 +16,46 @@ const _sid uint32 = 1 func TestFrameCancel(t *testing.T) { f := NewCancelFrame(_sid) - checkBasic(t, f, FrameTypeCancel) + checkBasic(t, f, core.FrameTypeCancel) f2 := NewCancelFrameSupport(_sid) checkBytes(t, f, f2) } func TestFrameError(t *testing.T) { errData := []byte(common.RandAlphanumeric(100)) - f := NewErrorFrame(_sid, common.ErrorCodeApplicationError, errData) - checkBasic(t, f, FrameTypeError) - assert.Equal(t, common.ErrorCodeApplicationError, f.ErrorCode()) + f := NewErrorFrame(_sid, core.ErrorCodeApplicationError, errData) + checkBasic(t, f, core.FrameTypeError) + assert.Equal(t, core.ErrorCodeApplicationError, f.ErrorCode()) assert.Equal(t, errData, f.ErrorData()) assert.NotEmpty(t, f.Error()) - f2 := NewErrorFrame(_sid, common.ErrorCodeApplicationError, errData) + f2 := NewErrorFrame(_sid, core.ErrorCodeApplicationError, errData) checkBytes(t, f, f2) } func TestFrameFNF(t *testing.T) { b := []byte(common.RandAlphanumeric(100)) // Without Metadata - f := NewFireAndForgetFrame(_sid, b, nil, FlagNext) - checkBasic(t, f, FrameTypeRequestFNF) + f := NewFireAndForgetFrame(_sid, b, nil, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestFNF) assert.Equal(t, b, f.Data()) metadata, ok := f.Metadata() assert.False(t, ok) assert.Nil(t, metadata) - assert.True(t, f.Header().Flag().Check(FlagNext)) - assert.False(t, f.Header().Flag().Check(FlagMetadata)) - f2 := NewFireAndForgetFrameSupport(_sid, b, nil, FlagNext) + assert.True(t, f.Header().Flag().Check(core.FlagNext)) + assert.False(t, f.Header().Flag().Check(core.FlagMetadata)) + f2 := NewFireAndForgetFrameSupport(_sid, b, nil, core.FlagNext) checkBytes(t, f, f2) // With Metadata - f = NewFireAndForgetFrame(_sid, nil, b, FlagNext) - checkBasic(t, f, FrameTypeRequestFNF) + f = NewFireAndForgetFrame(_sid, nil, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestFNF) assert.Empty(t, f.Data()) metadata, ok = f.Metadata() assert.True(t, ok) assert.Equal(t, b, metadata) - assert.True(t, f.Header().Flag().Check(FlagNext)) - assert.True(t, f.Header().Flag().Check(FlagMetadata)) - f2 = NewFireAndForgetFrameSupport(_sid, nil, b, FlagNext) + assert.True(t, f.Header().Flag().Check(core.FlagNext)) + assert.True(t, f.Header().Flag().Check(core.FlagMetadata)) + f2 = NewFireAndForgetFrameSupport(_sid, nil, b, core.FlagNext) checkBytes(t, f, f2) } @@ -62,10 +63,10 @@ func TestFrameKeepalive(t *testing.T) { pos := uint64(common.RandIntn(math.MaxInt32)) d := []byte(common.RandAlphanumeric(100)) f := NewKeepaliveFrame(pos, d, true) - checkBasic(t, f, FrameTypeKeepalive) + checkBasic(t, f, core.FrameTypeKeepalive) assert.Equal(t, d, f.Data()) assert.Equal(t, pos, f.LastReceivedPosition()) - assert.True(t, f.Header().Flag().Check(FlagRespond)) + assert.True(t, f.Header().Flag().Check(core.FlagRespond)) f2 := NewKeepaliveFrameSupport(pos, d, true) checkBytes(t, f, f2) } @@ -74,7 +75,7 @@ func TestFrameLease(t *testing.T) { metadata := []byte("foobar") n := uint32(4444) f := NewLeaseFrame(time.Second, n, metadata) - checkBasic(t, f, FrameTypeLease) + checkBasic(t, f, core.FrameTypeLease) assert.Equal(t, time.Second, f.TimeToLive()) assert.Equal(t, n, f.NumberOfRequests()) assert.Equal(t, metadata, f.Metadata()) @@ -85,7 +86,7 @@ func TestFrameLease(t *testing.T) { func TestFrameMetadataPush(t *testing.T) { metadata := []byte("foobar") f := NewMetadataPushFrame(metadata) - checkBasic(t, f, FrameTypeMetadataPush) + checkBasic(t, f, core.FrameTypeMetadataPush) metadata2, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, metadata, metadata2) @@ -95,35 +96,35 @@ func TestFrameMetadataPush(t *testing.T) { func TestPayloadFrame(t *testing.T) { b := []byte("foobar") - f := NewPayloadFrame(_sid, b, b, FlagNext) - checkBasic(t, f, FrameTypePayload) + f := NewPayloadFrame(_sid, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypePayload) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, f.Data()) assert.Equal(t, b, m) - assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) - f2 := NewPayloadFrameSupport(_sid, b, b, FlagNext) + assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) + f2 := NewPayloadFrameSupport(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) } func TestFrameRequestChannel(t *testing.T) { b := []byte("foobar") n := uint32(1) - f := NewRequestChannelFrame(_sid, n, b, b, FlagNext) - checkBasic(t, f, FrameTypeRequestChannel) + f := NewRequestChannelFrame(_sid, n, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestChannel) assert.Equal(t, n, f.InitialRequestN()) assert.Equal(t, b, f.Data()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) - f2 := NewRequestChannelFrameSupport(_sid, n, b, b, FlagNext) + f2 := NewRequestChannelFrameSupport(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } func TestFrameRequestN(t *testing.T) { n := uint32(1234) f := NewRequestNFrame(_sid, n, 0) - checkBasic(t, f, FrameTypeRequestN) + checkBasic(t, f, core.FrameTypeRequestN) assert.Equal(t, n, f.N()) f2 := NewRequestNFrameSupport(_sid, n, 0) checkBytes(t, f, f2) @@ -131,38 +132,38 @@ func TestFrameRequestN(t *testing.T) { func TestFrameRequestResponse(t *testing.T) { b := []byte("foobar") - f := NewRequestResponseFrame(_sid, b, b, FlagNext) - checkBasic(t, f, FrameTypeRequestResponse) + f := NewRequestResponseFrame(_sid, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestResponse) assert.Equal(t, b, f.Data()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) - assert.Equal(t, FlagNext|FlagMetadata, f.Header().Flag()) - f2 := NewRequestResponseFrameSupport(_sid, b, b, FlagNext) + assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) + f2 := NewRequestResponseFrameSupport(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) } func TestFrameRequestStream(t *testing.T) { b := []byte("foobar") n := uint32(1234) - f := NewRequestStreamFrame(_sid, n, b, b, FlagNext) - checkBasic(t, f, FrameTypeRequestStream) + f := NewRequestStreamFrame(_sid, n, b, b, core.FlagNext) + checkBasic(t, f, core.FrameTypeRequestStream) assert.Equal(t, b, f.Data()) assert.Equal(t, n, f.InitialRequestN()) m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) - f2 := NewRequestStreamFrameSupport(_sid, n, b, b, FlagNext) + f2 := NewRequestStreamFrameSupport(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } func TestFrameResume(t *testing.T) { - v := common.NewVersion(3, 1) + v := core.NewVersion(3, 1) token := []byte("hello") p1 := uint64(333) p2 := uint64(444) f := NewResumeFrame(v, token, p1, p2) - checkBasic(t, f, FrameTypeResume) + checkBasic(t, f, core.FrameTypeResume) assert.Equal(t, token, f.Token()) assert.Equal(t, p1, f.FirstAvailableClientPosition()) assert.Equal(t, p2, f.LastReceivedServerPosition()) @@ -175,14 +176,14 @@ func TestFrameResume(t *testing.T) { func TestFrameResumeOK(t *testing.T) { pos := uint64(1234) f := NewResumeOKFrame(pos) - checkBasic(t, f, FrameTypeResumeOK) + checkBasic(t, f, core.FrameTypeResumeOK) assert.Equal(t, pos, f.LastReceivedClientPosition()) f2 := NewResumeOKFrameSupport(pos) checkBytes(t, f, f2) } func TestFrameSetup(t *testing.T) { - v := common.NewVersion(3, 1) + v := core.NewVersion(3, 1) timeKeepalive := 20 * time.Second maxLifetime := time.Minute + 30*time.Second var token []byte @@ -191,7 +192,7 @@ func TestFrameSetup(t *testing.T) { d := []byte("你好") m := []byte("世界") f := NewSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) - checkBasic(t, f, FrameTypeSetup) + checkBasic(t, f, core.FrameTypeSetup) assert.Equal(t, v.Major(), f.Version().Major()) assert.Equal(t, v.Minor(), f.Version().Minor()) assert.Equal(t, timeKeepalive, f.TimeBetweenKeepalive()) @@ -209,10 +210,10 @@ func TestFrameSetup(t *testing.T) { checkBytes(t, f, fs) } -func checkBasic(t *testing.T, f Frame, typ FrameType) { +func checkBasic(t *testing.T, f core.Frame, typ core.FrameType) { sid := _sid switch typ { - case FrameTypeKeepalive, FrameTypeSetup, FrameTypeLease, FrameTypeResume, FrameTypeResumeOK, FrameTypeMetadataPush: + case core.FrameTypeKeepalive, core.FrameTypeSetup, core.FrameTypeLease, core.FrameTypeResume, core.FrameTypeResumeOK, core.FrameTypeMetadataPush: sid = 0 } assert.Equal(t, sid, f.Header().StreamID(), "wrong frame stream id") @@ -225,7 +226,7 @@ func checkBasic(t *testing.T, f Frame, typ FrameType) { <-f.DoneNotify() } -func checkBytes(t *testing.T, a Frame, b FrameSupport) { +func checkBytes(t *testing.T, a core.Frame, b core.FrameSupport) { assert.Equal(t, a.Len(), b.Len()) bf1, bf2 := &bytes.Buffer{}, &bytes.Buffer{} _, err := a.WriteTo(bf1) @@ -235,8 +236,8 @@ func checkBytes(t *testing.T, a Frame, b FrameSupport) { b1, b2 := bf1.Bytes(), bf2.Bytes() assert.Equal(t, b1, b2, "bytes doesn't match") bf := common.NewByteBuff() - _, _ = bf.Write(b1[HeaderLen:]) - raw := NewRawFrame(ParseFrameHeader(b1[:HeaderLen]), bf) + _, _ = bf.Write(b1[core.FrameHeaderLen:]) + raw := NewRawFrame(core.ParseFrameHeader(b1[:core.FrameHeaderLen]), bf) _, err = FromRawFrame(raw) assert.NoError(t, err, "create from raw failed") } diff --git a/internal/framing/misc.go b/core/framing/misc.go similarity index 69% rename from internal/framing/misc.go rename to core/framing/misc.go index 1cddce1..81b930f 100644 --- a/internal/framing/misc.go +++ b/core/framing/misc.go @@ -3,12 +3,13 @@ package framing import ( "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" ) // CalcPayloadFrameSize returns payload frame size. func CalcPayloadFrameSize(data, metadata []byte) int { - size := HeaderLen + len(data) + size := core.FrameHeaderLen + len(data) if n := len(metadata); n > 0 { size += 3 + n } @@ -16,38 +17,38 @@ func CalcPayloadFrameSize(data, metadata []byte) int { } // FromRawFrame creates a frame from a RawFrame. -func FromRawFrame(f *RawFrame) (frame Frame, err error) { +func FromRawFrame(f *RawFrame) (frame core.Frame, err error) { switch f.header.Type() { - case FrameTypeSetup: + case core.FrameTypeSetup: frame = &SetupFrame{RawFrame: f} - case FrameTypeKeepalive: + case core.FrameTypeKeepalive: frame = &KeepaliveFrame{RawFrame: f} - case FrameTypeRequestResponse: + case core.FrameTypeRequestResponse: frame = &RequestResponseFrame{RawFrame: f} - case FrameTypeRequestFNF: + case core.FrameTypeRequestFNF: frame = &FireAndForgetFrame{RawFrame: f} - case FrameTypeRequestStream: + case core.FrameTypeRequestStream: frame = &RequestStreamFrame{RawFrame: f} - case FrameTypeRequestChannel: + case core.FrameTypeRequestChannel: frame = &RequestChannelFrame{RawFrame: f} - case FrameTypeCancel: + case core.FrameTypeCancel: frame = &CancelFrame{RawFrame: f} - case FrameTypePayload: + case core.FrameTypePayload: frame = &PayloadFrame{RawFrame: f} - case FrameTypeMetadataPush: + case core.FrameTypeMetadataPush: frame = &MetadataPushFrame{RawFrame: f} - case FrameTypeError: + case core.FrameTypeError: frame = &ErrorFrame{RawFrame: f} - case FrameTypeRequestN: + case core.FrameTypeRequestN: frame = &RequestNFrame{RawFrame: f} - case FrameTypeLease: + case core.FrameTypeLease: frame = &LeaseFrame{RawFrame: f} - case FrameTypeResume: + case core.FrameTypeResume: frame = &ResumeFrame{RawFrame: f} - case FrameTypeResumeOK: + case core.FrameTypeResumeOK: frame = &ResumeOKFrame{RawFrame: f} default: - err = common.ErrInvalidFrame + err = core.ErrInvalidFrame } return } diff --git a/internal/framing/header.go b/core/header.go similarity index 63% rename from internal/framing/header.go rename to core/header.go index f4e5409..81713b1 100644 --- a/internal/framing/header.go +++ b/core/header.go @@ -1,4 +1,4 @@ -package framing +package core import ( "encoding/binary" @@ -8,18 +8,18 @@ import ( ) const ( - // HeaderLen is len of header. - HeaderLen = 6 + // FrameHeaderLen is len of header. + FrameHeaderLen = 6 ) -// Header is the header fo a RSocket frame. -// RSocket frames begin with a RSocket Frame Header. +// FrameHeader is the header fo a RSocket frame. +// RSocket frames begin with a RSocket Frame FrameHeader. // It includes StreamID, FrameType and Flags. -type Header [HeaderLen]byte +type FrameHeader [FrameHeaderLen]byte -func (h Header) String() string { +func (h FrameHeader) String() string { bu := strings.Builder{} - bu.WriteString("Header{id=") + bu.WriteString("FrameHeader{id=") bu.WriteString(strconv.FormatUint(uint64(h.StreamID()), 10)) bu.WriteString(",type=") bu.WriteString(h.Type().String()) @@ -30,7 +30,7 @@ func (h Header) String() string { } // Resumable returns true if frame supports resume. -func (h Header) Resumable() bool { +func (h FrameHeader) Resumable() bool { switch h.Type() { case FrameTypeRequestChannel, FrameTypeRequestStream, FrameTypeRequestResponse, FrameTypeRequestFNF, FrameTypeRequestN, FrameTypeCancel, FrameTypeError, FrameTypePayload: return true @@ -40,37 +40,37 @@ func (h Header) Resumable() bool { } // WriteTo writes frame header to a writer. -func (h Header) WriteTo(w io.Writer) (int64, error) { +func (h FrameHeader) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(h[:]) return int64(n), err } // StreamID returns StreamID. -func (h Header) StreamID() uint32 { +func (h FrameHeader) StreamID() uint32 { return binary.BigEndian.Uint32(h[:4]) } // Type returns frame type. -func (h Header) Type() FrameType { +func (h FrameHeader) Type() FrameType { return FrameType((h.n() & 0xFC00) >> 10) } // Flag returns flag of a frame. -func (h Header) Flag() FrameFlag { +func (h FrameHeader) Flag() FrameFlag { return FrameFlag(h.n() & 0x03FF) } -func (h Header) Bytes() []byte { +func (h FrameHeader) Bytes() []byte { return h[:] } -func (h Header) n() uint16 { +func (h FrameHeader) n() uint16 { return binary.BigEndian.Uint16(h[4:]) } // NewFrameHeader returns a new frame header. -func NewFrameHeader(streamID uint32, frameType FrameType, fg FrameFlag) Header { - var h [HeaderLen]byte +func NewFrameHeader(streamID uint32, frameType FrameType, fg FrameFlag) FrameHeader { + var h [FrameHeaderLen]byte binary.BigEndian.PutUint32(h[:], streamID) binary.BigEndian.PutUint16(h[4:], uint16(frameType)<<10|uint16(fg)) return h @@ -78,9 +78,9 @@ func NewFrameHeader(streamID uint32, frameType FrameType, fg FrameFlag) Header { } // ParseFrameHeader parse a header from bytes. -func ParseFrameHeader(bs []byte) Header { - _ = bs[HeaderLen-1] - var bb [HeaderLen]byte - copy(bb[:], bs[:HeaderLen]) +func ParseFrameHeader(bs []byte) FrameHeader { + _ = bs[FrameHeaderLen-1] + var bb [FrameHeaderLen]byte + copy(bb[:], bs[:FrameHeaderLen]) return bb } diff --git a/internal/framing/header_test.go b/core/header_test.go similarity index 87% rename from internal/framing/header_test.go rename to core/header_test.go index 6cc88f0..bc438c7 100644 --- a/internal/framing/header_test.go +++ b/core/header_test.go @@ -1,12 +1,12 @@ -package framing_test +package core_test import ( "bytes" "math" "testing" + . "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - . "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) @@ -25,5 +25,5 @@ func TestHeader_All(t *testing.T) { bf := &bytes.Buffer{} n, err := h2.WriteTo(bf) assert.NoError(t, err) - assert.Equal(t, int64(HeaderLen), n) + assert.Equal(t, int64(FrameHeaderLen), n) } diff --git a/internal/transport/connection_tcp.go b/core/transport/connection_tcp.go similarity index 81% rename from internal/transport/connection_tcp.go rename to core/transport/connection_tcp.go index 2998bbe..c56c63e 100644 --- a/internal/transport/connection_tcp.go +++ b/core/transport/connection_tcp.go @@ -7,8 +7,9 @@ import ( "time" "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/logger" ) @@ -16,10 +17,10 @@ type tcpConn struct { rawConn net.Conn writer *bufio.Writer decoder *LengthBasedFrameDecoder - counter *Counter + counter *core.Counter } -func (p *tcpConn) SetCounter(c *Counter) { +func (p *tcpConn) SetCounter(c *core.Counter) { p.counter = c } @@ -27,7 +28,7 @@ func (p *tcpConn) SetDeadline(deadline time.Time) error { return p.rawConn.SetReadDeadline(deadline) } -func (p *tcpConn) Read() (f framing.Frame, err error) { +func (p *tcpConn) Read() (f core.Frame, err error) { raw, err := p.decoder.Read() if err == io.EOF { return @@ -36,16 +37,16 @@ func (p *tcpConn) Read() (f framing.Frame, err error) { err = errors.Wrap(err, "read frame failed") return } - h := framing.ParseFrameHeader(raw) + h := core.ParseFrameHeader(raw) bf := common.NewByteBuff() - _, err = bf.Write(raw[framing.HeaderLen:]) + _, err = bf.Write(raw[core.FrameHeaderLen:]) if err != nil { err = errors.Wrap(err, "read frame failed") return } base := framing.NewRawFrame(h, bf) if p.counter != nil && base.Header().Resumable() { - p.counter.incrReadBytes(base.Len()) + p.counter.IncReadBytes(base.Len()) } f, err = framing.FromRawFrame(base) if err != nil { @@ -71,10 +72,10 @@ func (p *tcpConn) Flush() (err error) { return } -func (p *tcpConn) Write(frame framing.FrameSupport) (err error) { +func (p *tcpConn) Write(frame core.FrameSupport) (err error) { size := frame.Len() if p.counter != nil && frame.Header().Resumable() { - p.counter.incrWriteBytes(size) + p.counter.IncWriteBytes(size) } _, err = common.MustNewUint24(size).WriteTo(p.writer) if err != nil { diff --git a/internal/transport/connection_ws.go b/core/transport/connection_ws.go similarity index 82% rename from internal/transport/connection_ws.go rename to core/transport/connection_ws.go index 9fa6f37..933daa2 100644 --- a/internal/transport/connection_ws.go +++ b/core/transport/connection_ws.go @@ -8,8 +8,9 @@ import ( "github.com/gorilla/websocket" "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/logger" ) @@ -19,10 +20,10 @@ var _buffPool = sync.Pool{ type wsConn struct { c *websocket.Conn - counter *Counter + counter *core.Counter } -func (p *wsConn) SetCounter(c *Counter) { +func (p *wsConn) SetCounter(c *core.Counter) { p.counter = c } @@ -30,7 +31,7 @@ func (p *wsConn) SetDeadline(deadline time.Time) error { return p.c.SetReadDeadline(deadline) } -func (p *wsConn) Read() (f framing.Frame, err error) { +func (p *wsConn) Read() (f core.Frame, err error) { t, raw, err := p.c.ReadMessage() if err != nil { err = errors.Wrap(err, "read frame failed") @@ -41,13 +42,13 @@ func (p *wsConn) Read() (f framing.Frame, err error) { return p.Read() } // validate min length - if len(raw) < framing.HeaderLen { + if len(raw) < core.FrameHeaderLen { err = errors.Wrap(ErrIncompleteHeader, "read frame failed") return } - header := framing.ParseFrameHeader(raw) + header := core.ParseFrameHeader(raw) bf := common.NewByteBuff() - _, err = bf.Write(raw[framing.HeaderLen:]) + _, err = bf.Write(raw[core.FrameHeaderLen:]) if err != nil { err = errors.Wrap(err, "read frame failed") return @@ -73,7 +74,7 @@ func (p *wsConn) Flush() (err error) { return } -func (p *wsConn) Write(frame framing.FrameSupport) (err error) { +func (p *wsConn) Write(frame core.FrameSupport) (err error) { bf := _buffPool.Get().(*bytes.Buffer) defer func() { bf.Reset() diff --git a/internal/transport/decoder.go b/core/transport/decoder.go similarity index 86% rename from internal/transport/decoder.go rename to core/transport/decoder.go index 05d0655..8630e7c 100644 --- a/internal/transport/decoder.go +++ b/core/transport/decoder.go @@ -5,8 +5,8 @@ import ( "errors" "io" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) const ( @@ -32,14 +32,14 @@ func (p *LengthBasedFrameDecoder) Read() (raw []byte, err error) { return } raw = scanner.Bytes()[lengthFieldSize:] - if len(raw) < framing.HeaderLen { + if len(raw) < core.FrameHeaderLen { err = ErrIncompleteHeader } return } -func doSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF { +func doSplit(data []byte, eof bool) (advance int, token []byte, err error) { + if eof { return } if len(data) < lengthFieldSize { @@ -47,7 +47,7 @@ func doSplit(data []byte, atEOF bool) (advance int, token []byte, err error) { } frameLength := common.NewUint24Bytes(data).AsInt() if frameLength < 1 { - err = common.ErrInvalidFrameLength + err = core.ErrInvalidFrameLength return } frameSize := frameLength + lengthFieldSize diff --git a/internal/transport/decoder_test.go b/core/transport/decoder_test.go similarity index 76% rename from internal/transport/decoder_test.go rename to core/transport/decoder_test.go index 9728c55..30432d6 100644 --- a/internal/transport/decoder_test.go +++ b/core/transport/decoder_test.go @@ -6,8 +6,9 @@ import ( "fmt" "testing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) func TestDecoder(t *testing.T) { @@ -21,9 +22,9 @@ func TestDecoder(t *testing.T) { if err != nil { break } - h := framing.ParseFrameHeader(raw) + h := core.ParseFrameHeader(raw) bf := common.NewByteBuff() - _, _ = bf.Write(raw[framing.HeaderLen:]) + _, _ = bf.Write(raw[core.FrameHeaderLen:]) f, err := framing.FromRawFrame(framing.NewRawFrame(h, bf)) if err != nil { panic(err) diff --git a/internal/transport/misc.go b/core/transport/misc.go similarity index 64% rename from internal/transport/misc.go rename to core/transport/misc.go index 1c779aa..6b28984 100644 --- a/internal/transport/misc.go +++ b/core/transport/misc.go @@ -1,6 +1,7 @@ package transport import ( + "context" "net/http" "strings" ) @@ -17,3 +18,6 @@ func isClosedErr(err error) bool { } return false } + +type ToClientTransport = func(context.Context) (*Transport, error) +type ToServerTransport = func(context.Context) (ServerTransport, error) diff --git a/internal/transport/transport.go b/core/transport/transport.go similarity index 87% rename from internal/transport/transport.go rename to core/transport/transport.go index d420ce7..e223ed9 100644 --- a/internal/transport/transport.go +++ b/core/transport/transport.go @@ -3,20 +3,22 @@ package transport import ( "context" "io" + "log" "sync" "time" "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/logger" ) type ( // FrameHandler is an alias of frame handler. - FrameHandler = func(frame framing.Frame) (err error) + FrameHandler = func(frame core.Frame) (err error) // ServerTransportAcceptor is an alias of server transport handler. - ServerTransportAcceptor = func(ctx context.Context, tp *Transport) + ServerTransportAcceptor = func(ctx context.Context, tp *Transport, onClose func(*Transport)) ) var errTransportClosed = errors.New("transport closed") @@ -34,7 +36,7 @@ type ServerTransport interface { // Transport is RSocket transport which is used to carry RSocket frames. type Transport struct { - conn Conn + conn core.Conn maxLifetime time.Duration lastRcvPos uint64 once sync.Once @@ -62,7 +64,7 @@ func (p *Transport) HandleDisaster(handler FrameHandler) { } // Connection returns current connection. -func (p *Transport) Connection() Conn { +func (p *Transport) Connection() core.Conn { return p.conn } @@ -75,7 +77,7 @@ func (p *Transport) SetLifetime(lifetime time.Duration) { } // Send send a frame. -func (p *Transport) Send(frame framing.FrameSupport, flush bool) (err error) { +func (p *Transport) Send(frame core.FrameSupport, flush bool) (err error) { defer func() { // ensure frame done when send success. if err == nil { @@ -116,7 +118,7 @@ func (p *Transport) Close() (err error) { } // ReadFirst reads first frame. -func (p *Transport) ReadFirst(ctx context.Context) (frame framing.Frame, err error) { +func (p *Transport) ReadFirst(ctx context.Context) (frame core.Frame, err error) { select { case <-ctx.Done(): err = ctx.Err() @@ -134,13 +136,12 @@ func (p *Transport) ReadFirst(ctx context.Context) (frame framing.Frame, err err // Start start transport. func (p *Transport) Start(ctx context.Context) (err error) { - defer func() { - _ = p.Close() - }() + defer p.Close() L: for { select { case <-ctx.Done(): + log.Println("ctx end") err = ctx.Err() return default: @@ -234,7 +235,7 @@ func (p *Transport) HandleKeepalive(handler FrameHandler) { } // DispatchFrame delivery incoming frames. -func (p *Transport) DispatchFrame(_ context.Context, frame framing.Frame) (err error) { +func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err error) { header := frame.Header() t := header.Type() sid := header.StreamID() @@ -242,34 +243,34 @@ func (p *Transport) DispatchFrame(_ context.Context, frame framing.Frame) (err e var handler FrameHandler switch t { - case framing.FrameTypeSetup: + case core.FrameTypeSetup: p.maxLifetime = frame.(*framing.SetupFrame).MaxLifetime() handler = p.hSetup - case framing.FrameTypeResume: + case core.FrameTypeResume: handler = p.hResume - case framing.FrameTypeResumeOK: + case core.FrameTypeResumeOK: p.lastRcvPos = frame.(*framing.ResumeOKFrame).LastReceivedClientPosition() handler = p.hResumeOK - case framing.FrameTypeRequestFNF: + case core.FrameTypeRequestFNF: handler = p.hFireAndForget - case framing.FrameTypeMetadataPush: + case core.FrameTypeMetadataPush: if sid != 0 { // skip invalid metadata push logger.Warnf("rsocket.Transport: omit MetadataPush with non-zero stream id %d\n", sid) return } handler = p.hMetadataPush - case framing.FrameTypeRequestResponse: + case core.FrameTypeRequestResponse: handler = p.hRequestResponse - case framing.FrameTypeRequestStream: + case core.FrameTypeRequestStream: handler = p.hRequestStream - case framing.FrameTypeRequestChannel: + case core.FrameTypeRequestChannel: handler = p.hRequestChannel - case framing.FrameTypePayload: + case core.FrameTypePayload: handler = p.hPayload - case framing.FrameTypeRequestN: + case core.FrameTypeRequestN: handler = p.hRequestN - case framing.FrameTypeError: + case core.FrameTypeError: if sid == 0 { err = errors.New(frame.(*framing.ErrorFrame).Error()) if p.hError0 != nil { @@ -278,13 +279,13 @@ func (p *Transport) DispatchFrame(_ context.Context, frame framing.Frame) (err e return } handler = p.hError - case framing.FrameTypeCancel: + case core.FrameTypeCancel: handler = p.hCancel - case framing.FrameTypeKeepalive: + case core.FrameTypeKeepalive: ka := frame.(*framing.KeepaliveFrame) p.lastRcvPos = ka.LastReceivedPosition() handler = p.hKeepalive - case framing.FrameTypeLease: + case core.FrameTypeLease: handler = p.hLease } @@ -309,7 +310,7 @@ func (p *Transport) DispatchFrame(_ context.Context, frame framing.Frame) (err e return } -func newTransportClient(c Conn) *Transport { +func NewTransport(c core.Conn) *Transport { return &Transport{ conn: c, maxLifetime: common.DefaultKeepaliveMaxLifetime, diff --git a/internal/transport/transport_tcp.go b/core/transport/transport_tcp.go similarity index 63% rename from internal/transport/transport_tcp.go rename to core/transport/transport_tcp.go index 372efc5..12afa9b 100644 --- a/internal/transport/transport_tcp.go +++ b/core/transport/transport_tcp.go @@ -5,10 +5,7 @@ import ( "crypto/tls" "io" "net" - "os" - "os/signal" "sync" - "syscall" "github.com/pkg/errors" ) @@ -19,6 +16,7 @@ type tcpServerTransport struct { listener net.Listener onceClose sync.Once tls *tls.Config + transports *sync.Map } func (p *tcpServerTransport) Accept(acceptor ServerTransportAcceptor) { @@ -31,6 +29,12 @@ func (p *tcpServerTransport) Close() (err error) { } p.onceClose.Do(func() { err = p.listener.Close() + + p.transports.Range(func(key, value interface{}) bool { + _ = key.(*Transport).Close() + return true + }) + }) return } @@ -54,26 +58,24 @@ func (p *tcpServerTransport) Listen(ctx context.Context, notifier chan<- struct{ } func (p *tcpServerTransport) listen(ctx context.Context) (err error) { - // Remove unix socket file before exit. - if p.network == schemaUNIX { - // Monitor signal of current process and unlink unix socket file. - go func(sock string) { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - _ = p.Close() - }(p.addr) - } + done := make(chan struct{}) - stop := make(chan struct{}) - ctx, cancel := context.WithCancel(ctx) - go func(ctx context.Context, stop chan struct{}) { - defer func() { - _ = p.Close() - close(stop) - }() - <-ctx.Done() - }(ctx, stop) + defer func() { + close(done) + _ = p.Close() + }() + + go func() { + for { + select { + case <-ctx.Done(): + _ = p.Close() + return + case <-done: + return + } + } + }() // Start loop of accepting connections. var c net.Conn @@ -88,26 +90,25 @@ func (p *tcpServerTransport) listen(ctx context.Context) (err error) { break } // Dispatch raw conn. - go func(ctx context.Context, rawConn net.Conn) { - conn := newTCPRConnection(rawConn) - tp := newTransportClient(conn) - p.acceptor(ctx, tp) - }(ctx, c) + tp := NewTransport(newTCPRConnection(c)) + p.transports.Store(tp, struct{}{}) + go p.acceptor(ctx, tp, func(t *Transport) { + p.transports.Delete(t) + }) } - cancel() - <-stop return } -func newTCPServerTransport(network, addr string, c *tls.Config) *tcpServerTransport { +func NewTcpServerTransport(network, addr string, c *tls.Config) *tcpServerTransport { return &tcpServerTransport{ - network: network, - addr: addr, - tls: c, + network: network, + addr: addr, + tls: c, + transports: &sync.Map{}, } } -func newTCPClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { +func NewTcpClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { var rawConn net.Conn if tlsConfig == nil { rawConn, err = net.Dial(network, addr) @@ -117,6 +118,6 @@ func newTCPClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Tra if err != nil { return } - tp = newTransportClient(newTCPRConnection(rawConn)) + tp = NewTransport(newTCPRConnection(rawConn)) return } diff --git a/internal/transport/transport_ws.go b/core/transport/transport_ws.go similarity index 80% rename from internal/transport/transport_ws.go rename to core/transport/transport_ws.go index f22a168..a7ff7a1 100644 --- a/internal/transport/transport_ws.go +++ b/core/transport/transport_ws.go @@ -39,12 +39,13 @@ func init() { } type wsServerTransport struct { - addr string - path string - acceptor ServerTransportAcceptor - onceClose sync.Once - listener net.Listener - tls *tls.Config + addr string + path string + acceptor ServerTransportAcceptor + onceClose sync.Once + listener net.Listener + tls *tls.Config + transports *sync.Map } func (p *wsServerTransport) Close() (err error) { @@ -68,11 +69,12 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} logger.Errorf("create websocket conn failed: %s\n", err.Error()) return } - go func(c *websocket.Conn, ctx context.Context) { - conn := newWebsocketConnection(c) - tp := newTransportClient(conn) - p.acceptor(ctx, tp) - }(c, ctx) + + tp := NewTransport(newWebsocketConnection(c)) + p.transports.Store(tp, struct{}{}) + go p.acceptor(ctx, tp, func(tp *Transport) { + p.transports.Delete(tp) + }) }) if p.tls == nil { @@ -110,18 +112,19 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} return } -func newWebsocketServerTransport(addr string, path string, c *tls.Config) *wsServerTransport { +func NewWebsocketServerTransport(addr string, path string, c *tls.Config) *wsServerTransport { if path == "" { path = defaultWebsocketPath } return &wsServerTransport{ - addr: addr, - path: path, - tls: c, + addr: addr, + path: path, + tls: c, + transports: &sync.Map{}, } } -func newWebsocketClientTransport(url string, tc *tls.Config, header http.Header) (*Transport, error) { +func NewWebsocketClientTransport(url string, tc *tls.Config, header http.Header) (*Transport, error) { var d *websocket.Dialer if tc == nil { d = websocket.DefaultDialer @@ -136,5 +139,5 @@ func newWebsocketClientTransport(url string, tc *tls.Config, header http.Header) if err != nil { return nil, errors.Wrap(err, "dial websocket failed") } - return newTransportClient(newWebsocketConnection(wsConn)), nil + return NewTransport(newWebsocketConnection(wsConn)), nil } diff --git a/internal/framing/frame.go b/core/types.go similarity index 53% rename from internal/framing/frame.go rename to core/types.go index 5fb5845..a81ab49 100644 --- a/internal/framing/frame.go +++ b/core/types.go @@ -1,16 +1,11 @@ -package framing +package core import ( - "errors" - "fmt" "io" "strings" - - "github.com/rsocket/rsocket-go/internal/common" + "time" ) -var errIncompleteFrame = errors.New("incomplete frame") - // FrameType is type of frame. type FrameType uint8 @@ -116,8 +111,8 @@ func (f FrameFlag) Check(flag FrameFlag) bool { type FrameSupport interface { io.WriterTo - // Header returns frame Header. - Header() Header + // FrameHeader returns frame FrameHeader. + Header() FrameHeader // Len returns length of frame. Len() int // Done marks current frame has been sent. @@ -126,11 +121,6 @@ type FrameSupport interface { DoneNotify() <-chan struct{} } -func PrintFrame(f FrameSupport) string { - // TODO: print frame - return fmt.Sprintf("%+v", f) -} - // Frame is a single message containing a request, response, or protocol processing. type Frame interface { FrameSupport @@ -138,115 +128,18 @@ type Frame interface { Validate() error } -type tinyFrame struct { - header Header - done chan struct{} -} - -func (t *tinyFrame) Header() Header { - return t.header -} - -// Done can be invoked when a frame has been been processed. -func (t *tinyFrame) Done() (closed bool) { - defer func() { - if e := recover(); e != nil { - closed = true - } - }() - close(t.done) - return -} - -// DoneNotify notify when frame has been done. -func (t *tinyFrame) DoneNotify() <-chan struct{} { - return t.done -} - -// RawFrame is basic frame implementation. -type RawFrame struct { - *tinyFrame - body *common.ByteBuff -} - -// Body returns frame body. -func (f *RawFrame) Body() *common.ByteBuff { - return f.body -} - -// Len returns length of frame. -func (f *RawFrame) Len() int { - if f.body == nil { - return HeaderLen - } - return HeaderLen + f.body.Len() -} - -// WriteTo write frame to writer. -func (f *RawFrame) WriteTo(w io.Writer) (n int64, err error) { - var wrote int64 - wrote, err = f.header.WriteTo(w) - if err != nil { - return - } - n += wrote - if f.body != nil { - wrote, err = f.body.WriteTo(w) - if err != nil { - return - } - n += wrote - } - return -} - -func (f *RawFrame) trySeekMetadataLen(offset int) (n int, hasMetadata bool) { - raw := f.body.Bytes() - if offset > 0 { - raw = raw[offset:] - } - hasMetadata = f.header.Flag().Check(FlagMetadata) - if !hasMetadata { - return - } - if len(raw) < 3 { - n = -1 - } else { - n = common.NewUint24Bytes(raw).AsInt() - } - return -} - -func (f *RawFrame) trySliceMetadata(offset int) ([]byte, bool) { - n, ok := f.trySeekMetadataLen(offset) - if !ok || n < 0 { - return nil, false - } - return f.body.Bytes()[offset+3 : offset+3+n], true -} - -func (f *RawFrame) trySliceData(offset int) []byte { - n, ok := f.trySeekMetadataLen(offset) - if !ok { - return f.body.Bytes()[offset:] - } - if n < 0 { - return nil - } - return f.body.Bytes()[offset+n+3:] -} - -func newTinyFrame(header Header) *tinyFrame { - return &tinyFrame{ - header: header, - done: make(chan struct{}), - } -} - -// NewRawFrame returns a new RawFrame. -func NewRawFrame(header Header, body *common.ByteBuff) *RawFrame { - return &RawFrame{ - tinyFrame: newTinyFrame(header), - body: body, - } +// Conn is connection for RSocket. +type Conn interface { + io.Closer + // SetDeadline set deadline for current connection. + // After this deadline, connection will be closed. + SetDeadline(deadline time.Time) error + // SetCounter bind a counter which can count r/w bytes. + SetCounter(c *Counter) + // Read reads next frame from Conn. + Read() (Frame, error) + // Write writes a frame to Conn. + Write(FrameSupport) error + // Flush. + Flush() error } diff --git a/internal/common/version.go b/core/version.go similarity index 75% rename from internal/common/version.go rename to core/version.go index 52365e3..a847d56 100644 --- a/internal/common/version.go +++ b/core/version.go @@ -1,4 +1,4 @@ -package common +package core import ( "encoding/binary" @@ -32,6 +32,24 @@ func (p Version) Minor() uint16 { return p[1] } +func (p Version) Equals(version Version) bool { + return p.Major() == version.Major() && p.Minor() == version.Minor() +} + +func (p Version) GreaterThan(version Version) bool { + if p[0] == version[0] { + return p[1] > version[1] + } + return p[0] > version[0] +} + +func (p Version) LessThan(version Version) bool { + if p[0] == version[0] { + return p[1] < version[1] + } + return p[0] < version[0] +} + // WriteTo write raw version bytes to a writer. func (p Version) WriteTo(w io.Writer) (n int64, err error) { err = binary.Write(w, binary.BigEndian, p[0]) diff --git a/internal/common/version_test.go b/core/version_test.go similarity index 51% rename from internal/common/version_test.go rename to core/version_test.go index 9bde3ef..0cceb8b 100644 --- a/internal/common/version_test.go +++ b/core/version_test.go @@ -1,16 +1,16 @@ -package common_test +package core_test import ( "bytes" "encoding/binary" "testing" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" "github.com/stretchr/testify/assert" ) func BenchmarkVersion_String(b *testing.B) { - v := common.NewVersion(2, 3) + v := core.NewVersion(2, 3) b.ResetTimer() for i := 0; i < b.N; i++ { _ = v.String() @@ -22,7 +22,7 @@ func TestVersion(t *testing.T) { major uint16 = 2 minor uint16 = 1 ) - v := common.NewVersion(major, minor) + v := core.NewVersion(major, minor) assert.Equal(t, "2.1", v.String()) assert.Equal(t, uint16(2), v.Major(), "wrong major version") assert.Equal(t, uint16(1), v.Minor(), "wrong minor version") @@ -34,6 +34,38 @@ func TestVersion(t *testing.T) { checkBytes(t, v.Bytes(), 2, 1) } +func TestVersion_Equals(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + + assert.True(t, v1.Equals(v2)) + assert.True(t, v2.Equals(v1)) + assert.False(t, v1.Equals(v3)) + assert.False(t, v1.Equals(v4)) +} + +func TestVersion_GreaterThan(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + assert.False(t, v1.GreaterThan(v2)) + assert.False(t, v1.GreaterThan(v3)) + assert.True(t, v1.GreaterThan(v4)) +} + +func TestVersion_LessThan(t *testing.T) { + v1 := core.NewVersion(1, 3) + v2 := core.NewVersion(1, 3) + v3 := core.NewVersion(3, 1) + v4 := core.NewVersion(1, 2) + assert.False(t, v1.LessThan(v2)) + assert.True(t, v1.LessThan(v3)) + assert.False(t, v1.LessThan(v4)) +} + func checkBytes(t *testing.T, b []byte, expectMajor, expectMinor uint16) { assert.Equal(t, 4, len(b), "wrong version bytes") major := binary.BigEndian.Uint16(b[:2]) diff --git a/examples/echo/echo.go b/examples/echo/echo.go index 1d8ed6f..b04b166 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -18,10 +18,11 @@ import ( "github.com/rsocket/rsocket-go/rx/mono" ) -const ListenAt = "tcp://127.0.0.1:7878" +var MyTransporter rsocket.Transporter -//const ListenAt = "unix:///tmp/rsocket.echo.sock" -//const ListenAt = "ws://127.0.0.1:7878/echo" +func init() { + MyTransporter = rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() +} func main() { go func() { @@ -32,7 +33,7 @@ func main() { //Fragment(65535). //Resume(). OnStart(func() { - log.Println("server is listening:", ListenAt) + log.Println("server start success!") }). Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { //log.Println("SETUP BEGIN:----------------") @@ -61,7 +62,7 @@ func main() { } return responder(), nil }). - Transport(ListenAt). + Transport(MyTransporter). Serve(context.Background()) if err != nil { panic(err) diff --git a/examples/echo/echo_benchmark_test.go b/examples/echo_bench/echo_bench.go similarity index 60% rename from examples/echo/echo_benchmark_test.go rename to examples/echo_bench/echo_bench.go index 3507701..bcbb64a 100644 --- a/examples/echo/echo_benchmark_test.go +++ b/examples/echo_bench/echo_bench.go @@ -3,40 +3,53 @@ package main import ( "bytes" "context" - "fmt" + "flag" "log" - _ "net/http/pprof" + "math/rand" "sync" - "testing" "time" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/mono" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestClient_RequestResponse(t *testing.T) { - client, err := createClient(ListenAt) - require.NoError(t, err, "bad client") - defer func() { - _ = client.Close() - }() +var tp rsocket.Transporter + +func init() { + flag.Parse() + tp = rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() + rand.Seed(time.Now().UnixNano()) +} + +func main() { + var ( + n int + payloadSize int + mtu int + ) + flag.IntVar(&n, "n", 100*10000, "request amount.") + flag.IntVar(&payloadSize, "size", 1024, "payload data size.") + flag.IntVar(&mtu, "mtu", 0, "mut size, zero means disabled.") + + client, err := createClient(mtu) + if err != nil { + panic(err) + } + defer client.Close() wg := &sync.WaitGroup{} - n := 100 * 10000 + wg.Add(n) - data := []byte(common.RandAlphanumeric(1024)) + data := make([]byte, payloadSize) + rand.Read(data) now := time.Now() ctx := context.Background() sub := rx.NewSubscriber( rx.OnNext(func(input payload.Payload) { - assert.Equal(t, data, input.Data(), "data doesn't match") //m2, _ := elem.MetadataUTF8() //assert.Equal(t, m1, m2, "metadata doesn't match") wg.Done() @@ -44,20 +57,17 @@ func TestClient_RequestResponse(t *testing.T) { ) for i := 0; i < n; i++ { - m1 := []byte(fmt.Sprintf("benchmark_test_%d", i)) - client.RequestResponse(payload.New(data, m1)).SubscribeOn(scheduler.Elastic()).SubscribeWith(ctx, sub) + client.RequestResponse(payload.New(data, nil)).SubscribeOn(scheduler.Elastic()).SubscribeWith(ctx, sub) } wg.Wait() cost := time.Since(now) log.Println(n, "COST:", cost) log.Println(n, "QPS:", float64(n)/cost.Seconds()) - } -func createClient(uri string) (rsocket.Client, error) { +func createClient(mtu int) (rsocket.Client, error) { return rsocket.Connect(). - //Fragment(1024). - //Resume(). + Fragment(mtu). SetupPayload(payload.NewString("你好", "世界")). Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { return rsocket.NewAbstractSocket( @@ -70,6 +80,6 @@ func createClient(uri string) (rsocket.Client, error) { }), ) }). - Transport(uri). + Transport(tp). Start(context.Background()) } diff --git a/examples/fibonacci/main.go b/examples/fibonacci/main.go index 400e90d..0e97c5b 100644 --- a/examples/fibonacci/main.go +++ b/examples/fibonacci/main.go @@ -14,7 +14,12 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -const transportString = "tcp://127.0.0.1:7878" +var tp rsocket.Transporter + +func init() { + tp = rsocket.Tcp().Addr("127.0.0.1:7878").Build() +} + const number = 13 func main() { @@ -77,7 +82,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestStreamHandler), nil }). // specify transport - Transport(transportString). + Transport(tp). // serve will block execution unless an error occurred Serve(context.Background()) @@ -86,7 +91,7 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection - client, err := rsocket.Connect().Transport(transportString).Start(context.Background()) + client, err := rsocket.Connect().Transport(tp).Start(context.Background()) if err != nil { panic(err) } diff --git a/examples/lease/main.go b/examples/lease/main.go deleted file mode 100644 index b7023b6..0000000 --- a/examples/lease/main.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "context" - "log" - "time" - - "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/lease" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx/mono" -) - -func main() { - les, _ := lease.NewSimpleLease(10*time.Second, 7*time.Second, 1*time.Second, 5) - err := rsocket.Receive(). - Lease(les). - Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (socket rsocket.RSocket, e error) { - socket = rsocket.NewAbstractSocket( - rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - return mono.Just(msg) - }), - ) - return - }). - Transport("tcp://127.0.0.1:7878"). - Serve(context.Background()) - - if err != nil { - log.Fatal(err) - } -} diff --git a/examples/lease/main_test.go b/examples/lease/main_test.go deleted file mode 100644 index ec6c228..0000000 --- a/examples/lease/main_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package main_test - -import ( - "context" - "log" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.uber.org/atomic" - - "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/payload" -) - -func TestClientWithLease(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), 20*time.Second) - cli, err := rsocket.Connect(). - Lease(). - Transport("tcp://127.0.0.1:7878"). - Start(ctx) - if err != nil { - panic(err) - } - defer func() { - _ = cli.Close() - }() - - success := atomic.NewUint32(0) - -Loop: - for { - select { - case <-ctx.Done(): - break Loop - default: - time.Sleep(1 * time.Second) - v, err := cli.RequestResponse(payload.NewString("hello world", "go")).Block(context.Background()) - if err != nil { - log.Println("request failed:", err) - } else { - success.Inc() - log.Println("request success:", v) - } - } - } - assert.Equal(t, uint32(10), success.Load(), "bad requests") -} diff --git a/examples/word_counter/main.go b/examples/word_counter/main.go index 1767745..5fb50d8 100644 --- a/examples/word_counter/main.go +++ b/examples/word_counter/main.go @@ -13,7 +13,12 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -const transportString = "tcp://127.0.0.1:7878" +var tp rsocket.Transporter + +func init() { + tp = rsocket.Tcp().Addr("127.0.0.1:7878").Build() +} + const number = 13 func main() { @@ -54,7 +59,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestChannelHandler), nil }). // specify transport - Transport(transportString). + Transport(tp). // serve will block execution unless an error occurred Serve(context.Background()) @@ -63,7 +68,7 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection - client, err := rsocket.Connect().Transport(transportString).Start(context.Background()) + client, err := rsocket.Connect().Transport(tp).Start(context.Background()) if err != nil { panic(err) } diff --git a/fuzz.go b/fuzz.go index a62f23c..d04aa17 100644 --- a/fuzz.go +++ b/fuzz.go @@ -7,9 +7,10 @@ import ( "bytes" "errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" ) func Fuzz(data []byte) int { @@ -26,13 +27,13 @@ func Fuzz(data []byte) int { } func isExpectedError(err error) bool { - return err == common.ErrInvalidFrame || err == transport.ErrIncompleteHeader + return err == core.ErrInvalidFrame || err == transport.ErrIncompleteHeader } func handleRaw(raw []byte) (err error) { - h := framing.ParseFrameHeader(raw) + h := core.ParseFrameHeader(raw) bf := common.NewByteBuff() - var frame framing.Frame + var frame core.Frame frame, err = framing.FromRawFrame(framing.NewRawFrame(h, bf)) if err != nil { return @@ -41,7 +42,7 @@ func handleRaw(raw []byte) (err error) { if err != nil { return } - if frame.Len() >= framing.HeaderLen { + if frame.Len() >= core.FrameHeaderLen { return } err = errors.New("broken frame") diff --git a/internal/common/errors_test.go b/internal/common/errors_test.go deleted file mode 100644 index a29ea86..0000000 --- a/internal/common/errors_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package common_test - -import ( - "math" - "testing" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestErrorCode_String(t *testing.T) { - all := []common.ErrorCode{ - common.ErrorCodeInvalidSetup, - common.ErrorCodeUnsupportedSetup, - common.ErrorCodeRejectedSetup, - common.ErrorCodeRejectedResume, - common.ErrorCodeConnectionError, - common.ErrorCodeConnectionClose, - common.ErrorCodeApplicationError, - common.ErrorCodeRejected, - common.ErrorCodeCanceled, - common.ErrorCodeInvalid, - } - for _, code := range all { - assert.NotEqual(t, "UNKNOWN", code.String()) - } - assert.Equal(t, "UNKNOWN", common.ErrorCode(math.MaxUint32).String()) -} diff --git a/internal/fragmentation/joiner.go b/internal/fragmentation/joiner.go index cd208e1..0043ca2 100644 --- a/internal/fragmentation/joiner.go +++ b/internal/fragmentation/joiner.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" ) var errNoFrameInJoiner = errors.New("no frames in current joiner") @@ -14,15 +14,15 @@ type implJoiner struct { root *list.List // list of HeaderAndPayload } -func (p *implJoiner) First() framing.Frame { +func (p *implJoiner) First() core.Frame { first := p.root.Front() if first == nil { panic(errNoFrameInJoiner) } - return first.Value.(framing.Frame) + return first.Value.(core.Frame) } -func (p *implJoiner) Header() framing.Header { +func (p *implJoiner) Header() core.FrameHeader { return p.First().Header() } @@ -34,7 +34,7 @@ func (p *implJoiner) String() string { func (p *implJoiner) Metadata() (metadata []byte, ok bool) { for cur := p.root.Front(); cur != nil; cur = cur.Next() { f := cur.Value.(HeaderAndPayload) - if !f.Header().Flag().Check(framing.FlagMetadata) { + if !f.Header().Flag().Check(core.FlagMetadata) { break } if m, has := f.Metadata(); has { @@ -74,6 +74,6 @@ func (p *implJoiner) DataUTF8() (data string) { func (p *implJoiner) Push(elem HeaderAndPayload) (end bool) { p.root.PushBack(elem) h := elem.Header() - end = !h.Flag().Check(framing.FlagFollow) + end = !h.Flag().Check(core.FlagFollow) return } diff --git a/internal/fragmentation/joiner_test.go b/internal/fragmentation/joiner_test.go index eac91cc..1ea12c5 100644 --- a/internal/fragmentation/joiner_test.go +++ b/internal/fragmentation/joiner_test.go @@ -5,21 +5,22 @@ import ( "log" "testing" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" ) func TestFragmentPayload(t *testing.T) { const totals = 10 const sid = uint32(1) - fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), framing.FlagFollow|framing.FlagMetadata)) + fr := NewJoiner(framing.NewPayloadFrame(sid, []byte("(ROOT)"), []byte("(ROOT)"), core.FlagFollow|core.FlagMetadata)) for i := 0; i < totals; i++ { data := fmt.Sprintf("(data%04d)", i) var frame *framing.PayloadFrame if i < 3 { meta := fmt.Sprintf("(meta%04d)", i) - frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), framing.FlagFollow|framing.FlagMetadata) + frame = framing.NewPayloadFrame(sid, []byte(data), []byte(meta), core.FlagFollow|core.FlagMetadata) } else if i != totals-1 { - frame = framing.NewPayloadFrame(sid, []byte(data), nil, framing.FlagFollow) + frame = framing.NewPayloadFrame(sid, []byte(data), nil, core.FlagFollow) } else { frame = framing.NewPayloadFrame(sid, []byte(data), nil, 0) } diff --git a/internal/fragmentation/splitter.go b/internal/fragmentation/splitter.go index b5fc72d..a22b0a4 100644 --- a/internal/fragmentation/splitter.go +++ b/internal/fragmentation/splitter.go @@ -1,8 +1,8 @@ package fragmentation import ( + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" ) // HandleSplitResult is callback for fragmentation result. @@ -10,7 +10,7 @@ type HandleSplitResult = func(index int, result SplitResult) // SplitResult defines fragmentation result struct. type SplitResult struct { - Flag framing.FrameFlag + Flag core.FrameFlag Metadata []byte Data []byte } @@ -33,7 +33,7 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame HandleSp var follow bool for { bf = common.NewByteBuff() - left := mtu - framing.HeaderLen + left := mtu - core.FrameHeaderLen if idx == 0 && skip > 0 { left -= skip for i := 0; i < skip; i++ { @@ -57,18 +57,18 @@ func SplitSkip(mtu int, skip int, data []byte, metadata []byte, onFrame HandleSp curMetadata := metadata[begin1:cursor1] curData := data[begin2:cursor2] follow = cursor1+cursor2 < lenM+lenD - var flag framing.FrameFlag + var flag core.FrameFlag if follow { - flag |= framing.FlagFollow + flag |= core.FlagFollow } else { - flag &= ^framing.FlagFollow + flag &= ^core.FlagFollow } if hasMetadata { // metadata - flag |= framing.FlagMetadata + flag |= core.FlagMetadata } else { // non-metadata - flag &= ^framing.FlagMetadata + flag &= ^core.FlagMetadata } ch <- SplitResult{ Flag: flag, diff --git a/internal/fragmentation/splitter_test.go b/internal/fragmentation/splitter_test.go index 7f99319..443b62a 100644 --- a/internal/fragmentation/splitter_test.go +++ b/internal/fragmentation/splitter_test.go @@ -3,8 +3,9 @@ package fragmentation import ( "testing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/stretchr/testify/assert" ) @@ -26,7 +27,7 @@ func split2joiner(mtu int, data, metadata []byte) (joiner Joiner, err error) { fn := func(idx int, result SplitResult) { sid := uint32(77778888) if idx == 0 { - f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, framing.FlagComplete|result.Flag) + f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, core.FlagComplete|result.Flag) joiner = NewJoiner(f) } else { f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag) diff --git a/internal/fragmentation/types.go b/internal/fragmentation/types.go index bcd88dd..3110185 100644 --- a/internal/fragmentation/types.go +++ b/internal/fragmentation/types.go @@ -4,32 +4,32 @@ import ( "container/list" "fmt" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/payload" ) const ( // MinFragment is minimum fragment size in bytes. - MinFragment = framing.HeaderLen + 4 + MinFragment = core.FrameHeaderLen + 4 // MaxFragment is minimum fragment size in bytes. MaxFragment = common.MaxUint24 - 3 ) var errInvalidFragmentLen = fmt.Errorf("invalid fragment: [%d,%d]", MinFragment, MaxFragment) -// HeaderAndPayload is Payload which having a Header. +// HeaderAndPayload is Payload which having a FrameHeader. type HeaderAndPayload interface { payload.Payload - // Header returns a header of frame. - Header() framing.Header + // FrameHeader returns a header of frame. + Header() core.FrameHeader } // Joiner is used to join frames to a payload. type Joiner interface { HeaderAndPayload // First returns the first frame. - First() framing.Frame + First() core.Frame // Push append a new frame and returns true if joiner is end. Push(elem HeaderAndPayload) (end bool) } diff --git a/internal/socket/msg.go b/internal/socket/callback.go similarity index 100% rename from internal/socket/msg.go rename to internal/socket/callback.go diff --git a/internal/socket/client_default.go b/internal/socket/client_default.go index 31ffefd..0969bb6 100644 --- a/internal/socket/client_default.go +++ b/internal/socket/client_default.go @@ -2,22 +2,20 @@ package socket import ( "context" - "crypto/tls" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/logger" ) type defaultClientSocket struct { *baseSocket - uri *transport.URI - headers map[string][]string - tls *tls.Config + tp transport.ToClientTransport } func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { - tp, err := p.uri.MakeClientTransport(p.tls, p.headers) + tp, err := p.tp(ctx) if err != nil { return } @@ -28,7 +26,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err if setup.Lease { p.refreshLease(0, 0) - tp.HandleLease(func(frame framing.Frame) (err error) { + tp.HandleLease(func(frame core.Frame) (err error) { lease := frame.(*framing.LeaseFrame) p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) logger.Infof(">>>>> refresh lease: %v\n", lease) @@ -36,7 +34,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err }) } - tp.HandleDisaster(func(frame framing.Frame) (err error) { + tp.HandleDisaster(func(frame core.Frame) (err error) { p.socket.SetError(frame.(*framing.ErrorFrame)) return }) @@ -57,11 +55,9 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err } // NewClient create a simple client-side socket. -func NewClient(uri *transport.URI, socket *DuplexRSocket, tc *tls.Config, headers map[string][]string) ClientSocket { +func NewClient(tp transport.ToClientTransport, socket *DuplexRSocket) ClientSocket { return &defaultClientSocket{ baseSocket: newBaseSocket(socket), - uri: uri, - headers: headers, - tls: tc, + tp: tp, } } diff --git a/internal/socket/client_resume.go b/internal/socket/client_resume.go index 3ac8c0f..8636336 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/client_resume.go @@ -2,14 +2,13 @@ package socket import ( "context" - "crypto/tls" "errors" "math" "time" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/logger" "go.uber.org/atomic" ) @@ -19,10 +18,8 @@ const reconnectDelay = 1 * time.Second type resumeClientSocket struct { *baseSocket connects *atomic.Int32 - uri *transport.URI - headers map[string][]string setup *SetupInfo - tc *tls.Config + tp transport.ToClientTransport } func (p *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { @@ -50,7 +47,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { _ = p.Close() return } - tp, err := p.uri.MakeClientTransport(p.tc, p.headers) + tp, err := p.tp(ctx) if err != nil { if connects == 1 { return @@ -78,11 +75,11 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { } }(ctx, tp) - var f framing.FrameSupport + var f core.FrameSupport // connect first time. if len(p.setup.Token) < 1 || connects == 1 { - tp.HandleDisaster(func(frame framing.Frame) (err error) { + tp.HandleDisaster(func(frame core.Frame) (err error) { p.socket.SetError(frame.(*framing.ErrorFrame)) p.markClosing() return @@ -95,7 +92,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { } f = framing.NewResumeFrameSupport( - common.DefaultVersion, + core.DefaultVersion, p.setup.Token, p.socket.counter.WriteBytes(), p.socket.counter.ReadBytes(), @@ -103,15 +100,15 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { resumeErr := make(chan string) - tp.HandleResumeOK(func(frame framing.Frame) (err error) { + tp.HandleResumeOK(func(frame core.Frame) (err error) { close(resumeErr) return }) - tp.HandleDisaster(func(frame framing.Frame) (err error) { + tp.HandleDisaster(func(frame core.Frame) (err error) { // TODO: process other error with zero StreamID f := frame.(*framing.ErrorFrame) - if f.ErrorCode() == common.ErrorCodeRejectedResume { + if f.ErrorCode() == core.ErrorCodeRejectedResume { resumeErr <- f.Error() close(resumeErr) } @@ -149,12 +146,10 @@ func (p *resumeClientSocket) isClosed() bool { } // NewClientResume creates a client-side socket with resume support. -func NewClientResume(uri *transport.URI, socket *DuplexRSocket, tc *tls.Config, headers map[string][]string) ClientSocket { +func NewClientResume(tp transport.ToClientTransport, socket *DuplexRSocket) ClientSocket { return &resumeClientSocket{ baseSocket: newBaseSocket(socket), - uri: uri, - tc: tc, - headers: headers, connects: atomic.NewInt32(0), + tp: tp, } } diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index c08484e..f3e0048 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -9,10 +9,11 @@ import ( "time" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/internal/fragmentation" - "github.com/rsocket/rsocket-go/internal/framing" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/lease" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" @@ -38,10 +39,10 @@ func IsSocketClosedError(err error) bool { // DuplexRSocket represents a socket of RSocket which can be a requester or a responder. type DuplexRSocket struct { - counter *transport.Counter + counter *core.Counter tp *transport.Transport - outs chan framing.FrameSupport - outsPriority []framing.FrameSupport + outs chan core.FrameSupport + outsPriority []core.FrameSupport responder Responder messages common.U32Map sids StreamID @@ -49,7 +50,7 @@ type DuplexRSocket struct { fragments common.U32Map // key=streamID, value=Joiner closed *atomic.Bool done chan struct{} - keepaliver *keepaliver + keepaliver *Keepaliver cond *sync.Cond singleScheduler scheduler.Scheduler e error @@ -127,7 +128,7 @@ func (p *DuplexRSocket) Close() error { // FireAndForget start a request of FireAndForget. func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { data := sending.Data() - size := framing.HeaderLen + len(sending.Data()) + size := core.FrameHeaderLen + len(sending.Data()) m, ok := sending.Metadata() if ok { size += 3 + len(m) @@ -138,11 +139,11 @@ func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { return } p.doSplit(data, m, func(index int, result fragmentation.SplitResult) { - var f framing.FrameSupport + var f core.FrameSupport if index == 0 { f = framing.NewFireAndForgetFrameSupport(sid, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -179,11 +180,11 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { return } p.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { - var f framing.FrameSupport + var f core.FrameSupport if index == 0 { f = framing.NewRequestResponseFrameSupport(sid, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -235,11 +236,11 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { } p.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { - var f framing.FrameSupport + var f core.FrameSupport if index == 0 { f = framing.NewRequestStreamFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -287,7 +288,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { close(sndRequested) } if !newborn { - p.sendPayload(sid, item, framing.FlagNext) + p.sendPayload(sid, item, core.FlagNext) return } @@ -296,16 +297,16 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { size := framing.CalcPayloadFrameSize(d, m) + 4 if !p.shouldSplit(size) { metadata, _ := item.Metadata() - p.sendFrame(framing.NewRequestChannelFrameSupport(sid, n32, item.Data(), metadata, framing.FlagNext)) + p.sendFrame(framing.NewRequestChannelFrameSupport(sid, n32, item.Data(), metadata, core.FlagNext)) return } p.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { - var f framing.FrameSupport + var f core.FrameSupport if index == 0 { - f = framing.NewRequestChannelFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag|framing.FlagNext) + f = framing.NewRequestChannelFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag|core.FlagNext) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|framing.FlagNext) + f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -320,7 +321,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { // TODO: handle cancel or error switch sig { case rx.SignalComplete: - complete := framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete) + complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) p.sendFrame(complete) <-complete.DoneNotify() default: @@ -333,7 +334,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { return ret } -func (p *DuplexRSocket) onFrameRequestResponse(frame framing.Frame) error { +func (p *DuplexRSocket) onFrameRequestResponse(frame core.Frame) error { // fragment receiving, ok := p.doFragment(frame.(*framing.RequestResponseFrame)) if !ok { @@ -360,14 +361,14 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd } // 3. sending error with unsupported handler if sending == nil { - p.writeError(sid, framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestResponse)) + p.writeError(sid, framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) return nil } // 4. async subscribe publisher sub := rx.NewSubscriber( rx.OnNext(func(input payload.Payload) { - p.sendPayload(sid, input, framing.FlagNext|framing.FlagComplete) + p.sendPayload(sid, input, core.FlagNext|core.FlagComplete) }), rx.OnError(func(e error) { p.writeError(sid, e) @@ -386,7 +387,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd return nil } -func (p *DuplexRSocket) onFrameRequestChannel(input framing.Frame) error { +func (p *DuplexRSocket) onFrameRequestChannel(input core.Frame) error { receiving, ok := p.doFragment(input.(*framing.RequestChannelFrame)) if !ok { return nil @@ -441,7 +442,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) }() flux = p.responder.RequestChannel(receiving) if flux == nil { - err = framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestChannel) + err = framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) } return }() @@ -459,7 +460,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) p.writeError(sid, e) }), rx.OnComplete(func() { - complete := framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete) + complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) p.sendFrame(complete) <-complete.DoneNotify() }), @@ -469,7 +470,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) s.Request(initRequestN) }), rx.OnNext(func(elem payload.Payload) { - p.sendPayload(sid, elem, framing.FlagNext) + p.sendPayload(sid, elem, core.FlagNext) }), ) @@ -493,7 +494,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) return nil } -func (p *DuplexRSocket) respondMetadataPush(input framing.Frame) (err error) { +func (p *DuplexRSocket) respondMetadataPush(input core.Frame) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond METADATA_PUSH failed: %s\n", e) @@ -503,7 +504,7 @@ func (p *DuplexRSocket) respondMetadataPush(input framing.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameFNF(frame framing.Frame) error { +func (p *DuplexRSocket) onFrameFNF(frame core.Frame) error { receiving, ok := p.doFragment(frame.(*framing.FireAndForgetFrame)) if !ok { return nil @@ -521,7 +522,7 @@ func (p *DuplexRSocket) respondFNF(receiving fragmentation.HeaderAndPayload) (er return } -func (p *DuplexRSocket) onFrameRequestStream(frame framing.Frame) error { +func (p *DuplexRSocket) onFrameRequestStream(frame core.Frame) error { receiving, ok := p.doFragment(frame.(*framing.RequestStreamFrame)) if !ok { return nil @@ -539,7 +540,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa }() resp = p.responder.RequestStream(receiving) if resp == nil { - err = framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, unsupportedRequestStream) + err = framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestStream) } return }() @@ -563,7 +564,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa sub := rx.NewSubscriber( rx.OnNext(func(elem payload.Payload) { - p.sendPayload(sid, elem, framing.FlagNext) + p.sendPayload(sid, elem, core.FlagNext) }), rx.OnSubscribe(func(s rx.Subscription) { p.register(sid, requestStreamCallbackReverse{su: s}) @@ -573,7 +574,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa p.writeError(sid, e) }), rx.OnComplete(func() { - p.sendFrame(framing.NewPayloadFrame(sid, nil, nil, framing.FlagComplete)) + p.sendFrame(framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete)) }), ) @@ -595,10 +596,10 @@ func (p *DuplexRSocket) writeError(sid uint32, e error) { switch err := e.(type) { case *framing.ErrorFrame: p.sendFrame(err) - case common.CustomError: + case core.CustomError: p.sendFrame(framing.NewErrorFrameSupport(sid, err.ErrorCode(), err.ErrorData())) default: - p.sendFrame(framing.NewErrorFrameSupport(sid, common.ErrorCodeApplicationError, []byte(e.Error()))) + p.sendFrame(framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) } } @@ -607,9 +608,9 @@ func (p *DuplexRSocket) SetResponder(responder Responder) { p.responder = responder } -func (p *DuplexRSocket) onFrameKeepalive(frame framing.Frame) (err error) { +func (p *DuplexRSocket) onFrameKeepalive(frame core.Frame) (err error) { f := frame.(*framing.KeepaliveFrame) - if f.Header().Flag().Check(framing.FlagRespond) { + if f.Header().Flag().Check(core.FlagRespond) { k := framing.NewKeepaliveFrame(f.LastReceivedPosition(), f.Data(), false) //f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) p.sendFrame(k) @@ -617,7 +618,7 @@ func (p *DuplexRSocket) onFrameKeepalive(frame framing.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameCancel(frame framing.Frame) (err error) { +func (p *DuplexRSocket) onFrameCancel(frame core.Frame) (err error) { sid := frame.Header().StreamID() v, ok := p.messages.Load(sid) @@ -641,7 +642,7 @@ func (p *DuplexRSocket) onFrameCancel(frame framing.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { +func (p *DuplexRSocket) onFrameError(input core.Frame) (err error) { f := input.(*framing.ErrorFrame) logger.Errorf("handle error frame: %s\n", f) sid := f.Header().StreamID() @@ -665,7 +666,7 @@ func (p *DuplexRSocket) onFrameError(input framing.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameRequestN(input framing.Frame) (err error) { +func (p *DuplexRSocket) onFrameRequestN(input core.Frame) (err error) { f := input.(*framing.RequestNFrame) sid := f.Header().StreamID() v, ok := p.messages.Load(sid) @@ -702,7 +703,7 @@ func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fr } return } - ok = !h.Flag().Check(framing.FlagFollow) + ok = !h.Flag().Check(core.FlagFollow) if ok { out = input return @@ -711,23 +712,23 @@ func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fr return } -func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { +func (p *DuplexRSocket) onFramePayload(frame core.Frame) error { pl, ok := p.doFragment(frame.(*framing.PayloadFrame)) if !ok { return nil } h := pl.Header() t := h.Type() - if t == framing.FrameTypeRequestFNF { + if t == core.FrameTypeRequestFNF { return p.respondFNF(pl) } - if t == framing.FrameTypeRequestResponse { + if t == core.FrameTypeRequestResponse { return p.respondRequestResponse(pl) } - if t == framing.FrameTypeRequestStream { + if t == core.FrameTypeRequestStream { return p.respondRequestStream(pl) } - if t == framing.FrameTypeRequestChannel { + if t == core.FrameTypeRequestChannel { return p.respondRequestChannel(pl) } @@ -743,30 +744,30 @@ func (p *DuplexRSocket) onFramePayload(frame framing.Frame) error { vv.pc.Success(pl) case requestStreamCallback: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.pc.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { // Release pure complete payload vv.pc.Complete() } case requestChannelCallback: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.rcv.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { vv.rcv.Complete() } case requestChannelCallbackReverse: fg := h.Flag() - isNext := fg.Check(framing.FlagNext) + isNext := fg.Check(core.FlagNext) if isNext { vv.rcv.Next(pl) } - if fg.Check(framing.FlagComplete) { + if fg.Check(core.FlagComplete) { vv.rcv.Complete() } default: @@ -803,7 +804,7 @@ func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { p.cond.L.Unlock() } -func (p *DuplexRSocket) sendFrame(f framing.FrameSupport) { +func (p *DuplexRSocket) sendFrame(f core.FrameSupport) { defer func() { if e := recover(); e != nil { logger.Warnf("send frame failed: %s\n", e) @@ -815,7 +816,7 @@ func (p *DuplexRSocket) sendFrame(f framing.FrameSupport) { func (p *DuplexRSocket) sendPayload( sid uint32, sending payload.Payload, - frameFlag framing.FrameFlag, + frameFlag core.FrameFlag, ) { d := sending.Data() m, _ := sending.Metadata() @@ -830,7 +831,7 @@ func (p *DuplexRSocket) sendPayload( if index == 0 { flag |= frameFlag } else { - flag |= framing.FlagNext + flag |= core.FlagNext } p.sendFrame(framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, flag)) }) @@ -840,7 +841,7 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) if len(p.outs) > 0 { p.drain(nil) } - var out framing.FrameSupport + var out core.FrameSupport select { case <-p.keepaliver.C(): ok = true @@ -881,7 +882,7 @@ func (p *DuplexRSocket) drainWithKeepalive() (ok bool) { if len(p.outs) > 0 { p.drain(nil) } - var out framing.FrameSupport + var out core.FrameSupport select { case <-p.keepaliver.C(): @@ -939,7 +940,7 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { return true } -func (p *DuplexRSocket) drainOne(out framing.FrameSupport) (wrote bool) { +func (p *DuplexRSocket) drainOne(out core.FrameSupport) (wrote bool) { if p.tp == nil { p.outsPriority = append(p.outsPriority, out) return @@ -964,7 +965,7 @@ func (p *DuplexRSocket) drainOutBack() { if p.tp == nil { return } - var out framing.FrameSupport + var out core.FrameSupport for i := range p.outsPriority { out = p.outsPriority[i] if err := p.tp.Send(out, false); err != nil { @@ -1086,14 +1087,14 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { return &DuplexRSocket{ closed: atomic.NewBool(false), leases: leases, - outs: make(chan framing.FrameSupport, _outChanSize), + outs: make(chan core.FrameSupport, _outChanSize), mtu: mtu, messages: common.NewU32Map(), sids: &serverStreamIDs{}, fragments: common.NewU32MapLite(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), - counter: transport.NewCounter(), + counter: core.NewCounter(), singleScheduler: scheduler.NewSingle(64), } } @@ -1103,17 +1104,17 @@ func NewClientDuplexRSocket( mtu int, keepaliveInterval time.Duration, ) (s *DuplexRSocket) { - ka := newKeepaliver(keepaliveInterval) + ka := NewKeepaliver(keepaliveInterval) s = &DuplexRSocket{ closed: atomic.NewBool(false), - outs: make(chan framing.FrameSupport, _outChanSize), + outs: make(chan core.FrameSupport, _outChanSize), mtu: mtu, messages: common.NewU32Map(), sids: &clientStreamIDs{}, fragments: common.NewU32MapLite(), done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), - counter: transport.NewCounter(), + counter: core.NewCounter(), keepaliver: ka, singleScheduler: scheduler.NewSingle(64), } diff --git a/internal/socket/keepaliver.go b/internal/socket/keepaliver.go index a1a5f94..d7ade0c 100644 --- a/internal/socket/keepaliver.go +++ b/internal/socket/keepaliver.go @@ -4,31 +4,30 @@ import ( "time" ) -type keepaliver struct { - ticker *time.Ticker - interval time.Duration +type Keepaliver struct { + ticker *time.Ticker + done chan struct{} } -func (p *keepaliver) C() <-chan time.Time { +func (p Keepaliver) C() <-chan time.Time { return p.ticker.C } -func (p *keepaliver) Stop() { - if p.ticker != nil { - p.ticker.Stop() - } +func (p Keepaliver) Done() <-chan struct{} { + return p.done } -func (p *keepaliver) Reset(interval time.Duration) { - if interval != p.interval { - p.ticker.Stop() - p.ticker = time.NewTicker(interval) - } +func (p Keepaliver) Stop() { + defer func() { + _ = recover() + }() + p.ticker.Stop() + close(p.done) } -func newKeepaliver(interval time.Duration) *keepaliver { - return &keepaliver{ - interval: interval, - ticker: time.NewTicker(interval), +func NewKeepaliver(interval time.Duration) *Keepaliver { + return &Keepaliver{ + ticker: time.NewTicker(interval), + done: make(chan struct{}), } } diff --git a/internal/socket/keepaliver_test.go b/internal/socket/keepaliver_test.go new file mode 100644 index 0000000..3cb01eb --- /dev/null +++ b/internal/socket/keepaliver_test.go @@ -0,0 +1,33 @@ +package socket_test + +import ( + "fmt" + "testing" + "time" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +func TestKeepaliver(t *testing.T) { + k := socket.NewKeepaliver(100 * time.Millisecond) + + time.AfterFunc(time.Second+50*time.Millisecond, func() { + k.Stop() + // stop again + k.Stop() + }) + + beats := 0 +L: + for { + select { + case v := <-k.C(): + fmt.Println(v) + beats++ + case <-k.Done(): + break L + } + } + assert.Equal(t, 10, beats, "beats should be 10") +} diff --git a/internal/socket/misc.go b/internal/socket/misc.go index 3c258b1..c7a7311 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -4,15 +4,15 @@ import ( "time" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/common" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/rx" ) // SetupInfo represents basic info of setup. type SetupInfo struct { Lease bool - Version common.Version + Version core.Version KeepaliveInterval time.Duration KeepaliveLifetime time.Duration Token []byte @@ -22,7 +22,7 @@ type SetupInfo struct { Metadata []byte } -func (p *SetupInfo) toFrame() framing.FrameSupport { +func (p *SetupInfo) toFrame() core.FrameSupport { return framing.NewSetupFrameSupport( p.Version, p.KeepaliveInterval, diff --git a/internal/socket/server_default.go b/internal/socket/server_default.go index 74de7f8..84cbf7b 100644 --- a/internal/socket/server_default.go +++ b/internal/socket/server_default.go @@ -3,7 +3,7 @@ package socket import ( "context" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core/transport" ) type serverSocket struct { diff --git a/internal/socket/server_resume.go b/internal/socket/server_resume.go index decf673..9f876c7 100644 --- a/internal/socket/server_resume.go +++ b/internal/socket/server_resume.go @@ -3,7 +3,7 @@ package socket import ( "context" - "github.com/rsocket/rsocket-go/internal/transport" + "github.com/rsocket/rsocket-go/core/transport" ) type resumeServerSocket struct { diff --git a/internal/socket/smap_test.go b/internal/socket/smap_test.go deleted file mode 100644 index 3160c79..0000000 --- a/internal/socket/smap_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package socket - -import ( - "crypto/rand" - "encoding/binary" - "sync" - "testing" -) - -func nextID() uint32 { - b4 := make([]byte, 4) - _, _ = rand.Read(b4) - return uint32(maskStreamID) & binary.BigEndian.Uint32(b4) -} - -func BenchmarkLock(b *testing.B) { - var v requestChannelCallback - m := make(map[uint32]interface{}) - var lk sync.RWMutex - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - id := nextID() - switch id % 3 { - case 1: - lk.Lock() - m[id] = v - lk.Unlock() - case 2: - lk.RLock() - _ = m[id] - lk.RUnlock() - default: - lk.Lock() - delete(m, id) - lk.Unlock() - } - } - }) -} - -func BenchmarkSync(b *testing.B) { - var v requestChannelCallback - m := &sync.Map{} - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - id := nextID() - switch id % 3 { - case 1: - m.Store(id, v) - case 2: - m.Load(id) - default: - m.Delete(id) - } - } - }) - -} diff --git a/internal/socket/socket.go b/internal/socket/socket.go index f4fee83..42bd7d5 100644 --- a/internal/socket/socket.go +++ b/internal/socket/socket.go @@ -1,13 +1,10 @@ package socket import ( - "context" - "io" "sync" "time" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -23,51 +20,6 @@ var ( errUnimplementedRequestChannel = errors.New("REQUEST_CHANNEL is unimplemented") ) -// Closeable represents a closeable target. -type Closeable interface { - io.Closer - // OnClose bind a handler when closing. - OnClose(closer func(error)) -} - -// Responder is a contract providing different interaction models for RSocket protocol. -type Responder interface { - // FireAndForget is a single one-way message. - FireAndForget(message payload.Payload) - // MetadataPush sends asynchronous Metadata frame. - MetadataPush(message payload.Payload) - // RequestResponse request single response. - RequestResponse(message payload.Payload) mono.Mono - // RequestStream request a completable stream. - RequestStream(message payload.Payload) flux.Flux - // RequestChannel request a completable stream in both directions. - RequestChannel(messages rx.Publisher) flux.Flux -} - -// ClientSocket represents a client-side socket. -type ClientSocket interface { - Closeable - Responder - // Setup setups current socket. - Setup(ctx context.Context, setup *SetupInfo) (err error) -} - -// ServerSocket represents a server-side socket. -type ServerSocket interface { - Closeable - Responder - // SetResponder sets a responder for current socket. - SetResponder(responder Responder) - // SetTransport sets a transport for current socket. - SetTransport(tp *transport.Transport) - // Pause pause current socket. - Pause() bool - // Start starts current socket. - Start(ctx context.Context) error - // Token returns token of socket. - Token() (token []byte, ok bool) -} - // AbstractRSocket represents an abstract RSocket. type AbstractRSocket struct { FF func(payload.Payload) diff --git a/internal/socket/stream_id.go b/internal/socket/stream_id.go index 0282cc5..6ee3018 100644 --- a/internal/socket/stream_id.go +++ b/internal/socket/stream_id.go @@ -5,8 +5,8 @@ import ( ) const ( - maskStreamID uint64 = 0x7FFFFFFF - halfSeed uint64 = 0x40000000 + _maskStreamID uint64 = 0x7FFFFFFF + _halfSeed uint64 = 0x40000000 ) type StreamID interface { @@ -22,7 +22,7 @@ func (p *serverStreamIDs) Next() (uint32, bool) { seed := atomic.AddUint64(&p.cur, 1) v := 2 * seed if v != 0 { - return uint32(maskStreamID & v), seed <= halfSeed + return uint32(_maskStreamID & v), seed <= _halfSeed } return p.Next() } @@ -36,7 +36,7 @@ func (p *clientStreamIDs) Next() (uint32, bool) { seed := atomic.AddUint64(&p.cur, 1) v := 2*(seed-1) + 1 if v != 0 { - return uint32(maskStreamID & v), seed <= halfSeed + return uint32(_maskStreamID & v), seed <= _halfSeed } return p.Next() } diff --git a/internal/socket/types.go b/internal/socket/types.go new file mode 100644 index 0000000..c0cec5b --- /dev/null +++ b/internal/socket/types.go @@ -0,0 +1,57 @@ +package socket + +import ( + "context" + "io" + + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +// Closeable represents a closeable target. +type Closeable interface { + io.Closer + // OnClose bind a handler when closing. + OnClose(closer func(error)) +} + +// Responder is a contract providing different interaction models for RSocket protocol. +type Responder interface { + // FireAndForget is a single one-way message. + FireAndForget(message payload.Payload) + // MetadataPush sends asynchronous Metadata frame. + MetadataPush(message payload.Payload) + // RequestResponse request single response. + RequestResponse(message payload.Payload) mono.Mono + // RequestStream request a completable stream. + RequestStream(message payload.Payload) flux.Flux + // RequestChannel request a completable stream in both directions. + RequestChannel(messages rx.Publisher) flux.Flux +} + +// ClientSocket represents a client-side socket. +type ClientSocket interface { + Closeable + Responder + // Setup setups current socket. + Setup(ctx context.Context, setup *SetupInfo) (err error) +} + +// ServerSocket represents a server-side socket. +type ServerSocket interface { + Closeable + Responder + // SetResponder sets a responder for current socket. + SetResponder(responder Responder) + // SetTransport sets a transport for current socket. + SetTransport(tp *transport.Transport) + // Pause pause current socket. + Pause() bool + // Start starts current socket. + Start(ctx context.Context) error + // Token returns token of socket. + Token() (token []byte, ok bool) +} diff --git a/internal/transport/connection.go b/internal/transport/connection.go deleted file mode 100644 index d4badc4..0000000 --- a/internal/transport/connection.go +++ /dev/null @@ -1,24 +0,0 @@ -package transport - -import ( - "io" - "time" - - "github.com/rsocket/rsocket-go/internal/framing" -) - -// Conn is connection for RSocket. -type Conn interface { - io.Closer - // SetDeadline set deadline for current connection. - // After this deadline, connection will be closed. - SetDeadline(deadline time.Time) error - // SetCounter bind a counter which can count r/w bytes. - SetCounter(c *Counter) - // Read reads next frame from Conn. - Read() (framing.Frame, error) - // Write writes a frame to Conn. - Write(frames framing.FrameSupport) error - // Flush. - Flush() error -} diff --git a/lease/lease_test.go b/lease/lease_test.go new file mode 100644 index 0000000..50f54b1 --- /dev/null +++ b/lease/lease_test.go @@ -0,0 +1,81 @@ +package lease_test + +import ( + "context" + "fmt" + "log" + "testing" + "time" + + "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/lease" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx/mono" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +var _tp rsocket.Transporter + +func init() { + _tp = rsocket.Tcp().HostAndPort("127.0.0.1", 7979).Build() +} + +func Init(ctx context.Context, started chan<- struct{}) { + l, _ := lease.NewSimpleLease(10*time.Second, 7*time.Second, 1*time.Second, 5) + err := rsocket.Receive(). + Lease(l). + OnStart(func() { + close(started) + }). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { + return rsocket.NewAbstractSocket( + rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + return mono.Just(msg) + }), + ), nil + }). + Transport(_tp). + Serve(ctx) + if err != nil { + log.Fatal(err) + } +} + +func TestClientWithLease(t *testing.T) { + started := make(chan struct{}) + go Init(context.Background(), started) + <-started + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + cli, err := rsocket.Connect(). + Lease(). + Transport(_tp). + Start(ctx) + if err != nil { + require.NoError(t, err, "connect failed") + } + defer cli.Close() + + success := atomic.NewUint32(0) + +Loop: + for { + select { + case <-ctx.Done(): + break Loop + default: + time.Sleep(1 * time.Second) + v, err := cli.RequestResponse(payload.NewString("hello world", "go")).Block(context.Background()) + if err != nil { + fmt.Println("request failed:", err) + } else { + success.Inc() + fmt.Println("request success:", v) + } + } + } + assert.Equal(t, uint32(10), success.Load(), "bad requests") +} diff --git a/payload/payload.go b/payload/payload.go index e91259b..c899d51 100644 --- a/payload/payload.go +++ b/payload/payload.go @@ -4,7 +4,7 @@ import ( "io/ioutil" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" ) type ( @@ -36,7 +36,7 @@ type ( // MaxLifetime returns max lifetime of RSocket connection. MaxLifetime() time.Duration // Version return RSocket protocol version. - Version() common.Version + Version() core.Version } ) diff --git a/payload/payload_raw.go b/payload/payload_raw.go index 23c563c..e92e82c 100644 --- a/payload/payload_raw.go +++ b/payload/payload_raw.go @@ -1,46 +1,10 @@ package payload -import ( - "fmt" - "strings" - "unicode/utf8" -) - type rawPayload struct { data []byte metadata []byte } -func (p *rawPayload) String() string { - bu := strings.Builder{} - bu.WriteString("Payload{data=") - if utf8.Valid(p.data) { - bu.Write(p.data) - } else { - bu.WriteByte('[') - for _, b := range p.data { - bu.WriteString(fmt.Sprintf(" 0x%x", b)) - } - bu.WriteByte(' ') - bu.WriteByte(']') - } - bu.WriteString(",metadata=") - if len(p.metadata) > 0 { - if utf8.Valid(p.metadata) { - bu.Write(p.metadata) - } else { - bu.WriteByte('[') - for _, b := range p.metadata { - bu.WriteString(fmt.Sprintf(" 0x%x", b)) - } - bu.WriteByte(' ') - bu.WriteByte(']') - } - } - bu.WriteByte('}') - return bu.String() -} - func (p *rawPayload) Metadata() (metadata []byte, ok bool) { return p.metadata, len(p.metadata) > 0 } diff --git a/payload/payload_str.go b/payload/payload_str.go index 502c21b..4cb0fad 100644 --- a/payload/payload_str.go +++ b/payload/payload_str.go @@ -1,24 +1,10 @@ package payload -import ( - "strings" -) - type strPayload struct { data string metadata string } -func (p *strPayload) String() string { - bu := strings.Builder{} - bu.WriteString("Payload{data=") - bu.WriteString(p.data) - bu.WriteString("metadata=") - bu.WriteString(p.metadata) - bu.WriteByte('}') - return bu.String() -} - func (p *strPayload) Metadata() (metadata []byte, ok bool) { ok = len(p.metadata) > 0 if ok { diff --git a/payload/payload_test.go b/payload/payload_test.go index 771b5ef..e65664e 100644 --- a/payload/payload_test.go +++ b/payload/payload_test.go @@ -3,6 +3,7 @@ package payload_test import ( "fmt" "testing" + "unicode/utf8" "github.com/rsocket/rsocket-go/payload" "github.com/stretchr/testify/assert" @@ -40,9 +41,8 @@ func TestRawPayload(t *testing.T) { invalid := []byte{0xff, 0xfe, 0xfd} badPayload := payload.New(invalid, invalid) - s := badPayload.(fmt.Stringer).String() - fmt.Println("no utf8 payload:", s) - assert.NotEmpty(t, s) + s := badPayload.DataUTF8() + assert.False(t, utf8.Valid([]byte(s))) } func TestStrPayload(t *testing.T) { diff --git a/rsocket.go b/rsocket.go index 093b61b..c9b5a7f 100644 --- a/rsocket.go +++ b/rsocket.go @@ -1,7 +1,7 @@ package rsocket import ( - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/socket" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -11,33 +11,33 @@ import ( const ( // ErrorCodeInvalidSetup means the setup frame is invalid for the server. - ErrorCodeInvalidSetup = common.ErrorCodeInvalidSetup + ErrorCodeInvalidSetup = core.ErrorCodeInvalidSetup // ErrorCodeUnsupportedSetup means some (or all) of the parameters specified by the client are unsupported by the server. - ErrorCodeUnsupportedSetup = common.ErrorCodeUnsupportedSetup + ErrorCodeUnsupportedSetup = core.ErrorCodeUnsupportedSetup // ErrorCodeRejectedSetup means server rejected the setup, it can specify the reason in the payload. - ErrorCodeRejectedSetup = common.ErrorCodeRejectedSetup + ErrorCodeRejectedSetup = core.ErrorCodeRejectedSetup // ErrorCodeRejectedResume means server rejected the resume, it can specify the reason in the payload. - ErrorCodeRejectedResume = common.ErrorCodeRejectedResume + ErrorCodeRejectedResume = core.ErrorCodeRejectedResume // ErrorCodeConnectionError means the connection is being terminated. - ErrorCodeConnectionError = common.ErrorCodeConnectionError + ErrorCodeConnectionError = core.ErrorCodeConnectionError // ErrorCodeConnectionClose means the connection is being terminated. - ErrorCodeConnectionClose = common.ErrorCodeConnectionClose + ErrorCodeConnectionClose = core.ErrorCodeConnectionClose // ErrorCodeApplicationError means application layer logic generating a Reactive Streams onError event. - ErrorCodeApplicationError = common.ErrorCodeApplicationError + ErrorCodeApplicationError = core.ErrorCodeApplicationError // ErrorCodeRejected means Responder reject it. - ErrorCodeRejected = common.ErrorCodeRejected + ErrorCodeRejected = core.ErrorCodeRejected // ErrorCodeCanceled means the Responder canceled the request but may have started processing it (similar to REJECTED but doesn't guarantee lack of side-effects). - ErrorCodeCanceled = common.ErrorCodeCanceled + ErrorCodeCanceled = core.ErrorCodeCanceled // ErrorCodeInvalid means the request is invalid. - ErrorCodeInvalid = common.ErrorCodeInvalid + ErrorCodeInvalid = core.ErrorCodeInvalid ) // Aliases for Error defines. type ( // ErrorCode is code for RSocket error. - ErrorCode = common.ErrorCode + ErrorCode = core.ErrorCode // Error provides a method of accessing code and data. - Error = common.CustomError + Error = core.CustomError ) type ( diff --git a/rsocket_example_test.go b/rsocket_example_test.go index 9d279c3..e3acc0b 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -16,6 +16,7 @@ import ( func Example() { // Serve a server + tp := rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() err := rsocket.Receive(). Resume(). // Enable RESUME //Lease(). @@ -27,7 +28,7 @@ func Example() { }), ), nil }). - Transport("tcp://127.0.0.1:7878"). + Transport(tp). Serve(context.Background()) if err != nil { panic(err) @@ -36,7 +37,7 @@ func Example() { // Connect to a server. cli, err := rsocket.Connect(). SetupPayload(payload.NewString("Hello World", "From Golang")). - Transport("tcp://127.0.0.1:7878"). + Transport(tp). Start(context.Background()) if err != nil { panic(err) @@ -92,7 +93,7 @@ func ExampleReceive() { }), ), nil }). - Transport("tcp://0.0.0.0:7878"). + Transport(rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build()). Serve(context.Background()) panic(err) } @@ -100,7 +101,7 @@ func ExampleReceive() { func ExampleConnect() { cli, err := rsocket.Connect(). Resume(). // Enable RESUME. - Lease(). // Enable LEASE. + Lease(). // Enable LEASE. Fragment(4096). SetupPayload(payload.NewString("Hello", "World")). Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { @@ -110,7 +111,7 @@ func ExampleConnect() { }), ) }). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.Tcp().Addr("127.0.0.1:7878").Build()). Start(context.Background()) if err != nil { panic(err) diff --git a/rsocket_test.go b/rsocket_test.go index aa1ff6c..4d34011 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -40,17 +40,16 @@ func init() { var testData = "Hello World!" func TestSuite(t *testing.T) { - addresses := map[string]string{ - //"unix": "unix:///tmp/rsocket.test.sock", - "tcp": "tcp://localhost:7878", - "websocket": "ws://localhost:8080/test", + transports := map[string]Transporter{ + "tcp": Tcp().Addr("127.0.0.1:7878").Build(), + "websocket": Websocket().Url("ws://127.0.0.1:8080/test").Build(), } - for k, v := range addresses { - testAll(k, v, t) + for k, v := range transports { + testAll(t, k, v) } } -func testAll(proto string, addr string, t *testing.T) { +func testAll(t *testing.T, proto string, tp Transporter) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -112,7 +111,7 @@ func testAll(proto string, addr string, t *testing.T) { }), ), nil }). - Transport(addr). + Transport(tp). Serve(ctx) fmt.Println("SERVER STOPPED!!!!!") if err != nil { @@ -126,7 +125,7 @@ func testAll(proto string, addr string, t *testing.T) { cli, err := Connect(). Fragment(192). SetupPayload(NewString(setupData, setupMetadata)). - Transport(addr). + Transport(tp). Start(context.Background()) assert.NoError(t, err, "connect failed") defer func() { diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index e7b3553..44cc770 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -7,7 +7,7 @@ import ( "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/internal/framing" + "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" ) @@ -86,7 +86,7 @@ func (p proxy) ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e }). Subscribe(ctx, rx.OnNext(func(v payload.Payload) { - if _, ok := v.(framing.Frame); ok { + if _, ok := v.(core.Frame); ok { ch <- payload.Clone(v) } else { ch <- v diff --git a/server.go b/server.go index dc18036..f200cbd 100644 --- a/server.go +++ b/server.go @@ -2,15 +2,14 @@ package rsocket import ( "context" - "crypto/tls" "time" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/fragmentation" - "github.com/rsocket/rsocket-go/internal/framing" "github.com/rsocket/rsocket-go/internal/session" "github.com/rsocket/rsocket-go/internal/socket" - "github.com/rsocket/rsocket-go/internal/transport" "github.com/rsocket/rsocket-go/lease" "github.com/rsocket/rsocket-go/logger" ) @@ -38,46 +37,21 @@ type ( // Resume enable resume for current server. Resume(opts ...OpServerResume) ServerBuilder // Acceptor register server acceptor which is used to handle incoming RSockets. - Acceptor(acceptor ServerAcceptor) ServerTransportBuilder + Acceptor(acceptor ServerAcceptor) ToServerStarter // OnStart register a handler when serve success. OnStart(onStart func()) ServerBuilder } - // ServerTransportBuilder is used to build a RSocket server with custom Transport string. - ServerTransportBuilder interface { + // ToServerStarter is used to build a RSocket server with custom Transport string. + ToServerStarter interface { // Transport specify transport string. - Transport(transport string) Start + Transport(t Transporter) Start } // Start start a RSocket server. Start interface { // Serve serve RSocket server. Serve(ctx context.Context) error - // Serve serve RSocket server with TLS. - // - // You can generate cert.pem and key.pem for local testing: - // - // go run $GOROOT/src/crypto/tls/generate_cert.go --host localhost - // - // Load X509 - // cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") - // if err != nil { - // panic(err) - // } - // // Init TLS configuration. - // tc := &tls.Config{ - // MinVersion: tls.VersionTLS12, - // CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - // PreferServerCipherSuites: true, - // CipherSuites: []uint16{ - // tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - // tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - // tls.TLS_RSA_WITH_AES_256_CBC_SHA, - // }, - // Certificates: []tls.Certificate{cert}, - // } - ServeTLS(ctx context.Context, c *tls.Config) error } ) @@ -99,9 +73,9 @@ type serverResumeOptions struct { } type server struct { + tp transport.ToServerTransport resumeOpts *serverResumeOptions fragment int - addr string acc ServerAcceptor sm *session.Manager done chan struct{} @@ -134,34 +108,23 @@ func (p *server) Fragment(mtu int) ServerBuilder { return p } -func (p *server) Acceptor(acceptor ServerAcceptor) ServerTransportBuilder { +func (p *server) Acceptor(acceptor ServerAcceptor) ToServerStarter { p.acc = acceptor return p } -func (p *server) Transport(transport string) Start { - p.addr = transport +func (p *server) Transport(t Transporter) Start { + p.tp = t.Server() return p } -func (p *server) ServeTLS(ctx context.Context, c *tls.Config) error { - return p.serve(ctx, c) -} - func (p *server) Serve(ctx context.Context) error { - return p.serve(ctx, nil) -} - -func (p *server) serve(ctx context.Context, tc *tls.Config) error { - u, err := transport.ParseURI(p.addr) + err := fragmentation.IsValidFragment(p.fragment) if err != nil { return err } - err = fragmentation.IsValidFragment(p.fragment) - if err != nil { - return err - } - t, err := u.MakeServerTransport(tc) + + t, err := p.tp(ctx) if err != nil { return err } @@ -173,8 +136,8 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { go func(ctx context.Context) { _ = p.loopCleanSession(ctx) }(ctx) - - t.Accept(func(ctx context.Context, tp *transport.Transport) { + t.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) socketChan := make(chan socket.ServerSocket, 1) defer func() { select { @@ -222,7 +185,7 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { } }(ctx, sendingSocket) default: - err := framing.NewErrorFrameSupport(0, common.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) + err := framing.NewErrorFrameSupport(0, core.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) _ = tp.Send(err, true) _ = tp.Close() return @@ -244,16 +207,16 @@ func (p *server) serve(ctx context.Context, tc *tls.Config) error { func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) (sendingSocket socket.ServerSocket, err *framing.ErrorFrameSupport) { - if frame.Header().Flag().Check(framing.FlagLease) && p.leases == nil { - err = framing.NewErrorFrameSupport(0, common.ErrorCodeUnsupportedSetup, errUnavailableLease) + if frame.Header().Flag().Check(core.FlagLease) && p.leases == nil { + err = framing.NewErrorFrameSupport(0, core.ErrorCodeUnsupportedSetup, errUnavailableLease) return } - isResume := frame.Header().Flag().Check(framing.FlagResume) + isResume := frame.Header().Flag().Check(core.FlagResume) // 1. receive a token but server doesn't support resume. if isResume && !p.resumeOpts.enable { - err = framing.NewErrorFrameSupport(0, common.ErrorCodeUnsupportedSetup, errUnavailableResume) + err = framing.NewErrorFrameSupport(0, core.ErrorCodeUnsupportedSetup, errUnavailableResume) return } @@ -263,7 +226,7 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc if !isResume { sendingSocket = socket.NewServer(rawSocket) if responder, e := p.acc(frame, sendingSocket); e != nil { - err = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedSetup, []byte(e.Error())) + err = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedSetup, []byte(e.Error())) } else { sendingSocket.SetResponder(responder) sendingSocket.SetTransport(tp) @@ -276,7 +239,7 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc // 3. resume reject because of duplicated token. if _, ok := p.sm.Load(token); ok { - err = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedSetup, errDuplicatedSetupToken) + err = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedSetup, errDuplicatedSetupToken) return } @@ -288,7 +251,7 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc case *framing.ErrorFrame: err = framing.NewErrorFrameSupport(0, vv.ErrorCode(), vv.ErrorData()) default: - err = framing.NewErrorFrameSupport(0, common.ErrorCodeInvalidSetup, []byte(e.Error())) + err = framing.NewErrorFrameSupport(0, core.ErrorCodeInvalidSetup, []byte(e.Error())) } } else { sendingSocket.SetResponder(responder) @@ -299,9 +262,9 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc } func (p *server) doResume(frame *framing.ResumeFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { - var sending framing.FrameSupport + var sending core.FrameSupport if !p.resumeOpts.enable { - sending = framing.NewErrorFrameSupport(0, common.ErrorCodeRejectedResume, errUnavailableResume) + sending = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedResume, errUnavailableResume) } else if s, ok := p.sm.Load(frame.Token()); ok { sending = framing.NewResumeOKFrameSupport(0) s.Socket().SetTransport(tp) @@ -312,7 +275,7 @@ func (p *server) doResume(frame *framing.ResumeFrame, tp *transport.Transport, s } else { sending = framing.NewErrorFrameSupport( 0, - common.ErrorCodeRejectedResume, + core.ErrorCodeRejectedResume, []byte("no such session"), ) } diff --git a/transporter.go b/transporter.go new file mode 100644 index 0000000..f8c3371 --- /dev/null +++ b/transporter.go @@ -0,0 +1,181 @@ +package rsocket + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "net/url" + "os" + + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" +) + +type Transporter interface { + Client() transport.ToClientTransport + Server() transport.ToServerTransport +} + +type tcpTransporter struct { + addr string + tls *tls.Config +} + +type TcpTransporterBuilder struct { + opts []func(*tcpTransporter) +} + +func (t *tcpTransporter) Server() transport.ToServerTransport { + return func(ctx context.Context) (transport.ServerTransport, error) { + return transport.NewTcpServerTransport("tcp", t.addr, t.tls), nil + } +} + +func (t *tcpTransporter) Client() transport.ToClientTransport { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewTcpClientTransport("tcp", t.addr, t.tls) + } +} + +func (t *TcpTransporterBuilder) Addr(addr string) *TcpTransporterBuilder { + t.opts = append(t.opts, func(transporter *tcpTransporter) { + transporter.addr = addr + }) + return t +} + +func (t *TcpTransporterBuilder) HostAndPort(host string, port int) *TcpTransporterBuilder { + return t.Addr(fmt.Sprintf("%s:%d", host, port)) +} + +func (t *TcpTransporterBuilder) TLS(config *tls.Config) *TcpTransporterBuilder { + t.opts = append(t.opts, func(transporter *tcpTransporter) { + transporter.tls = config + }) + return t +} + +func (t *TcpTransporterBuilder) Build() Transporter { + tp := &tcpTransporter{ + addr: ":7878", + tls: nil, + } + for _, opt := range t.opts { + opt(tp) + } + return tp +} + +type wsTransporter struct { + url string + tls *tls.Config + header http.Header +} + +type WebsocketTransporterBuilder struct { + opts []func(*wsTransporter) +} + +func (w *WebsocketTransporterBuilder) Header(header http.Header) *WebsocketTransporterBuilder { + w.opts = append(w.opts, func(transporter *wsTransporter) { + transporter.header = header + }) + return w +} + +func (w *WebsocketTransporterBuilder) Url(url string) *WebsocketTransporterBuilder { + w.opts = append(w.opts, func(transporter *wsTransporter) { + transporter.url = url + }) + return w +} + +func (w *WebsocketTransporterBuilder) TLS(config *tls.Config) *WebsocketTransporterBuilder { + w.opts = append(w.opts, func(transporter *wsTransporter) { + transporter.tls = config + }) + return w +} + +func (w *WebsocketTransporterBuilder) Build() Transporter { + ws := &wsTransporter{ + url: "", + } + for _, opt := range w.opts { + opt(ws) + } + return ws +} + +func (w *wsTransporter) Server() transport.ToServerTransport { + return func(ctx context.Context) (transport.ServerTransport, error) { + u, err := url.Parse(w.url) + if err != nil { + return nil, err + } + port := u.Port() + if len(port) < 1 { + return nil, errors.New("missing websocket port") + } + return transport.NewWebsocketServerTransport(fmt.Sprintf("%s:%s", u.Hostname(), port), u.Path, w.tls), nil + } +} + +func (w *wsTransporter) Client() transport.ToClientTransport { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewWebsocketClientTransport(w.url, w.tls, w.header) + } +} + +type UnixTransporter struct { + path string +} + +type UnixTransporterBuilder struct { + opts []func(*UnixTransporter) +} + +func (u *UnixTransporter) Server() transport.ToServerTransport { + return func(ctx context.Context) (transport.ServerTransport, error) { + if _, err := os.Stat(u.path); !os.IsNotExist(err) { + return nil, err + } + return transport.NewTcpServerTransport("unix", u.path, nil), nil + } +} + +func (u *UnixTransporter) Client() transport.ToClientTransport { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewTcpClientTransport("unix", u.path, nil) + } +} + +func (u *UnixTransporterBuilder) Path(path string) *UnixTransporterBuilder { + u.opts = append(u.opts, func(transporter *UnixTransporter) { + transporter.path = path + }) + return u +} + +func (u *UnixTransporterBuilder) Build() Transporter { + tp := &UnixTransporter{ + path: "/var/run/rsocket.sock", + } + for _, opt := range u.opts { + opt(tp) + } + return tp +} + +func Tcp() *TcpTransporterBuilder { + return &TcpTransporterBuilder{} +} + +func Websocket() *WebsocketTransporterBuilder { + return &WebsocketTransporterBuilder{} +} + +func Unix() *UnixTransporterBuilder { + return &UnixTransporterBuilder{} +} diff --git a/transporter_test.go b/transporter_test.go new file mode 100644 index 0000000..a485b21 --- /dev/null +++ b/transporter_test.go @@ -0,0 +1,30 @@ +package rsocket_test + +import ( + "context" + "fmt" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/rsocket/rsocket-go" + "github.com/stretchr/testify/assert" +) + +func TestUnix(t *testing.T) { + sockFile := fmt.Sprintf("%s/test-rsocket-%s.sock", strings.TrimRight(os.TempDir(), "/"), uuid.New().String()) + defer os.Remove(sockFile) + u := rsocket.Unix().Path(sockFile).Build() + assert.NotNil(t, u) + _, err := u.Server()(context.Background()) + assert.NoError(t, err) +} + +func TestTcp(t *testing.T) { + rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() +} + +func TestWebsocket(t *testing.T) { + rsocket.Websocket() +} From 198429aa6f8072f94993e213a1c4dd6f1ce1dbf1 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 12 Jul 2020 00:20:02 +0800 Subject: [PATCH 09/26] Upgrade reactor-go and switch default scheduler to Parallel. --- README.md | 6 +- balancer/round_robin_test.go | 2 +- examples/echo/echo.go | 4 +- examples/echo_bench/echo_bench.go | 2 +- go.mod | 3 +- go.sum | 96 ++++++++++++++++++++++++++----- internal/socket/duplex.go | 8 +-- rsocket_example_test.go | 2 +- rx/flux/flux_test.go | 2 +- rx/mono/mono_test.go | 2 +- 10 files changed, 100 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 821ae84..2924282 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ func main() { }), ), nil }). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.Tcp().Addr(":7878").Build()). Serve(context.Background()) panic(err) } @@ -72,7 +72,7 @@ func main() { Resume(). Fragment(1024). SetupPayload(payload.NewString("Hello", "World")). - Transport("tcp://127.0.0.1:7878"). + Transport(rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build()). Start(context.Background()) if err != nil { panic(err) @@ -137,7 +137,7 @@ func main() { // Do something here... fmt.Println("bingo:", input) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) <-done diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index d21c8b8..bd9df35 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -79,7 +79,7 @@ func TestRoundRobin(t *testing.T) { DoOnError(func(e error) { assert.Fail(t, "should never run here") }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) } wg.Wait() diff --git a/examples/echo/echo.go b/examples/echo/echo.go index b04b166..d8e61ec 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -26,6 +27,7 @@ func init() { func main() { go func() { + http.Handle("/metrics", promhttp.Handler()) log.Println(http.ListenAndServe(":4444", nil)) }() //logger.SetLevel(logger.LevelDebug) @@ -140,7 +142,7 @@ func responder() rsocket.RSocket { //return payloads.(flux.Flux) payloads.(flux.Flux). //LimitRate(1). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). DoOnNext(func(elem payload.Payload) { log.Println("receiving:", elem) }). diff --git a/examples/echo_bench/echo_bench.go b/examples/echo_bench/echo_bench.go index bcbb64a..9b07cb2 100644 --- a/examples/echo_bench/echo_bench.go +++ b/examples/echo_bench/echo_bench.go @@ -57,7 +57,7 @@ func main() { ) for i := 0; i < n; i++ { - client.RequestResponse(payload.New(data, nil)).SubscribeOn(scheduler.Elastic()).SubscribeWith(ctx, sub) + client.RequestResponse(payload.New(data, nil)).SubscribeOn(scheduler.Parallel()).SubscribeWith(ctx, sub) } wg.Wait() cost := time.Since(now) diff --git a/go.mod b/go.mod index 8a1bf19..c3e426e 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,9 @@ go 1.12 require ( github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 - github.com/jjeffcaii/reactor-go v0.1.3 + github.com/jjeffcaii/reactor-go v0.1.4 github.com/pkg/errors v0.9.1 + github.com/prometheus/client_golang v1.7.1 github.com/stretchr/testify v1.4.0 github.com/urfave/cli/v2 v2.1.1 go.uber.org/atomic v1.5.1 diff --git a/go.sum b/go.sum index 7a23336..ba42ec0 100644 --- a/go.sum +++ b/go.sum @@ -1,47 +1,117 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jjeffcaii/reactor-go v0.1.3 h1:HPvOkeoH1Z11t0TlWIyYuQkbSG/9/e3LgTN4QuLvPFs= -github.com/jjeffcaii/reactor-go v0.1.3/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/jjeffcaii/reactor-go v0.1.4 h1:/M2Mjy72u+4Q9PQpq/i4bxFpXjaR1pxUh1GfMXUZa1A= +github.com/jjeffcaii/reactor-go v0.1.4/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/panjf2000/ants/v2 v2.4.1 h1:7RtUqj5lGOw0WnZhSKDZ2zzJhaX5490ZW1sUolRXCxY= github.com/panjf2000/ants/v2 v2.4.1/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index f3e0048..2cd8d96 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -328,7 +328,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { panic(fmt.Errorf("unsupported sending channel signal: %s", sig)) } }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) }) return ret @@ -382,7 +382,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd DoFinally(func(sig rx.SignalType) { p.unregister(sid) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } @@ -487,7 +487,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) default: } }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) <-mustSub @@ -583,7 +583,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa DoFinally(func(s rx.SignalType) { p.unregister(sid) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } diff --git a/rsocket_example_test.go b/rsocket_example_test.go index e3acc0b..f5aa4d7 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -70,7 +70,7 @@ func ExampleReceive() { DoOnSuccess(func(elem payload.Payload) { log.Println("response of Ping from client:", elem) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) // Return responser which just echo. return rsocket.NewAbstractSocket( diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 8921ed7..674f4cb 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -271,7 +271,7 @@ func TestFluxProcessorWithRequest(t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) <-done } diff --git a/rx/mono/mono_test.go b/rx/mono/mono_test.go index 008d383..ee8507c 100644 --- a/rx/mono/mono_test.go +++ b/rx/mono/mono_test.go @@ -66,7 +66,7 @@ func TestProxy_SubscribeOn(t *testing.T) { sink.Success(payload.NewString("foo", "bar")) }) }). - SubscribeOn(scheduler.Elastic()). + SubscribeOn(scheduler.Parallel()). DoOnSuccess(func(i payload.Payload) { log.Println("success:", i) }). From 6159d3adcd127143f9f62a125ef3f0ef5dd470c0 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Mon, 13 Jul 2020 23:22:02 +0800 Subject: [PATCH 10/26] Simplify transport struct. --- .../{connection_tcp.go => tcp_conn.go} | 0 .../{transport_tcp.go => tcp_transport.go} | 0 core/transport/transport.go | 160 +++++------------- core/transport/transport_test.go | 61 +++++++ .../{connection_ws.go => websocket_conn.go} | 0 ...transport_ws.go => websocket_transport.go} | 0 go.sum | 18 ++ internal/common/u32map.go | 126 -------------- internal/common/u32map_test.go | 63 ------- internal/socket/client_default.go | 4 +- internal/socket/client_resume.go | 7 +- internal/socket/duplex.go | 43 ++--- 12 files changed, 148 insertions(+), 334 deletions(-) rename core/transport/{connection_tcp.go => tcp_conn.go} (100%) rename core/transport/{transport_tcp.go => tcp_transport.go} (100%) create mode 100644 core/transport/transport_test.go rename core/transport/{connection_ws.go => websocket_conn.go} (100%) rename core/transport/{transport_ws.go => websocket_transport.go} (100%) delete mode 100644 internal/common/u32map.go delete mode 100644 internal/common/u32map_test.go diff --git a/core/transport/connection_tcp.go b/core/transport/tcp_conn.go similarity index 100% rename from core/transport/connection_tcp.go rename to core/transport/tcp_conn.go diff --git a/core/transport/transport_tcp.go b/core/transport/tcp_transport.go similarity index 100% rename from core/transport/transport_tcp.go rename to core/transport/tcp_transport.go diff --git a/core/transport/transport.go b/core/transport/transport.go index e223ed9..77348f0 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -3,7 +3,6 @@ package transport import ( "context" "io" - "log" "sync" "time" @@ -14,15 +13,14 @@ import ( "github.com/rsocket/rsocket-go/logger" ) -type ( - // FrameHandler is an alias of frame handler. - FrameHandler = func(frame core.Frame) (err error) - // ServerTransportAcceptor is an alias of server transport handler. - ServerTransportAcceptor = func(ctx context.Context, tp *Transport, onClose func(*Transport)) -) - var errTransportClosed = errors.New("transport closed") +// FrameHandler is an alias of frame handler. +type FrameHandler = func(frame core.Frame) (err error) + +// ServerTransportAcceptor is an alias of server transport handler. +type ServerTransportAcceptor = func(ctx context.Context, tp *Transport, onClose func(*Transport)) + // ServerTransport is server-side RSocket transport. type ServerTransport interface { io.Closer @@ -34,33 +32,39 @@ type ServerTransport interface { Listen(ctx context.Context, notifier chan<- struct{}) error } +type EventType int + +const ( + OnSetup EventType = iota + OnResume + OnLease + OnResumeOK + OnFireAndForget + OnMetadataPush + OnRequestResponse + OnRequestStream + OnRequestChannel + OnPayload + OnRequestN + OnError + OnErrorWithZeroStreamID + OnCancel + OnKeepalive + + handlerLen = int(OnKeepalive) + 1 +) + // Transport is RSocket transport which is used to carry RSocket frames. type Transport struct { conn core.Conn maxLifetime time.Duration lastRcvPos uint64 once sync.Once - - hSetup FrameHandler - hResume FrameHandler - hLease FrameHandler - hResumeOK FrameHandler - hFireAndForget FrameHandler - hMetadataPush FrameHandler - hRequestResponse FrameHandler - hRequestStream FrameHandler - hRequestChannel FrameHandler - hPayload FrameHandler - hRequestN FrameHandler - hError FrameHandler - hError0 FrameHandler - hCancel FrameHandler - hKeepalive FrameHandler + handlers [handlerLen]FrameHandler } -// HandleDisaster registers handler when receiving frame of DISASTER Error with zero StreamID. -func (p *Transport) HandleDisaster(handler FrameHandler) { - p.hError0 = handler +func (p *Transport) RegisterHandler(event EventType, handler FrameHandler) { + p.handlers[int(event)] = handler } // Connection returns current connection. @@ -141,7 +145,6 @@ L: for { select { case <-ctx.Done(): - log.Println("ctx end") err = ctx.Err() return default: @@ -165,75 +168,6 @@ L: return } -// HandleSetup registers handler when receiving a frame of Setup. -func (p *Transport) HandleSetup(handler FrameHandler) { - p.hSetup = handler -} - -// HandleResume registers handler when receiving a frame of Resume. -func (p *Transport) HandleResume(handler FrameHandler) { - p.hResume = handler -} - -func (p *Transport) HandleLease(handler FrameHandler) { - p.hLease = handler -} - -// HandleResumeOK registers handler when receiving a frame of ResumeOK. -func (p *Transport) HandleResumeOK(handler FrameHandler) { - p.hResumeOK = handler -} - -// HandleFNF registers handler when receiving a frame of FireAndForget. -func (p *Transport) HandleFNF(handler FrameHandler) { - p.hFireAndForget = handler -} - -// HandleMetadataPush registers handler when receiving a frame of MetadataPush. -func (p *Transport) HandleMetadataPush(handler FrameHandler) { - p.hMetadataPush = handler -} - -// HandleRequestResponse registers handler when receiving a frame of RequestResponse. -func (p *Transport) HandleRequestResponse(handler FrameHandler) { - p.hRequestResponse = handler -} - -// HandleRequestStream registers handler when receiving a frame of RequestStream. -func (p *Transport) HandleRequestStream(handler FrameHandler) { - p.hRequestStream = handler -} - -// HandleRequestChannel registers handler when receiving a frame of RequestChannel. -func (p *Transport) HandleRequestChannel(handler FrameHandler) { - p.hRequestChannel = handler -} - -// HandlePayload registers handler when receiving a frame of Payload. -func (p *Transport) HandlePayload(handler FrameHandler) { - p.hPayload = handler -} - -// HandleRequestN registers handler when receiving a frame of RequestN. -func (p *Transport) HandleRequestN(handler FrameHandler) { - p.hRequestN = handler -} - -// HandleError registers handler when receiving a frame of Error. -func (p *Transport) HandleError(handler FrameHandler) { - p.hError = handler -} - -// HandleCancel registers handler when receiving a frame of Cancel. -func (p *Transport) HandleCancel(handler FrameHandler) { - p.hCancel = handler -} - -// HandleKeepalive registers handler when receiving a frame of Keepalive. -func (p *Transport) HandleKeepalive(handler FrameHandler) { - p.hKeepalive = handler -} - // DispatchFrame delivery incoming frames. func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err error) { header := frame.Header() @@ -245,48 +179,48 @@ func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err erro switch t { case core.FrameTypeSetup: p.maxLifetime = frame.(*framing.SetupFrame).MaxLifetime() - handler = p.hSetup + handler = p.handlers[OnSetup] case core.FrameTypeResume: - handler = p.hResume + handler = p.handlers[OnResume] case core.FrameTypeResumeOK: p.lastRcvPos = frame.(*framing.ResumeOKFrame).LastReceivedClientPosition() - handler = p.hResumeOK + handler = p.handlers[OnResumeOK] case core.FrameTypeRequestFNF: - handler = p.hFireAndForget + handler = p.handlers[OnFireAndForget] case core.FrameTypeMetadataPush: if sid != 0 { // skip invalid metadata push logger.Warnf("rsocket.Transport: omit MetadataPush with non-zero stream id %d\n", sid) return } - handler = p.hMetadataPush + handler = p.handlers[OnMetadataPush] case core.FrameTypeRequestResponse: - handler = p.hRequestResponse + handler = p.handlers[OnRequestResponse] case core.FrameTypeRequestStream: - handler = p.hRequestStream + handler = p.handlers[OnRequestStream] case core.FrameTypeRequestChannel: - handler = p.hRequestChannel + handler = p.handlers[OnRequestChannel] case core.FrameTypePayload: - handler = p.hPayload + handler = p.handlers[OnPayload] case core.FrameTypeRequestN: - handler = p.hRequestN + handler = p.handlers[OnRequestN] case core.FrameTypeError: if sid == 0 { err = errors.New(frame.(*framing.ErrorFrame).Error()) - if p.hError0 != nil { - _ = p.hError0(frame) + if call := p.handlers[OnErrorWithZeroStreamID]; call != nil { + _ = call(frame) } return } - handler = p.hError + handler = p.handlers[OnError] case core.FrameTypeCancel: - handler = p.hCancel + handler = p.handlers[OnCancel] case core.FrameTypeKeepalive: ka := frame.(*framing.KeepaliveFrame) p.lastRcvPos = ka.LastReceivedPosition() - handler = p.hKeepalive + handler = p.handlers[OnKeepalive] case core.FrameTypeLease: - handler = p.hLease + handler = p.handlers[OnLease] } // Set deadline. diff --git a/core/transport/transport_test.go b/core/transport/transport_test.go new file mode 100644 index 0000000..30049cb --- /dev/null +++ b/core/transport/transport_test.go @@ -0,0 +1,61 @@ +package transport_test + +import ( + "bytes" + "time" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/common" +) + +type mockConn struct { + spy map[string]int + c chan core.FrameSupport +} + +func (m *mockConn) call(fn string) { + m.spy[fn] = m.spy[fn] + 1 +} + +func (m *mockConn) Close() error { + m.call("Close") + return nil +} + +func (m *mockConn) SetDeadline(deadline time.Time) error { + m.call("SetDeadline") + return nil +} + +func (m *mockConn) SetCounter(c *core.Counter) { + m.call("SetCounter") +} + +func (m *mockConn) Read() (next core.Frame, err error) { + f := <-m.c + bf := &bytes.Buffer{} + _, err = f.WriteTo(bf) + if err != nil { + return + } + bs := bf.Bytes() + header := core.ParseFrameHeader(bs) + bb := common.NewByteBuff() + _, err = bb.Write(bs[core.FrameHeaderLen:]) + if err != nil { + return + } + next, err = framing.FromRawFrame(framing.NewRawFrame(header, bb)) + return +} + +func (m *mockConn) Write(support core.FrameSupport) (err error) { + m.c <- support + return +} + +func (m *mockConn) Flush() (err error) { + m.call("Flush") + return +} diff --git a/core/transport/connection_ws.go b/core/transport/websocket_conn.go similarity index 100% rename from core/transport/connection_ws.go rename to core/transport/websocket_conn.go diff --git a/core/transport/transport_ws.go b/core/transport/websocket_transport.go similarity index 100% rename from core/transport/transport_ws.go rename to core/transport/websocket_transport.go diff --git a/go.sum b/go.sum index ba42ec0..27fc68f 100644 --- a/go.sum +++ b/go.sum @@ -5,10 +5,14 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -24,6 +28,7 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -43,6 +48,7 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -55,19 +61,26 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1 h1:NTGy1Ja9pByO+xAeH/qiWnLrKtr3hJPNjaVUwnjpdpA= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -75,7 +88,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= @@ -95,6 +110,7 @@ golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -106,6 +122,7 @@ google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -114,4 +131,5 @@ gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/common/u32map.go b/internal/common/u32map.go deleted file mode 100644 index a4ea297..0000000 --- a/internal/common/u32map.go +++ /dev/null @@ -1,126 +0,0 @@ -package common - -import ( - "sync" -) - -const _slots = 2 * 2 * 2 * 2 - -type U32Map interface { - Clear() - Range(fn func(k uint32, v interface{}) bool) - Load(key uint32) (v interface{}, ok bool) - Store(key uint32, value interface{}) - Delete(key uint32) -} - -type u32map struct { - slots [_slots]*u32slot -} - -func (u *u32map) Clear() { - for _, slot := range u.slots { - slot.Clear() - } -} - -func (u *u32map) Range(fn func(k uint32, v interface{}) bool) { - for _, slot := range u.slots { - if !slot.innerRange(fn) { - return - } - } -} - -func (u *u32map) Load(key uint32) (v interface{}, ok bool) { - return u.seek(key).Load(key) -} - -func (u *u32map) Store(key uint32, value interface{}) { - u.seek(key).Store(key, value) -} - -func (u *u32map) Delete(key uint32) { - u.seek(key).Delete(key) -} - -func (u *u32map) seek(key uint32) *u32slot { - k := key & (_slots - 1) - return u.slots[k] -} - -type u32slot struct { - k sync.RWMutex - m map[uint32]interface{} -} - -func (u *u32slot) Clear() { - if u == nil || u.m == nil { - return - } - u.k.Lock() - u.m = nil - u.k.Unlock() -} - -func (u *u32slot) Range(fn func(k uint32, v interface{}) bool) { - u.innerRange(fn) -} - -func (u *u32slot) Load(key uint32) (v interface{}, ok bool) { - if u == nil || u.m == nil { - return - } - u.k.RLock() - v, ok = u.m[key] - u.k.RUnlock() - return -} - -func (u *u32slot) Store(key uint32, value interface{}) { - if u == nil || u.m == nil { - return - } - u.k.Lock() - u.m[key] = value - u.k.Unlock() -} - -func (u *u32slot) Delete(key uint32) { - if u == nil || u.m == nil { - return - } - u.k.Lock() - delete(u.m, key) - u.k.Unlock() -} - -func (u *u32slot) innerRange(fn func(k uint32, v interface{}) bool) bool { - if u == nil || u.m == nil { - return false - } - u.k.RLock() - defer u.k.RUnlock() - for key, value := range u.m { - if !fn(key, value) { - return false - } - } - return true -} - -func NewU32MapLite() U32Map { - return &u32slot{ - m: make(map[uint32]interface{}), - } -} - -func NewU32Map() U32Map { - var slots [_slots]*u32slot - for i := 0; i < len(slots); i++ { - slots[i] = &u32slot{ - m: make(map[uint32]interface{}), - } - } - return &u32map{slots: slots} -} diff --git a/internal/common/u32map_test.go b/internal/common/u32map_test.go deleted file mode 100644 index a49a458..0000000 --- a/internal/common/u32map_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package common_test - -import ( - "sort" - "sync/atomic" - "testing" - - "github.com/rsocket/rsocket-go/internal/common" - "github.com/stretchr/testify/assert" -) - -func TestU32map(t *testing.T) { - var keys []int - value := common.RandAlphanumeric(10) - m := common.NewU32Map() - for i := uint32(0); i < 10; i++ { - m.Store(i, value) - keys = append(keys, int(i)) - } - v, ok := m.Load(1) - assert.True(t, ok, "key not found") - assert.Equal(t, value, v, "value doesn't match") - - _, ok = m.Load(10) - assert.False(t, ok, "key should not exist") - - var keys2 []int - m.Range(func(k uint32, _ interface{}) bool { - keys2 = append(keys2, int(k)) - return true - }) - sort.Ints(keys) - sort.Ints(keys2) - assert.Equal(t, keys, keys2, "keys doesn't match") - - m.Delete(1) - _, ok = m.Load(1) - assert.False(t, ok, "key should be deleted") - - var c int - m.Range(func(k uint32, v interface{}) bool { - c++ - return false - }) - assert.Equal(t, 1, c, "should be 1") - - m.Clear() - _, ok = m.Load(2) - assert.False(t, ok, "should be closed already") -} - -func BenchmarkU32Map(b *testing.B) { - const value = "foobar" - m := common.NewU32Map() - next := uint32(0) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - m.Store(atomic.AddUint32(&next, 1)-1, value) - } - }) -} diff --git a/internal/socket/client_default.go b/internal/socket/client_default.go index 0969bb6..4812682 100644 --- a/internal/socket/client_default.go +++ b/internal/socket/client_default.go @@ -26,7 +26,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err if setup.Lease { p.refreshLease(0, 0) - tp.HandleLease(func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnLease, func(frame core.Frame) (err error) { lease := frame.(*framing.LeaseFrame) p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) logger.Infof(">>>>> refresh lease: %v\n", lease) @@ -34,7 +34,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err }) } - tp.HandleDisaster(func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { p.socket.SetError(frame.(*framing.ErrorFrame)) return }) diff --git a/internal/socket/client_resume.go b/internal/socket/client_resume.go index 8636336..f13359a 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/client_resume.go @@ -79,12 +79,11 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { // connect first time. if len(p.setup.Token) < 1 || connects == 1 { - tp.HandleDisaster(func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { p.socket.SetError(frame.(*framing.ErrorFrame)) p.markClosing() return }) - f = p.setup.toFrame() err = tp.Send(f, true) p.socket.SetTransport(tp) @@ -100,12 +99,12 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { resumeErr := make(chan string) - tp.HandleResumeOK(func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnResumeOK, func(frame core.Frame) (err error) { close(resumeErr) return }) - tp.HandleDisaster(func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { // TODO: process other error with zero StreamID f := frame.(*framing.ErrorFrame) if f.ErrorCode() == core.ErrorCodeRejectedResume { diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 2cd8d96..749b0c4 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -12,7 +12,6 @@ import ( "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/core/transport" - "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/internal/fragmentation" "github.com/rsocket/rsocket-go/lease" "github.com/rsocket/rsocket-go/logger" @@ -44,10 +43,10 @@ type DuplexRSocket struct { outs chan core.FrameSupport outsPriority []core.FrameSupport responder Responder - messages common.U32Map + messages *sync.Map sids StreamID mtu int - fragments common.U32Map // key=streamID, value=Joiner + fragments *sync.Map // common.U32Map // key=streamID, value=Joiner closed *atomic.Bool done chan struct{} keepaliver *Keepaliver @@ -91,7 +90,6 @@ func (p *DuplexRSocket) Close() error { p.cond.Broadcast() p.cond.L.Unlock() - p.fragments.Clear() <-p.done if p.tp != nil { @@ -101,13 +99,7 @@ func (p *DuplexRSocket) Close() error { _ = p.tp.Close() } } - - p.fragments.Range(func(key uint32, value interface{}) bool { - return true - }) - p.fragments.Clear() - - p.messages.Range(func(key uint32, value interface{}) bool { + p.messages.Range(func(key, value interface{}) bool { if cc, ok := value.(callback); ok { if p.e == nil { go func() { @@ -121,7 +113,6 @@ func (p *DuplexRSocket) Close() error { } return true }) - p.messages.Clear() return p.e } @@ -784,18 +775,18 @@ func (p *DuplexRSocket) clearTransport() { // SetTransport sets a transport for current socket. func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { - tp.HandleCancel(p.onFrameCancel) - tp.HandleError(p.onFrameError) - tp.HandleRequestN(p.onFrameRequestN) - tp.HandlePayload(p.onFramePayload) - tp.HandleKeepalive(p.onFrameKeepalive) + tp.RegisterHandler(transport.OnCancel, p.onFrameCancel) + tp.RegisterHandler(transport.OnError, p.onFrameError) + tp.RegisterHandler(transport.OnRequestN, p.onFrameRequestN) + tp.RegisterHandler(transport.OnPayload, p.onFramePayload) + tp.RegisterHandler(transport.OnKeepalive, p.onFrameKeepalive) if p.responder != nil { - tp.HandleRequestResponse(p.onFrameRequestResponse) - tp.HandleMetadataPush(p.respondMetadataPush) - tp.HandleFNF(p.onFrameFNF) - tp.HandleRequestStream(p.onFrameRequestStream) - tp.HandleRequestChannel(p.onFrameRequestChannel) + tp.RegisterHandler(transport.OnRequestResponse, p.onFrameRequestResponse) + tp.RegisterHandler(transport.OnMetadataPush, p.respondMetadataPush) + tp.RegisterHandler(transport.OnFireAndForget, p.onFrameFNF) + tp.RegisterHandler(transport.OnRequestStream, p.onFrameRequestStream) + tp.RegisterHandler(transport.OnRequestChannel, p.onFrameRequestChannel) } p.cond.L.Lock() @@ -1089,9 +1080,9 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { leases: leases, outs: make(chan core.FrameSupport, _outChanSize), mtu: mtu, - messages: common.NewU32Map(), + messages: &sync.Map{}, sids: &serverStreamIDs{}, - fragments: common.NewU32MapLite(), + fragments: &sync.Map{}, done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: core.NewCounter(), @@ -1109,9 +1100,9 @@ func NewClientDuplexRSocket( closed: atomic.NewBool(false), outs: make(chan core.FrameSupport, _outChanSize), mtu: mtu, - messages: common.NewU32Map(), + messages: &sync.Map{}, sids: &clientStreamIDs{}, - fragments: common.NewU32MapLite(), + fragments: &sync.Map{}, done: make(chan struct{}), cond: sync.NewCond(&sync.Mutex{}), counter: core.NewCounter(), From 2b268440b3f1e5b2c9666fdc900e805c2e0c07f0 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 14 Jul 2020 22:17:21 +0800 Subject: [PATCH 11/26] Add unit test for transport. --- client.go | 2 +- core/transport/misc.go | 3 - core/transport/transport.go | 44 +++--- core/transport/transport_test.go | 254 +++++++++++++++++++++++++----- core/transport/types.go | 30 ++++ core/transport/types_mock.go | 118 ++++++++++++++ core/types.go | 17 -- go.mod | 1 + go.sum | 6 + internal/socket/client_default.go | 4 +- internal/socket/client_resume.go | 4 +- server.go | 2 +- transporter.go | 16 +- 13 files changed, 408 insertions(+), 93 deletions(-) create mode 100644 core/transport/types.go create mode 100644 core/transport/types_mock.go diff --git a/client.go b/client.go index e5d71ae..f0a2885 100644 --- a/client.go +++ b/client.go @@ -82,7 +82,7 @@ type setupClientSocket interface { type clientBuilder struct { resume *resumeOpts fragment int - tpGen transport.ToClientTransport + tpGen transport.ClientTransportFunc setup *socket.SetupInfo acceptor ClientSocketAcceptor onCloses []func(error) diff --git a/core/transport/misc.go b/core/transport/misc.go index 6b28984..cd47eb9 100644 --- a/core/transport/misc.go +++ b/core/transport/misc.go @@ -1,7 +1,6 @@ package transport import ( - "context" "net/http" "strings" ) @@ -19,5 +18,3 @@ func isClosedErr(err error) bool { return false } -type ToClientTransport = func(context.Context) (*Transport, error) -type ToServerTransport = func(context.Context) (ServerTransport, error) diff --git a/core/transport/transport.go b/core/transport/transport.go index 77348f0..075fb50 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -13,7 +13,10 @@ import ( "github.com/rsocket/rsocket-go/logger" ) -var errTransportClosed = errors.New("transport closed") +var ( + errTransportClosed = errors.New("transport closed") + errNoHandler = errors.New("you must register a handler") +) // FrameHandler is an alias of frame handler. type FrameHandler = func(frame core.Frame) (err error) @@ -56,7 +59,7 @@ const ( // Transport is RSocket transport which is used to carry RSocket frames. type Transport struct { - conn core.Conn + conn Conn maxLifetime time.Duration lastRcvPos uint64 once sync.Once @@ -68,7 +71,7 @@ func (p *Transport) RegisterHandler(event EventType, handler FrameHandler) { } // Connection returns current connection. -func (p *Transport) Connection() core.Conn { +func (p *Transport) Connection() Conn { return p.conn } @@ -139,33 +142,26 @@ func (p *Transport) ReadFirst(ctx context.Context) (frame core.Frame, err error) } // Start start transport. -func (p *Transport) Start(ctx context.Context) (err error) { +func (p *Transport) Start(ctx context.Context) error { defer p.Close() -L: for { select { case <-ctx.Done(): - err = ctx.Err() - return + return ctx.Err() default: f, err := p.conn.Read() - if err != nil { - break L + if err == nil { + err = p.DispatchFrame(ctx, f) + } + if err == nil { + continue } - err = p.DispatchFrame(ctx, f) - if err != nil { - break L + if errors.Is(err, io.EOF) { + return nil } + return errors.Wrap(err, "read and delivery frame failed") } } - if err == io.EOF { - err = nil - return - } - if err != nil { - err = errors.Wrap(err, "read and delivery frame failed") - } - return } // DispatchFrame delivery incoming frames. @@ -232,7 +228,7 @@ func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err erro // missing handler if handler == nil { - err = errors.Errorf("missing frame handler: type=%s", t) + err = errNoHandler return } @@ -244,9 +240,13 @@ func (p *Transport) DispatchFrame(_ context.Context, frame core.Frame) (err erro return } -func NewTransport(c core.Conn) *Transport { +func NewTransport(c Conn) *Transport { return &Transport{ conn: c, maxLifetime: common.DefaultKeepaliveMaxLifetime, } } + +func IsNoHandlerError(err error) bool { + return err == errNoHandler +} diff --git a/core/transport/transport_test.go b/core/transport/transport_test.go index 30049cb..4c68370 100644 --- a/core/transport/transport_test.go +++ b/core/transport/transport_test.go @@ -1,61 +1,241 @@ package transport_test import ( - "bytes" + "context" + "io" + "testing" "time" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" - "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) -type mockConn struct { - spy map[string]int - c chan core.FrameSupport -} +var fakeErr = errors.New("fake error") -func (m *mockConn) call(fn string) { - m.spy[fn] = m.spy[fn] + 1 +func Init(t *testing.T) (*gomock.Controller, *transport.MockConn, *transport.Transport) { + ctrl := gomock.NewController(t) + conn := transport.NewMockConn(ctrl) + tp := transport.NewTransport(conn) + return ctrl, conn, tp } -func (m *mockConn) Close() error { - m.call("Close") - return nil -} +func TestTransport_Start(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() -func (m *mockConn) SetDeadline(deadline time.Time) error { - m.call("SetDeadline") - return nil -} + conn.EXPECT().Close().Return(nil).Times(1) -func (m *mockConn) SetCounter(c *core.Counter) { - m.call("SetCounter") + conn.EXPECT().Read().Return(nil, fakeErr).Times(1) + err := tp.Start(context.Background()) + assert.Error(t, err, "should be an error") + assert.True(t, errors.Cause(err) == fakeErr, "should be the fake error") + + conn.EXPECT().Read().Return(nil, io.EOF).Times(1) + err = tp.Start(context.Background()) + assert.NoError(t, err, "there should be no error here if io.EOF occurred") } -func (m *mockConn) Read() (next core.Frame, err error) { - f := <-m.c - bf := &bytes.Buffer{} - _, err = f.WriteTo(bf) - if err != nil { - return +func TestTransport_RegisterHandler(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + var cursor int + fakeToken := []byte("fake-token") + fakeDada := []byte("fake-data") + fakeMetadata := []byte("fake-metadata") + fakeMimeType := []byte("fake-mime-type") + fakeFlag := core.FrameFlag(0) + frames := []core.Frame{ + framing.NewSetupFrame( + core.NewVersion(1, 0), + 30*time.Second, + 90*time.Second, + nil, + fakeMimeType, + fakeMimeType, + fakeDada, + fakeMetadata, + false), + framing.NewMetadataPushFrame(fakeDada), + framing.NewFireAndForgetFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestResponseFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestStreamFrame(1, 1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestChannelFrame(1, 1, fakeDada, fakeMetadata, fakeFlag), + framing.NewRequestNFrame(1, 1, fakeFlag), + framing.NewKeepaliveFrame(1, fakeDada, true), + framing.NewCancelFrame(1), + framing.NewErrorFrame(1, core.ErrorCodeApplicationError, fakeDada), + framing.NewLeaseFrame(30*time.Second, 1, fakeMetadata), + framing.NewPayloadFrame(1, fakeDada, fakeMetadata, fakeFlag), + framing.NewResumeFrame(core.DefaultVersion, fakeToken, 1, 1), + framing.NewResumeOKFrame(1), + framing.NewErrorFrame(0, core.ErrorCodeRejected, fakeDada), + } + conn.EXPECT(). + Read(). + DoAndReturn(func() (core.Frame, error) { + defer func() { + cursor++ + }() + if cursor >= len(frames) { + return nil, io.EOF + } + return frames[cursor], nil + }). + AnyTimes() + + calls := make(map[core.FrameType]int) + fakeHandler := func(frame core.Frame) (err error) { + typ := frame.Header().Type() + calls[typ] = calls[typ] + 1 + return nil } - bs := bf.Bytes() - header := core.ParseFrameHeader(bs) - bb := common.NewByteBuff() - _, err = bb.Write(bs[core.FrameHeaderLen:]) - if err != nil { + + callsErrorWithZeroStreamID := atomic.NewInt32(0) + + tp.RegisterHandler(transport.OnSetup, fakeHandler) + tp.RegisterHandler(transport.OnRequestResponse, fakeHandler) + tp.RegisterHandler(transport.OnFireAndForget, fakeHandler) + tp.RegisterHandler(transport.OnMetadataPush, fakeHandler) + tp.RegisterHandler(transport.OnRequestStream, fakeHandler) + tp.RegisterHandler(transport.OnRequestChannel, fakeHandler) + tp.RegisterHandler(transport.OnKeepalive, fakeHandler) + tp.RegisterHandler(transport.OnRequestN, fakeHandler) + tp.RegisterHandler(transport.OnPayload, fakeHandler) + tp.RegisterHandler(transport.OnError, fakeHandler) + tp.RegisterHandler(transport.OnCancel, fakeHandler) + tp.RegisterHandler(transport.OnResumeOK, fakeHandler) + tp.RegisterHandler(transport.OnResume, fakeHandler) + tp.RegisterHandler(transport.OnLease, fakeHandler) + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { + callsErrorWithZeroStreamID.Inc() return + }) + + toHaveBeenCalled := func(typ core.FrameType, n int) { + called, ok := calls[typ] + assert.True(t, ok, "%s have not been called", typ) + assert.Equal(t, n, called, "%s have not been called %d times", n) } - next, err = framing.FromRawFrame(framing.NewRawFrame(header, bb)) - return + + err := tp.Start(context.Background()) + assert.Error(t, err, "should be no error") + + for _, typ := range []core.FrameType{ + core.FrameTypeSetup, + core.FrameTypeRequestResponse, + core.FrameTypeRequestFNF, + core.FrameTypeMetadataPush, + core.FrameTypeRequestStream, + core.FrameTypeRequestChannel, + core.FrameTypeKeepalive, + core.FrameTypeRequestN, + core.FrameTypePayload, + core.FrameTypeError, + core.FrameTypeCancel, + core.FrameTypeResume, + core.FrameTypeResumeOK, + core.FrameTypeLease, + } { + toHaveBeenCalled(typ, 1) + } + assert.Equal(t, int32(1), callsErrorWithZeroStreamID.Load(), "error frame with zero stream id has not been called") } -func (m *mockConn) Write(support core.FrameSupport) (err error) { - m.c <- support - return +func TestTransport_ReadFirst(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().AnyTimes() + + conn.EXPECT().Read().Return(nil, fakeErr).Times(1) + _, err := tp.ReadFirst(context.Background()) + assert.Error(t, err, "should be error") + + expect := framing.NewCancelFrame(1) + conn.EXPECT().Read().Return(expect, nil).Times(1) + actual, err := tp.ReadFirst(context.Background()) + assert.NoError(t, err, "should not be error") + assert.Equal(t, expect, actual, "not match") } -func (m *mockConn) Flush() (err error) { - m.call("Flush") - return +func TestTransport_Send(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Write(gomock.Any()).Times(2) + conn.EXPECT().Flush().Times(1) + + var err error + + err = tp.Send(framing.NewCancelFrame(1), false) + assert.NoError(t, err, "send failed") + + err = tp.Send(framing.NewCancelFrame(1), true) + assert.NoError(t, err, "send failed") +} + +func TestTransport_Connection(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + c := tp.Connection() + assert.Equal(t, conn, c) +} + +func TestTransport_Flush(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Flush().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + + err := tp.Flush() + assert.NoError(t, err, "flush failed") + conn.SetCounter(core.NewCounter()) +} + +func TestTransport_Close(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().Times(1) + + err := tp.Close() + assert.NoError(t, err, "close transport failed") +} + +func TestTransport_HandlerReturnsError(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + conn.EXPECT().Close().Times(1) + conn.EXPECT().Read().Return(framing.NewCancelFrame(1), nil).Times(1) + + tp.RegisterHandler(transport.OnCancel, func(_ core.Frame) error { + return fakeErr + }) + err := tp.Start(context.Background()) + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fakeError") +} + +func TestTransport_EmptyHandler(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + conn.EXPECT().Close().Times(1) + conn.EXPECT().Read().Return(framing.NewCancelFrame(1), nil).Times(1) + + err := tp.Start(context.Background()) + assert.True(t, transport.IsNoHandlerError(errors.Cause(err)), "should be no handler error") } diff --git a/core/transport/types.go b/core/transport/types.go new file mode 100644 index 0000000..98a871b --- /dev/null +++ b/core/transport/types.go @@ -0,0 +1,30 @@ +package transport + +import ( + "context" + "io" + "time" + + "github.com/rsocket/rsocket-go/core" +) + +type ( + ClientTransportFunc = func(context.Context) (*Transport, error) + ServerTransportFunc = func(context.Context) (ServerTransport, error) +) + +// Conn is connection for RSocket. +type Conn interface { + io.Closer + // SetDeadline set deadline for current connection. + // After this deadline, connection will be closed. + SetDeadline(deadline time.Time) error + // SetCounter bind a counter which can count r/w bytes. + SetCounter(c *core.Counter) + // Read reads next frame from Conn. + Read() (core.Frame, error) + // Write writes a frame to Conn. + Write(core.FrameSupport) error + // Flush. + Flush() error +} diff --git a/core/transport/types_mock.go b/core/transport/types_mock.go new file mode 100644 index 0000000..a7360b7 --- /dev/null +++ b/core/transport/types_mock.go @@ -0,0 +1,118 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/types.go + +// Package transport is a generated GoMock package. +package transport + +import ( + gomock "github.com/golang/mock/gomock" + core "github.com/rsocket/rsocket-go/core" + reflect "reflect" + time "time" +) + +// MockConn is a mock of Conn interface +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder +} + +// MockConnMockRecorder is the mock recorder for MockConn +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// SetDeadline mocks base method +func (m *MockConn) SetDeadline(deadline time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", deadline) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), deadline) +} + +// SetCounter mocks base method +func (m *MockConn) SetCounter(c *core.Counter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCounter", c) +} + +// SetCounter indicates an expected call of SetCounter +func (mr *MockConnMockRecorder) SetCounter(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCounter", reflect.TypeOf((*MockConn)(nil).SetCounter), c) +} + +// Read mocks base method +func (m *MockConn) Read() (core.Frame, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read") + ret0, _ := ret[0].(core.Frame) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockConnMockRecorder) Read() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read)) +} + +// Write mocks base method +func (m *MockConn) Write(arg0 core.FrameSupport) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write +func (mr *MockConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), arg0) +} + +// Flush mocks base method +func (m *MockConn) Flush() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +// Flush indicates an expected call of Flush +func (mr *MockConnMockRecorder) Flush() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockConn)(nil).Flush)) +} diff --git a/core/types.go b/core/types.go index a81ab49..25947ba 100644 --- a/core/types.go +++ b/core/types.go @@ -3,7 +3,6 @@ package core import ( "io" "strings" - "time" ) // FrameType is type of frame. @@ -127,19 +126,3 @@ type Frame interface { // Validate returns error if frame is invalid. Validate() error } - -// Conn is connection for RSocket. -type Conn interface { - io.Closer - // SetDeadline set deadline for current connection. - // After this deadline, connection will be closed. - SetDeadline(deadline time.Time) error - // SetCounter bind a counter which can count r/w bytes. - SetCounter(c *Counter) - // Read reads next frame from Conn. - Read() (Frame, error) - // Write writes a frame to Conn. - Write(FrameSupport) error - // Flush. - Flush() error -} diff --git a/go.mod b/go.mod index c3e426e..c21fcae 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/rsocket/rsocket-go go 1.12 require ( + github.com/golang/mock v1.4.3 github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 github.com/jjeffcaii/reactor-go v0.1.4 diff --git a/go.sum b/go.sum index 27fc68f..28aec31 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -112,8 +114,10 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -133,3 +137,5 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/socket/client_default.go b/internal/socket/client_default.go index 4812682..84647a9 100644 --- a/internal/socket/client_default.go +++ b/internal/socket/client_default.go @@ -11,7 +11,7 @@ import ( type defaultClientSocket struct { *baseSocket - tp transport.ToClientTransport + tp transport.ClientTransportFunc } func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { @@ -55,7 +55,7 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err } // NewClient create a simple client-side socket. -func NewClient(tp transport.ToClientTransport, socket *DuplexRSocket) ClientSocket { +func NewClient(tp transport.ClientTransportFunc, socket *DuplexRSocket) ClientSocket { return &defaultClientSocket{ baseSocket: newBaseSocket(socket), tp: tp, diff --git a/internal/socket/client_resume.go b/internal/socket/client_resume.go index f13359a..8e94104 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/client_resume.go @@ -19,7 +19,7 @@ type resumeClientSocket struct { *baseSocket connects *atomic.Int32 setup *SetupInfo - tp transport.ToClientTransport + tp transport.ClientTransportFunc } func (p *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { @@ -145,7 +145,7 @@ func (p *resumeClientSocket) isClosed() bool { } // NewClientResume creates a client-side socket with resume support. -func NewClientResume(tp transport.ToClientTransport, socket *DuplexRSocket) ClientSocket { +func NewClientResume(tp transport.ClientTransportFunc, socket *DuplexRSocket) ClientSocket { return &resumeClientSocket{ baseSocket: newBaseSocket(socket), connects: atomic.NewInt32(0), diff --git a/server.go b/server.go index f200cbd..d5c305a 100644 --- a/server.go +++ b/server.go @@ -73,7 +73,7 @@ type serverResumeOptions struct { } type server struct { - tp transport.ToServerTransport + tp transport.ServerTransportFunc resumeOpts *serverResumeOptions fragment int acc ServerAcceptor diff --git a/transporter.go b/transporter.go index f8c3371..2562162 100644 --- a/transporter.go +++ b/transporter.go @@ -13,8 +13,8 @@ import ( ) type Transporter interface { - Client() transport.ToClientTransport - Server() transport.ToServerTransport + Client() transport.ClientTransportFunc + Server() transport.ServerTransportFunc } type tcpTransporter struct { @@ -26,13 +26,13 @@ type TcpTransporterBuilder struct { opts []func(*tcpTransporter) } -func (t *tcpTransporter) Server() transport.ToServerTransport { +func (t *tcpTransporter) Server() transport.ServerTransportFunc { return func(ctx context.Context) (transport.ServerTransport, error) { return transport.NewTcpServerTransport("tcp", t.addr, t.tls), nil } } -func (t *tcpTransporter) Client() transport.ToClientTransport { +func (t *tcpTransporter) Client() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { return transport.NewTcpClientTransport("tcp", t.addr, t.tls) } @@ -108,7 +108,7 @@ func (w *WebsocketTransporterBuilder) Build() Transporter { return ws } -func (w *wsTransporter) Server() transport.ToServerTransport { +func (w *wsTransporter) Server() transport.ServerTransportFunc { return func(ctx context.Context) (transport.ServerTransport, error) { u, err := url.Parse(w.url) if err != nil { @@ -122,7 +122,7 @@ func (w *wsTransporter) Server() transport.ToServerTransport { } } -func (w *wsTransporter) Client() transport.ToClientTransport { +func (w *wsTransporter) Client() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { return transport.NewWebsocketClientTransport(w.url, w.tls, w.header) } @@ -136,7 +136,7 @@ type UnixTransporterBuilder struct { opts []func(*UnixTransporter) } -func (u *UnixTransporter) Server() transport.ToServerTransport { +func (u *UnixTransporter) Server() transport.ServerTransportFunc { return func(ctx context.Context) (transport.ServerTransport, error) { if _, err := os.Stat(u.path); !os.IsNotExist(err) { return nil, err @@ -145,7 +145,7 @@ func (u *UnixTransporter) Server() transport.ToServerTransport { } } -func (u *UnixTransporter) Client() transport.ToClientTransport { +func (u *UnixTransporter) Client() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { return transport.NewTcpClientTransport("unix", u.path, nil) } From e6c79eb1a640ccaaf1b214e5f15ef2fdd8f63dfe Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 19 Jul 2020 22:30:15 +0800 Subject: [PATCH 12/26] more ut for transports and sockets. --- .gitignore | 2 + .travis.yml | 2 +- README.md | 6 +- client.go | 2 +- cmd/rsocket-cli/rsocket-cli.go | 29 ++- cmd/rsocket-cli/uri.go | 8 +- core/framing/frame.go | 17 +- core/framing/frame_cancel.go | 10 +- core/framing/frame_error.go | 18 +- core/framing/frame_fnf.go | 10 +- core/framing/frame_keepalive.go | 10 +- core/framing/frame_lease.go | 22 +- core/framing/frame_metadata_push.go | 10 +- core/framing/frame_payload.go | 59 +++--- core/framing/frame_request_channel.go | 10 +- core/framing/frame_request_n.go | 10 +- core/framing/frame_request_response.go | 12 +- core/framing/frame_request_stream.go | 10 +- core/framing/frame_resume.go | 26 +-- core/framing/frame_resume_ok.go | 10 +- core/framing/frame_setup.go | 34 +-- core/framing/frame_test.go | 49 +++-- core/transport/base_test.go | 9 + core/transport/decoder_test.go | 89 ++++++-- core/transport/misc_test.go | 16 ++ .../{types_mock.go => mock_conn_test.go} | 9 +- core/transport/tcp_conn.go | 44 ++-- core/transport/tcp_conn_mock_test.go | 149 ++++++++++++++ core/transport/tcp_conn_test.go | 158 ++++++++++++++ core/transport/tcp_transport.go | 55 ++--- core/transport/tcp_transport_mock_test.go | 77 +++++++ core/transport/tcp_transport_test.go | 165 +++++++++++++++ core/transport/transport.go | 2 +- core/transport/transport_test.go | 6 +- core/transport/types.go | 2 +- core/transport/websocket_conn.go | 54 ++--- core/transport/websocket_conn_mock_test.go | 92 +++++++++ core/transport/websocket_conn_test.go | 169 +++++++++++++++ core/transport/websocket_transport.go | 4 +- core/types.go | 4 +- internal/fragmentation/splitter_test.go | 4 +- .../socket/{socket.go => abstract_socket.go} | 80 -------- internal/socket/abstract_socket_test.go | 105 ++++++++++ internal/socket/base_socket.go | 89 ++++++++ internal/socket/base_socket_test.go | 60 ++++++ internal/socket/callback.go | 2 +- internal/socket/duplex.go | 193 +++++++++--------- internal/socket/misc.go | 11 +- internal/socket/misc_test.go | 24 ++- internal/socket/mock_conn_test.go | 119 +++++++++++ ...t_resume.go => resumable_client_socket.go} | 12 +- ...r_resume.go => resumable_server_socket.go} | 8 +- internal/socket/server_default.go | 41 ---- ...ent_default.go => simple_client_socket.go} | 17 +- internal/socket/simple_client_socket_test.go | 143 +++++++++++++ internal/socket/simple_server_socket.go | 41 ++++ internal/socket/simple_server_socket_test.go | 81 ++++++++ internal/socket/socket_test.go | 42 ++++ internal/socket/stream_id.go | 4 +- internal/socket/types.go | 2 +- justfile | 17 +- lease/lease_test.go | 76 +------ logger/logger.go | 138 ++++++------- rsocket_test.go | 17 -- server.go | 28 +-- transporter.go | 8 +- 66 files changed, 2134 insertions(+), 698 deletions(-) create mode 100644 core/transport/base_test.go create mode 100644 core/transport/misc_test.go rename core/transport/{types_mock.go => mock_conn_test.go} (97%) create mode 100644 core/transport/tcp_conn_mock_test.go create mode 100644 core/transport/tcp_conn_test.go create mode 100644 core/transport/tcp_transport_mock_test.go create mode 100644 core/transport/tcp_transport_test.go create mode 100644 core/transport/websocket_conn_mock_test.go create mode 100644 core/transport/websocket_conn_test.go rename internal/socket/{socket.go => abstract_socket.go} (53%) create mode 100644 internal/socket/abstract_socket_test.go create mode 100644 internal/socket/base_socket.go create mode 100644 internal/socket/base_socket_test.go create mode 100644 internal/socket/mock_conn_test.go rename internal/socket/{client_resume.go => resumable_client_socket.go} (94%) rename internal/socket/{server_resume.go => resumable_server_socket.go} (82%) delete mode 100644 internal/socket/server_default.go rename internal/socket/{client_default.go => simple_client_socket.go} (73%) create mode 100644 internal/socket/simple_client_socket_test.go create mode 100644 internal/socket/simple_server_socket.go create mode 100644 internal/socket/simple_server_socket_test.go create mode 100644 internal/socket/socket_test.go diff --git a/.gitignore b/.gitignore index aa6673e..4055cb2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ suppressions/ .idea cmd/rsocket-cli/rsocket-cli + +coverage.out diff --git a/.travis.yml b/.travis.yml index 7ac7c9f..2081ba1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,5 @@ install: script: - golangci-lint run ./... - - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 ./rx/... ./internal/... ./extension/... ./payload/... . + - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 ./core/... ./balancer/... ./rx/... ./internal/... ./extension/... ./payload/... . - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/README.md b/README.md index 2924282..b8247b5 100644 --- a/README.md +++ b/README.md @@ -263,11 +263,7 @@ import ( ) func init() { - logger.SetFunc(logger.LevelDebug|logger.LevelInfo|logger.LevelWarn|logger.LevelError, func(template string, args ...interface{}) { - // Implement your own logger here... - log.Printf(template, args...) - }) - logger.SetLevel(logger.LevelInfo) + logger.SetLevel(logger.LevelDebug) } ``` diff --git a/client.go b/client.go index f0a2885..c046a89 100644 --- a/client.go +++ b/client.go @@ -165,7 +165,7 @@ func (p *clientBuilder) Start(ctx context.Context) (client Client, err error) { return nil, err } - sk := socket.NewClientDuplexRSocket( + sk := socket.NewClientDuplexConnection( p.fragment, p.setup.KeepaliveInterval, ) diff --git a/cmd/rsocket-cli/rsocket-cli.go b/cmd/rsocket-cli/rsocket-cli.go index e5fade4..6d2c6cb 100644 --- a/cmd/rsocket-cli/rsocket-cli.go +++ b/cmd/rsocket-cli/rsocket-cli.go @@ -11,16 +11,27 @@ import ( "github.com/urfave/cli/v2" ) +type fmtLogger struct { +} + +func (f fmtLogger) Debugf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Infof(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Warnf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func (f fmtLogger) Errorf(format string, args ...interface{}) { + _, _ = os.Stderr.WriteString(fmt.Sprintf(format, args...)) +} + func init() { - logger.DisablePrefix() - fn := func(s string, i ...interface{}) { - fmt.Printf(s, i...) - } - logger.SetFunc(logger.LevelDebug, fn) - logger.SetFunc(logger.LevelInfo, fn) - logger.SetFunc(logger.LevelError, func(s string, i ...interface{}) { - _, _ = os.Stderr.WriteString(fmt.Sprintf(s, i...)) - }) + logger.SetLogger(fmtLogger{}) } func main() { diff --git a/cmd/rsocket-cli/uri.go b/cmd/rsocket-cli/uri.go index b955023..3e063ee 100644 --- a/cmd/rsocket-cli/uri.go +++ b/cmd/rsocket-cli/uri.go @@ -37,7 +37,7 @@ func (p *URI) IsWebsocket() bool { func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) (*transport.Transport, error) { switch strings.ToLower(p.Scheme) { case schemaTCP: - return transport.NewTcpClientTransport(schemaTCP, p.Host, tc) + return transport.NewTcpClientTransportWithAddr(schemaTCP, p.Host, tc) case schemaWebsocket: if tc == nil { return transport.NewWebsocketClientTransport(p.pp().String(), nil, headers) @@ -51,7 +51,7 @@ func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) ( } return transport.NewWebsocketClientTransport(p.pp().String(), tc, headers) case schemaUNIX: - return transport.NewTcpClientTransport(schemaUNIX, p.Path, tc) + return transport.NewTcpClientTransportWithAddr(schemaUNIX, p.Path, tc) default: return nil, errors.Errorf("unsupported transport url: %s", p.pp().String()) } @@ -61,7 +61,7 @@ func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) ( func (p *URI) MakeServerTransport(c *tls.Config) (tp transport.ServerTransport, err error) { switch strings.ToLower(p.Scheme) { case schemaTCP: - tp = transport.NewTcpServerTransport(schemaTCP, p.Host, c) + tp = transport.NewTcpServerTransportWithAddr(schemaTCP, p.Host, c) case schemaWebsocket: tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) case schemaWebsocketSecure: @@ -71,7 +71,7 @@ func (p *URI) MakeServerTransport(c *tls.Config) (tp transport.ServerTransport, } tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) case schemaUNIX: - tp = transport.NewTcpServerTransport(schemaUNIX, p.Path, c) + tp = transport.NewTcpServerTransportWithAddr(schemaUNIX, p.Path, c) default: err = errors.Errorf("unsupported transport url: %s", p.pp().String()) } diff --git a/core/framing/frame.go b/core/framing/frame.go index 4e5af5f..32b7214 100644 --- a/core/framing/frame.go +++ b/core/framing/frame.go @@ -124,7 +124,22 @@ func NewRawFrame(header core.FrameHeader, body *common.ByteBuff) *RawFrame { } } -func PrintFrame(f core.FrameSupport) string { +// FromBytes creates frame from a byte slice. +func FromBytes(b []byte) (core.Frame, error) { + if len(b) < core.FrameHeaderLen { + return nil, errIncompleteFrame + } + header := core.ParseFrameHeader(b[:core.FrameHeaderLen]) + bb := common.NewByteBuff() + _, err := bb.Write(b[core.FrameHeaderLen:]) + if err != nil { + return nil, err + } + raw := NewRawFrame(header, bb) + return FromRawFrame(raw) +} + +func PrintFrame(f core.WriteableFrame) string { // TODO: print frame return fmt.Sprintf("%+v", f) } diff --git a/core/framing/frame_cancel.go b/core/framing/frame_cancel.go index 9dcbf5e..cbef6fa 100644 --- a/core/framing/frame_cancel.go +++ b/core/framing/frame_cancel.go @@ -11,11 +11,11 @@ type CancelFrame struct { *RawFrame } -type CancelFrameSupport struct { +type WriteableCancelFrame struct { *tinyFrame } -func (c CancelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (c WriteableCancelFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = c.header.WriteTo(w) if err != nil { @@ -25,7 +25,7 @@ func (c CancelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (c CancelFrameSupport) Len() int { +func (c WriteableCancelFrame) Len() int { return core.FrameHeaderLen } @@ -38,9 +38,9 @@ func (f *CancelFrame) Validate() (err error) { return } -func NewCancelFrameSupport(id uint32) *CancelFrameSupport { +func NewWriteableCancelFrame(id uint32) *WriteableCancelFrame { h := core.NewFrameHeader(id, core.FrameTypeCancel, 0) - return &CancelFrameSupport{ + return &WriteableCancelFrame{ tinyFrame: newTinyFrame(h), } } diff --git a/core/framing/frame_error.go b/core/framing/frame_error.go index 5e45b42..c5f8ae5 100644 --- a/core/framing/frame_error.go +++ b/core/framing/frame_error.go @@ -20,17 +20,17 @@ type ErrorFrame struct { *RawFrame } -type ErrorFrameSupport struct { +type WriteableErrorFrame struct { *tinyFrame code core.ErrorCode data []byte } -func (e ErrorFrameSupport) Error() string { +func (e WriteableErrorFrame) Error() string { return makeErrorString(e.code, e.data) } -func (e ErrorFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (e WriteableErrorFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = e.header.WriteTo(w) if err != nil { @@ -43,10 +43,16 @@ func (e ErrorFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } n += 4 + + l, err := w.Write(e.data) + if err != nil { + return + } + n += int64(l) return } -func (e ErrorFrameSupport) Len() int { +func (e WriteableErrorFrame) Len() int { return core.FrameHeaderLen + 4 + len(e.data) } @@ -73,10 +79,10 @@ func (p *ErrorFrame) ErrorData() []byte { return p.body.Bytes()[errDataOff:] } -func NewErrorFrameSupport(id uint32, code core.ErrorCode, data []byte) *ErrorFrameSupport { +func NewWriteableErrorFrame(id uint32, code core.ErrorCode, data []byte) *WriteableErrorFrame { h := core.NewFrameHeader(id, core.FrameTypeError, 0) t := newTinyFrame(h) - return &ErrorFrameSupport{ + return &WriteableErrorFrame{ tinyFrame: t, code: code, data: data, diff --git a/core/framing/frame_fnf.go b/core/framing/frame_fnf.go index e1b748f..f432cd8 100644 --- a/core/framing/frame_fnf.go +++ b/core/framing/frame_fnf.go @@ -12,13 +12,13 @@ type FireAndForgetFrame struct { *RawFrame } -type FireAndForgetFrameSupport struct { +type WriteableFireAndForgetFrame struct { *tinyFrame metadata []byte data []byte } -func (f FireAndForgetFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (f WriteableFireAndForgetFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = f.header.WriteTo(w) if err != nil { @@ -34,7 +34,7 @@ func (f FireAndForgetFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (f FireAndForgetFrameSupport) Len() int { +func (f WriteableFireAndForgetFrame) Len() int { return CalcPayloadFrameSize(f.data, f.metadata) } @@ -70,13 +70,13 @@ func (f *FireAndForgetFrame) DataUTF8() string { return string(f.Data()) } -func NewFireAndForgetFrameSupport(sid uint32, data, metadata []byte, flag core.FrameFlag) *FireAndForgetFrameSupport { +func NewWriteableFireAndForgetFrame(sid uint32, data, metadata []byte, flag core.FrameFlag) *WriteableFireAndForgetFrame { if len(metadata) > 0 { flag |= core.FlagMetadata } h := core.NewFrameHeader(sid, core.FrameTypeRequestFNF, flag) t := newTinyFrame(h) - return &FireAndForgetFrameSupport{ + return &WriteableFireAndForgetFrame{ tinyFrame: t, metadata: metadata, data: data, diff --git a/core/framing/frame_keepalive.go b/core/framing/frame_keepalive.go index ec03da5..c53ca3d 100644 --- a/core/framing/frame_keepalive.go +++ b/core/framing/frame_keepalive.go @@ -18,13 +18,13 @@ type KeepaliveFrame struct { *RawFrame } -type KeepaliveFrameSupport struct { +type WriteableKeepaliveFrame struct { *tinyFrame pos [8]byte data []byte } -func (k KeepaliveFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (k WriteableKeepaliveFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = k.header.WriteTo(w) if err != nil { @@ -48,7 +48,7 @@ func (k KeepaliveFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (k KeepaliveFrameSupport) Len() int { +func (k WriteableKeepaliveFrame) Len() int { return core.FrameHeaderLen + 8 + len(k.data) } @@ -70,7 +70,7 @@ func (k *KeepaliveFrame) Data() []byte { return k.body.Bytes()[lastRecvPosLen:] } -func NewKeepaliveFrameSupport(position uint64, data []byte, respond bool) *KeepaliveFrameSupport { +func NewWriteableKeepaliveFrame(position uint64, data []byte, respond bool) *WriteableKeepaliveFrame { var flag core.FrameFlag if respond { flag |= core.FlagRespond @@ -82,7 +82,7 @@ func NewKeepaliveFrameSupport(position uint64, data []byte, respond bool) *Keepa h := core.NewFrameHeader(0, core.FrameTypeKeepalive, flag) t := newTinyFrame(h) - return &KeepaliveFrameSupport{ + return &WriteableKeepaliveFrame{ tinyFrame: t, pos: b, data: data, diff --git a/core/framing/frame_lease.go b/core/framing/frame_lease.go index b258491..a9aec95 100644 --- a/core/framing/frame_lease.go +++ b/core/framing/frame_lease.go @@ -21,6 +21,13 @@ type LeaseFrame struct { *RawFrame } +type WriteableLeaseFrame struct { + *tinyFrame + ttl [4]byte + n [4]byte + metadata []byte +} + // Validate returns error if frame is invalid. func (l *LeaseFrame) Validate() (err error) { if l.body.Len() < minLeaseFrame { @@ -48,14 +55,7 @@ func (l *LeaseFrame) Metadata() []byte { return l.body.Bytes()[8:] } -type LeaseFrameSupport struct { - *tinyFrame - ttl [4]byte - n [4]byte - metadata []byte -} - -func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (l WriteableLeaseFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = l.header.WriteTo(w) if err != nil { @@ -87,7 +87,7 @@ func (l LeaseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (l LeaseFrameSupport) Len() int { +func (l WriteableLeaseFrame) Len() int { n := core.FrameHeaderLen + 8 if l.header.Flag().Check(core.FlagMetadata) { n += len(l.metadata) @@ -95,7 +95,7 @@ func (l LeaseFrameSupport) Len() int { return n } -func NewLeaseFrameSupport(ttl time.Duration, n uint32, metadata []byte) *LeaseFrameSupport { +func NewWriteableLeaseFrame(ttl time.Duration, n uint32, metadata []byte) *WriteableLeaseFrame { var a, b [4]byte binary.BigEndian.PutUint32(a[:], uint32(ttl.Milliseconds())) binary.BigEndian.PutUint32(b[:], n) @@ -106,7 +106,7 @@ func NewLeaseFrameSupport(ttl time.Duration, n uint32, metadata []byte) *LeaseFr } h := core.NewFrameHeader(0, core.FrameTypeLease, flag) t := newTinyFrame(h) - return &LeaseFrameSupport{ + return &WriteableLeaseFrame{ tinyFrame: t, ttl: a, n: b, diff --git a/core/framing/frame_metadata_push.go b/core/framing/frame_metadata_push.go index 713e8ed..f83f05d 100644 --- a/core/framing/frame_metadata_push.go +++ b/core/framing/frame_metadata_push.go @@ -13,7 +13,7 @@ var _metadataPushHeader = core.NewFrameHeader(0, core.FrameTypeMetadataPush, cor type MetadataPushFrame struct { *RawFrame } -type MetadataPushFrameSupport struct { +type WriteableMetadataPushFrame struct { *tinyFrame metadata []byte } @@ -42,7 +42,7 @@ func (m *MetadataPushFrame) MetadataUTF8() (metadata string, ok bool) { return } -func (m MetadataPushFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (m WriteableMetadataPushFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = m.header.WriteTo(w) if err != nil { @@ -59,7 +59,7 @@ func (m MetadataPushFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (m MetadataPushFrameSupport) Len() int { +func (m WriteableMetadataPushFrame) Len() int { return core.FrameHeaderLen + len(m.metadata) } @@ -68,9 +68,9 @@ func (m *MetadataPushFrame) DataUTF8() (data string) { return } -func NewMetadataPushFrameSupport(metadata []byte) *MetadataPushFrameSupport { +func NewWriteableMetadataPushFrame(metadata []byte) *WriteableMetadataPushFrame { t := newTinyFrame(_metadataPushHeader) - return &MetadataPushFrameSupport{ + return &WriteableMetadataPushFrame{ tinyFrame: t, metadata: metadata, } diff --git a/core/framing/frame_payload.go b/core/framing/frame_payload.go index 1b4e92e..2eb83cd 100644 --- a/core/framing/frame_payload.go +++ b/core/framing/frame_payload.go @@ -12,6 +12,12 @@ type PayloadFrame struct { *RawFrame } +type WriteablePayloadFrame struct { + *tinyFrame + metadata []byte + data []byte +} + // Validate returns error if frame is invalid. func (p *PayloadFrame) Validate() (err error) { // Minimal length should be 3 if metadata exists. @@ -31,7 +37,6 @@ func (p *PayloadFrame) Data() []byte { return p.trySliceData(0) } -// MetadataUTF8 returns metadata as UTF8 string. func (p *PayloadFrame) MetadataUTF8() (metadata string, ok bool) { raw, ok := p.Metadata() if ok { @@ -40,46 +45,38 @@ func (p *PayloadFrame) MetadataUTF8() (metadata string, ok bool) { return } -func (p *PayloadFrame) MustMetadataUTF8() string { - s, ok := p.MetadataUTF8() - if !ok { - panic("cannot convert metadata to utf8") - } - return s -} - -// DataUTF8 returns data as UTF8 string. func (p *PayloadFrame) DataUTF8() string { return string(p.Data()) } -type PayloadFrameSupport struct { - *tinyFrame - metadata []byte - data []byte -} - -func (p PayloadFrameSupport) DataUTF8() string { - return string(p.data) +func (p WriteablePayloadFrame) Data() []byte { + return p.data } -func (p PayloadFrameSupport) MetadataUTF8() (metadata string, ok bool) { - if p.header.Flag().Check(core.FlagMetadata) { - metadata = string(p.metadata) - ok = true +func (p WriteablePayloadFrame) Metadata() (metadata []byte, ok bool) { + ok = p.header.Flag().Check(core.FlagMetadata) + if ok { + metadata = p.metadata } return } -func (p PayloadFrameSupport) Data() []byte { - return p.data +func (p WriteablePayloadFrame) DataUTF8() (data string) { + if p.data != nil { + data = string(p.data) + } + return } -func (p PayloadFrameSupport) Metadata() ([]byte, bool) { - return p.metadata, p.header.Flag().Check(core.FlagMetadata) +func (p WriteablePayloadFrame) MetadataUTF8() (metadata string, ok bool) { + ok = p.header.Flag().Check(core.FlagMetadata) + if ok { + metadata = string(p.metadata) + } + return } -func (p PayloadFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (p WriteablePayloadFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = p.header.WriteTo(w) if err != nil { @@ -93,18 +90,18 @@ func (p PayloadFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (p PayloadFrameSupport) Len() int { +func (p WriteablePayloadFrame) Len() int { return CalcPayloadFrameSize(p.data, p.metadata) } -// NewPayloadFrameSupport returns a new payload frame. -func NewPayloadFrameSupport(id uint32, data, metadata []byte, flag core.FrameFlag) *PayloadFrameSupport { +// NewWriteablePayloadFrame returns a new payload frame. +func NewWriteablePayloadFrame(id uint32, data, metadata []byte, flag core.FrameFlag) *WriteablePayloadFrame { if len(metadata) > 0 { flag |= core.FlagMetadata } h := core.NewFrameHeader(id, core.FrameTypePayload, flag) t := newTinyFrame(h) - return &PayloadFrameSupport{ + return &WriteablePayloadFrame{ tinyFrame: t, metadata: metadata, data: data, diff --git a/core/framing/frame_request_channel.go b/core/framing/frame_request_channel.go index 466844e..f344ee7 100644 --- a/core/framing/frame_request_channel.go +++ b/core/framing/frame_request_channel.go @@ -18,7 +18,7 @@ type RequestChannelFrame struct { *RawFrame } -type RequestChannelFrameSupport struct { +type WriteableRequestChannelFrame struct { *tinyFrame n [4]byte metadata []byte @@ -66,7 +66,7 @@ func (r *RequestChannelFrame) DataUTF8() string { return string(r.Data()) } -func (r RequestChannelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableRequestChannelFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -90,11 +90,11 @@ func (r RequestChannelFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r RequestChannelFrameSupport) Len() int { +func (r WriteableRequestChannelFrame) Len() int { return CalcPayloadFrameSize(r.data, r.metadata) + 4 } -func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *RequestChannelFrameSupport { +func NewWriteableRequestChannelFrame(sid uint32, n uint32, data, metadata []byte, flag core.FrameFlag) *WriteableRequestChannelFrame { var b [4]byte binary.BigEndian.PutUint32(b[:], n) if len(metadata) > 0 { @@ -102,7 +102,7 @@ func NewRequestChannelFrameSupport(sid uint32, n uint32, data, metadata []byte, } h := core.NewFrameHeader(sid, core.FrameTypeRequestChannel, flag) t := newTinyFrame(h) - return &RequestChannelFrameSupport{ + return &WriteableRequestChannelFrame{ tinyFrame: t, n: b, metadata: metadata, diff --git a/core/framing/frame_request_n.go b/core/framing/frame_request_n.go index 394ad8f..6a54f96 100644 --- a/core/framing/frame_request_n.go +++ b/core/framing/frame_request_n.go @@ -13,7 +13,7 @@ type RequestNFrame struct { *RawFrame } -type RequestNFrameSupport struct { +type WriteableRequestNFrame struct { *tinyFrame n [4]byte } @@ -31,7 +31,7 @@ func (r *RequestNFrame) N() uint32 { return binary.BigEndian.Uint32(r.body.Bytes()) } -func (r RequestNFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableRequestNFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -45,14 +45,14 @@ func (r RequestNFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r RequestNFrameSupport) Len() int { +func (r WriteableRequestNFrame) Len() int { return core.FrameHeaderLen + 4 } -func NewRequestNFrameSupport(id uint32, n uint32, fg core.FrameFlag) *RequestNFrameSupport { +func NewWriteableRequestNFrame(id uint32, n uint32, fg core.FrameFlag) *WriteableRequestNFrame { var b4 [4]byte binary.BigEndian.PutUint32(b4[:], n) - return &RequestNFrameSupport{ + return &WriteableRequestNFrame{ tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestN, fg)), n: b4, } diff --git a/core/framing/frame_request_response.go b/core/framing/frame_request_response.go index 1eb716a..2e74e40 100644 --- a/core/framing/frame_request_response.go +++ b/core/framing/frame_request_response.go @@ -12,7 +12,7 @@ type RequestResponseFrame struct { *RawFrame } -type RequestResponseFrameSupport struct { +type WriteableRequestResponseFrame struct { *tinyFrame metadata []byte data []byte @@ -50,7 +50,7 @@ func (r *RequestResponseFrame) DataUTF8() string { return string(r.Data()) } -func (r RequestResponseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableRequestResponseFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -64,16 +64,16 @@ func (r RequestResponseFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r RequestResponseFrameSupport) Len() int { +func (r WriteableRequestResponseFrame) Len() int { return CalcPayloadFrameSize(r.data, r.metadata) } -// NewRequestResponseFrameSupport returns a new RequestResponse frame support. -func NewRequestResponseFrameSupport(id uint32, data, metadata []byte, fg core.FrameFlag) core.FrameSupport { +// NewWriteableRequestResponseFrame returns a new RequestResponse frame support. +func NewWriteableRequestResponseFrame(id uint32, data, metadata []byte, fg core.FrameFlag) core.WriteableFrame { if len(metadata) > 0 { fg |= core.FlagMetadata } - return &RequestResponseFrameSupport{ + return &WriteableRequestResponseFrame{ tinyFrame: newTinyFrame(core.NewFrameHeader(id, core.FrameTypeRequestResponse, fg)), metadata: metadata, data: data, diff --git a/core/framing/frame_request_stream.go b/core/framing/frame_request_stream.go index 50d3db0..1221869 100644 --- a/core/framing/frame_request_stream.go +++ b/core/framing/frame_request_stream.go @@ -17,7 +17,7 @@ type RequestStreamFrame struct { *RawFrame } -type RequestStreamFrameSupport struct { +type WriteableRequestStreamFrame struct { *tinyFrame n [4]byte metadata []byte @@ -65,7 +65,7 @@ func (r *RequestStreamFrame) DataUTF8() string { return string(r.Data()) } -func (r RequestStreamFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableRequestStreamFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -88,11 +88,11 @@ func (r RequestStreamFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r RequestStreamFrameSupport) Len() int { +func (r WriteableRequestStreamFrame) Len() int { return 4 + CalcPayloadFrameSize(r.data, r.metadata) } -func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) core.FrameSupport { +func NewWriteableRequestStreamFrame(id uint32, n uint32, data, metadata []byte, flag core.FrameFlag) core.WriteableFrame { if len(metadata) > 0 { flag |= core.FlagMetadata } @@ -100,7 +100,7 @@ func NewRequestStreamFrameSupport(id uint32, n uint32, data, metadata []byte, fl binary.BigEndian.PutUint32(b[:], n) h := core.NewFrameHeader(id, core.FrameTypeRequestStream, flag) t := newTinyFrame(h) - return &RequestStreamFrameSupport{ + return &WriteableRequestStreamFrame{ tinyFrame: t, n: b, metadata: metadata, diff --git a/core/framing/frame_resume.go b/core/framing/frame_resume.go index c505651..67eb4ec 100644 --- a/core/framing/frame_resume.go +++ b/core/framing/frame_resume.go @@ -25,6 +25,14 @@ type ResumeFrame struct { *RawFrame } +type WriteableResumeFrame struct { + *tinyFrame + version core.Version + token []byte + posFirst [8]byte + posLast [8]byte +} + // Validate validate current frame. func (r *ResumeFrame) Validate() (err error) { if r.body.Len() < _minResumeLength { @@ -62,15 +70,7 @@ func (r *ResumeFrame) FirstAvailableClientPosition() uint64 { return binary.BigEndian.Uint64(raw[offset:]) } -type ResumeFrameSupport struct { - *tinyFrame - version core.Version - token []byte - posFirst [8]byte - posLast [8]byte -} - -func (r ResumeFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableResumeFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -114,19 +114,19 @@ func (r ResumeFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r ResumeFrameSupport) Len() int { +func (r WriteableResumeFrame) Len() int { return core.FrameHeaderLen + _lenTokenLength + _lenFirstPos + _lenLastRecvPos + _lenVersion + len(r.token) } -// NewResumeFrameSupport creates a new frame support of Resume. -func NewResumeFrameSupport(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *ResumeFrameSupport { +// NewWriteableResumeFrame creates a new frame support of Resume. +func NewWriteableResumeFrame(version core.Version, token []byte, firstAvailableClientPosition, lastReceivedServerPosition uint64) *WriteableResumeFrame { h := core.NewFrameHeader(0, core.FrameTypeResume, 0) t := newTinyFrame(h) var a, b [8]byte binary.BigEndian.PutUint64(a[:], firstAvailableClientPosition) binary.BigEndian.PutUint64(b[:], lastReceivedServerPosition) - return &ResumeFrameSupport{ + return &WriteableResumeFrame{ tinyFrame: t, version: version, token: token, diff --git a/core/framing/frame_resume_ok.go b/core/framing/frame_resume_ok.go index f4d5401..21903f8 100644 --- a/core/framing/frame_resume_ok.go +++ b/core/framing/frame_resume_ok.go @@ -13,7 +13,7 @@ type ResumeOKFrame struct { *RawFrame } -type ResumeOKFrameSupport struct { +type WriteableResumeOKFrame struct { *tinyFrame pos [8]byte } @@ -33,7 +33,7 @@ func (r *ResumeOKFrame) LastReceivedClientPosition() uint64 { return binary.BigEndian.Uint64(raw) } -func (r ResumeOKFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (r WriteableResumeOKFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = r.header.WriteTo(w) if err != nil { @@ -49,16 +49,16 @@ func (r ResumeOKFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (r ResumeOKFrameSupport) Len() int { +func (r WriteableResumeOKFrame) Len() int { return core.FrameHeaderLen + 8 } -func NewResumeOKFrameSupport(position uint64) *ResumeOKFrameSupport { +func NewWriteableResumeOKFrame(position uint64) *WriteableResumeOKFrame { h := core.NewFrameHeader(0, core.FrameTypeResumeOK, 0) t := newTinyFrame(h) var b [8]byte binary.BigEndian.PutUint64(b[:], position) - return &ResumeOKFrameSupport{ + return &WriteableResumeOKFrame{ tinyFrame: t, pos: b, } diff --git a/core/framing/frame_setup.go b/core/framing/frame_setup.go index cf9b86d..8f84941 100644 --- a/core/framing/frame_setup.go +++ b/core/framing/frame_setup.go @@ -22,6 +22,18 @@ type SetupFrame struct { *RawFrame } +type WriteableSetupFrame struct { + *tinyFrame + version core.Version + keepalive [4]byte + lifetime [4]byte + token []byte + mimeMetadata []byte + mimeData []byte + metadata []byte + data []byte +} + // Validate returns error if frame is invalid. func (p *SetupFrame) Validate() (err error) { if p.Len() < _minSetupFrameLen { @@ -126,19 +138,7 @@ func (p *SetupFrame) seekMIME() int { return 14 + int(l) } -type SetupFrameSupport struct { - *tinyFrame - version core.Version - keepalive [4]byte - lifetime [4]byte - token []byte - mimeMetadata []byte - mimeData []byte - metadata []byte - data []byte -} - -func (s SetupFrameSupport) WriteTo(w io.Writer) (n int64, err error) { +func (s WriteableSetupFrame) WriteTo(w io.Writer) (n int64, err error) { var wrote int64 wrote, err = s.header.WriteTo(w) if err != nil { @@ -211,7 +211,7 @@ func (s SetupFrameSupport) WriteTo(w io.Writer) (n int64, err error) { return } -func (s SetupFrameSupport) Len() int { +func (s WriteableSetupFrame) Len() int { n := _minSetupFrameLen + CalcPayloadFrameSize(s.data, s.metadata) n += len(s.mimeData) + len(s.mimeMetadata) if l := len(s.token); l > 0 { @@ -220,7 +220,7 @@ func (s SetupFrameSupport) Len() int { return n } -func NewSetupFrameSupport( +func NewWriteableSetupFrame( version core.Version, timeBetweenKeepalive, maxLifetime time.Duration, @@ -230,7 +230,7 @@ func NewSetupFrameSupport( data []byte, metadata []byte, lease bool, -) *SetupFrameSupport { +) *WriteableSetupFrame { var flag core.FrameFlag if l := len(token); l > 0 { flag |= core.FlagResume @@ -247,7 +247,7 @@ func NewSetupFrameSupport( var a, b [4]byte binary.BigEndian.PutUint32(a[:], uint32(timeBetweenKeepalive.Nanoseconds()/1e6)) binary.BigEndian.PutUint32(b[:], uint32(maxLifetime.Nanoseconds()/1e6)) - return &SetupFrameSupport{ + return &WriteableSetupFrame{ tinyFrame: t, version: version, keepalive: a, diff --git a/core/framing/frame_test.go b/core/framing/frame_test.go index ba42b60..b5221db 100644 --- a/core/framing/frame_test.go +++ b/core/framing/frame_test.go @@ -14,21 +14,35 @@ import ( const _sid uint32 = 1 +func TestFromBytes(t *testing.T) { + // empty + _, err := FromBytes([]byte{}) + assert.Error(t, err, "should be error") + + b := &bytes.Buffer{} + frame := NewWriteableRequestResponseFrame(42, []byte("fake-data"), []byte("fake-metadata"), 0) + _, _ = frame.WriteTo(b) + frameActual, err := FromBytes(b.Bytes()) + assert.NoError(t, err, "should not be error") + assert.Equal(t, frame.Header(), frameActual.Header(), "header does not match") + assert.Equal(t, frame.Len(), frameActual.Len()) +} + func TestFrameCancel(t *testing.T) { f := NewCancelFrame(_sid) checkBasic(t, f, core.FrameTypeCancel) - f2 := NewCancelFrameSupport(_sid) + f2 := NewWriteableCancelFrame(_sid) checkBytes(t, f, f2) } func TestFrameError(t *testing.T) { - errData := []byte(common.RandAlphanumeric(100)) + errData := []byte(common.RandAlphanumeric(10)) f := NewErrorFrame(_sid, core.ErrorCodeApplicationError, errData) checkBasic(t, f, core.FrameTypeError) assert.Equal(t, core.ErrorCodeApplicationError, f.ErrorCode()) assert.Equal(t, errData, f.ErrorData()) assert.NotEmpty(t, f.Error()) - f2 := NewErrorFrame(_sid, core.ErrorCodeApplicationError, errData) + f2 := NewWriteableErrorFrame(_sid, core.ErrorCodeApplicationError, errData) checkBytes(t, f, f2) } @@ -43,7 +57,7 @@ func TestFrameFNF(t *testing.T) { assert.Nil(t, metadata) assert.True(t, f.Header().Flag().Check(core.FlagNext)) assert.False(t, f.Header().Flag().Check(core.FlagMetadata)) - f2 := NewFireAndForgetFrameSupport(_sid, b, nil, core.FlagNext) + f2 := NewWriteableFireAndForgetFrame(_sid, b, nil, core.FlagNext) checkBytes(t, f, f2) // With Metadata @@ -55,7 +69,7 @@ func TestFrameFNF(t *testing.T) { assert.Equal(t, b, metadata) assert.True(t, f.Header().Flag().Check(core.FlagNext)) assert.True(t, f.Header().Flag().Check(core.FlagMetadata)) - f2 = NewFireAndForgetFrameSupport(_sid, nil, b, core.FlagNext) + f2 = NewWriteableFireAndForgetFrame(_sid, nil, b, core.FlagNext) checkBytes(t, f, f2) } @@ -67,7 +81,7 @@ func TestFrameKeepalive(t *testing.T) { assert.Equal(t, d, f.Data()) assert.Equal(t, pos, f.LastReceivedPosition()) assert.True(t, f.Header().Flag().Check(core.FlagRespond)) - f2 := NewKeepaliveFrameSupport(pos, d, true) + f2 := NewWriteableKeepaliveFrame(pos, d, true) checkBytes(t, f, f2) } @@ -79,7 +93,7 @@ func TestFrameLease(t *testing.T) { assert.Equal(t, time.Second, f.TimeToLive()) assert.Equal(t, n, f.NumberOfRequests()) assert.Equal(t, metadata, f.Metadata()) - f2 := NewLeaseFrameSupport(time.Second, n, metadata) + f2 := NewWriteableLeaseFrame(time.Second, n, metadata) checkBytes(t, f, f2) } @@ -90,7 +104,7 @@ func TestFrameMetadataPush(t *testing.T) { metadata2, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, metadata, metadata2) - f2 := NewMetadataPushFrameSupport(metadata) + f2 := NewWriteableMetadataPushFrame(metadata) checkBytes(t, f, f2) } @@ -103,7 +117,7 @@ func TestPayloadFrame(t *testing.T) { assert.Equal(t, b, f.Data()) assert.Equal(t, b, m) assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) - f2 := NewPayloadFrameSupport(_sid, b, b, core.FlagNext) + f2 := NewWriteablePayloadFrame(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -117,7 +131,7 @@ func TestFrameRequestChannel(t *testing.T) { m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) - f2 := NewRequestChannelFrameSupport(_sid, n, b, b, core.FlagNext) + f2 := NewWriteableRequestChannelFrame(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -126,7 +140,7 @@ func TestFrameRequestN(t *testing.T) { f := NewRequestNFrame(_sid, n, 0) checkBasic(t, f, core.FrameTypeRequestN) assert.Equal(t, n, f.N()) - f2 := NewRequestNFrameSupport(_sid, n, 0) + f2 := NewWriteableRequestNFrame(_sid, n, 0) checkBytes(t, f, f2) } @@ -139,7 +153,7 @@ func TestFrameRequestResponse(t *testing.T) { assert.True(t, ok) assert.Equal(t, b, m) assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) - f2 := NewRequestResponseFrameSupport(_sid, b, b, core.FlagNext) + f2 := NewWriteableRequestResponseFrame(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -153,7 +167,7 @@ func TestFrameRequestStream(t *testing.T) { m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) - f2 := NewRequestStreamFrameSupport(_sid, n, b, b, core.FlagNext) + f2 := NewWriteableRequestStreamFrame(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -169,7 +183,7 @@ func TestFrameResume(t *testing.T) { assert.Equal(t, p2, f.LastReceivedServerPosition()) assert.Equal(t, v.Major(), f.Version().Major()) assert.Equal(t, v.Minor(), f.Version().Minor()) - f2 := NewResumeFrameSupport(v, token, p1, p2) + f2 := NewWriteableResumeFrame(v, token, p1, p2) checkBytes(t, f, f2) } @@ -178,7 +192,7 @@ func TestFrameResumeOK(t *testing.T) { f := NewResumeOKFrame(pos) checkBasic(t, f, core.FrameTypeResumeOK) assert.Equal(t, pos, f.LastReceivedClientPosition()) - f2 := NewResumeOKFrameSupport(pos) + f2 := NewWriteableResumeOKFrame(pos) checkBytes(t, f, f2) } @@ -205,7 +219,7 @@ func TestFrameSetup(t *testing.T) { assert.True(t, ok) assert.Equal(t, m, m2) - fs := NewSetupFrameSupport(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) + fs := NewWriteableSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) checkBytes(t, f, fs) } @@ -226,7 +240,8 @@ func checkBasic(t *testing.T, f core.Frame, typ core.FrameType) { <-f.DoneNotify() } -func checkBytes(t *testing.T, a core.Frame, b core.FrameSupport) { +func checkBytes(t *testing.T, a core.Frame, b core.WriteableFrame) { + assert.NoError(t, a.Validate()) assert.Equal(t, a.Len(), b.Len()) bf1, bf2 := &bytes.Buffer{}, &bytes.Buffer{} _, err := a.WriteTo(bf1) diff --git a/core/transport/base_test.go b/core/transport/base_test.go new file mode 100644 index 0000000..3c1a24c --- /dev/null +++ b/core/transport/base_test.go @@ -0,0 +1,9 @@ +package transport_test + +import "github.com/pkg/errors" + +var ( + fakeErr = errors.New("fake-error") + fakeData = []byte("fake-data") + fakeMetadata = []byte("fake-metadata") +) diff --git a/core/transport/decoder_test.go b/core/transport/decoder_test.go index 30432d6..8203b6d 100644 --- a/core/transport/decoder_test.go +++ b/core/transport/decoder_test.go @@ -1,35 +1,90 @@ -package transport +package transport_test import ( "bytes" - "encoding/hex" - "fmt" + "io" "testing" + "time" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/common" + "github.com/stretchr/testify/assert" ) -func TestDecoder(t *testing.T) { - bs, _ := hex.DecodeString("000012000000012920000003797979776f726c6432000006000000012840") - r := bytes.NewBuffer(bs) +func TestLengthBasedFrameDecoder_ReadBroken(t *testing.T) { + b := &bytes.Buffer{} - d := NewLengthBasedFrameDecoder(r) + _, _ = common.MustNewUint24(5).WriteTo(b) + _, _ = b.Write([]byte{'_', 'f', 'a', 'k', 'e'}) + decoder := transport.NewLengthBasedFrameDecoder(b) + _, err := decoder.Read() + assert.Equal(t, transport.ErrIncompleteHeader, err, "should be incomplete header error") + b.Reset() + _, _ = b.Write([]byte{0, 0, 0, 'f', 'a', 'k', 'e'}) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, core.ErrInvalidFrameLength, err, "should be invalid length error") + + b.Reset() + b.Write([]byte{0, 0}) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, io.EOF, err, "should read nothing") + + b.Reset() + _, _ = common.MustNewUint24(10).WriteTo(b) + var notEnough [7]byte + _, _ = b.Write(notEnough[:]) + decoder = transport.NewLengthBasedFrameDecoder(b) + _, err = decoder.Read() + assert.Equal(t, io.EOF, err, "should read nothing") +} + +func TestLengthBasedFrameDecoder_Read(t *testing.T) { + b := &bytes.Buffer{} + frames := []core.Frame{ + framing.NewSetupFrame( + core.DefaultVersion, + 30*time.Second, + 90*time.Second, + nil, + []byte("text/plain"), + []byte("text/plain"), + []byte("fake-data"), + []byte("fake-metadata"), + false, + ), + framing.NewKeepaliveFrame(0, []byte("fake-data"), true), + framing.NewRequestResponseFrame(1, []byte("fake-data"), []byte("fake-metadata"), 0), + } + + for _, it := range frames { + n, err := common.NewUint24(it.Len()) + assert.NoError(t, err, "convert to uint24 failed") + _, err = n.WriteTo(b) + assert.NoError(t, err, "write length failed") + _, err = it.WriteTo(b) + assert.NoError(t, err, "write frame failed") + } + + var results []core.Frame + + decoder := transport.NewLengthBasedFrameDecoder(b) for { - raw, err := d.Read() - if err != nil { + next, err := decoder.Read() + if err == io.EOF { break } - h := core.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, _ = bf.Write(raw[core.FrameHeaderLen:]) - f, err := framing.FromRawFrame(framing.NewRawFrame(h, bf)) - if err != nil { - panic(err) - } - fmt.Println(f) + assert.NoError(t, err, "decode next frame failed") + frame, err := framing.FromBytes(next) + assert.NoError(t, err, "read next frame failed") + results = append(results, frame) } + for i := 0; i < len(results); i++ { + assert.Equal(t, frames[i].Header(), results[i].Header()) + } } diff --git a/core/transport/misc_test.go b/core/transport/misc_test.go new file mode 100644 index 0000000..b0a1991 --- /dev/null +++ b/core/transport/misc_test.go @@ -0,0 +1,16 @@ +package transport + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsClosedErr(t *testing.T) { + assert.False(t, isClosedErr(nil)) + assert.False(t, isClosedErr(errors.New("fake error"))) + assert.True(t, isClosedErr(http.ErrServerClosed)) + assert.True(t, isClosedErr(errors.New("use of closed network connection"))) +} diff --git a/core/transport/types_mock.go b/core/transport/mock_conn_test.go similarity index 97% rename from core/transport/types_mock.go rename to core/transport/mock_conn_test.go index a7360b7..d1ea299 100644 --- a/core/transport/types_mock.go +++ b/core/transport/mock_conn_test.go @@ -2,13 +2,14 @@ // Source: core/transport/types.go // Package transport is a generated GoMock package. -package transport +package transport_test import ( - gomock "github.com/golang/mock/gomock" - core "github.com/rsocket/rsocket-go/core" reflect "reflect" time "time" + + gomock "github.com/golang/mock/gomock" + core "github.com/rsocket/rsocket-go/core" ) // MockConn is a mock of Conn interface @@ -90,7 +91,7 @@ func (mr *MockConnMockRecorder) Read() *gomock.Call { } // Write mocks base method -func (m *MockConn) Write(arg0 core.FrameSupport) error { +func (m *MockConn) Write(arg0 core.WriteableFrame) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Write", arg0) ret0, _ := ret[0].(error) diff --git a/core/transport/tcp_conn.go b/core/transport/tcp_conn.go index c56c63e..bce1f23 100644 --- a/core/transport/tcp_conn.go +++ b/core/transport/tcp_conn.go @@ -13,22 +13,22 @@ import ( "github.com/rsocket/rsocket-go/logger" ) -type tcpConn struct { - rawConn net.Conn +type TcpConn struct { + conn net.Conn writer *bufio.Writer decoder *LengthBasedFrameDecoder counter *core.Counter } -func (p *tcpConn) SetCounter(c *core.Counter) { +func (p *TcpConn) SetCounter(c *core.Counter) { p.counter = c } -func (p *tcpConn) SetDeadline(deadline time.Time) error { - return p.rawConn.SetReadDeadline(deadline) +func (p *TcpConn) SetDeadline(deadline time.Time) error { + return p.conn.SetReadDeadline(deadline) } -func (p *tcpConn) Read() (f core.Frame, err error) { +func (p *TcpConn) Read() (f core.Frame, err error) { raw, err := p.decoder.Read() if err == io.EOF { return @@ -37,21 +37,13 @@ func (p *tcpConn) Read() (f core.Frame, err error) { err = errors.Wrap(err, "read frame failed") return } - h := core.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, err = bf.Write(raw[core.FrameHeaderLen:]) + f, err = framing.FromBytes(raw) if err != nil { err = errors.Wrap(err, "read frame failed") return } - base := framing.NewRawFrame(h, bf) - if p.counter != nil && base.Header().Resumable() { - p.counter.IncReadBytes(base.Len()) - } - f, err = framing.FromRawFrame(base) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return + if p.counter != nil && f.Header().Resumable() { + p.counter.IncReadBytes(f.Len()) } err = f.Validate() if err != nil { @@ -64,7 +56,7 @@ func (p *tcpConn) Read() (f core.Frame, err error) { return } -func (p *tcpConn) Flush() (err error) { +func (p *TcpConn) Flush() (err error) { err = p.writer.Flush() if err != nil { err = errors.Wrap(err, "flush failed") @@ -72,7 +64,7 @@ func (p *tcpConn) Flush() (err error) { return } -func (p *tcpConn) Write(frame core.FrameSupport) (err error) { +func (p *TcpConn) Write(frame core.WriteableFrame) (err error) { size := frame.Len() if p.counter != nil && frame.Header().Resumable() { p.counter.IncWriteBytes(size) @@ -97,14 +89,14 @@ func (p *tcpConn) Write(frame core.FrameSupport) (err error) { return } -func (p *tcpConn) Close() error { - return p.rawConn.Close() +func (p *TcpConn) Close() error { + return p.conn.Close() } -func newTCPRConnection(rawConn net.Conn) *tcpConn { - return &tcpConn{ - rawConn: rawConn, - writer: bufio.NewWriter(rawConn), - decoder: NewLengthBasedFrameDecoder(rawConn), +func NewTcpConn(conn net.Conn) *TcpConn { + return &TcpConn{ + conn: conn, + writer: bufio.NewWriter(conn), + decoder: NewLengthBasedFrameDecoder(conn), } } diff --git a/core/transport/tcp_conn_mock_test.go b/core/transport/tcp_conn_mock_test.go new file mode 100644 index 0000000..b73a6f9 --- /dev/null +++ b/core/transport/tcp_conn_mock_test.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Conn) + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + net "net" + reflect "reflect" + time "time" +) + +// mockNetConn is a mock of Conn interface +type mockNetConn struct { + ctrl *gomock.Controller + recorder *mockNetConnMockRecorder +} + +// mockNetConnMockRecorder is the mock recorder for mockNetConn +type mockNetConnMockRecorder struct { + mock *mockNetConn +} + +// newMockNetConn creates a new mock instance +func newMockNetConn(ctrl *gomock.Controller) *mockNetConn { + mock := &mockNetConn{ctrl: ctrl} + mock.recorder = &mockNetConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockNetConn) EXPECT() *mockNetConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *mockNetConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockNetConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockNetConn)(nil).Close)) +} + +// LocalAddr mocks base method +func (m *mockNetConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *mockNetConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*mockNetConn)(nil).LocalAddr)) +} + +// Read mocks base method +func (m *mockNetConn) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *mockNetConnMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*mockNetConn)(nil).Read), arg0) +} + +// RemoteAddr mocks base method +func (m *mockNetConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr +func (mr *mockNetConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*mockNetConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method +func (m *mockNetConn) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *mockNetConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*mockNetConn)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method +func (m *mockNetConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *mockNetConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*mockNetConn)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method +func (m *mockNetConn) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *mockNetConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*mockNetConn)(nil).SetWriteDeadline), arg0) +} + +// Write mocks base method +func (m *mockNetConn) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *mockNetConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*mockNetConn)(nil).Write), arg0) +} diff --git a/core/transport/tcp_conn_test.go b/core/transport/tcp_conn_test.go new file mode 100644 index 0000000..daeed1f --- /dev/null +++ b/core/transport/tcp_conn_test.go @@ -0,0 +1,158 @@ +package transport_test + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/common" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +func InitMockTcpConn(t *testing.T) (*gomock.Controller, *mockNetConn, *transport.TcpConn) { + ctrl := gomock.NewController(t) + nc := newMockNetConn(ctrl) + tc := transport.NewTcpConn(nc) + return ctrl, nc, tc +} + +func TestTcpConn_Read_Empty(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT().Read(gomock.Any()).Return(0, fakeErr).AnyTimes() + _, err := tc.Read() + assert.Error(t, err, "should read failed") +} + +func TestTcpConn_Read(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + bf := &bytes.Buffer{} + c := core.NewCounter() + tc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var writtenBytes int + + for _, frame := range toBeWritten { + n := frame.Len() + if frame.Header().Resumable() { + writtenBytes += n + } + _, _ = common.MustNewUint24(n).WriteTo(bf) + _, _ = frame.WriteTo(bf) + } + + nc.EXPECT(). + Read(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Read(b) + }). + AnyTimes() + var results []core.Frame + for { + next, err := tc.Read() + if err == io.EOF { + break + } + assert.NoError(t, err, "read next frame failed") + results = append(results, next) + } + assert.Equal(t, len(toBeWritten), len(results), "result amount does not match") + for i := 0; i < len(results); i++ { + assert.Equal(t, toBeWritten[i].Header(), results[i].Header(), "header does not match") + } + assert.Equal(t, writtenBytes, int(c.ReadBytes()), "read bytes doesn't match") +} + +func TestTcpConn_SetDeadline(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + nc.EXPECT().SetReadDeadline(gomock.Any()).Times(1) + err := tc.SetDeadline(time.Now()) + assert.NoError(t, err, "call setDeadline failed") +} + +func TestTcpConn_Flush_Nothing(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + c := core.NewCounter() + tc.SetCounter(c) + + nc.EXPECT().Write(gomock.Any()).Times(0) + + err := tc.Flush() + assert.NoError(t, err, "flush failed") + assert.Equal(t, 0, int(c.WriteBytes()), "bytes written should be zero") +} + +func TestTcpConn_WriteWithBrokenConn(t *testing.T) { + logger.SetLogger(nil) + logger.SetLevel(logger.LevelDebug) + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT(). + Write(gomock.Any()). + Return(0, fakeErr). + AnyTimes() + _ = tc.Write(framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0)) + err := tc.Flush() + assert.Equal(t, fakeErr, errors.Cause(err), "should be fake error") +} + +func TestTcpConn_WriteAndFlush(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + + c := core.NewCounter() + tc.SetCounter(c) + + nc.EXPECT(). + Write(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return len(b), nil + }). + Times(1) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var bytesWritten uint64 + for _, frame := range toBeWritten { + n := frame.Len() + err := tc.Write(frame) + assert.NoError(t, err, "write failed") + if frame.Header().Resumable() { + bytesWritten += uint64(n) + } + } + err := tc.Flush() + assert.NoError(t, err) + assert.Equal(t, bytesWritten, c.WriteBytes(), "write bytes doesn't match") +} + +func TestTcpConn_Close(t *testing.T) { + ctrl, nc, tc := InitMockTcpConn(t) + defer ctrl.Finish() + nc.EXPECT().Close().Return(fakeErr).Times(1) + err := tc.Close() + assert.Equal(t, fakeErr, err, "should return fake error") +} diff --git a/core/transport/tcp_transport.go b/core/transport/tcp_transport.go index 12afa9b..1619f8a 100644 --- a/core/transport/tcp_transport.go +++ b/core/transport/tcp_transport.go @@ -11,12 +11,11 @@ import ( ) type tcpServerTransport struct { - network, addr string - acceptor ServerTransportAcceptor - listener net.Listener - onceClose sync.Once - tls *tls.Config - transports *sync.Map + listenerFn func() (net.Listener, error) + acceptor ServerTransportAcceptor + listener net.Listener + onceClose sync.Once + transports *sync.Map } func (p *tcpServerTransport) Accept(acceptor ServerTransportAcceptor) { @@ -40,20 +39,13 @@ func (p *tcpServerTransport) Close() (err error) { } func (p *tcpServerTransport) Listen(ctx context.Context, notifier chan<- struct{}) (err error) { - if p.tls == nil { - p.listener, err = net.Listen(p.network, p.addr) - if err != nil { - err = errors.Wrap(err, "server listen failed") - return - } - } else { - p.listener, err = tls.Listen(p.network, p.addr, p.tls) - if err != nil { - err = errors.Wrap(err, "server listen failed") - return - } + p.listener, err = p.listenerFn() + if err != nil { + close(notifier) + return } notifier <- struct{}{} + close(notifier) return p.listen(ctx) } @@ -90,7 +82,7 @@ func (p *tcpServerTransport) listen(ctx context.Context) (err error) { break } // Dispatch raw conn. - tp := NewTransport(newTCPRConnection(c)) + tp := NewTransport(NewTcpConn(c)) p.transports.Store(tp, struct{}{}) go p.acceptor(ctx, tp, func(t *Transport) { p.transports.Delete(t) @@ -99,16 +91,29 @@ func (p *tcpServerTransport) listen(ctx context.Context) (err error) { return } -func NewTcpServerTransport(network, addr string, c *tls.Config) *tcpServerTransport { +func NewTcpServerTransport(gen func() (net.Listener, error)) ServerTransport { return &tcpServerTransport{ - network: network, - addr: addr, - tls: c, + listenerFn: gen, transports: &sync.Map{}, } } -func NewTcpClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { +func NewTcpServerTransportWithAddr(network, addr string, c *tls.Config) ServerTransport { + gen := func() (net.Listener, error) { + if c == nil { + return net.Listen(network, addr) + } else { + return tls.Listen(network, addr, c) + } + } + return NewTcpServerTransport(gen) +} + +func NewTcpClientTransport(rawConn net.Conn) *Transport { + return NewTransport(NewTcpConn(rawConn)) +} + +func NewTcpClientTransportWithAddr(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { var rawConn net.Conn if tlsConfig == nil { rawConn, err = net.Dial(network, addr) @@ -118,6 +123,6 @@ func NewTcpClientTransport(network, addr string, tlsConfig *tls.Config) (tp *Tra if err != nil { return } - tp = NewTransport(newTCPRConnection(rawConn)) + tp = NewTcpClientTransport(rawConn) return } diff --git a/core/transport/tcp_transport_mock_test.go b/core/transport/tcp_transport_mock_test.go new file mode 100644 index 0000000..3ac1eff --- /dev/null +++ b/core/transport/tcp_transport_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Listener) + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + net "net" + reflect "reflect" +) + +// mockNetListener is a mock of Listener interface +type mockNetListener struct { + ctrl *gomock.Controller + recorder *mockNetListenerMockRecorder +} + +// mockNetListenerMockRecorder is the mock recorder for mockNetListener +type mockNetListenerMockRecorder struct { + mock *mockNetListener +} + +// newMockNetListener creates a new mock instance +func newMockNetListener(ctrl *gomock.Controller) *mockNetListener { + mock := &mockNetListener{ctrl: ctrl} + mock.recorder = &mockNetListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockNetListener) EXPECT() *mockNetListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method +func (m *mockNetListener) Accept() (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept") + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept +func (mr *mockNetListenerMockRecorder) Accept() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*mockNetListener)(nil).Accept)) +} + +// Addr mocks base method +func (m *mockNetListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr +func (mr *mockNetListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*mockNetListener)(nil).Addr)) +} + +// Close mocks base method +func (m *mockNetListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockNetListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockNetListener)(nil).Close)) +} diff --git a/core/transport/tcp_transport_test.go b/core/transport/tcp_transport_test.go new file mode 100644 index 0000000..2b76faa --- /dev/null +++ b/core/transport/tcp_transport_test.go @@ -0,0 +1,165 @@ +package transport_test + +import ( + "context" + "crypto/tls" + "io" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" +) + +func InitTcpServerTransport(t *testing.T) (*gomock.Controller, *mockNetListener, transport.ServerTransport) { + ctrl := gomock.NewController(t) + listener := newMockNetListener(ctrl) + tp := transport.NewTcpServerTransport(func() (net.Listener, error) { + return listener, nil + }) + return ctrl, listener, tp +} + +func TestTcpServerTransport_ListenBroken(t *testing.T) { + tp := transport.NewTcpServerTransport(func() (net.Listener, error) { + return nil, fakeErr + }) + + defer tp.Close() + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fake error") + }() + _, ok := <-notifier + assert.False(t, ok) + + <-done +} + +func TestTcpServerTransport_Listen(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + + listener.EXPECT().Accept().Return(nil, io.EOF).AnyTimes() + listener.EXPECT().Close().Times(1) + + done := make(chan struct{}) + + ctx, cancel := context.WithCancel(context.Background()) + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(ctx, notifier) + assert.True(t, err == nil || err == io.EOF) + }() + _, ok := <-notifier + assert.True(t, ok) + + time.Sleep(100 * time.Millisecond) + cancel() + + <-done +} + +func TestTcpServerTransport_Accept(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + defer tp.Close() + + connChan := make(chan net.Conn, 1) + listener.EXPECT(). + Accept(). + DoAndReturn(func() (net.Conn, error) { + c, ok := <-connChan + if !ok { + return nil, io.EOF + } + return c, nil + }). + AnyTimes() + listener.EXPECT().Close().Times(1) + + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) + err := tp.Start(ctx) + assert.True(t, err == nil || err == io.EOF) + }) + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.True(t, err == nil || err == io.EOF) + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier failed") + + c := newMockNetConn(ctrl) + + c.EXPECT().Read(gomock.Any()).Return(0, io.EOF).AnyTimes() + c.EXPECT().Close().Times(1) + + connChan <- c + + time.Sleep(100 * time.Millisecond) + close(connChan) + + <-done +} + +func TestTcpServerTransport_AcceptBroken(t *testing.T) { + ctrl, listener, tp := InitTcpServerTransport(t) + defer ctrl.Finish() + + listener.EXPECT(). + Accept(). + Return(nil, fakeErr). + AnyTimes() + listener.EXPECT().Close().Times(1) + + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + defer onClose(tp) + err := tp.Start(ctx) + assert.True(t, err == nil || err == io.EOF) + }) + + done := make(chan struct{}) + + notifier := make(chan struct{}) + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Error(t, err, "should be error") + assert.Equal(t, fakeErr, errors.Cause(err), "should caused by fake error") + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier failed") + + <-done +} + +func TestNewTcpServerTransportWithAddr(t *testing.T) { + assert.NotPanics(t, func() { + tp := transport.NewTcpServerTransportWithAddr("tcp", ":9999", nil) + assert.NotNil(t, tp) + }) + assert.NotPanics(t, func() { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tp := transport.NewTcpServerTransportWithAddr("tcp", ":9999", tlsConfig) + assert.NotNil(t, tp) + }) +} diff --git a/core/transport/transport.go b/core/transport/transport.go index 075fb50..36126bb 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -84,7 +84,7 @@ func (p *Transport) SetLifetime(lifetime time.Duration) { } // Send send a frame. -func (p *Transport) Send(frame core.FrameSupport, flush bool) (err error) { +func (p *Transport) Send(frame core.WriteableFrame, flush bool) (err error) { defer func() { // ensure frame done when send success. if err == nil { diff --git a/core/transport/transport_test.go b/core/transport/transport_test.go index 4c68370..d491e95 100644 --- a/core/transport/transport_test.go +++ b/core/transport/transport_test.go @@ -15,11 +15,9 @@ import ( "go.uber.org/atomic" ) -var fakeErr = errors.New("fake error") - -func Init(t *testing.T) (*gomock.Controller, *transport.MockConn, *transport.Transport) { +func Init(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { ctrl := gomock.NewController(t) - conn := transport.NewMockConn(ctrl) + conn := NewMockConn(ctrl) tp := transport.NewTransport(conn) return ctrl, conn, tp } diff --git a/core/transport/types.go b/core/transport/types.go index 98a871b..5a60a9c 100644 --- a/core/transport/types.go +++ b/core/transport/types.go @@ -24,7 +24,7 @@ type Conn interface { // Read reads next frame from Conn. Read() (core.Frame, error) // Write writes a frame to Conn. - Write(core.FrameSupport) error + Write(core.WriteableFrame) error // Flush. Flush() error } diff --git a/core/transport/websocket_conn.go b/core/transport/websocket_conn.go index 933daa2..e90335d 100644 --- a/core/transport/websocket_conn.go +++ b/core/transport/websocket_conn.go @@ -10,7 +10,6 @@ import ( "github.com/pkg/errors" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" - "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/logger" ) @@ -18,21 +17,31 @@ var _buffPool = sync.Pool{ New: func() interface{} { return &bytes.Buffer{} }, } -type wsConn struct { - c *websocket.Conn +type RawWsConn interface { + io.Closer + SetReadDeadline(time.Time) error + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error +} + +type WsConn struct { + c RawWsConn counter *core.Counter } -func (p *wsConn) SetCounter(c *core.Counter) { +func (p *WsConn) SetCounter(c *core.Counter) { p.counter = c } -func (p *wsConn) SetDeadline(deadline time.Time) error { +func (p *WsConn) SetDeadline(deadline time.Time) error { return p.c.SetReadDeadline(deadline) } -func (p *wsConn) Read() (f core.Frame, err error) { +func (p *WsConn) Read() (f core.Frame, err error) { t, raw, err := p.c.ReadMessage() + if err == io.EOF { + return + } if err != nil { err = errors.Wrap(err, "read frame failed") return @@ -41,24 +50,17 @@ func (p *wsConn) Read() (f core.Frame, err error) { logger.Warnf("omit non-binary message %d\n", t) return p.Read() } - // validate min length - if len(raw) < core.FrameHeaderLen { - err = errors.Wrap(ErrIncompleteHeader, "read frame failed") - return - } - header := core.ParseFrameHeader(raw) - bf := common.NewByteBuff() - _, err = bf.Write(raw[core.FrameHeaderLen:]) + + f, err = framing.FromBytes(raw) if err != nil { err = errors.Wrap(err, "read frame failed") return } - base := framing.NewRawFrame(header, bf) - f, err = framing.FromRawFrame(base) - if err != nil { - err = errors.Wrap(err, "read frame failed") - return + + if p.counter != nil && f.Header().Resumable() { + p.counter.IncReadBytes(f.Len()) } + err = f.Validate() if err != nil { err = errors.Wrap(err, "read frame failed") @@ -70,11 +72,12 @@ func (p *wsConn) Read() (f core.Frame, err error) { return } -func (p *wsConn) Flush() (err error) { +func (p *WsConn) Flush() (err error) { return } -func (p *wsConn) Write(frame core.FrameSupport) (err error) { +func (p *WsConn) Write(frame core.WriteableFrame) (err error) { + size := frame.Len() bf := _buffPool.Get().(*bytes.Buffer) defer func() { bf.Reset() @@ -92,18 +95,21 @@ func (p *wsConn) Write(frame core.FrameSupport) (err error) { err = errors.Wrap(err, "write frame failed") return } + if p.counter != nil && frame.Header().Resumable() { + p.counter.IncWriteBytes(size) + } if logger.IsDebugEnabled() { logger.Debugf("---> snd: %s\n", frame) } return } -func (p *wsConn) Close() error { +func (p *WsConn) Close() error { return p.c.Close() } -func newWebsocketConnection(rawConn *websocket.Conn) *wsConn { - return &wsConn{ +func NewWebsocketConnection(rawConn RawWsConn) *WsConn { + return &WsConn{ c: rawConn, } } diff --git a/core/transport/websocket_conn_mock_test.go b/core/transport/websocket_conn_mock_test.go new file mode 100644 index 0000000..7a73349 --- /dev/null +++ b/core/transport/websocket_conn_mock_test.go @@ -0,0 +1,92 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/websocket_conn.go + +// Package transport is a generated GoMock package. +package transport_test + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" + time "time" +) + +// mockRawWsConn is a mock of RawWsConn interface +type mockRawWsConn struct { + ctrl *gomock.Controller + recorder *mockRawWsConnMockRecorder +} + +// mockRawWsConnMockRecorder is the mock recorder for mockRawWsConn +type mockRawWsConnMockRecorder struct { + mock *mockRawWsConn +} + +// newMockRawWsConn creates a new mock instance +func newMockRawWsConn(ctrl *gomock.Controller) *mockRawWsConn { + mock := &mockRawWsConn{ctrl: ctrl} + mock.recorder = &mockRawWsConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *mockRawWsConn) EXPECT() *mockRawWsConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *mockRawWsConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *mockRawWsConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*mockRawWsConn)(nil).Close)) +} + +// SetReadDeadline mocks base method +func (m *mockRawWsConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *mockRawWsConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*mockRawWsConn)(nil).SetReadDeadline), arg0) +} + +// ReadMessage mocks base method +func (m *mockRawWsConn) ReadMessage() (int, []byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadMessage") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadMessage indicates an expected call of ReadMessage +func (mr *mockRawWsConnMockRecorder) ReadMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadMessage", reflect.TypeOf((*mockRawWsConn)(nil).ReadMessage)) +} + +// WriteMessage mocks base method +func (m *mockRawWsConn) WriteMessage(messageType int, data []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", messageType, data) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage +func (mr *mockRawWsConnMockRecorder) WriteMessage(messageType, data interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*mockRawWsConn)(nil).WriteMessage), messageType, data) +} diff --git a/core/transport/websocket_conn_test.go b/core/transport/websocket_conn_test.go new file mode 100644 index 0000000..b80c2fb --- /dev/null +++ b/core/transport/websocket_conn_test.go @@ -0,0 +1,169 @@ +package transport_test + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +func InitMockWsConn(t *testing.T) (*gomock.Controller, *mockRawWsConn, *transport.WsConn) { + ctrl := gomock.NewController(t) + rawConn := newMockRawWsConn(ctrl) + conn := transport.NewWebsocketConnection(rawConn) + return ctrl, rawConn, conn +} + +func TestWsConn_Read_Empty(t *testing.T) { + ctrl, rawConn, conn := InitMockWsConn(t) + defer ctrl.Finish() + + rawConn.EXPECT().ReadMessage().Return(0, nil, fakeErr).AnyTimes() + _, err := conn.Read() + assert.Error(t, err, "should read failed") +} + +func TestWsConn_Read(t *testing.T) { + ctrl, rc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewCounter() + wc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + var ( + writtenBytesSlice [][]byte + writtenBytes int + ) + + for _, frame := range toBeWritten { + n := frame.Len() + if frame.Header().Resumable() { + writtenBytes += n + } + b := &bytes.Buffer{} + _, err := frame.WriteTo(b) + assert.NoError(t, err, "write frame failed") + writtenBytesSlice = append(writtenBytesSlice, b.Bytes()) + } + + cursor := 0 + rc.EXPECT(). + ReadMessage(). + DoAndReturn(func() (int, []byte, error) { + defer func() { + cursor++ + }() + if cursor >= len(writtenBytesSlice) { + return 0, nil, io.EOF + } + return websocket.BinaryMessage, writtenBytesSlice[cursor], nil + }). + AnyTimes() + + var results []core.Frame + for { + next, err := wc.Read() + if err == io.EOF { + break + } + assert.NoError(t, err, "read next frame failed") + results = append(results, next) + } + + assert.Equal(t, len(toBeWritten), len(results), "result amount does not match") + for i := 0; i < len(results); i++ { + assert.Equal(t, toBeWritten[i].Header(), results[i].Header(), "header does not match") + } + assert.Equal(t, writtenBytes, int(c.ReadBytes()), "read bytes doesn't match") +} + +func TestWsConn_SetDeadline(t *testing.T) { + ctrl, mc, c := InitMockWsConn(t) + defer ctrl.Finish() + + mc.EXPECT().SetReadDeadline(gomock.Any()).Times(1) + err := c.SetDeadline(time.Now()) + assert.NoError(t, err, "call setDeadline failed") +} + +func TestWsConn_Flush_Nothing(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewCounter() + wc.SetCounter(c) + + mc.EXPECT().WriteMessage(websocket.BinaryMessage, gomock.Any()).Times(0) + + err := wc.Flush() + assert.NoError(t, err, "flush failed") + assert.Equal(t, 0, int(c.WriteBytes()), "bytes written should be zero") +} + +func TestWsConn_WriteWithBrokenConn(t *testing.T) { + logger.SetLogger(nil) + logger.SetLevel(logger.LevelDebug) + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + mc.EXPECT(). + WriteMessage(websocket.BinaryMessage, gomock.Any()). + Return(fakeErr). + AnyTimes() + err := wc.Write(framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0)) + assert.Equal(t, fakeErr, errors.Cause(err), "should be fake error") +} + +func TestWsConn_Write(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + + c := core.NewCounter() + wc.SetCounter(c) + + toBeWritten := []core.WriteableFrame{ + framing.NewWriteablePayloadFrame(1, fakeData, fakeMetadata, 0), + framing.NewWriteableKeepaliveFrame(0, fakeData, true), + framing.NewWriteableRequestResponseFrame(2, fakeData, fakeMetadata, 0), + } + + mc.EXPECT(). + WriteMessage(websocket.BinaryMessage, gomock.Any()). + Return(nil). + Times(len(toBeWritten)) + + var bytesWritten uint64 + for _, frame := range toBeWritten { + n := frame.Len() + err := wc.Write(frame) + assert.NoError(t, err, "write failed") + if frame.Header().Resumable() { + bytesWritten += uint64(n) + } + } + err := wc.Flush() + assert.NoError(t, err) + assert.Equal(t, bytesWritten, c.WriteBytes(), "write bytes doesn't match") +} + +func TestWsConn_Close(t *testing.T) { + ctrl, mc, wc := InitMockWsConn(t) + defer ctrl.Finish() + mc.EXPECT().Close().Return(fakeErr).Times(1) + err := wc.Close() + assert.Equal(t, fakeErr, err, "should return fake error") +} diff --git a/core/transport/websocket_transport.go b/core/transport/websocket_transport.go index a7ff7a1..5b9f9fe 100644 --- a/core/transport/websocket_transport.go +++ b/core/transport/websocket_transport.go @@ -70,7 +70,7 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} return } - tp := NewTransport(newWebsocketConnection(c)) + tp := NewTransport(NewWebsocketConnection(c)) p.transports.Store(tp, struct{}{}) go p.acceptor(ctx, tp, func(tp *Transport) { p.transports.Delete(tp) @@ -139,5 +139,5 @@ func NewWebsocketClientTransport(url string, tc *tls.Config, header http.Header) if err != nil { return nil, errors.Wrap(err, "dial websocket failed") } - return NewTransport(newWebsocketConnection(wsConn)), nil + return NewTransport(NewWebsocketConnection(wsConn)), nil } diff --git a/core/types.go b/core/types.go index 25947ba..a646d8b 100644 --- a/core/types.go +++ b/core/types.go @@ -108,7 +108,7 @@ func (f FrameFlag) Check(flag FrameFlag) bool { return flag&f == flag } -type FrameSupport interface { +type WriteableFrame interface { io.WriterTo // FrameHeader returns frame FrameHeader. Header() FrameHeader @@ -122,7 +122,7 @@ type FrameSupport interface { // Frame is a single message containing a request, response, or protocol processing. type Frame interface { - FrameSupport + WriteableFrame // Validate returns error if frame is invalid. Validate() error } diff --git a/internal/fragmentation/splitter_test.go b/internal/fragmentation/splitter_test.go index 443b62a..783fc5f 100644 --- a/internal/fragmentation/splitter_test.go +++ b/internal/fragmentation/splitter_test.go @@ -27,10 +27,10 @@ func split2joiner(mtu int, data, metadata []byte) (joiner Joiner, err error) { fn := func(idx int, result SplitResult) { sid := uint32(77778888) if idx == 0 { - f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, core.FlagComplete|result.Flag) + f := framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, core.FlagComplete|result.Flag) joiner = NewJoiner(f) } else { - f := framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag) + f := framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag) joiner.Push(f) } } diff --git a/internal/socket/socket.go b/internal/socket/abstract_socket.go similarity index 53% rename from internal/socket/socket.go rename to internal/socket/abstract_socket.go index 42bd7d5..b31749a 100644 --- a/internal/socket/socket.go +++ b/internal/socket/abstract_socket.go @@ -1,9 +1,6 @@ package socket import ( - "sync" - "time" - "github.com/pkg/errors" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" @@ -70,80 +67,3 @@ func (p AbstractRSocket) RequestChannel(messages rx.Publisher) flux.Flux { } return p.RC(messages) } - -type baseSocket struct { - socket *DuplexRSocket - closers []func(error) - once sync.Once - reqLease *leaser -} - -func (p *baseSocket) refreshLease(ttl time.Duration, n int64) { - deadline := time.Now().Add(ttl) - if p.reqLease == nil { - p.reqLease = newLeaser(deadline, n) - } else { - p.reqLease.refresh(deadline, n) - } -} - -func (p *baseSocket) FireAndForget(message payload.Payload) { - if err := p.reqLease.allow(); err != nil { - logger.Warnf("request FireAndForget failed: %v\n", err) - } - p.socket.FireAndForget(message) -} - -func (p *baseSocket) MetadataPush(message payload.Payload) { - p.socket.MetadataPush(message) -} - -func (p *baseSocket) RequestResponse(message payload.Payload) mono.Mono { - if err := p.reqLease.allow(); err != nil { - return mono.Error(err) - } - return p.socket.RequestResponse(message) -} - -func (p *baseSocket) RequestStream(message payload.Payload) flux.Flux { - if err := p.reqLease.allow(); err != nil { - return flux.Error(err) - } - return p.socket.RequestStream(message) -} - -func (p *baseSocket) RequestChannel(messages rx.Publisher) flux.Flux { - if err := p.reqLease.allow(); err != nil { - return flux.Error(err) - } - return p.socket.RequestChannel(messages) -} - -func (p *baseSocket) OnClose(fn func(error)) { - if fn != nil { - p.closers = append(p.closers, fn) - } -} - -func (p *baseSocket) Close() (err error) { - p.once.Do(func() { - err = p.socket.Close() - for i, l := 0, len(p.closers); i < l; i++ { - func(fn func(error)) { - defer func() { - if e := tryRecover(recover()); e != nil { - logger.Errorf("handle socket closer failed: %s\n", e) - } - }() - fn(err) - }(p.closers[l-i-1]) - } - }) - return -} - -func newBaseSocket(rawSocket *DuplexRSocket) *baseSocket { - return &baseSocket{ - socket: rawSocket, - } -} diff --git a/internal/socket/abstract_socket_test.go b/internal/socket/abstract_socket_test.go new file mode 100644 index 0000000..3896ac2 --- /dev/null +++ b/internal/socket/abstract_socket_test.go @@ -0,0 +1,105 @@ +package socket_test + +import ( + "context" + "testing" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +var emptyAbstractRSocket = &socket.AbstractRSocket{} +var fakeRequest = payload.New(fakeData, fakeMetadata) + +func TestAbstractRSocket_FireAndForget(t *testing.T) { + called := atomic.NewBool(false) + s := &socket.AbstractRSocket{ + FF: func(payload payload.Payload) { + called.CAS(false, true) + }, + } + s.FireAndForget(fakeRequest) + assert.True(t, called.Load()) + + assert.NotPanics(t, func() { + emptyAbstractRSocket.FireAndForget(fakeRequest) + }) +} + +func TestAbstractRSocket_RequestResponse(t *testing.T) { + s := &socket.AbstractRSocket{ + RR: func(p payload.Payload) mono.Mono { + return mono.Just(p) + }, + } + res, err := s.RequestResponse(fakeRequest).Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fakeData, res.Data()) + assert.Equal(t, fakeMetadata, extractMetadata(res)) + + _, err = emptyAbstractRSocket.RequestResponse(fakeRequest).Block(context.Background()) + assert.Error(t, err, "should return an error") +} + +func TestAbstractRSocket_MetadataPush(t *testing.T) { + called := atomic.NewBool(false) + s := &socket.AbstractRSocket{ + MP: func(p payload.Payload) { + called.CAS(false, true) + }, + } + + assert.NotPanics(t, func() { + s.MetadataPush(fakeRequest) + }) + assert.NotPanics(t, func() { + emptyAbstractRSocket.MetadataPush(fakeRequest) + }) +} + +func TestAbstractRSocket_RequestStream(t *testing.T) { + s := &socket.AbstractRSocket{ + RS: func(p payload.Payload) flux.Flux { + return flux.Just(p) + }, + } + + var res []payload.Payload + + _, err := s.RequestStream(fakeRequest). + DoOnNext(func(input payload.Payload) { + res = append(res, input) + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, fakeRequest, res[0]) + + _, err = emptyAbstractRSocket.RequestStream(fakeRequest).BlockLast(context.Background()) + assert.Error(t, err, "should return an error") +} + +func TestAbstractRSocket_RequestChannel(t *testing.T) { + s := &socket.AbstractRSocket{ + RC: func(publisher rx.Publisher) flux.Flux { + return flux.Clone(publisher) + }, + } + var res []payload.Payload + _, err := s.RequestChannel(flux.Just(fakeRequest)). + DoOnNext(func(input payload.Payload) { + res = append(res, input) + }). + BlockLast(context.Background()) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, fakeRequest, res[0]) + + _, err = emptyAbstractRSocket.RequestChannel(flux.Just(fakeRequest)).BlockFirst(context.Background()) + assert.Error(t, err, "should return an error") +} diff --git a/internal/socket/base_socket.go b/internal/socket/base_socket.go new file mode 100644 index 0000000..126a989 --- /dev/null +++ b/internal/socket/base_socket.go @@ -0,0 +1,89 @@ +package socket + +import ( + "sync" + "time" + + "github.com/rsocket/rsocket-go/logger" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/rsocket/rsocket-go/rx/mono" +) + +type BaseSocket struct { + socket *DuplexConnection + closers []func(error) + once sync.Once + reqLease *leaser +} + +func (p *BaseSocket) refreshLease(ttl time.Duration, n int64) { + deadline := time.Now().Add(ttl) + if p.reqLease == nil { + p.reqLease = newLeaser(deadline, n) + } else { + p.reqLease.refresh(deadline, n) + } +} + +func (p *BaseSocket) FireAndForget(message payload.Payload) { + if err := p.reqLease.allow(); err != nil { + logger.Warnf("request FireAndForget failed: %v\n", err) + } + p.socket.FireAndForget(message) +} + +func (p *BaseSocket) MetadataPush(message payload.Payload) { + p.socket.MetadataPush(message) +} + +func (p *BaseSocket) RequestResponse(message payload.Payload) mono.Mono { + if err := p.reqLease.allow(); err != nil { + return mono.Error(err) + } + return p.socket.RequestResponse(message) +} + +func (p *BaseSocket) RequestStream(message payload.Payload) flux.Flux { + if err := p.reqLease.allow(); err != nil { + return flux.Error(err) + } + return p.socket.RequestStream(message) +} + +func (p *BaseSocket) RequestChannel(messages rx.Publisher) flux.Flux { + if err := p.reqLease.allow(); err != nil { + return flux.Error(err) + } + return p.socket.RequestChannel(messages) +} + +func (p *BaseSocket) OnClose(fn func(error)) { + if fn != nil { + p.closers = append(p.closers, fn) + } +} + +func (p *BaseSocket) Close() (err error) { + p.once.Do(func() { + err = p.socket.Close() + for i, l := 0, len(p.closers); i < l; i++ { + func(fn func(error)) { + defer func() { + if e := tryRecover(recover()); e != nil { + logger.Errorf("handle socket closer failed: %s\n", e) + } + }() + fn(err) + }(p.closers[l-i-1]) + } + }) + return +} + +func NewBaseSocket(rawSocket *DuplexConnection) *BaseSocket { + return &BaseSocket{ + socket: rawSocket, + } +} diff --git a/internal/socket/base_socket_test.go b/internal/socket/base_socket_test.go new file mode 100644 index 0000000..7be55f1 --- /dev/null +++ b/internal/socket/base_socket_test.go @@ -0,0 +1,60 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/rx/flux" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestBaseSocket(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + conn.EXPECT().Close().Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().Return(nil, io.EOF).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + duplex := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + duplex.SetTransport(tp) + + go func() { + _ = duplex.LoopWrite(context.Background()) + }() + + s := socket.NewBaseSocket(duplex) + + onClosedCalled := atomic.NewBool(false) + + s.OnClose(func(err error) { + onClosedCalled.CAS(false, true) + }) + + done := make(chan struct{}) + go func() { + defer close(done) + _ = tp.Start(context.Background()) + }() + assert.NotPanics(t, func() { + s.MetadataPush(fakeRequest) + s.FireAndForget(fakeRequest) + s.RequestResponse(fakeRequest) + s.RequestStream(fakeRequest) + s.RequestChannel(flux.Just(fakeRequest)) + }) + + <-done + + _ = s.Close() + + assert.Equal(t, true, onClosedCalled.Load()) +} diff --git a/internal/socket/callback.go b/internal/socket/callback.go index 6277cd2..c97aabb 100644 --- a/internal/socket/callback.go +++ b/internal/socket/callback.go @@ -1,7 +1,7 @@ package socket import ( - rs "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 749b0c4..e895144 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -36,12 +36,12 @@ func IsSocketClosedError(err error) bool { return err == errSocketClosed } -// DuplexRSocket represents a socket of RSocket which can be a requester or a responder. -type DuplexRSocket struct { +// DuplexConnection represents a socket of RSocket which can be a requester or a responder. +type DuplexConnection struct { counter *core.Counter tp *transport.Transport - outs chan core.FrameSupport - outsPriority []core.FrameSupport + outs chan core.WriteableFrame + outsPriority []core.WriteableFrame responder Responder messages *sync.Map sids StreamID @@ -57,11 +57,16 @@ type DuplexRSocket struct { } // SetError sets error for current socket. -func (p *DuplexRSocket) SetError(e error) { +func (p *DuplexConnection) SetError(e error) { p.e = e } -func (p *DuplexRSocket) nextStreamID() (sid uint32) { +// GetError get the error set. +func (p *DuplexConnection) GetError() error { + return p.e +} + +func (p *DuplexConnection) nextStreamID() (sid uint32) { var lap1st bool for { // There's no required to check StreamID conflicts. @@ -77,7 +82,7 @@ func (p *DuplexRSocket) nextStreamID() (sid uint32) { } // Close close current socket. -func (p *DuplexRSocket) Close() error { +func (p *DuplexConnection) Close() error { if !p.closed.CAS(false, true) { return nil } @@ -117,7 +122,7 @@ func (p *DuplexRSocket) Close() error { } // FireAndForget start a request of FireAndForget. -func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { +func (p *DuplexConnection) FireAndForget(sending payload.Payload) { data := sending.Data() size := core.FrameHeaderLen + len(sending.Data()) m, ok := sending.Metadata() @@ -126,28 +131,28 @@ func (p *DuplexRSocket) FireAndForget(sending payload.Payload) { } sid := p.nextStreamID() if !p.shouldSplit(size) { - p.sendFrame(framing.NewFireAndForgetFrameSupport(sid, data, m, 0)) + p.sendFrame(framing.NewWriteableFireAndForgetFrame(sid, data, m, 0)) return } p.doSplit(data, m, func(index int, result fragmentation.SplitResult) { - var f core.FrameSupport + var f core.WriteableFrame if index == 0 { - f = framing.NewFireAndForgetFrameSupport(sid, result.Data, result.Metadata, result.Flag) + f = framing.NewWriteableFireAndForgetFrame(sid, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) } // MetadataPush start a request of MetadataPush. -func (p *DuplexRSocket) MetadataPush(payload payload.Payload) { +func (p *DuplexConnection) MetadataPush(payload payload.Payload) { metadata, _ := payload.Metadata() - p.sendFrame(framing.NewMetadataPushFrameSupport(metadata)) + p.sendFrame(framing.NewWriteableMetadataPushFrame(metadata)) } // RequestResponse start a request of RequestResponse. -func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { +func (p *DuplexConnection) RequestResponse(pl payload.Payload) (mo mono.Mono) { sid := p.nextStreamID() resp := mono.CreateProcessor() @@ -158,7 +163,7 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { mo = resp. DoFinally(func(s rx.SignalType) { if s == rx.SignalCancel { - p.sendFrame(framing.NewCancelFrameSupport(sid)) + p.sendFrame(framing.NewWriteableCancelFrame(sid)) } p.unregister(sid) }) @@ -167,15 +172,15 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { // sending... size := framing.CalcPayloadFrameSize(data, metadata) if !p.shouldSplit(size) { - p.sendFrame(framing.NewRequestResponseFrameSupport(sid, data, metadata, 0)) + p.sendFrame(framing.NewWriteableRequestResponseFrame(sid, data, metadata, 0)) return } p.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { - var f core.FrameSupport + var f core.WriteableFrame if index == 0 { - f = framing.NewRequestResponseFrameSupport(sid, result.Data, result.Metadata, result.Flag) + f = framing.NewWriteableRequestResponseFrame(sid, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -184,7 +189,7 @@ func (p *DuplexRSocket) RequestResponse(pl payload.Payload) (mo mono.Mono) { } // RequestStream start a request of RequestStream. -func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { +func (p *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux) { sid := p.nextStreamID() pc := flux.CreateProcessor() @@ -195,12 +200,12 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { ret = pc. DoFinally(func(sig rx.SignalType) { if sig == rx.SignalCancel { - p.sendFrame(framing.NewCancelFrameSupport(sid)) + p.sendFrame(framing.NewWriteableCancelFrame(sid)) } p.unregister(sid) }). DoOnRequest(func(n int) { - n32 := toU32N(n) + n32 := ToUint32RequestN(n) var newborn bool select { @@ -211,7 +216,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { } if !newborn { - frameN := framing.NewRequestNFrameSupport(sid, n32, 0) + frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) p.sendFrame(frameN) <-frameN.DoneNotify() return @@ -222,16 +227,16 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { size := framing.CalcPayloadFrameSize(data, metadata) + 4 if !p.shouldSplit(size) { - p.sendFrame(framing.NewRequestStreamFrameSupport(sid, n32, data, metadata, 0)) + p.sendFrame(framing.NewWriteableRequestStreamFrame(sid, n32, data, metadata, 0)) return } p.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { - var f core.FrameSupport + var f core.WriteableFrame if index == 0 { - f = framing.NewRequestStreamFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag) + f = framing.NewWriteableRequestStreamFrame(sid, n32, result.Data, result.Metadata, result.Flag) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -240,7 +245,7 @@ func (p *DuplexRSocket) RequestStream(sending payload.Payload) (ret flux.Flux) { } // RequestChannel start a request of RequestChannel. -func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { +func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { sid := p.nextStreamID() sending := publisher.(flux.Flux) @@ -253,7 +258,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { p.unregister(sid) }). DoOnRequest(func(n int) { - n32 := toU32N(n) + n32 := ToUint32RequestN(n) var newborn bool select { case <-rcvRequested: @@ -262,7 +267,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { close(rcvRequested) } if !newborn { - frameN := framing.NewRequestNFrameSupport(sid, n32, 0) + frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) p.sendFrame(frameN) <-frameN.DoneNotify() return @@ -288,16 +293,16 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { size := framing.CalcPayloadFrameSize(d, m) + 4 if !p.shouldSplit(size) { metadata, _ := item.Metadata() - p.sendFrame(framing.NewRequestChannelFrameSupport(sid, n32, item.Data(), metadata, core.FlagNext)) + p.sendFrame(framing.NewWriteableRequestChannelFrame(sid, n32, item.Data(), metadata, core.FlagNext)) return } p.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { - var f core.FrameSupport + var f core.WriteableFrame if index == 0 { - f = framing.NewRequestChannelFrameSupport(sid, n32, result.Data, result.Metadata, result.Flag|core.FlagNext) + f = framing.NewWriteableRequestChannelFrame(sid, n32, result.Data, result.Metadata, result.Flag|core.FlagNext) } else { - f = framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } p.sendFrame(f) }) @@ -325,7 +330,7 @@ func (p *DuplexRSocket) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { return ret } -func (p *DuplexRSocket) onFrameRequestResponse(frame core.Frame) error { +func (p *DuplexConnection) onFrameRequestResponse(frame core.Frame) error { // fragment receiving, ok := p.doFragment(frame.(*framing.RequestResponseFrame)) if !ok { @@ -334,7 +339,7 @@ func (p *DuplexRSocket) onFrameRequestResponse(frame core.Frame) error { return p.respondRequestResponse(receiving) } -func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { +func (p *DuplexConnection) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // 1. execute socket handler @@ -352,7 +357,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd } // 3. sending error with unsupported handler if sending == nil { - p.writeError(sid, framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) + p.writeError(sid, framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) return nil } @@ -378,7 +383,7 @@ func (p *DuplexRSocket) respondRequestResponse(receiving fragmentation.HeaderAnd return nil } -func (p *DuplexRSocket) onFrameRequestChannel(input core.Frame) error { +func (p *DuplexConnection) onFrameRequestChannel(input core.Frame) error { receiving, ok := p.doFragment(input.(*framing.RequestChannelFrame)) if !ok { return nil @@ -386,14 +391,14 @@ func (p *DuplexRSocket) onFrameRequestChannel(input core.Frame) error { return p.respondRequestChannel(receiving) } -func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { +func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { // seek initRequestN var initRequestN int switch v := pl.(type) { case *framing.RequestChannelFrame: - initRequestN = toIntN(v.InitialRequestN()) + initRequestN = ToIntRequestN(v.InitialRequestN()) case fragmentation.Joiner: - initRequestN = toIntN(v.First().(*framing.RequestChannelFrame).InitialRequestN()) + initRequestN = ToIntRequestN(v.First().(*framing.RequestChannelFrame).InitialRequestN()) default: panic("unreachable") } @@ -417,7 +422,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) } }). DoOnRequest(func(n int) { - frameN := framing.NewRequestNFrameSupport(sid, toU32N(n), 0) + frameN := framing.NewWriteableRequestNFrame(sid, ToUint32RequestN(n), 0) p.sendFrame(frameN) <-frameN.DoneNotify() }) @@ -433,7 +438,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) }() flux = p.responder.RequestChannel(receiving) if flux == nil { - err = framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) + err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) } return }() @@ -485,7 +490,7 @@ func (p *DuplexRSocket) respondRequestChannel(pl fragmentation.HeaderAndPayload) return nil } -func (p *DuplexRSocket) respondMetadataPush(input core.Frame) (err error) { +func (p *DuplexConnection) respondMetadataPush(input core.Frame) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond METADATA_PUSH failed: %s\n", e) @@ -495,7 +500,7 @@ func (p *DuplexRSocket) respondMetadataPush(input core.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameFNF(frame core.Frame) error { +func (p *DuplexConnection) onFrameFNF(frame core.Frame) error { receiving, ok := p.doFragment(frame.(*framing.FireAndForgetFrame)) if !ok { return nil @@ -503,7 +508,7 @@ func (p *DuplexRSocket) onFrameFNF(frame core.Frame) error { return p.respondFNF(receiving) } -func (p *DuplexRSocket) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { +func (p *DuplexConnection) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond FireAndForget failed: %s\n", e) @@ -513,7 +518,7 @@ func (p *DuplexRSocket) respondFNF(receiving fragmentation.HeaderAndPayload) (er return } -func (p *DuplexRSocket) onFrameRequestStream(frame core.Frame) error { +func (p *DuplexConnection) onFrameRequestStream(frame core.Frame) error { receiving, ok := p.doFragment(frame.(*framing.RequestStreamFrame)) if !ok { return nil @@ -521,7 +526,7 @@ func (p *DuplexRSocket) onFrameRequestStream(frame core.Frame) error { return p.respondRequestStream(receiving) } -func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { +func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // execute request stream handler @@ -531,7 +536,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa }() resp = p.responder.RequestStream(receiving) if resp == nil { - err = framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, unsupportedRequestStream) + err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestStream) } return }() @@ -579,7 +584,7 @@ func (p *DuplexRSocket) respondRequestStream(receiving fragmentation.HeaderAndPa return nil } -func (p *DuplexRSocket) writeError(sid uint32, e error) { +func (p *DuplexConnection) writeError(sid uint32, e error) { // ignore sending error because current socket has been closed. if IsSocketClosedError(e) { return @@ -588,18 +593,18 @@ func (p *DuplexRSocket) writeError(sid uint32, e error) { case *framing.ErrorFrame: p.sendFrame(err) case core.CustomError: - p.sendFrame(framing.NewErrorFrameSupport(sid, err.ErrorCode(), err.ErrorData())) + p.sendFrame(framing.NewWriteableErrorFrame(sid, err.ErrorCode(), err.ErrorData())) default: - p.sendFrame(framing.NewErrorFrameSupport(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) + p.sendFrame(framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) } } // SetResponder sets a responder for current socket. -func (p *DuplexRSocket) SetResponder(responder Responder) { +func (p *DuplexConnection) SetResponder(responder Responder) { p.responder = responder } -func (p *DuplexRSocket) onFrameKeepalive(frame core.Frame) (err error) { +func (p *DuplexConnection) onFrameKeepalive(frame core.Frame) (err error) { f := frame.(*framing.KeepaliveFrame) if f.Header().Flag().Check(core.FlagRespond) { k := framing.NewKeepaliveFrame(f.LastReceivedPosition(), f.Data(), false) @@ -609,7 +614,7 @@ func (p *DuplexRSocket) onFrameKeepalive(frame core.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameCancel(frame core.Frame) (err error) { +func (p *DuplexConnection) onFrameCancel(frame core.Frame) (err error) { sid := frame.Header().StreamID() v, ok := p.messages.Load(sid) @@ -633,7 +638,7 @@ func (p *DuplexRSocket) onFrameCancel(frame core.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameError(input core.Frame) (err error) { +func (p *DuplexConnection) onFrameError(input core.Frame) (err error) { f := input.(*framing.ErrorFrame) logger.Errorf("handle error frame: %s\n", f) sid := f.Header().StreamID() @@ -657,7 +662,7 @@ func (p *DuplexRSocket) onFrameError(input core.Frame) (err error) { return } -func (p *DuplexRSocket) onFrameRequestN(input core.Frame) (err error) { +func (p *DuplexConnection) onFrameRequestN(input core.Frame) (err error) { f := input.(*framing.RequestNFrame) sid := f.Header().StreamID() v, ok := p.messages.Load(sid) @@ -667,7 +672,7 @@ func (p *DuplexRSocket) onFrameRequestN(input core.Frame) (err error) { } return } - n := toIntN(f.N()) + n := ToIntRequestN(f.N()) switch vv := v.(type) { case requestStreamCallbackReverse: vv.su.Request(n) @@ -681,7 +686,7 @@ func (p *DuplexRSocket) onFrameRequestN(input core.Frame) (err error) { return } -func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { +func (p *DuplexConnection) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { h := input.Header() sid := h.StreamID() v, exist := p.fragments.Load(sid) @@ -703,7 +708,7 @@ func (p *DuplexRSocket) doFragment(input fragmentation.HeaderAndPayload) (out fr return } -func (p *DuplexRSocket) onFramePayload(frame core.Frame) error { +func (p *DuplexConnection) onFramePayload(frame core.Frame) error { pl, ok := p.doFragment(frame.(*framing.PayloadFrame)) if !ok { return nil @@ -767,14 +772,14 @@ func (p *DuplexRSocket) onFramePayload(frame core.Frame) error { return nil } -func (p *DuplexRSocket) clearTransport() { +func (p *DuplexConnection) clearTransport() { p.cond.L.Lock() p.tp = nil p.cond.L.Unlock() } // SetTransport sets a transport for current socket. -func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { +func (p *DuplexConnection) SetTransport(tp *transport.Transport) { tp.RegisterHandler(transport.OnCancel, p.onFrameCancel) tp.RegisterHandler(transport.OnError, p.onFrameError) tp.RegisterHandler(transport.OnRequestN, p.onFrameRequestN) @@ -795,7 +800,7 @@ func (p *DuplexRSocket) SetTransport(tp *transport.Transport) { p.cond.L.Unlock() } -func (p *DuplexRSocket) sendFrame(f core.FrameSupport) { +func (p *DuplexConnection) sendFrame(f core.WriteableFrame) { defer func() { if e := recover(); e != nil { logger.Warnf("send frame failed: %s\n", e) @@ -804,7 +809,7 @@ func (p *DuplexRSocket) sendFrame(f core.FrameSupport) { p.outs <- f } -func (p *DuplexRSocket) sendPayload( +func (p *DuplexConnection) sendPayload( sid uint32, sending payload.Payload, frameFlag core.FrameFlag, @@ -814,7 +819,7 @@ func (p *DuplexRSocket) sendPayload( size := framing.CalcPayloadFrameSize(d, m) if !p.shouldSplit(size) { - p.sendFrame(framing.NewPayloadFrameSupport(sid, d, m, frameFlag)) + p.sendFrame(framing.NewWriteablePayloadFrame(sid, d, m, frameFlag)) return } p.doSplit(d, m, func(index int, result fragmentation.SplitResult) { @@ -824,15 +829,15 @@ func (p *DuplexRSocket) sendPayload( } else { flag |= core.FlagNext } - p.sendFrame(framing.NewPayloadFrameSupport(sid, result.Data, result.Metadata, flag)) + p.sendFrame(framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, flag)) }) } -func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { +func (p *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { if len(p.outs) > 0 { p.drain(nil) } - var out core.FrameSupport + var out core.WriteableFrame select { case <-p.keepaliver.C(): ok = true @@ -848,7 +853,7 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) if !ok { return } - out = framing.NewLeaseFrameSupport(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) + out = framing.NewWriteableLeaseFrame(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) if p.tp == nil { p.outsPriority = append(p.outsPriority, out) } else if err := p.tp.Send(out, true); err != nil { @@ -869,11 +874,11 @@ func (p *DuplexRSocket) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) return } -func (p *DuplexRSocket) drainWithKeepalive() (ok bool) { +func (p *DuplexConnection) drainWithKeepalive() (ok bool) { if len(p.outs) > 0 { p.drain(nil) } - var out core.FrameSupport + var out core.WriteableFrame select { case <-p.keepaliver.C(): @@ -899,7 +904,7 @@ func (p *DuplexRSocket) drainWithKeepalive() (ok bool) { return } -func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { +func (p *DuplexConnection) drain(leaseChan <-chan lease.Lease) bool { var flush bool cycle := len(p.outs) if cycle < 1 { @@ -911,7 +916,7 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { if !ok { return false } - if p.drainOne(framing.NewLeaseFrameSupport(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { + if p.drainOne(framing.NewWriteableLeaseFrame(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { flush = true } case out, ok := <-p.outs: @@ -931,7 +936,7 @@ func (p *DuplexRSocket) drain(leaseChan <-chan lease.Lease) bool { return true } -func (p *DuplexRSocket) drainOne(out core.FrameSupport) (wrote bool) { +func (p *DuplexConnection) drainOne(out core.WriteableFrame) (wrote bool) { if p.tp == nil { p.outsPriority = append(p.outsPriority, out) return @@ -946,7 +951,7 @@ func (p *DuplexRSocket) drainOne(out core.FrameSupport) (wrote bool) { return } -func (p *DuplexRSocket) drainOutBack() { +func (p *DuplexConnection) drainOutBack() { if len(p.outsPriority) < 1 { return } @@ -956,7 +961,7 @@ func (p *DuplexRSocket) drainOutBack() { if p.tp == nil { return } - var out core.FrameSupport + var out core.WriteableFrame for i := range p.outsPriority { out = p.outsPriority[i] if err := p.tp.Send(out, false); err != nil { @@ -969,7 +974,7 @@ func (p *DuplexRSocket) drainOutBack() { } } -func (p *DuplexRSocket) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { +func (p *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { for { if p.tp == nil { p.cond.L.Lock() @@ -1008,11 +1013,11 @@ func (p *DuplexRSocket) loopWriteWithKeepaliver(ctx context.Context, leaseChan < return nil } -func (p *DuplexRSocket) cleanOuts() { +func (p *DuplexConnection) cleanOuts() { p.outsPriority = nil } -func (p *DuplexRSocket) loopWrite(ctx context.Context) error { +func (p *DuplexConnection) LoopWrite(ctx context.Context) error { defer close(p.done) var leaseChan chan lease.Lease @@ -1052,33 +1057,33 @@ func (p *DuplexRSocket) loopWrite(ctx context.Context) error { return nil } -func (p *DuplexRSocket) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { +func (p *DuplexConnection) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { fragmentation.Split(p.mtu, data, metadata, handler) } -func (p *DuplexRSocket) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { +func (p *DuplexConnection) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { fragmentation.SplitSkip(p.mtu, skip, data, metadata, handler) } -func (p *DuplexRSocket) shouldSplit(size int) bool { +func (p *DuplexConnection) shouldSplit(size int) bool { return size > p.mtu } -func (p *DuplexRSocket) register(sid uint32, msg interface{}) { +func (p *DuplexConnection) register(sid uint32, msg interface{}) { p.messages.Store(sid, msg) } -func (p *DuplexRSocket) unregister(sid uint32) { +func (p *DuplexConnection) unregister(sid uint32) { p.messages.Delete(sid) p.fragments.Delete(sid) } -// NewServerDuplexRSocket creates a new server-side DuplexRSocket. -func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { - return &DuplexRSocket{ +// NewServerDuplexConnection creates a new server-side DuplexConnection. +func NewServerDuplexConnection(mtu int, leases lease.Leases) *DuplexConnection { + return &DuplexConnection{ closed: atomic.NewBool(false), leases: leases, - outs: make(chan core.FrameSupport, _outChanSize), + outs: make(chan core.WriteableFrame, _outChanSize), mtu: mtu, messages: &sync.Map{}, sids: &serverStreamIDs{}, @@ -1090,15 +1095,15 @@ func NewServerDuplexRSocket(mtu int, leases lease.Leases) *DuplexRSocket { } } -// NewClientDuplexRSocket creates a new client-side DuplexRSocket. -func NewClientDuplexRSocket( +// NewClientDuplexConnection creates a new client-side DuplexConnection. +func NewClientDuplexConnection( mtu int, keepaliveInterval time.Duration, -) (s *DuplexRSocket) { +) (s *DuplexConnection) { ka := NewKeepaliver(keepaliveInterval) - s = &DuplexRSocket{ + s = &DuplexConnection{ closed: atomic.NewBool(false), - outs: make(chan core.FrameSupport, _outChanSize), + outs: make(chan core.WriteableFrame, _outChanSize), mtu: mtu, messages: &sync.Map{}, sids: &clientStreamIDs{}, diff --git a/internal/socket/misc.go b/internal/socket/misc.go index c7a7311..d387396 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -22,8 +22,8 @@ type SetupInfo struct { Metadata []byte } -func (p *SetupInfo) toFrame() core.FrameSupport { - return framing.NewSetupFrameSupport( +func (p *SetupInfo) toFrame() core.WriteableFrame { + return framing.NewWriteableSetupFrame( p.Version, p.KeepaliveInterval, p.KeepaliveLifetime, @@ -51,14 +51,17 @@ func tryRecover(e interface{}) (err error) { return } -func toIntN(n uint32) int { +func ToIntRequestN(n uint32) int { if n > rx.RequestMax { return rx.RequestMax } return int(n) } -func toU32N(n int) uint32 { +func ToUint32RequestN(n int) uint32 { + if n < 0 { + panic("invalid negative int") + } if n > rx.RequestMax { return rx.RequestMax } diff --git a/internal/socket/misc_test.go b/internal/socket/misc_test.go index 69513dc..f448896 100644 --- a/internal/socket/misc_test.go +++ b/internal/socket/misc_test.go @@ -1 +1,23 @@ -package socket +package socket_test + +import ( + "math" + "testing" + + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/rx" + "github.com/stretchr/testify/assert" +) + +func TestToUint32RequestN(t *testing.T) { + assert.Equal(t, uint32(1), socket.ToUint32RequestN(1)) + assert.Panics(t, func() { + socket.ToUint32RequestN(-1) + }, "should panic") + assert.Equal(t, uint32(rx.RequestMax), socket.ToUint32RequestN(math.MaxInt64)) +} + +func TestToIntRequestN(t *testing.T) { + assert.Equal(t, 1, socket.ToIntRequestN(1)) + assert.Equal(t, rx.RequestMax, socket.ToIntRequestN(math.MaxUint32)) +} diff --git a/internal/socket/mock_conn_test.go b/internal/socket/mock_conn_test.go new file mode 100644 index 0000000..204109d --- /dev/null +++ b/internal/socket/mock_conn_test.go @@ -0,0 +1,119 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: core/transport/types.go + +// Package socket_test is a generated GoMock package. +package socket_test + +import ( + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + core "github.com/rsocket/rsocket-go/core" +) + +// MockConn is a mock of Conn interface +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder +} + +// MockConnMockRecorder is the mock recorder for MockConn +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// SetDeadline mocks base method +func (m *MockConn) SetDeadline(deadline time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", deadline) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), deadline) +} + +// SetCounter mocks base method +func (m *MockConn) SetCounter(c *core.Counter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetCounter", c) +} + +// SetCounter indicates an expected call of SetCounter +func (mr *MockConnMockRecorder) SetCounter(c interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCounter", reflect.TypeOf((*MockConn)(nil).SetCounter), c) +} + +// Read mocks base method +func (m *MockConn) Read() (core.Frame, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read") + ret0, _ := ret[0].(core.Frame) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockConnMockRecorder) Read() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read)) +} + +// Write mocks base method +func (m *MockConn) Write(arg0 core.WriteableFrame) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Write indicates an expected call of Write +func (mr *MockConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), arg0) +} + +// Flush mocks base method +func (m *MockConn) Flush() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Flush") + ret0, _ := ret[0].(error) + return ret0 +} + +// Flush indicates an expected call of Flush +func (mr *MockConnMockRecorder) Flush() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockConn)(nil).Flush)) +} diff --git a/internal/socket/client_resume.go b/internal/socket/resumable_client_socket.go similarity index 94% rename from internal/socket/client_resume.go rename to internal/socket/resumable_client_socket.go index 8e94104..8dc7fb4 100644 --- a/internal/socket/client_resume.go +++ b/internal/socket/resumable_client_socket.go @@ -16,7 +16,7 @@ import ( const reconnectDelay = 1 * time.Second type resumeClientSocket struct { - *baseSocket + *BaseSocket connects *atomic.Int32 setup *SetupInfo tp transport.ClientTransportFunc @@ -25,7 +25,7 @@ type resumeClientSocket struct { func (p *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { p.setup = setup go func(ctx context.Context) { - _ = p.socket.loopWrite(ctx) + _ = p.socket.LoopWrite(ctx) }(ctx) return p.connect(ctx) } @@ -75,7 +75,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { } }(ctx, tp) - var f core.FrameSupport + var f core.WriteableFrame // connect first time. if len(p.setup.Token) < 1 || connects == 1 { @@ -90,7 +90,7 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { return } - f = framing.NewResumeFrameSupport( + f = framing.NewWriteableResumeFrame( core.DefaultVersion, p.setup.Token, p.socket.counter.WriteBytes(), @@ -145,9 +145,9 @@ func (p *resumeClientSocket) isClosed() bool { } // NewClientResume creates a client-side socket with resume support. -func NewClientResume(tp transport.ClientTransportFunc, socket *DuplexRSocket) ClientSocket { +func NewClientResume(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { return &resumeClientSocket{ - baseSocket: newBaseSocket(socket), + BaseSocket: NewBaseSocket(socket), connects: atomic.NewInt32(0), tp: tp, } diff --git a/internal/socket/server_resume.go b/internal/socket/resumable_server_socket.go similarity index 82% rename from internal/socket/server_resume.go rename to internal/socket/resumable_server_socket.go index 9f876c7..e9a580c 100644 --- a/internal/socket/server_resume.go +++ b/internal/socket/resumable_server_socket.go @@ -7,7 +7,7 @@ import ( ) type resumeServerSocket struct { - *baseSocket + *BaseSocket token []byte } @@ -33,13 +33,13 @@ func (p *resumeServerSocket) Start(ctx context.Context) error { defer func() { _ = p.Close() }() - return p.socket.loopWrite(ctx) + return p.socket.LoopWrite(ctx) } // NewServerResume creates a new server-side socket with resume support. -func NewServerResume(socket *DuplexRSocket, token []byte) ServerSocket { +func NewServerResume(socket *DuplexConnection, token []byte) ServerSocket { return &resumeServerSocket{ - baseSocket: newBaseSocket(socket), + BaseSocket: NewBaseSocket(socket), token: token, } } diff --git a/internal/socket/server_default.go b/internal/socket/server_default.go deleted file mode 100644 index 84cbf7b..0000000 --- a/internal/socket/server_default.go +++ /dev/null @@ -1,41 +0,0 @@ -package socket - -import ( - "context" - - "github.com/rsocket/rsocket-go/core/transport" -) - -type serverSocket struct { - *baseSocket -} - -func (p *serverSocket) Pause() bool { - return false -} - -func (p *serverSocket) SetResponder(responder Responder) { - p.socket.SetResponder(responder) -} - -func (p *serverSocket) SetTransport(tp *transport.Transport) { - p.socket.SetTransport(tp) -} - -func (p *serverSocket) Token() (token []byte, ok bool) { - return -} - -func (p *serverSocket) Start(ctx context.Context) error { - defer func() { - _ = p.Close() - }() - return p.socket.loopWrite(ctx) -} - -// NewServer creates a new server-side socket. -func NewServer(socket *DuplexRSocket) ServerSocket { - return &serverSocket{ - baseSocket: newBaseSocket(socket), - } -} diff --git a/internal/socket/client_default.go b/internal/socket/simple_client_socket.go similarity index 73% rename from internal/socket/client_default.go rename to internal/socket/simple_client_socket.go index 84647a9..40d0b91 100644 --- a/internal/socket/client_default.go +++ b/internal/socket/simple_client_socket.go @@ -9,12 +9,12 @@ import ( "github.com/rsocket/rsocket-go/logger" ) -type defaultClientSocket struct { - *baseSocket +type simpleClientSocket struct { + *BaseSocket tp transport.ClientTransportFunc } -func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { +func (p *simpleClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err error) { tp, err := p.tp(ctx) if err != nil { return @@ -29,7 +29,6 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err tp.RegisterHandler(transport.OnLease, func(frame core.Frame) (err error) { lease := frame.(*framing.LeaseFrame) p.refreshLease(lease.TimeToLive(), int64(lease.NumberOfRequests())) - logger.Infof(">>>>> refresh lease: %v\n", lease) return }) } @@ -46,18 +45,16 @@ func (p *defaultClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err _ = p.Close() }(ctx, tp) - go func(ctx context.Context) { - _ = p.socket.loopWrite(ctx) - }(ctx) + go p.socket.LoopWrite(ctx) setupFrame := setup.toFrame() err = p.socket.tp.Send(setupFrame, true) return } // NewClient create a simple client-side socket. -func NewClient(tp transport.ClientTransportFunc, socket *DuplexRSocket) ClientSocket { - return &defaultClientSocket{ - baseSocket: newBaseSocket(socket), +func NewClient(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { + return &simpleClientSocket{ + BaseSocket: NewBaseSocket(socket), tp: tp, } } diff --git a/internal/socket/simple_client_socket_test.go b/internal/socket/simple_client_socket_test.go new file mode 100644 index 0000000..b26c85b --- /dev/null +++ b/internal/socket/simple_client_socket_test.go @@ -0,0 +1,143 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestNewClientWithBrokenTransporter(t *testing.T) { + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + // Must failed transporter + transporter := func(ctx context.Context) (*transport.Transport, error) { + return nil, fakeErr + } + cli := socket.NewClient(transporter, ds) + err := cli.Setup(context.Background(), fakeSetup) + assert.Equal(t, fakeErr, err, "should be fake error") +} + +func TestNewClient(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + cli := socket.NewClient(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer func() { + err := cli.Close() + assert.NoError(t, err, "close client failed") + }() + + err := cli.Setup(context.Background(), fakeSetup) + assert.NoError(t, err, "setup client failed") + + requestId := atomic.NewUint32(1) + nextRequestId := func() uint32 { + return requestId.Add(2) - 2 + } + + result, err := cli.RequestResponse(payload.New(fakeData, fakeMetadata)). + DoOnSubscribe(func(s rx.Subscription) { + readChan <- framing.NewPayloadFrame(nextRequestId(), fakeData, fakeMetadata, core.FlagComplete) + }). + Block(context.Background()) + assert.NoError(t, err, "request response failed") + assert.Equal(t, fakeData, result.Data(), "response data doesn't match") + assert.Equal(t, fakeMetadata, extractMetadata(result), "response metadata doesn't match") + + var stream []payload.Payload + _, err = cli.RequestStream(payload.New(fakeData, fakeMetadata)). + DoOnNext(func(input payload.Payload) { + stream = append(stream, input) + }). + DoOnSubscribe(func(s rx.Subscription) { + nextId := nextRequestId() + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext|core.FlagComplete) + }). + BlockLast(context.Background()) + assert.NoError(t, err, "request stream failed") + + // When a fatal error occurred, client should be stopped immediately. + fatalErr := []byte("fatal error") + readChan <- framing.NewErrorFrame(0, core.ErrorCodeRejected, fatalErr) + time.Sleep(100 * time.Millisecond) + err = ds.GetError() + assert.Error(t, err, "should get error") + assert.Equal(t, fatalErr, err.(core.CustomError).ErrorData()) +} + +func TestLease(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + cli := socket.NewClient(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer func() { + err := cli.Close() + assert.NoError(t, err, "close client failed") + }() + + setup := *fakeSetup + setup.Lease = true + err := cli.Setup(context.Background(), &setup) + assert.NoError(t, err, "setup client failed") + readChan <- framing.NewLeaseFrame(10*time.Second, 10, fakeMetadata) + time.Sleep(3 * time.Second) +} + +func extractMetadata(p payload.Payload) []byte { + m, _ := p.Metadata() + return m +} diff --git a/internal/socket/simple_server_socket.go b/internal/socket/simple_server_socket.go new file mode 100644 index 0000000..6247b07 --- /dev/null +++ b/internal/socket/simple_server_socket.go @@ -0,0 +1,41 @@ +package socket + +import ( + "context" + + "github.com/rsocket/rsocket-go/core/transport" +) + +type simpleServerSocket struct { + *BaseSocket +} + +func (p *simpleServerSocket) Pause() bool { + return false +} + +func (p *simpleServerSocket) SetResponder(responder Responder) { + p.socket.SetResponder(responder) +} + +func (p *simpleServerSocket) SetTransport(tp *transport.Transport) { + p.socket.SetTransport(tp) +} + +func (p *simpleServerSocket) Token() (token []byte, ok bool) { + return +} + +func (p *simpleServerSocket) Start(ctx context.Context) error { + defer func() { + _ = p.Close() + }() + return p.socket.LoopWrite(ctx) +} + +// NewSimpleServerSocket creates a new server-side socket. +func NewSimpleServerSocket(socket *DuplexConnection) ServerSocket { + return &simpleServerSocket{ + BaseSocket: NewBaseSocket(socket), + } +} diff --git a/internal/socket/simple_server_socket_test.go b/internal/socket/simple_server_socket_test.go new file mode 100644 index 0000000..d49d4cc --- /dev/null +++ b/internal/socket/simple_server_socket_test.go @@ -0,0 +1,81 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +var fakeResponder = rsocket.NewAbstractSocket() + +func TestSimpleServerSocket_Start(t *testing.T) { + ctrl, conn, tp := Init(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + setupFrame := framing.NewSetupFrame( + core.DefaultVersion, + 30*time.Second, + 90*time.Second, + nil, + fakeMimeType, + fakeMimeType, + fakeData, + fakeMetadata, + false, + ) + readChan <- setupFrame + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).AnyTimes() + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + firstFrame, err := tp.ReadFirst(context.Background()) + assert.NoError(t, err, "read first frame failed") + assert.Equal(t, setupFrame, firstFrame, "first should be setup frame") + + close(readChan) + + ds := socket.NewServerDuplexConnection(fragmentation.MaxFragment, nil) + ss := socket.NewSimpleServerSocket(ds) + ss.SetResponder(fakeResponder) + ss.SetTransport(tp) + + assert.Equal(t, false, ss.Pause(), "should always returns false") + token, ok := ss.Token() + assert.False(t, ok) + assert.Nil(t, token, "token should be nil") + + done := make(chan struct{}) + go func() { + defer close(done) + err := ss.Start(context.Background()) + assert.NoError(t, err, "start server socket failed") + }() + + err = tp.Start(context.Background()) + assert.NoError(t, err, "start transport failed") + + _ = ds.Close() + + <-done +} diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go new file mode 100644 index 0000000..29e0551 --- /dev/null +++ b/internal/socket/socket_test.go @@ -0,0 +1,42 @@ +package socket_test + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/logger" +) + +var ( + fakeErr = errors.New("fake error") + + fakeMetadata = []byte("fake-metadata") + fakeData = []byte("fake-data") + fakeMimeType = []byte("fake-mime-type") + + fakeSetup = &socket.SetupInfo{ + Version: core.DefaultVersion, + MetadataMimeType: fakeMimeType, + DataMimeType: fakeMimeType, + Metadata: fakeMetadata, + Data: fakeData, + KeepaliveLifetime: 90 * time.Second, + KeepaliveInterval: 30 * time.Second, + } +) + +func Init(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { + ctrl := gomock.NewController(t) + conn := NewMockConn(ctrl) + tp := transport.NewTransport(conn) + return ctrl, conn, tp +} + +func init() { + logger.SetLevel(logger.LevelError) +} diff --git a/internal/socket/stream_id.go b/internal/socket/stream_id.go index 6ee3018..771e2a8 100644 --- a/internal/socket/stream_id.go +++ b/internal/socket/stream_id.go @@ -9,8 +9,10 @@ const ( _halfSeed uint64 = 0x40000000 ) +// StreamID can be used to generate stream ids. type StreamID interface { - Next() (id uint32, lap1st bool) + // Next returns next stream id. + Next() (id uint32, firstLoop bool) } type serverStreamIDs struct { diff --git a/internal/socket/types.go b/internal/socket/types.go index c0cec5b..6c109c2 100644 --- a/internal/socket/types.go +++ b/internal/socket/types.go @@ -37,7 +37,7 @@ type ClientSocket interface { Closeable Responder // Setup setups current socket. - Setup(ctx context.Context, setup *SetupInfo) (err error) + Setup(ctx context.Context, setup *SetupInfo) error } // ServerSocket represents a server-side socket. diff --git a/justfile b/justfile index f3ec1e7..e0efec4 100644 --- a/justfile +++ b/justfile @@ -3,6 +3,21 @@ default: lint: golangci-lint run ./... test: - go test -race -count=1 . -v + go test -count=1 -coverprofile=coverage.out \ + ./balancer/... \ + ./core/... \ + ./extension/... \ + ./internal/... \ + ./lease/... \ + ./logger/... \ + ./payload/... \ + ./rx/... \ + . +test-no-cover: + go test -count=1 ./... -v +test-race: + go test -race -count=1 ./... -v fmt: @go fmt ./... +cover: + @go tool cover -html=coverage.out diff --git a/lease/lease_test.go b/lease/lease_test.go index 50f54b1..fd2ccf7 100644 --- a/lease/lease_test.go +++ b/lease/lease_test.go @@ -3,79 +3,19 @@ package lease_test import ( "context" "fmt" - "log" "testing" "time" - "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/lease" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx/mono" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/atomic" ) -var _tp rsocket.Transporter - -func init() { - _tp = rsocket.Tcp().HostAndPort("127.0.0.1", 7979).Build() -} - -func Init(ctx context.Context, started chan<- struct{}) { - l, _ := lease.NewSimpleLease(10*time.Second, 7*time.Second, 1*time.Second, 5) - err := rsocket.Receive(). - Lease(l). - OnStart(func() { - close(started) - }). - Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { - return rsocket.NewAbstractSocket( - rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - return mono.Just(msg) - }), - ), nil - }). - Transport(_tp). - Serve(ctx) - if err != nil { - log.Fatal(err) - } -} - -func TestClientWithLease(t *testing.T) { - started := make(chan struct{}) - go Init(context.Background(), started) - <-started - - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - cli, err := rsocket.Connect(). - Lease(). - Transport(_tp). - Start(ctx) - if err != nil { - require.NoError(t, err, "connect failed") - } - defer cli.Close() - - success := atomic.NewUint32(0) - -Loop: - for { - select { - case <-ctx.Done(): - break Loop - default: - time.Sleep(1 * time.Second) - v, err := cli.RequestResponse(payload.NewString("hello world", "go")).Block(context.Background()) - if err != nil { - fmt.Println("request failed:", err) - } else { - success.Inc() - fmt.Println("request success:", v) - } - } - } - assert.Equal(t, uint32(10), success.Load(), "bad requests") +func TestSimpleLease_Next(t *testing.T) { + l, err := lease.NewSimpleLease(3*time.Second, 1*time.Second, 1*time.Second, 1) + assert.NoError(t, err, "create simple lease failed") + lease, ok := l.Next(context.Background()) + assert.True(t, ok, "get next lease chan failed") + next, ok := <-lease + assert.True(t, ok, "get lease failed") + fmt.Println(next) } diff --git a/logger/logger.go b/logger/logger.go index 24d486e..df24833 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -1,30 +1,9 @@ package logger -import ( - "fmt" - "log" -) +import "log" -// Func is alias of logger function. -type Func = func(string, ...interface{}) - -// Level is level of logger. -type Level int8 - -func (s Level) String() string { - switch s { - case LevelDebug: - return "DEBUG" - case LevelInfo: - return "INFO" - case LevelWarn: - return "WARN" - case LevelError: - return "ERROR" - default: - return "UNKNOWN" - } -} +var _level = LevelInfo +var _logger Logger = simpleLogger{} const ( // LevelDebug is DEBUG level. @@ -37,93 +16,90 @@ const ( LevelError ) -var ( - lvl = LevelInfo - i, d, w, e = log.Printf, log.Printf, log.Printf, log.Printf - prefix = true -) +// Logger is used to print logs. +type Logger interface { + // Debugf print to the debug level logs. + Debugf(format string, args ...interface{}) + // Infof print to the info level logs. + Infof(format string, args ...interface{}) + // Warnf print to the info level logs. + Warnf(format string, args ...interface{}) + // Errorf print to the info level logs. + Errorf(format string, args ...interface{}) +} + +// Level is level of logger. +type Level int8 // SetLevel set global RSocket log level. -// Available levels are `LogLevelDebug`, `LogLevelInfo`, `LogLevelWarn` and `LogLevelError`. +// Available levels are `LevelDebug`, `LevelInfo`, `LevelWarn` and `LevelError`. func SetLevel(level Level) { - lvl = level + _level = level } -// DisablePrefix disable print level prefix. -func DisablePrefix() { - prefix = false +// SetLogger customize the global logger. +// A standard log implementation will be used by default. +func SetLogger(logger Logger) { + _logger = logger } // GetLevel returns current logger level. func GetLevel() Level { - return lvl -} - -// SetFunc set logger func for custom level. -func SetFunc(level Level, fn Func) { - if fn == nil { - return - } - if level&LevelDebug != 0 { - d = fn - } - if level&LevelInfo != 0 { - i = fn - } - if level&LevelWarn != 0 { - w = fn - } - if level&LevelError != 0 { - e = fn - } + return _level } // IsDebugEnabled returns true if debug level is open. func IsDebugEnabled() bool { - return lvl <= LevelDebug + return _level <= LevelDebug } // Debugf prints debug level log. -func Debugf(format string, v ...interface{}) { - if lvl > LevelDebug { +func Debugf(format string, args ...interface{}) { + if _logger == nil || _level > LevelDebug { return } - if prefix { - d(fmt.Sprintf("[%s] %s", LevelDebug, format), v...) - } else { - d(format, v...) - } + _logger.Debugf(format, args...) } // Infof prints info level log. -func Infof(format string, v ...interface{}) { - if lvl > LevelInfo { +func Infof(format string, args ...interface{}) { + if _logger == nil || _level > LevelInfo { return } - if prefix { - i(fmt.Sprintf("[%s] %s", LevelInfo, format), v...) - } else { - i(format, v...) - } + _logger.Infof(format, args...) } // Warnf prints warn level log. -func Warnf(format string, v ...interface{}) { - if lvl > LevelWarn { +func Warnf(format string, args ...interface{}) { + if _logger == nil || _level > LevelWarn { return } - if prefix { - w(fmt.Sprintf("[%s] %s", LevelWarn, format), v...) - } else { - w(format, v...) - } + _logger.Warnf(format, args...) } // Errorf prints error level log. -func Errorf(format string, v ...interface{}) { - if prefix { - e(fmt.Sprintf("[%s] %s", LevelError, format), v...) - } else { - e(format, v...) +func Errorf(format string, args ...interface{}) { + if _logger == nil || _level > LevelError { + return } + _logger.Errorf(format, args...) +} + +type simpleLogger struct { +} + +func (s simpleLogger) Debugf(format string, args ...interface{}) { + log.Printf("[DEBUG] "+format, args...) +} + +func (s simpleLogger) Infof(format string, args ...interface{}) { + log.Printf("[INFO] "+format, args...) +} + +func (s simpleLogger) Warnf(format string, args ...interface{}) { + log.Printf("[WARN] "+format, args...) +} + +func (s simpleLogger) Errorf(format string, args ...interface{}) { + log.Printf("[ERROR] "+format, args...) } diff --git a/rsocket_test.go b/rsocket_test.go index 4d34011..2bfe511 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -7,7 +7,6 @@ import ( "testing" . "github.com/rsocket/rsocket-go" - "github.com/rsocket/rsocket-go/logger" . "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" @@ -21,22 +20,6 @@ const ( channelElements = int32(2) ) -func init() { - //logger.SetLevel(logger.LevelDebug) - logger.SetFunc(logger.LevelInfo, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelDebug, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelWarn, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) - logger.SetFunc(logger.LevelError, func(s string, i ...interface{}) { - fmt.Printf(s, i...) - }) -} - var testData = "Hello World!" func TestSuite(t *testing.T) { diff --git a/server.go b/server.go index d5c305a..9a2c68a 100644 --- a/server.go +++ b/server.go @@ -185,7 +185,7 @@ func (p *server) Serve(ctx context.Context) error { } }(ctx, sendingSocket) default: - err := framing.NewErrorFrameSupport(0, core.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) + err := framing.NewWriteableErrorFrame(0, core.ErrorCodeConnectionError, []byte("first frame must be setup or resume")) _ = tp.Send(err, true) _ = tp.Close() return @@ -205,10 +205,10 @@ func (p *server) Serve(ctx context.Context) error { return t.Listen(ctx, serveNotifier) } -func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) (sendingSocket socket.ServerSocket, err *framing.ErrorFrameSupport) { +func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) (sendingSocket socket.ServerSocket, err *framing.WriteableErrorFrame) { if frame.Header().Flag().Check(core.FlagLease) && p.leases == nil { - err = framing.NewErrorFrameSupport(0, core.ErrorCodeUnsupportedSetup, errUnavailableLease) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeUnsupportedSetup, errUnavailableLease) return } @@ -216,17 +216,17 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc // 1. receive a token but server doesn't support resume. if isResume && !p.resumeOpts.enable { - err = framing.NewErrorFrameSupport(0, core.ErrorCodeUnsupportedSetup, errUnavailableResume) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeUnsupportedSetup, errUnavailableResume) return } - rawSocket := socket.NewServerDuplexRSocket(p.fragment, p.leases) + rawSocket := socket.NewServerDuplexConnection(p.fragment, p.leases) // 2. no resume if !isResume { - sendingSocket = socket.NewServer(rawSocket) + sendingSocket = socket.NewSimpleServerSocket(rawSocket) if responder, e := p.acc(frame, sendingSocket); e != nil { - err = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedSetup, []byte(e.Error())) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedSetup, []byte(e.Error())) } else { sendingSocket.SetResponder(responder) sendingSocket.SetTransport(tp) @@ -239,7 +239,7 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc // 3. resume reject because of duplicated token. if _, ok := p.sm.Load(token); ok { - err = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedSetup, errDuplicatedSetupToken) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedSetup, errDuplicatedSetupToken) return } @@ -249,9 +249,9 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc if responder, e := p.acc(frame, sendingSocket); e != nil { switch vv := e.(type) { case *framing.ErrorFrame: - err = framing.NewErrorFrameSupport(0, vv.ErrorCode(), vv.ErrorData()) + err = framing.NewWriteableErrorFrame(0, vv.ErrorCode(), vv.ErrorData()) default: - err = framing.NewErrorFrameSupport(0, core.ErrorCodeInvalidSetup, []byte(e.Error())) + err = framing.NewWriteableErrorFrame(0, core.ErrorCodeInvalidSetup, []byte(e.Error())) } } else { sendingSocket.SetResponder(responder) @@ -262,18 +262,18 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc } func (p *server) doResume(frame *framing.ResumeFrame, tp *transport.Transport, socketChan chan<- socket.ServerSocket) { - var sending core.FrameSupport + var sending core.WriteableFrame if !p.resumeOpts.enable { - sending = framing.NewErrorFrameSupport(0, core.ErrorCodeRejectedResume, errUnavailableResume) + sending = framing.NewWriteableErrorFrame(0, core.ErrorCodeRejectedResume, errUnavailableResume) } else if s, ok := p.sm.Load(frame.Token()); ok { - sending = framing.NewResumeOKFrameSupport(0) + sending = framing.NewWriteableResumeOKFrame(0) s.Socket().SetTransport(tp) socketChan <- s.Socket() if logger.IsDebugEnabled() { logger.Debugf("recover session: %s\n", s) } } else { - sending = framing.NewErrorFrameSupport( + sending = framing.NewWriteableErrorFrame( 0, core.ErrorCodeRejectedResume, []byte("no such session"), diff --git a/transporter.go b/transporter.go index 2562162..1560cb5 100644 --- a/transporter.go +++ b/transporter.go @@ -28,13 +28,13 @@ type TcpTransporterBuilder struct { func (t *tcpTransporter) Server() transport.ServerTransportFunc { return func(ctx context.Context) (transport.ServerTransport, error) { - return transport.NewTcpServerTransport("tcp", t.addr, t.tls), nil + return transport.NewTcpServerTransportWithAddr("tcp", t.addr, t.tls), nil } } func (t *tcpTransporter) Client() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTcpClientTransport("tcp", t.addr, t.tls) + return transport.NewTcpClientTransportWithAddr("tcp", t.addr, t.tls) } } @@ -141,13 +141,13 @@ func (u *UnixTransporter) Server() transport.ServerTransportFunc { if _, err := os.Stat(u.path); !os.IsNotExist(err) { return nil, err } - return transport.NewTcpServerTransport("unix", u.path, nil), nil + return transport.NewTcpServerTransportWithAddr("unix", u.path, nil), nil } } func (u *UnixTransporter) Client() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTcpClientTransport("unix", u.path, nil) + return transport.NewTcpClientTransportWithAddr("unix", u.path, nil) } } From a62b8b901799f01b7c926be9d510526bd6d84d7e Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sun, 26 Jul 2020 23:02:18 +0800 Subject: [PATCH 13/26] fix lint. --- internal/socket/simple_client_socket.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/socket/simple_client_socket.go b/internal/socket/simple_client_socket.go index 40d0b91..e1770f3 100644 --- a/internal/socket/simple_client_socket.go +++ b/internal/socket/simple_client_socket.go @@ -45,7 +45,9 @@ func (p *simpleClientSocket) Setup(ctx context.Context, setup *SetupInfo) (err e _ = p.Close() }(ctx, tp) - go p.socket.LoopWrite(ctx) + go func() { + _ = p.socket.LoopWrite(ctx) + }() setupFrame := setup.toFrame() err = p.socket.tp.Send(setupFrame, true) return From 5f3afcf5cd483e709c4abd65b0e5c2ec4734b442 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 28 Jul 2020 22:24:37 +0800 Subject: [PATCH 14/26] ut. --- cmd/rsocket-cli/uri.go | 4 +- core/framing/frame_test.go | 26 ++++ ...onn_mock_test.go => net_conn_mock_test.go} | 0 ...mock_test.go => net_listener_mock_test.go} | 0 core/transport/tcp_transport.go | 14 +- core/transport/websocket_transport.go | 39 +++--- core/transport/websocket_transport_test.go | 121 ++++++++++++++++++ core/types_test.go | 13 ++ logger/logger.go | 6 +- logger/logger_mock_test.go | 101 +++++++++++++++ logger/logger_test.go | 49 +++++++ rx/flux/flux.go | 2 + rx/flux/proxy.go | 22 ++++ transporter.go | 2 +- 14 files changed, 364 insertions(+), 35 deletions(-) rename core/transport/{tcp_conn_mock_test.go => net_conn_mock_test.go} (100%) rename core/transport/{tcp_transport_mock_test.go => net_listener_mock_test.go} (100%) create mode 100644 core/transport/websocket_transport_test.go create mode 100644 core/types_test.go create mode 100644 logger/logger_mock_test.go create mode 100644 logger/logger_test.go diff --git a/cmd/rsocket-cli/uri.go b/cmd/rsocket-cli/uri.go index 3e063ee..7d63410 100644 --- a/cmd/rsocket-cli/uri.go +++ b/cmd/rsocket-cli/uri.go @@ -63,13 +63,13 @@ func (p *URI) MakeServerTransport(c *tls.Config) (tp transport.ServerTransport, case schemaTCP: tp = transport.NewTcpServerTransportWithAddr(schemaTCP, p.Host, c) case schemaWebsocket: - tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) + tp = transport.NewWebsocketServerTransportWithAddr(p.Host, p.Path, c) case schemaWebsocketSecure: if c == nil { err = errors.Errorf("missing TLS Config for proto %s", schemaWebsocketSecure) return } - tp = transport.NewWebsocketServerTransport(p.Host, p.Path, c) + tp = transport.NewWebsocketServerTransportWithAddr(p.Host, p.Path, c) case schemaUNIX: tp = transport.NewTcpServerTransportWithAddr(schemaUNIX, p.Path, c) default: diff --git a/core/framing/frame_test.go b/core/framing/frame_test.go index b5221db..a9320d2 100644 --- a/core/framing/frame_test.go +++ b/core/framing/frame_test.go @@ -52,6 +52,8 @@ func TestFrameFNF(t *testing.T) { f := NewFireAndForgetFrame(_sid, b, nil, core.FlagNext) checkBasic(t, f, core.FrameTypeRequestFNF) assert.Equal(t, b, f.Data()) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() metadata, ok := f.Metadata() assert.False(t, ok) assert.Nil(t, metadata) @@ -100,10 +102,14 @@ func TestFrameLease(t *testing.T) { func TestFrameMetadataPush(t *testing.T) { metadata := []byte("foobar") f := NewMetadataPushFrame(metadata) + assert.Nil(t, f.Data(), "should not be nil") + assert.Equal(t, "", f.DataUTF8(), "should be zero string") checkBasic(t, f, core.FrameTypeMetadataPush) metadata2, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, metadata, metadata2) + _, _ = f.MetadataUTF8() + f2 := NewWriteableMetadataPushFrame(metadata) checkBytes(t, f, f2) } @@ -116,9 +122,18 @@ func TestPayloadFrame(t *testing.T) { assert.True(t, ok) assert.Equal(t, b, f.Data()) assert.Equal(t, b, m) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) f2 := NewWriteablePayloadFrame(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) + + assert.Equal(t, b, f2.Data()) + _ = f2.DataUTF8() + m2, ok := f2.Metadata() + assert.True(t, ok) + assert.Equal(t, b, m2) + _, _ = f2.MetadataUTF8() } func TestFrameRequestChannel(t *testing.T) { @@ -131,6 +146,10 @@ func TestFrameRequestChannel(t *testing.T) { m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) + + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + f2 := NewWriteableRequestChannelFrame(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -153,6 +172,8 @@ func TestFrameRequestResponse(t *testing.T) { assert.True(t, ok) assert.Equal(t, b, m) assert.Equal(t, core.FlagNext|core.FlagMetadata, f.Header().Flag()) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() f2 := NewWriteableRequestResponseFrame(_sid, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -167,6 +188,8 @@ func TestFrameRequestStream(t *testing.T) { m, ok := f.Metadata() assert.True(t, ok) assert.Equal(t, b, m) + _, _ = f.MetadataUTF8() + _ = f.DataUTF8() f2 := NewWriteableRequestStreamFrame(_sid, n, b, b, core.FlagNext) checkBytes(t, f, f2) } @@ -219,6 +242,9 @@ func TestFrameSetup(t *testing.T) { assert.True(t, ok) assert.Equal(t, m, m2) + _ = f.DataUTF8() + _, _ = f.MetadataUTF8() + fs := NewWriteableSetupFrame(v, timeKeepalive, maxLifetime, token, mimeMetadata, mimeData, d, m, false) checkBytes(t, f, fs) diff --git a/core/transport/tcp_conn_mock_test.go b/core/transport/net_conn_mock_test.go similarity index 100% rename from core/transport/tcp_conn_mock_test.go rename to core/transport/net_conn_mock_test.go diff --git a/core/transport/tcp_transport_mock_test.go b/core/transport/net_listener_mock_test.go similarity index 100% rename from core/transport/tcp_transport_mock_test.go rename to core/transport/net_listener_mock_test.go diff --git a/core/transport/tcp_transport.go b/core/transport/tcp_transport.go index 1619f8a..5ef0a07 100644 --- a/core/transport/tcp_transport.go +++ b/core/transport/tcp_transport.go @@ -98,21 +98,21 @@ func NewTcpServerTransport(gen func() (net.Listener, error)) ServerTransport { } } -func NewTcpServerTransportWithAddr(network, addr string, c *tls.Config) ServerTransport { +func NewTcpClientTransport(c net.Conn) *Transport { + return NewTransport(NewTcpConn(c)) +} + +func NewTcpServerTransportWithAddr(network, addr string, tlsConfig *tls.Config) ServerTransport { gen := func() (net.Listener, error) { - if c == nil { + if tlsConfig == nil { return net.Listen(network, addr) } else { - return tls.Listen(network, addr, c) + return tls.Listen(network, addr, tlsConfig) } } return NewTcpServerTransport(gen) } -func NewTcpClientTransport(rawConn net.Conn) *Transport { - return NewTransport(NewTcpConn(rawConn)) -} - func NewTcpClientTransportWithAddr(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { var rawConn net.Conn if tlsConfig == nil { diff --git a/core/transport/websocket_transport.go b/core/transport/websocket_transport.go index 5b9f9fe..681f1d4 100644 --- a/core/transport/websocket_transport.go +++ b/core/transport/websocket_transport.go @@ -6,8 +6,6 @@ import ( "io" "net" "net/http" - "os" - "strings" "sync" "time" @@ -21,30 +19,21 @@ const defaultWebsocketPath = "/" var upgrader websocket.Upgrader func init() { - // Default allow CORS. - cors := true - if v, ok := os.LookupEnv("RSOCKET_WS_CORS"); ok { - v = strings.TrimSpace(strings.ToLower(v)) - cors = v == "yes" || v == "on" || v == "1" || v == "true" - } upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - } - if cors { - upgrader.CheckOrigin = func(r *http.Request) bool { + CheckOrigin: func(r *http.Request) bool { return true - } + }, } } type wsServerTransport struct { - addr string path string acceptor ServerTransportAcceptor onceClose sync.Once + listenerFn func() (net.Listener, error) listener net.Listener - tls *tls.Config transports *sync.Map } @@ -69,7 +58,6 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} logger.Errorf("create websocket conn failed: %s\n", err.Error()) return } - tp := NewTransport(NewWebsocketConnection(c)) p.transports.Store(tp, struct{}{}) go p.acceptor(ctx, tp, func(tp *Transport) { @@ -77,14 +65,11 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} }) }) - if p.tls == nil { - p.listener, err = net.Listen("tcp", p.addr) - } else { - p.listener, err = tls.Listen("tcp", p.addr, p.tls) - } + p.listener, err = p.listenerFn() if err != nil { err = errors.Wrap(err, "server listen failed") + close(notifier) return } @@ -112,18 +97,26 @@ func (p *wsServerTransport) Listen(ctx context.Context, notifier chan<- struct{} return } -func NewWebsocketServerTransport(addr string, path string, c *tls.Config) *wsServerTransport { +func NewWebsocketServerTransport(gen func() (net.Listener, error), path string) ServerTransport { if path == "" { path = defaultWebsocketPath } return &wsServerTransport{ - addr: addr, path: path, - tls: c, + listenerFn: gen, transports: &sync.Map{}, } } +func NewWebsocketServerTransportWithAddr(addr string, path string, c *tls.Config) ServerTransport { + return NewWebsocketServerTransport(func() (net.Listener, error) { + if c == nil { + return net.Listen("tcp", addr) + } + return tls.Listen("tcp", addr, c) + }, path) +} + func NewWebsocketClientTransport(url string, tc *tls.Config, header http.Header) (*Transport, error) { var d *websocket.Dialer if tc == nil { diff --git a/core/transport/websocket_transport_test.go b/core/transport/websocket_transport_test.go new file mode 100644 index 0000000..6c3711c --- /dev/null +++ b/core/transport/websocket_transport_test.go @@ -0,0 +1,121 @@ +package transport_test + +import ( + "bytes" + "context" + "io" + "net" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/stretchr/testify/assert" +) + +func TestNewWebsocketServerTransport(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + listener := newMockNetListener(ctrl) + + fakeConnChan := make(chan net.Conn, 1) + + conn := newMockNetConn(ctrl) + + conn.EXPECT(). + RemoteAddr(). + DoAndReturn(func() net.Addr { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") + assert.NoError(t, err, "bad addr") + return addr + }). + AnyTimes() + conn.EXPECT(). + LocalAddr(). + DoAndReturn(func() net.Addr { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:18080") + assert.NoError(t, err, "bad addr") + return addr + }). + AnyTimes() + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Close().Return(nil).Times(1) + bf := &bytes.Buffer{} + conn.EXPECT(). + Read(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Read(b) + }). + AnyTimes() + conn.EXPECT(). + Write(gomock.Any()). + DoAndReturn(func(b []byte) (int, error) { + return bf.Write(b) + }). + AnyTimes() + + fakeConnChan <- conn + + listener.EXPECT(). + Accept(). + DoAndReturn(func() (net.Conn, error) { + c, ok := <-fakeConnChan + if !ok { + return nil, io.EOF + } + return c, nil + }). + AnyTimes() + listener.EXPECT().Close().AnyTimes() + + tp := transport.NewWebsocketServerTransport(func() (net.Listener, error) { + return listener, nil + }, "") + + notifier := make(chan struct{}) + + done := make(chan struct{}) + + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.True(t, err == nil || err == io.EOF) + }() + + _, ok := <-notifier + assert.True(t, ok, "notifier should return ok=true") + + time.Sleep(100 * time.Millisecond) + + close(fakeConnChan) + + <-done +} + +func TestNewWebsocketServerTransport_Broken(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tp := transport.NewWebsocketServerTransport(func() (net.Listener, error) { + return nil, fakeErr + }, "") + tp.Accept(func(ctx context.Context, tp *transport.Transport, onClose func(*transport.Transport)) { + }) + + notifier := make(chan struct{}) + + done := make(chan struct{}) + + go func() { + defer close(done) + err := tp.Listen(context.Background(), notifier) + assert.Equal(t, fakeErr, errors.Cause(err)) + }() + + _, ok := <-notifier + assert.False(t, ok, "notifier should return ok=false") + + <-done +} diff --git a/core/types_test.go b/core/types_test.go new file mode 100644 index 0000000..f40db3d --- /dev/null +++ b/core/types_test.go @@ -0,0 +1,13 @@ +package core_test + +import ( + "testing" + + "github.com/rsocket/rsocket-go/core" + "github.com/stretchr/testify/assert" +) + +func TestFrameFlag_String(t *testing.T) { + f := core.FlagNext | core.FlagComplete | core.FlagFollow | core.FlagMetadata | core.FlagIgnore + assert.True(t, f.String() != "") +} diff --git a/logger/logger.go b/logger/logger.go index df24833..b88c7e0 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,8 +2,10 @@ package logger import "log" -var _level = LevelInfo -var _logger Logger = simpleLogger{} +var ( + _level = LevelInfo + _logger Logger = simpleLogger{} +) const ( // LevelDebug is DEBUG level. diff --git a/logger/logger_mock_test.go b/logger/logger_mock_test.go new file mode 100644 index 0000000..05b4380 --- /dev/null +++ b/logger/logger_mock_test.go @@ -0,0 +1,101 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: logger/logger.go + +// Package logger_test is a generated GoMock package. +package logger_test + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) + +// MockLogger is a mock of Logger interface +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debugf mocks base method +func (m *MockLogger) Debugf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf +func (mr *MockLoggerMockRecorder) Debugf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Infof mocks base method +func (m *MockLogger) Infof(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Infof", varargs...) +} + +// Infof indicates an expected call of Infof +func (mr *MockLoggerMockRecorder) Infof(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...) +} + +// Warnf mocks base method +func (m *MockLogger) Warnf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf +func (mr *MockLoggerMockRecorder) Warnf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...) +} + +// Errorf mocks base method +func (m *MockLogger) Errorf(format string, args ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{format} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf +func (mr *MockLoggerMockRecorder) Errorf(format interface{}, args ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{format}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} diff --git a/logger/logger_test.go b/logger/logger_test.go new file mode 100644 index 0000000..ac26835 --- /dev/null +++ b/logger/logger_test.go @@ -0,0 +1,49 @@ +package logger_test + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/logger" + "github.com/stretchr/testify/assert" +) + +var ( + fakeFormat = "fake format" + fakeArgs = []interface{}{"fake args"} +) + +func TestSetLogger(t *testing.T) { + logger.SetLevel(logger.LevelDebug) + + call := func() { + logger.Debugf(fakeFormat, fakeArgs...) + logger.Infof(fakeFormat, fakeArgs...) + logger.Warnf(fakeFormat, fakeArgs...) + logger.Errorf(fakeFormat, fakeArgs...) + } + + call() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := NewMockLogger(ctrl) + l.EXPECT().Debugf(gomock.Any(), gomock.Any()).Times(1) + l.EXPECT().Infof(gomock.Any(), gomock.Any()).Times(2) + l.EXPECT().Warnf(gomock.Any(), gomock.Any()).Times(2) + l.EXPECT().Errorf(gomock.Any(), gomock.Any()).Times(2) + + logger.SetLogger(l) + assert.Equal(t, logger.LevelDebug, logger.GetLevel(), "wrong logger level") + assert.True(t, logger.IsDebugEnabled(), "should be enabled") + + call() + + logger.SetLevel(logger.LevelInfo) + call() + + logger.SetLevel(logger.LevelDebug) + logger.SetLogger(nil) + call() +} diff --git a/rx/flux/flux.go b/rx/flux/flux.go index 91607e6..59b9e67 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -60,6 +60,8 @@ type Flux interface { // ToChan subscribe to this Flux and puts items into a chan. // It also puts errors into another chan. ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e <-chan error) + // BlockSlice subscribe to this Flux and convert to payload slice. + BlockSlice(context.Context) ([]payload.Payload, error) } // Processor represent a base processor that exposes Flux API for Processor. diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index 44cc770..cf3d8a4 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -121,6 +121,28 @@ func (p proxy) BlockLast(ctx context.Context) (last payload.Payload, err error) return } +func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err error) { + done := make(chan struct{}) + p.Flux. + DoFinally(func(s reactor.SignalType) { + close(done) + }). + DoOnCancel(func() { + err = reactor.ErrSubscribeCancelled + }). + Subscribe( + ctx, + reactor.OnNext(func(v interface{}) { + results = append(results, v.(payload.Payload)) + }), + reactor.OnError(func(e error) { + err = e + }), + ) + <-done + return +} + func (p proxy) DoOnSubscribe(fn rx.FnOnSubscribe) Flux { return newProxy(p.Flux.DoOnSubscribe(func(su reactor.Subscription) { fn(su) diff --git a/transporter.go b/transporter.go index 1560cb5..e8632f15 100644 --- a/transporter.go +++ b/transporter.go @@ -118,7 +118,7 @@ func (w *wsTransporter) Server() transport.ServerTransportFunc { if len(port) < 1 { return nil, errors.New("missing websocket port") } - return transport.NewWebsocketServerTransport(fmt.Sprintf("%s:%s", u.Hostname(), port), u.Path, w.tls), nil + return transport.NewWebsocketServerTransportWithAddr(fmt.Sprintf("%s:%s", u.Hostname(), port), u.Path, w.tls), nil } } From 30c9d1dac0b3150dc4729f3921872c7a7b27bee1 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Wed, 29 Jul 2020 22:23:04 +0800 Subject: [PATCH 15/26] remove race for ut. --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 2081ba1..5ad217c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,5 @@ install: script: - golangci-lint run ./... - - go test -v -covermode=atomic -coverprofile=coverage.out -race -count=1 ./core/... ./balancer/... ./rx/... ./internal/... ./extension/... ./payload/... . + - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 ./logger/... ./lease/... ./core/... ./balancer/... ./rx/... ./internal/... ./extension/... ./payload/... . - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN From 297cdba4b65af69b58cb5c5a4d264c2fb636f813 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Thu, 30 Jul 2020 22:48:19 +0800 Subject: [PATCH 16/26] foobar --- client.go | 2 +- cmd/rsocket-cli/rsocket-cli.go | 2 +- core/transport/websocket_conn.go | 10 +++- internal/socket/base_socket_test.go | 2 +- internal/socket/resumable_client_socket.go | 4 +- .../socket/resumable_client_socket_test.go | 47 +++++++++++++++++++ internal/socket/simple_client_socket_test.go | 4 +- internal/socket/simple_server_socket_test.go | 2 +- internal/socket/socket_test.go | 2 +- rx/flux/flux_test.go | 14 ++++++ rx/mono/utils_test.go | 9 ++++ server.go | 2 +- 12 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 internal/socket/resumable_client_socket_test.go diff --git a/client.go b/client.go index c046a89..f0d7cce 100644 --- a/client.go +++ b/client.go @@ -173,7 +173,7 @@ func (p *clientBuilder) Start(ctx context.Context) (client Client, err error) { var cs setupClientSocket if p.resume != nil { p.setup.Token = p.resume.tokenGen() - cs = socket.NewClientResume(p.tpGen, sk) + cs = socket.NewResumableClientSocket(p.tpGen, sk) } else { cs = socket.NewClient(p.tpGen, sk) } diff --git a/cmd/rsocket-cli/rsocket-cli.go b/cmd/rsocket-cli/rsocket-cli.go index 6d2c6cb..f4fb39c 100644 --- a/cmd/rsocket-cli/rsocket-cli.go +++ b/cmd/rsocket-cli/rsocket-cli.go @@ -41,7 +41,7 @@ func main() { app.UsageText = "rsocket-cli [global options] [URI]" app.Name = "rsocket-cli" app.Usage = "CLI for RSocket." - app.Version = "v0.5" + app.Version = "v0.6" app.Flags = newFlags(conf) app.ArgsUsage = "[URI]" app.Action = func(c *cli.Context) (err error) { diff --git a/core/transport/websocket_conn.go b/core/transport/websocket_conn.go index e90335d..23f32a7 100644 --- a/core/transport/websocket_conn.go +++ b/core/transport/websocket_conn.go @@ -39,15 +39,23 @@ func (p *WsConn) SetDeadline(deadline time.Time) error { func (p *WsConn) Read() (f core.Frame, err error) { t, raw, err := p.c.ReadMessage() + if err == io.EOF { return } + + if websocket.IsCloseError(err) || websocket.IsUnexpectedCloseError(err) || isClosedErr(err) { + err = io.EOF + return + } + if err != nil { err = errors.Wrap(err, "read frame failed") return } + + // Skip non-binary message if t != websocket.BinaryMessage { - logger.Warnf("omit non-binary message %d\n", t) return p.Read() } diff --git a/internal/socket/base_socket_test.go b/internal/socket/base_socket_test.go index 7be55f1..520674f 100644 --- a/internal/socket/base_socket_test.go +++ b/internal/socket/base_socket_test.go @@ -15,7 +15,7 @@ import ( ) func TestBaseSocket(t *testing.T) { - ctrl, conn, tp := Init(t) + ctrl, conn, tp := InitTransport(t) defer ctrl.Finish() conn.EXPECT().Close().Times(1) diff --git a/internal/socket/resumable_client_socket.go b/internal/socket/resumable_client_socket.go index 8dc7fb4..60d08b7 100644 --- a/internal/socket/resumable_client_socket.go +++ b/internal/socket/resumable_client_socket.go @@ -144,8 +144,8 @@ func (p *resumeClientSocket) isClosed() bool { return p.connects.Load() < 0 } -// NewClientResume creates a client-side socket with resume support. -func NewClientResume(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { +// NewResumableClientSocket creates a client-side socket with resume support. +func NewResumableClientSocket(tp transport.ClientTransportFunc, socket *DuplexConnection) ClientSocket { return &resumeClientSocket{ BaseSocket: NewBaseSocket(socket), connects: atomic.NewInt32(0), diff --git a/internal/socket/resumable_client_socket_test.go b/internal/socket/resumable_client_socket_test.go new file mode 100644 index 0000000..565230a --- /dev/null +++ b/internal/socket/resumable_client_socket_test.go @@ -0,0 +1,47 @@ +package socket_test + +import ( + "context" + "io" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/transport" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +func TestNewResumableClientSocket(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + + rcs := socket.NewResumableClientSocket(func(ctx context.Context) (*transport.Transport, error) { + return tp, nil + }, ds) + + defer rcs.Close() + + err := rcs.Setup(context.Background(), fakeSetup) + assert.NoError(t, err) +} diff --git a/internal/socket/simple_client_socket_test.go b/internal/socket/simple_client_socket_test.go index b26c85b..755e5e0 100644 --- a/internal/socket/simple_client_socket_test.go +++ b/internal/socket/simple_client_socket_test.go @@ -30,7 +30,7 @@ func TestNewClientWithBrokenTransporter(t *testing.T) { } func TestNewClient(t *testing.T) { - ctrl, conn, tp := Init(t) + ctrl, conn, tp := InitTransport(t) defer ctrl.Finish() // For test @@ -100,7 +100,7 @@ func TestNewClient(t *testing.T) { } func TestLease(t *testing.T) { - ctrl, conn, tp := Init(t) + ctrl, conn, tp := InitTransport(t) defer ctrl.Finish() // For test diff --git a/internal/socket/simple_server_socket_test.go b/internal/socket/simple_server_socket_test.go index d49d4cc..74551e2 100644 --- a/internal/socket/simple_server_socket_test.go +++ b/internal/socket/simple_server_socket_test.go @@ -18,7 +18,7 @@ import ( var fakeResponder = rsocket.NewAbstractSocket() func TestSimpleServerSocket_Start(t *testing.T) { - ctrl, conn, tp := Init(t) + ctrl, conn, tp := InitTransport(t) defer ctrl.Finish() // For test diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go index 29e0551..aee5693 100644 --- a/internal/socket/socket_test.go +++ b/internal/socket/socket_test.go @@ -30,7 +30,7 @@ var ( } ) -func Init(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { +func InitTransport(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { ctrl := gomock.NewController(t) conn := NewMockConn(ctrl) tp := transport.NewTransport(conn) diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 674f4cb..0c88b66 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -415,3 +415,17 @@ loop: } } + +func TestFlux_BlockSlice(t *testing.T) { + const n = 10 + arr, err := flux. + Create(func(ctx context.Context, s flux.Sink) { + for i := 0; i < n; i++ { + s.Next(payload.NewString("hello", strconv.Itoa(i))) + } + s.Complete() + }). + BlockSlice(context.Background()) + assert.NoError(t, err) + assert.Len(t, arr, n) +} diff --git a/rx/mono/utils_test.go b/rx/mono/utils_test.go index 04127ce..390bff5 100644 --- a/rx/mono/utils_test.go +++ b/rx/mono/utils_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + rsMono "github.com/jjeffcaii/reactor-go/mono" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx/mono" @@ -126,3 +128,10 @@ loop: } } + +func TestRaw(t *testing.T) { + fakePayload := payload.NewString("fake", "payload") + res, err := mono.Raw(rsMono.Just(fakePayload)).Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fakePayload, res) +} diff --git a/server.go b/server.go index 9a2c68a..473c9b7 100644 --- a/server.go +++ b/server.go @@ -191,7 +191,7 @@ func (p *server) Serve(ctx context.Context) error { return } if err := tp.Start(ctx); err != nil { - logger.Warnf("transport exit: %s\n", err.Error()) + logger.Warnf("transport exit: %+v\n", err) } }) From a76d11127fb56efe4f2087da79992d8a35ec1ec0 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 4 Aug 2020 23:15:21 +0800 Subject: [PATCH 17/26] Bump reactor-go to v0.2.0. --- balancer/group_test.go | 10 ++--- go.mod | 2 +- go.sum | 11 ++++- internal/socket/abstract_socket_test.go | 6 ++- internal/socket/callback.go | 2 +- internal/socket/duplex.go | 12 ++++-- internal/socket/simple_client_socket_test.go | 3 +- rsocket_example_test.go | 15 ++++--- rsocket_test.go | 18 ++++++--- rx/flux/flux.go | 2 +- rx/flux/flux_test.go | 42 +++++++++++++------- rx/flux/proxy.go | 20 +++++----- rx/flux/utils.go | 3 +- rx/mono/mono_test.go | 18 ++++++--- rx/mono/proxy.go | 25 ++++++------ rx/rx.go | 2 +- rx/subscriber.go | 14 ++++--- 17 files changed, 127 insertions(+), 78 deletions(-) diff --git a/balancer/group_test.go b/balancer/group_test.go index 9ec4cca..5dcd6f2 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -5,7 +5,6 @@ import ( "crypto/md5" "errors" "fmt" - "log" "testing" "time" @@ -43,7 +42,7 @@ func ExampleNewGroup() { if !ok { panic(errors.New("missing service ID in metadata")) } - log.Println("[broker] redirect request to service", requestServiceID) + fmt.Println("[broker] redirect request to service", requestServiceID) return group.Get(requestServiceID).MustNext(context.Background()).RequestResponse(msg) })), nil }). @@ -72,7 +71,7 @@ func TestServiceSubscribe(t *testing.T) { Acceptor(func(socket RSocket) RSocket { return NewAbstractSocket(RequestResponse(func(msg payload.Payload) mono.Mono { result := payload.NewString(fmt.Sprintf("%02x", md5.Sum(msg.Data())), "MD5 RESULT") - log.Println("[publisher] accept MD5 request:", msg.DataUTF8()) + fmt.Println("[publisher] accept MD5 request:", msg.DataUTF8()) return mono.Just(result) })) }). @@ -98,9 +97,10 @@ func TestServiceSubscribe(t *testing.T) { time.Sleep(200 * time.Millisecond) }() _, err = cli.RequestResponse(payload.NewString("Hello World!", "md5")). - DoOnSuccess(func(elem payload.Payload) { - log.Println("[subscriber] receive MD5 response:", elem.DataUTF8()) + DoOnSuccess(func(elem payload.Payload) error { + fmt.Println("[subscriber] receive MD5 response:", elem.DataUTF8()) require.Equal(t, "ed076287532e86365e841e92bfc50d8c", elem.DataUTF8(), "bad md5") + return nil }). Block(context.Background()) require.NoError(t, err, "request failed") diff --git a/go.mod b/go.mod index c21fcae..73bece1 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang/mock v1.4.3 github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 - github.com/jjeffcaii/reactor-go v0.1.4 + github.com/jjeffcaii/reactor-go v0.2.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.7.1 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 28aec31..913b615 100644 --- a/go.sum +++ b/go.sum @@ -34,21 +34,24 @@ github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0 github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jjeffcaii/reactor-go v0.1.4 h1:/M2Mjy72u+4Q9PQpq/i4bxFpXjaR1pxUh1GfMXUZa1A= -github.com/jjeffcaii/reactor-go v0.1.4/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/jjeffcaii/reactor-go v0.2.0 h1:sIiEfclB65HH4Ne+Cz1Q8EaKn88/7le5hYHzjHhrRvA= +github.com/jjeffcaii/reactor-go v0.2.0/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -98,6 +101,7 @@ go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -118,8 +122,10 @@ golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fq golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -130,6 +136,7 @@ google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyz google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/socket/abstract_socket_test.go b/internal/socket/abstract_socket_test.go index 3896ac2..f042281 100644 --- a/internal/socket/abstract_socket_test.go +++ b/internal/socket/abstract_socket_test.go @@ -72,8 +72,9 @@ func TestAbstractRSocket_RequestStream(t *testing.T) { var res []payload.Payload _, err := s.RequestStream(fakeRequest). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { res = append(res, input) + return nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -92,8 +93,9 @@ func TestAbstractRSocket_RequestChannel(t *testing.T) { } var res []payload.Payload _, err := s.RequestChannel(flux.Just(fakeRequest)). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { res = append(res, input) + return nil }). BlockLast(context.Background()) assert.NoError(t, err) diff --git a/internal/socket/callback.go b/internal/socket/callback.go index c97aabb..6afea4b 100644 --- a/internal/socket/callback.go +++ b/internal/socket/callback.go @@ -38,7 +38,7 @@ func (s requestChannelCallback) Close(err error) { } type requestResponseCallbackReverse struct { - su rs.Subscription + su reactor.Subscription } func (s requestResponseCallbackReverse) Close(err error) { diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index e895144..410f449 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -275,7 +275,7 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux sndRequested := make(chan struct{}) sub := rx.NewSubscriber( - rx.OnNext(func(item payload.Payload) { + rx.OnNext(func(item payload.Payload) (err error) { var newborn bool select { case <-sndRequested: @@ -306,6 +306,7 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux } p.sendFrame(f) }) + return }), rx.OnSubscribe(func(s rx.Subscription) { p.register(sid, requestChannelCallback{rcv: receiving, snd: s}) @@ -363,8 +364,9 @@ func (p *DuplexConnection) respondRequestResponse(receiving fragmentation.Header // 4. async subscribe publisher sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { p.sendPayload(sid, input, core.FlagNext|core.FlagComplete) + return nil }), rx.OnError(func(e error) { p.writeError(sid, e) @@ -465,8 +467,9 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo close(mustSub) s.Request(initRequestN) }), - rx.OnNext(func(elem payload.Payload) { + rx.OnNext(func(elem payload.Payload) error { p.sendPayload(sid, elem, core.FlagNext) + return nil }), ) @@ -559,8 +562,9 @@ func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAn } sub := rx.NewSubscriber( - rx.OnNext(func(elem payload.Payload) { + rx.OnNext(func(elem payload.Payload) error { p.sendPayload(sid, elem, core.FlagNext) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { p.register(sid, requestStreamCallbackReverse{su: s}) diff --git a/internal/socket/simple_client_socket_test.go b/internal/socket/simple_client_socket_test.go index 755e5e0..43fb872 100644 --- a/internal/socket/simple_client_socket_test.go +++ b/internal/socket/simple_client_socket_test.go @@ -78,8 +78,9 @@ func TestNewClient(t *testing.T) { var stream []payload.Payload _, err = cli.RequestStream(payload.New(fakeData, fakeMetadata)). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { stream = append(stream, input) + return nil }). DoOnSubscribe(func(s rx.Subscription) { nextId := nextRequestId() diff --git a/rsocket_example_test.go b/rsocket_example_test.go index f5aa4d7..e926924 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -46,8 +46,9 @@ func Example() { _ = cli.Close() }() cli.RequestResponse(payload.NewString("Ping", time.Now().String())). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("incoming response:", elem) + return nil }). Subscribe(context.Background()) } @@ -67,8 +68,9 @@ func ExampleReceive() { // Request to client. sendingSocket.RequestResponse(payload.NewString("Ping", time.Now().String())). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("response of Ping from client:", elem) + return nil }). SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) @@ -123,16 +125,18 @@ func ExampleConnect() { cli.FireAndForget(payload.NewString("This is a FNF message.", "")) // Simple RequestResponse. cli.RequestResponse(payload.NewString("This is a RequestResponse message.", "")). - DoOnSuccess(func(elem payload.Payload) { + DoOnSuccess(func(elem payload.Payload) error { log.Println("response:", elem) + return nil }). Subscribe(context.Background()) var s rx.Subscription // RequestStream with backpressure. (one by one) cli.RequestStream(payload.NewString("This is a RequestStream message.", "")). - DoOnNext(func(elem payload.Payload) { + DoOnNext(func(elem payload.Payload) error { log.Println("next element in stream:", elem) s.Request(1) + return nil }). Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { s.Request(1) @@ -145,8 +149,9 @@ func ExampleConnect() { s.Complete() }) cli.RequestChannel(sendFlux). - DoOnNext(func(elem payload.Payload) { + DoOnNext(func(elem payload.Payload) error { log.Println("next element in channel:", elem) + return nil }). Subscribe(context.Background()) } diff --git a/rsocket_test.go b/rsocket_test.go index 2bfe511..ab46bce 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -80,9 +80,10 @@ func testAll(t *testing.T, proto string, tp Transporter) { inputs.(flux.Flux).DoFinally(func(s rx.SignalType) { close(receives) - }).Subscribe(context.Background(), rx.OnNext(func(input Payload) { + }).Subscribe(context.Background(), rx.OnNext(func(input Payload) error { //fmt.Println("rcv from channel:", input) receives <- input + return nil })) return flux.Create(func(ctx context.Context, s flux.Sink) { @@ -149,11 +150,12 @@ func testRequestStream(ctx context.Context, cli Client, t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d", atomic.LoadInt32(&seq)), m, "bad stream metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad stream data") atomic.AddInt32(&seq, 1) + return nil }). BlockLast(ctx) <-done @@ -169,12 +171,13 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d", atomic.LoadInt32(&seq)), m, "bad stream metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad stream data") atomic.AddInt32(&seq, 1) su.Request(1) + return nil }). Subscribe(ctx, rx.OnSubscribe(func(s rx.Subscription) { su = s @@ -196,12 +199,13 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) { var seq int _, err := cli.RequestChannel(send). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { //fmt.Println(elem) m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d_from_server", seq), m, "bad channel metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad channel data") seq++ + return nil }). BlockLast(ctx) assert.NoError(t, err, "block last failed") @@ -227,15 +231,17 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) { assert.Equal(t, rx.SignalComplete, s, "bad signal type") close(done) }). - DoOnNext(func(elem Payload) { + DoOnNext(func(elem Payload) error { fmt.Println(elem) m, _ := elem.MetadataUTF8() assert.Equal(t, fmt.Sprintf("%d_from_server", seq), m, "bad channel metadata") assert.Equal(t, testData, elem.DataUTF8(), "bad channel data") seq++ + return nil }). - Subscribe(ctx, rx.OnNext(func(elem Payload) { + Subscribe(ctx, rx.OnNext(func(elem Payload) error { su.Request(1) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { su = s su.Request(1) diff --git a/rx/flux/flux.go b/rx/flux/flux.go index 59b9e67..f066fd0 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -44,7 +44,7 @@ type Flux interface { // DoOnSubscribe add behavior triggered when the Flux is done being subscribed. DoOnSubscribe(rx.FnOnSubscribe) Flux // Map transform the items emitted by this Flux by applying a synchronous function to each item. - Map(func(payload.Payload) payload.Payload) Flux + Map(func(payload.Payload) (payload.Payload, error)) Flux // SwitchOnFirst transform the current Flux once it emits its first element, making a conditional transformation possible. SwitchOnFirst(FnSwitchOnFirst) Flux // SubscribeOn run subscribe, onSubscribe and request on a specified scheduler. diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 0c88b66..6d00419 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/jjeffcaii/reactor-go" nativeFlux "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go/payload" @@ -19,8 +20,9 @@ import ( func TestEmpty(t *testing.T) { last, err := flux.Empty(). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { assert.FailNow(t, "unreachable") + return nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -33,8 +35,9 @@ func TestEmpty(t *testing.T) { func TestError(t *testing.T) { err := errors.New("boom") _, _ = flux.Error(err). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { assert.FailNow(t, "unreachable") + return nil }). DoOnError(func(e error) { assert.Equal(t, err, e) @@ -54,8 +57,9 @@ func TestClone(t *testing.T) { c := atomic.NewInt32(0) last, err := clone. - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { c.Inc() + return nil }). DoOnError(func(e error) { assert.FailNow(t, "unreachable") @@ -70,12 +74,13 @@ func TestRaw(t *testing.T) { const total = 10 c := atomic.NewInt32(0) f := flux. - Raw(nativeFlux.Range(0, total).Map(func(v interface{}) interface{} { - return payload.NewString(fmt.Sprintf("data_%d", v.(int)), "") + Raw(nativeFlux.Range(0, total).Map(func(v reactor.Any) (reactor.Any, error) { + return payload.NewString(fmt.Sprintf("data_%d", v.(int)), ""), nil })) last, err := f. - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { c.Inc() + return nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -85,8 +90,9 @@ func TestRaw(t *testing.T) { c.Store(0) const take = 3 last, err = f.Take(take). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { c.Inc() + return nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -102,8 +108,9 @@ func TestJust(t *testing.T) { payload.NewString("bar", ""), payload.NewString("qux", ""), ). - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { c.Inc() + return nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -126,9 +133,10 @@ func TestCreate(t *testing.T) { nextRequests := atomic.NewInt32(0) f. - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { fmt.Println("next:", input) su.Request(1) + return nil }). DoOnRequest(func(n int) { fmt.Println("request:", n) @@ -152,8 +160,8 @@ func TestCreate(t *testing.T) { func TestMap(t *testing.T) { last, err := flux. Just(payload.NewString("hello", "")). - Map(func(p payload.Payload) payload.Payload { - return payload.NewString(p.DataUTF8()+" world", "") + Map(func(p payload.Payload) (payload.Payload, error) { + return payload.NewString(p.DataUTF8()+" world", ""), nil }). BlockLast(context.Background()) assert.NoError(t, err) @@ -173,8 +181,9 @@ func TestProcessor(t *testing.T) { done := make(chan struct{}) processor. - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { fmt.Println("next:", input) + return nil }). DoFinally(func(s rx.SignalType) { close(done) @@ -200,8 +209,9 @@ func TestSwitchOnFirst(t *testing.T) { n, _ := strconv.Atoi(input.DataUTF8()) return n > first }) - }).Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) { + }).Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) error { fmt.Println("next:", input.DataUTF8()) + return nil })) } @@ -216,9 +226,10 @@ func TestFluxRequest(t *testing.T) { var su rx.Subscription sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { fmt.Println("onNext:", input) su.Request(1) + return nil }), rx.OnComplete(func() { fmt.Println("complete") @@ -256,8 +267,9 @@ func TestFluxProcessorWithRequest(t *testing.T) { var su rx.Subscription sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { su.Request(1) + return nil }), rx.OnSubscribe(func(s rx.Subscription) { su = s diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index cf3d8a4..c0be448 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -3,7 +3,7 @@ package flux import ( "context" - reactor "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go" "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/pkg/errors" @@ -32,8 +32,8 @@ func (p proxy) Next(v payload.Payload) { p.mustProcessor().Next(v) } -func (p proxy) Map(fn func(in payload.Payload) payload.Payload) Flux { - return newProxy(p.Flux.Map(func(i interface{}) interface{} { +func (p proxy) Map(fn func(in payload.Payload) (payload.Payload, error)) Flux { + return newProxy(p.Flux.Map(func(i reactor.Any) (reactor.Any, error) { return fn(i.(payload.Payload)) })) } @@ -65,8 +65,8 @@ func (p proxy) DoOnError(fn rx.FnOnError) Flux { } func (p proxy) DoOnNext(fn rx.FnOnNext) Flux { - return newProxy(p.Flux.DoOnNext(func(v interface{}) { - fn(v.(payload.Payload)) + return newProxy(p.Flux.DoOnNext(func(v reactor.Any) error { + return fn(v.(payload.Payload)) })) } @@ -85,12 +85,13 @@ func (p proxy) ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e close(err) }). Subscribe(ctx, - rx.OnNext(func(v payload.Payload) { + rx.OnNext(func(v payload.Payload) error { if _, ok := v.(core.Frame); ok { ch <- payload.Clone(v) } else { ch <- v } + return nil }), rx.OnError(func(e error) { err <- e @@ -132,8 +133,9 @@ func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err e }). Subscribe( ctx, - reactor.OnNext(func(v interface{}) { + reactor.OnNext(func(v reactor.Any) error { results = append(results, v.(payload.Payload)) + return nil }), reactor.OnError(func(e error) { err = e @@ -179,8 +181,8 @@ func (p proxy) SubscribeWith(ctx context.Context, s rx.Subscriber) { sub = rx.EmptyRawSubscriber } else { sub = reactor.NewSubscriber( - reactor.OnNext(func(v interface{}) { - s.OnNext(v.(payload.Payload)) + reactor.OnNext(func(v reactor.Any) error { + return s.OnNext(v.(payload.Payload)) }), reactor.OnError(func(e error) { s.OnError(e) diff --git a/rx/flux/utils.go b/rx/flux/utils.go index 4ad7333..a5fed25 100644 --- a/rx/flux/utils.go +++ b/rx/flux/utils.go @@ -61,8 +61,9 @@ func CreateProcessor() Processor { func Clone(source rx.Publisher) Flux { return Create(func(ctx context.Context, s Sink) { source.Subscribe(ctx, - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { s.Next(input) + return nil }), rx.OnComplete(func() { s.Complete() diff --git a/rx/mono/mono_test.go b/rx/mono/mono_test.go index ee8507c..bd910a0 100644 --- a/rx/mono/mono_test.go +++ b/rx/mono/mono_test.go @@ -49,8 +49,9 @@ func TestJustOrEmpty(t *testing.T) { func TestJust(t *testing.T) { Just(payload.NewString("hello", "world")). - Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) { + Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) error { log.Println("next:", i) + return nil })) } @@ -67,8 +68,9 @@ func TestProxy_SubscribeOn(t *testing.T) { }) }). SubscribeOn(scheduler.Parallel()). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { log.Println("success:", i) + return nil }). Block(context.Background()) assert.NoError(t, err) @@ -101,8 +103,9 @@ func TestProxy_Filter(t *testing.T) { Filter(func(i payload.Payload) bool { return strings.EqualFold("hello_no", i.DataUTF8()) }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { assert.Fail(t, "should never run here") + return nil }). DoFinally(func(i rx.SignalType) { log.Println("finally:", i) @@ -114,14 +117,16 @@ func TestCreate(t *testing.T) { Create(func(i context.Context, sink Sink) { sink.Success(payload.NewString("hello", "world")) }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { log.Println("doOnNext:", i) + return nil }). DoFinally(func(s rx.SignalType) { log.Println("doFinally:", s) }). - Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) { + Subscribe(context.Background(), rx.OnNext(func(i payload.Payload) error { log.Println("next:", i) + return nil })) Create(func(i context.Context, sink Sink) { @@ -130,8 +135,9 @@ func TestCreate(t *testing.T) { DoOnError(func(e error) { assert.Equal(t, "foobar", e.Error(), "bad error") }). - DoOnSuccess(func(i payload.Payload) { + DoOnSuccess(func(i payload.Payload) error { assert.Fail(t, "should never run here") + return nil }). Subscribe(context.Background()) } diff --git a/rx/mono/proxy.go b/rx/mono/proxy.go index 10aafe2..47560d1 100644 --- a/rx/mono/proxy.go +++ b/rx/mono/proxy.go @@ -39,8 +39,9 @@ func (p proxy) ToChan(ctx context.Context) (c <-chan payload.Payload, e <-chan e errorChannel := make(chan error, 1) payloadChannel := make(chan payload.Payload, 1) p. - DoOnSuccess(func(input payload.Payload) { + DoOnSuccess(func(input payload.Payload) error { payloadChannel <- input + return nil }). DoOnError(func(e error) { errorChannel <- e @@ -75,7 +76,7 @@ func (p proxy) Filter(fn rx.FnPredicate) Mono { } func (p proxy) DoFinally(fn rx.FnFinally) Mono { - return newProxy(p.Mono.DoFinally(func(signal rs.SignalType) { + return newProxy(p.Mono.DoFinally(func(signal reactor.SignalType) { fn(rx.SignalType(signal)) })) } @@ -86,13 +87,13 @@ func (p proxy) DoOnError(fn rx.FnOnError) Mono { })) } func (p proxy) DoOnSuccess(next rx.FnOnNext) Mono { - return newProxy(p.Mono.DoOnNext(func(v interface{}) { - next(v.(payload.Payload)) + return newProxy(p.Mono.DoOnNext(func(v reactor.Any) error { + return next(v.(payload.Payload)) })) } func (p proxy) DoOnSubscribe(fn rx.FnOnSubscribe) Mono { - return newProxy(p.Mono.DoOnSubscribe(func(su rs.Subscription) { + return newProxy(p.Mono.DoOnSubscribe(func(su reactor.Subscription) { fn(su) })) } @@ -110,21 +111,21 @@ func (p proxy) Subscribe(ctx context.Context, options ...rx.SubscriberOption) { } func (p proxy) SubscribeWith(ctx context.Context, actual rx.Subscriber) { - var sub rs.Subscriber + var sub reactor.Subscriber if actual == rx.EmptySubscriber { sub = rx.EmptyRawSubscriber } else { - sub = rs.NewSubscriber( - rs.OnNext(func(v interface{}) { - actual.OnNext(v.(payload.Payload)) + sub = reactor.NewSubscriber( + reactor.OnNext(func(v reactor.Any) error { + return actual.OnNext(v.(payload.Payload)) }), - rs.OnComplete(func() { + reactor.OnComplete(func() { actual.OnComplete() }), - rs.OnSubscribe(func(su rs.Subscription) { + reactor.OnSubscribe(func(su reactor.Subscription) { actual.OnSubscribe(su) }), - rs.OnError(func(e error) { + reactor.OnError(func(e error) { actual.OnError(e) }), ) diff --git a/rx/rx.go b/rx/rx.go index 7173749..f06f490 100644 --- a/rx/rx.go +++ b/rx/rx.go @@ -23,7 +23,7 @@ type ( // FnOnComplete is alias of function for signal when no more elements are available FnOnComplete = func() // FnOnNext is alias of function for signal when next element arrived. - FnOnNext = func(input payload.Payload) + FnOnNext = func(input payload.Payload) error // FnOnSubscribe is alias of function for signal when subscribe begin. FnOnSubscribe = func(s Subscription) // FnOnError is alias of function for signal when an error occurred. diff --git a/rx/subscriber.go b/rx/subscriber.go index c805c18..4a2ec85 100644 --- a/rx/subscriber.go +++ b/rx/subscriber.go @@ -1,7 +1,7 @@ package rx import ( - reactor "github.com/jjeffcaii/reactor-go" + "github.com/jjeffcaii/reactor-go" "github.com/rsocket/rsocket-go/payload" ) @@ -9,7 +9,8 @@ var ( // EmptySubscriber is a blank Subscriber. EmptySubscriber Subscriber = &subscriber{} // EmptyRawSubscriber is a blank native Subscriber in reactor-go. - EmptyRawSubscriber = reactor.NewSubscriber(reactor.OnNext(func(v interface{}) { + EmptyRawSubscriber = reactor.NewSubscriber(reactor.OnNext(func(v reactor.Any) error { + return nil })) ) @@ -19,7 +20,7 @@ type Subscription reactor.Subscription // Subscriber will receive call to OnSubscribe(Subscription) once after passing an instance of Subscriber to Publisher#SubscribeWith type Subscriber interface { // OnNext represents data notification sent by the Publisher in response to requests to Subscription#Request. - OnNext(payload payload.Payload) + OnNext(payload payload.Payload) error // OnError represents failed terminal state. OnError(error) // OnComplete represents successful terminal state. @@ -36,10 +37,11 @@ type subscriber struct { fnOnError FnOnError } -func (s *subscriber) OnNext(payload payload.Payload) { - if s != nil && s.fnOnNext != nil { - s.fnOnNext(payload) +func (s *subscriber) OnNext(payload payload.Payload) error { + if s == nil || s.fnOnNext == nil { + return nil } + return s.fnOnNext(payload) } func (s *subscriber) OnError(err error) { From 607353e6bcfdd4b8667849a20cc9ccaf00a6c4e0 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 4 Aug 2020 23:20:45 +0800 Subject: [PATCH 18/26] Fix compile error. --- cmd/rsocket-cli/runner.go | 6 ++++-- examples/echo/echo.go | 3 ++- examples/echo_bench/echo_bench.go | 3 ++- examples/fibonacci/main.go | 5 +++-- examples/word_counter/main.go | 6 ++++-- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index 2576842..0f66efa 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -195,8 +195,9 @@ func (p *Runner) runServerMode(ctx context.Context) error { return sendingPayloads })) options = append(options, rsocket.RequestChannel(func(messages rx.Publisher) flux.Flux { - messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) { + messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) error { p.showPayload(input) + return nil })) return sendingPayloads })) @@ -270,8 +271,9 @@ func (p *Runner) execRequestStream(ctx context.Context, c rsocket.Client, send p func (p *Runner) printFlux(ctx context.Context, f flux.Flux) (err error) { _, err = f. - DoOnNext(func(input payload.Payload) { + DoOnNext(func(input payload.Payload) error { p.showPayload(input) + return nil }). BlockLast(ctx) return diff --git a/examples/echo/echo.go b/examples/echo/echo.go index d8e61ec..89bfe53 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -143,8 +143,9 @@ func responder() rsocket.RSocket { payloads.(flux.Flux). //LimitRate(1). SubscribeOn(scheduler.Parallel()). - DoOnNext(func(elem payload.Payload) { + DoOnNext(func(elem payload.Payload) error { log.Println("receiving:", elem) + return nil }). Subscribe(context.Background()) return flux.Create(func(i context.Context, sink flux.Sink) { diff --git a/examples/echo_bench/echo_bench.go b/examples/echo_bench/echo_bench.go index 9b07cb2..8c9041b 100644 --- a/examples/echo_bench/echo_bench.go +++ b/examples/echo_bench/echo_bench.go @@ -49,10 +49,11 @@ func main() { ctx := context.Background() sub := rx.NewSubscriber( - rx.OnNext(func(input payload.Payload) { + rx.OnNext(func(input payload.Payload) error { //m2, _ := elem.MetadataUTF8() //assert.Equal(t, m1, m2, "metadata doesn't match") wg.Done() + return nil }), ) diff --git a/examples/fibonacci/main.go b/examples/fibonacci/main.go index 0e97c5b..d3ac482 100644 --- a/examples/fibonacci/main.go +++ b/examples/fibonacci/main.go @@ -108,11 +108,12 @@ func client() { wg := sync.WaitGroup{} wg.Add(1) - f.DoOnNext(func(input payload.Payload) { + f.DoOnNext(func(input payload.Payload) error { // print each number in a stream fmt.Println(input.DataUTF8()) + return nil }).DoOnComplete(func() { - // will be called on successfull completion of the stream + // will be called on successful completion of the stream fmt.Println("Fibonacci sequence done") }).DoOnError(func(err error) { // will be called if a error occurs diff --git a/examples/word_counter/main.go b/examples/word_counter/main.go index 5fb50d8..55b6003 100644 --- a/examples/word_counter/main.go +++ b/examples/word_counter/main.go @@ -39,9 +39,10 @@ func server(readyCh chan struct{}) { // create a handler that will be called when the server receives the RequestChannel frame (FrameTypeRequestChannel - 0x07) requestChannelHandler := rsocket.RequestChannel(func(msgs rx.Publisher) flux.Flux { return flux.Create(func(ctx context.Context, s flux.Sink) { - msgs.(flux.Flux).DoOnNext(func(elem payload.Payload) { + msgs.(flux.Flux).DoOnNext(func(elem payload.Payload) error { // for each payload in a flux stream respond with a word count s.Next(payload.NewString(fmt.Sprintf("%d", wordCount(elem.DataUTF8())), "")) + return nil }).DoOnComplete(func() { // signal completion of the response stream s.Complete() @@ -91,10 +92,11 @@ func client() { counter := 0 // register handler for RequestChannel - client.RequestChannel(f).DoOnNext(func(input payload.Payload) { + client.RequestChannel(f).DoOnNext(func(input payload.Payload) error { // print word count fmt.Println(strings[counter].DataUTF8(), ":", input.DataUTF8()) counter = counter + 1 + return nil }).DoOnComplete(func() { // will be called on successfull completion of the stream fmt.Println("Word counter ended.") From f7cae7194e909bcc4b4ea4b0a5e8a60e98f68e55 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Wed, 5 Aug 2020 23:33:16 +0800 Subject: [PATCH 19/26] add more functions for rx api. --- payload/payload.go | 18 +++++++++ rx/flux/flux.go | 4 ++ rx/flux/flux_test.go | 91 ++++++++++++++++++++++++++++++------------- rx/flux/proxy.go | 51 +++++++++++------------- rx/mono/mono.go | 2 + rx/mono/proxy.go | 49 +++++++++++++++-------- rx/mono/utils_test.go | 33 +++++++++++++--- 7 files changed, 169 insertions(+), 79 deletions(-) diff --git a/payload/payload.go b/payload/payload.go index c899d51..4a60531 100644 --- a/payload/payload.go +++ b/payload/payload.go @@ -1,6 +1,7 @@ package payload import ( + "bytes" "io/ioutil" "time" @@ -110,3 +111,20 @@ func MustNewFile(filename string, metadata []byte) Payload { } return foo } + +func Equal(a Payload, b Payload) bool { + if a == b { + return true + } + if !bytes.Equal(a.Data(), b.Data()) { + return false + } + + m1, ok1 := a.Metadata() + m2, ok2 := b.Metadata() + if ok1 != ok2 { + return false + } + + return bytes.Equal(m1, m2) +} diff --git a/rx/flux/flux.go b/rx/flux/flux.go index f066fd0..2749c1a 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -49,6 +49,10 @@ type Flux interface { SwitchOnFirst(FnSwitchOnFirst) Flux // SubscribeOn run subscribe, onSubscribe and request on a specified scheduler. SubscribeOn(scheduler.Scheduler) Flux + // SubscribeWithChan subscribe to this Flux and puts items/error into a chan. + SubscribeWithChan(ctx context.Context, values chan<- payload.Payload, err chan<- error) + // BlockToSlice subscribe Flux and save values into slice. + BlockToSlice(ctx context.Context, results *[]payload.Payload) error // Raw returns Native Flux in reactor-go. Raw() flux.Flux // BlockFirst subscribe to this Flux and block indefinitely until the upstream signals its first value or completes. diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index 6d00419..bf7e6a8 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -2,15 +2,15 @@ package flux_test import ( "context" - "errors" "fmt" "strconv" "testing" "time" "github.com/jjeffcaii/reactor-go" - nativeFlux "github.com/jjeffcaii/reactor-go/flux" + reactorFlux "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" @@ -74,7 +74,7 @@ func TestRaw(t *testing.T) { const total = 10 c := atomic.NewInt32(0) f := flux. - Raw(nativeFlux.Range(0, total).Map(func(v reactor.Any) (reactor.Any, error) { + Raw(reactorFlux.Range(0, total).Map(func(v reactor.Any) (reactor.Any, error) { return payload.NewString(fmt.Sprintf("data_%d", v.(int)), ""), nil })) last, err := f. @@ -368,23 +368,22 @@ func TestToChannel(t *testing.T) { f := flux.CreateFromChannel(payloads, err) - channel, chanerrors := f.ToChan(context.Background(), 0) + valueChan, errChan := f.ToChan(context.Background(), 0) var count int loop: for { select { - case _, o := <-channel: - if o { - count++ - } else { + case _, ok := <-valueChan: + if !ok { break loop } - case err := <-chanerrors: + count++ + case err := <-errChan: if err != nil { - t.Error(err) - break loop + assert.NoError(t, err) } + break loop } } @@ -400,29 +399,25 @@ func TestToChannelEmitError(t *testing.T) { defer close(err) for i := 1; i <= 10; i++ { - err <- errors.New("boom!") + err <- errors.New("boom") } }() f := flux.CreateFromChannel(payloads, err) - channel, chanerrors := f.ToChan(context.Background(), 0) + valChan, errChan := f.ToChan(context.Background(), 0) loop: for { select { - case _, o := <-channel: - if o { - t.Fail() - } else { - break loop - } - case err := <-chanerrors: - if err != nil { + case _, ok := <-valChan: + if !ok { break loop - } else { - t.Fail() } + assert.Fail(t, "should be unreachable") + case err := <-errChan: + assert.Error(t, err, "should return error") + break loop } } @@ -430,14 +425,54 @@ loop: func TestFlux_BlockSlice(t *testing.T) { const n = 10 - arr, err := flux. + arr, err := genRandomFlux(n).BlockSlice(context.Background()) + assert.NoError(t, err) + assert.Len(t, arr, n) +} + +func TestFlux_BlockToSlice(t *testing.T) { + results := make([]payload.Payload, 0) + const n = 10 + err := genRandomFlux(n).BlockToSlice(context.Background(), &results) + assert.NoError(t, err) + assert.Len(t, results, n) +} + +func TestFlux_SubscribeWithChan(t *testing.T) { + ch := make(chan payload.Payload) + err := make(chan error) + done := make(chan struct{}) + + const n = 10 + genRandomFlux(n). + DoFinally(func(s rx.SignalType) { + close(done) + }). + SubscribeOn(scheduler.Parallel()). + SubscribeWithChan(context.Background(), ch, err) + + var results []payload.Payload + +L: + for { + select { + case v := <-ch: + results = append(results, v) + case e := <-err: + assert.NoError(t, e) + case <-done: + break L + } + } + assert.Len(t, results, n) +} + +func genRandomFlux(n int) flux.Flux { + return flux. Create(func(ctx context.Context, s flux.Sink) { for i := 0; i < n; i++ { s.Next(payload.NewString("hello", strconv.Itoa(i))) } s.Complete() - }). - BlockSlice(context.Background()) - assert.NoError(t, err) - assert.Len(t, arr, n) + }) } diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index c0be448..19a1c57 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -7,7 +7,6 @@ import ( "github.com/jjeffcaii/reactor-go/flux" "github.com/jjeffcaii/reactor-go/scheduler" "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" ) @@ -20,14 +19,6 @@ func (p proxy) Raw() flux.Flux { return p.Flux } -func (p proxy) mustProcessor() flux.Processor { - processor, ok := p.Flux.(flux.Processor) - if !ok { - panic(errors.New("require flux.Processor")) - } - return processor -} - func (p proxy) Next(v payload.Payload) { p.mustProcessor().Next(v) } @@ -70,33 +61,21 @@ func (p proxy) DoOnNext(fn rx.FnOnNext) Flux { })) } -func (p proxy) ToChan(ctx context.Context, cap int) (c <-chan payload.Payload, e <-chan error) { +func (p proxy) ToChan(ctx context.Context, cap int) (<-chan payload.Payload, <-chan error) { if cap < 1 { cap = 1 } ch := make(chan payload.Payload, cap) err := make(chan error, 1) - p. - DoFinally(func(s rx.SignalType) { - if s == rx.SignalCancel { + p.Flux. + DoFinally(func(s reactor.SignalType) { + defer close(ch) + defer close(err) + if s == reactor.SignalTypeCancel { err <- reactor.ErrSubscribeCancelled } - close(ch) - close(err) }). - Subscribe(ctx, - rx.OnNext(func(v payload.Payload) error { - if _, ok := v.(core.Frame); ok { - ch <- payload.Clone(v) - } else { - ch <- v - } - return nil - }), - rx.OnError(func(e error) { - err <- e - }), - ) + SubscribeWithChan(ctx, ch, err) return ch, err } @@ -122,6 +101,14 @@ func (p proxy) BlockLast(ctx context.Context) (last payload.Payload, err error) return } +func (p proxy) SubscribeWithChan(ctx context.Context, payloads chan<- payload.Payload, err chan<- error) { + p.Flux.SubscribeWithChan(ctx, payloads, err) +} + +func (p proxy) BlockToSlice(ctx context.Context, results *[]payload.Payload) error { + return p.Flux.BlockToSlice(ctx, results) +} + func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err error) { done := make(chan struct{}) p.Flux. @@ -198,6 +185,14 @@ func (p proxy) SubscribeWith(ctx context.Context, s rx.Subscriber) { p.Flux.SubscribeWith(ctx, sub) } +func (p proxy) mustProcessor() flux.Processor { + processor, ok := p.Flux.(flux.Processor) + if !ok { + panic(errors.New("require flux.Processor")) + } + return processor +} + type sinkProxy struct { flux.Sink } diff --git a/rx/mono/mono.go b/rx/mono/mono.go index b4e74f4..e1052b0 100644 --- a/rx/mono/mono.go +++ b/rx/mono/mono.go @@ -27,6 +27,8 @@ type Mono interface { DoOnSubscribe(rx.FnOnSubscribe) Mono // SubscribeOn customize a Scheduler running Subscribe, OnSubscribe and Request. SubscribeOn(scheduler.Scheduler) Mono + // SubscribeWithChan subscribe to this Mono and puts item/error into channels. + SubscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error) // Block blocks Mono and returns data and error. Block(context.Context) (payload.Payload, error) //SwitchIfEmpty switch to an alternative Publisher if this Mono is completed without any data. diff --git a/rx/mono/proxy.go b/rx/mono/proxy.go index 47560d1..3cdafd2 100644 --- a/rx/mono/proxy.go +++ b/rx/mono/proxy.go @@ -35,29 +35,44 @@ func (p proxy) Error(e error) { p.mustProcessor().Error(e) } -func (p proxy) ToChan(ctx context.Context) (c <-chan payload.Payload, e <-chan error) { - errorChannel := make(chan error, 1) - payloadChannel := make(chan payload.Payload, 1) - p. - DoOnSuccess(func(input payload.Payload) error { - payloadChannel <- input - return nil - }). - DoOnError(func(e error) { - errorChannel <- e - }). - DoFinally(func(s rx.SignalType) { - close(payloadChannel) - close(errorChannel) - }). - Subscribe(ctx) - return payloadChannel, errorChannel +func (p proxy) ToChan(ctx context.Context) (<-chan payload.Payload, <-chan error) { + value := make(chan payload.Payload, 1) + err := make(chan error, 1) + p.subscribeWithChan(ctx, value, err, true) + return value, err } func (p proxy) SubscribeOn(sc scheduler.Scheduler) Mono { return newProxy(p.Mono.SubscribeOn(sc)) } +func (p proxy) subscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error, autoClose bool) { + p.Mono. + DoFinally(func(s reactor.SignalType) { + if autoClose { + defer close(valueChan) + defer close(errChan) + } + if s == reactor.SignalTypeCancel { + errChan <- reactor.ErrSubscribeCancelled + } + }). + Subscribe( + ctx, + reactor.OnNext(func(v reactor.Any) error { + valueChan <- v.(payload.Payload) + return nil + }), + reactor.OnError(func(e error) { + errChan <- e + }), + ) +} + +func (p proxy) SubscribeWithChan(ctx context.Context, valueChan chan<- payload.Payload, errChan chan<- error) { + p.subscribeWithChan(ctx, valueChan, errChan, false) +} + func (p proxy) Block(ctx context.Context) (pa payload.Payload, err error) { v, err := p.Mono.Block(ctx) if err != nil { diff --git a/rx/mono/utils_test.go b/rx/mono/utils_test.go index 390bff5..5f0d326 100644 --- a/rx/mono/utils_test.go +++ b/rx/mono/utils_test.go @@ -5,6 +5,7 @@ import ( "testing" rsMono "github.com/jjeffcaii/reactor-go/mono" + "github.com/jjeffcaii/reactor-go/scheduler" "github.com/pkg/errors" "github.com/rsocket/rsocket-go/payload" @@ -74,19 +75,19 @@ func TestToChannel(t *testing.T) { payloads <- p }() - channel, chanerrors := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) + valueChan, errChan := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) loop: for { select { - case p, ok := <-channel: + case p, ok := <-valueChan: if !ok { break loop } assert.Equal(t, "data", p.DataUTF8()) md, _ := p.MetadataUTF8() assert.Equal(t, "metadata", md) - case err := <-chanerrors: + case err := <-errChan: if err != nil { assert.NoError(t, err) } @@ -109,17 +110,17 @@ func TestToChannelEmitError(t *testing.T) { } }() - channel, chanerrors := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) + valueChan, errChan := mono.CreateFromChannel(payloads, err).ToChan(context.Background()) loop: for { select { - case _, ok := <-channel: + case _, ok := <-valueChan: if !ok { break loop } assert.Fail(t, "should never receive anything") - case err := <-chanerrors: + case err := <-errChan: if err != nil { break loop } @@ -135,3 +136,23 @@ func TestRaw(t *testing.T) { assert.NoError(t, err) assert.Equal(t, fakePayload, res) } + +func TestSubscribeWithChan(t *testing.T) { + valueChan := make(chan payload.Payload) + errChan := make(chan error) + + fakePayload := payload.NewString("fake data", "fake metadata") + + mono. + Create(func(ctx context.Context, sink mono.Sink) { + sink.Success(fakePayload) + }). + SubscribeOn(scheduler.Parallel()). + SubscribeWithChan(context.Background(), valueChan, errChan) + select { + case next := <-valueChan: + assert.True(t, payload.Equal(fakePayload, next), "result doesn't match") + case err := <-errChan: + assert.NoError(t, err, "should not return error") + } +} From cf980e5482ff941c5db3bd5c034984aa37ef30f8 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Thu, 6 Aug 2020 23:08:04 +0800 Subject: [PATCH 20/26] Fix. --- balancer/group_test.go | 41 ++++--- balancer/round_robin_test.go | 4 +- client.go | 9 +- cmd/rsocket-cli/runner.go | 185 +++++++++++++++------------- cmd/rsocket-cli/uri.go | 96 --------------- cmd/rsocket-cli/uri_test.go | 25 ---- justfile | 1 - rsocket_example_test.go | 15 +-- rsocket_test.go | 27 +++-- rx/flux/flux.go | 4 +- rx/flux/flux_test.go | 16 +-- rx/flux/proxy.go | 23 ++-- rx/mono/mono.go | 2 +- server.go | 8 +- transporter.go | 225 ++++++++++++++++++----------------- transporter_test.go | 59 +++++++-- 16 files changed, 340 insertions(+), 400 deletions(-) delete mode 100644 cmd/rsocket-cli/uri.go delete mode 100644 cmd/rsocket-cli/uri_test.go diff --git a/balancer/group_test.go b/balancer/group_test.go index 5dcd6f2..207aff3 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -8,19 +8,13 @@ import ( "testing" "time" - . "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go" . "github.com/rsocket/rsocket-go/balancer" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx/mono" "github.com/stretchr/testify/require" ) -var tp Transporter - -func init() { - tp = Tcp().HostAndPort("127.0.0.1", 7878).Build() -} - func ExampleNewGroup() { group := NewGroup(func() Balancer { return NewRoundRobinBalancer() @@ -29,24 +23,26 @@ func ExampleNewGroup() { _ = group.Close() }() // Create a broker with resume. - err := Receive(). - Resume(WithServerResumeSessionDuration(10 * time.Second)). - Acceptor(func(setup payload.SetupPayload, sendingSocket CloseableRSocket) (RSocket, error) { + err := rsocket.Receive(). + Resume(rsocket.WithServerResumeSessionDuration(10 * time.Second)). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { // Register service using Setup Metadata as service ID. if serviceID, ok := setup.MetadataUTF8(); ok { group.Get(serviceID).Put(sendingSocket) } // Proxy requests by group. - return NewAbstractSocket(RequestResponse(func(msg payload.Payload) mono.Mono { + return rsocket.NewAbstractSocket(rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { requestServiceID, ok := msg.MetadataUTF8() if !ok { panic(errors.New("missing service ID in metadata")) } fmt.Println("[broker] redirect request to service", requestServiceID) - return group.Get(requestServiceID).MustNext(context.Background()).RequestResponse(msg) + upstream := group.Get(requestServiceID).MustNext(context.Background()) + fmt.Println("[broker] choose upstream:", upstream) + return upstream.RequestResponse(msg) })), nil }). - Transport(tp). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). Serve(context.Background()) if err != nil { panic(err) @@ -60,20 +56,23 @@ func TestServiceSubscribe(t *testing.T) { // Waiting broker up by sleeping 200 ms. time.Sleep(200 * time.Millisecond) + tp := rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() + // Deploy MD5 service. go func() { done := make(chan struct{}) - cli, err := Connect(). + cli, err := rsocket.Connect(). OnClose(func(err error) { close(done) }). SetupPayload(payload.NewString("This is a Service Publisher!", "md5")). - Acceptor(func(socket RSocket) RSocket { - return NewAbstractSocket(RequestResponse(func(msg payload.Payload) mono.Mono { - result := payload.NewString(fmt.Sprintf("%02x", md5.Sum(msg.Data())), "MD5 RESULT") - fmt.Println("[publisher] accept MD5 request:", msg.DataUTF8()) - return mono.Just(result) - })) + Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { + return rsocket.NewAbstractSocket( + rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + result := payload.NewString(fmt.Sprintf("%02x", md5.Sum(msg.Data())), "MD5 RESULT") + fmt.Println("[publisher] accept MD5 request:", msg.DataUTF8()) + return mono.Just(result) + })) }). Transport(tp). Start(context.Background()) @@ -87,7 +86,7 @@ func TestServiceSubscribe(t *testing.T) { }() // Create a client and request md5 service. - cli, err := Connect(). + cli, err := rsocket.Connect(). SetupPayload(payload.NewString("This is a Subscriber", "")). Transport(tp). Start(context.Background()) diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index bd9df35..f00b229 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -29,7 +29,7 @@ func startServer(ctx context.Context, port int, counter *sync.Map) { }), ), nil }). - Transport(rsocket.Tcp().HostAndPort("127.0.0.1", port).Build()). + Transport(rsocket.TcpServer().SetHostAndPort("127.0.0.1", port).Build()). Serve(ctx) } @@ -60,7 +60,7 @@ func TestRoundRobin(t *testing.T) { for i := 0; i < len(ports); i++ { client, err := rsocket.Connect(). - Transport(rsocket.Tcp().HostAndPort("127.0.0.1", ports[i]).Build()). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", ports[i]).Build()). Start(context.Background()) assert.NoError(t, err) b.PutLabel(fmt.Sprintf("test-client-%d", ports[i]), client) diff --git a/client.go b/client.go index f0d7cce..d08866d 100644 --- a/client.go +++ b/client.go @@ -64,13 +64,12 @@ type ClientBuilder interface { } type ToClientStarter interface { - // Transport set Transport for current RSocket client. - // URI is used to create RSocket Transport: + // Transport set generator func for current RSocket client. // Example: // "tcp://127.0.0.1:7878" means a TCP RSocket transport. // "ws://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport. // "wss://127.0.0.1:8080/a/b/c" means a Websocket RSocket transport with HTTPS. - Transport(Transporter) ClientStarter + Transport(transport.ClientTransportFunc) ClientStarter } // ToClientStarter is used to build a RSocket client with custom Transport string. @@ -153,8 +152,8 @@ func (p *clientBuilder) Acceptor(acceptor ClientSocketAcceptor) ToClientStarter return p } -func (p *clientBuilder) Transport(support Transporter) ClientStarter { - p.tpGen = support.Client() +func (p *clientBuilder) Transport(t transport.ClientTransportFunc) ClientStarter { + p.tpGen = t return p } diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index 0f66efa..e7f97fb 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -4,16 +4,18 @@ import ( "bufio" "context" "encoding/json" - "errors" "fmt" "io/ioutil" + "net/http" "net/url" "os" "strconv" "strings" "time" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/logger" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" @@ -45,17 +47,17 @@ type Runner struct { N int Resume bool URI string - wsHeaders map[string][]string + wsHeaders http.Header } -func (p *Runner) preflight() (err error) { - if p.Debug { +func (r *Runner) preflight() (err error) { + if r.Debug { logger.SetLevel(logger.LevelDebug) } - headers := p.Headers.Value() + headers := r.Headers.Value() - if len(headers) > 0 && len(p.Metadata) > 0 { + if len(headers) > 0 && len(r.Metadata) > 0 { return errConflictHeadersAndMetadata } if len(headers) > 0 { @@ -70,65 +72,55 @@ func (p *Runner) preflight() (err error) { headers[strings.TrimSpace(k)] = strings.TrimSpace(v) } bs, _ := json.Marshal(headers) - p.Metadata = string(bs) + r.Metadata = string(bs) } - - tpHeaders := p.TransportHeaders.Value() - if len(tpHeaders) > 0 { - headers := make(map[string][]string) - for _, it := range tpHeaders { + if v := r.TransportHeaders.Value(); len(v) > 0 { + r.wsHeaders = make(http.Header) + for _, it := range v { idx := strings.Index(it, ":") if idx < 0 { return fmt.Errorf("invalid transport header: %s", it) } k := strings.TrimSpace(it[:idx]) v := strings.TrimSpace(it[idx+1:]) - headers[k] = append(headers[k], v) + r.wsHeaders.Add(k, v) } - p.wsHeaders = headers } return } -func (p *Runner) Run() error { - if err := p.preflight(); err != nil { +func (r *Runner) Run() error { + if err := r.preflight(); err != nil { return err } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if p.ServerMode { - return p.runServerMode(ctx) + if r.ServerMode { + return r.runServerMode(ctx) } - return p.runClientMode(ctx) + return r.runClientMode(ctx) } -func (p *Runner) runClientMode(ctx context.Context) (err error) { +func (r *Runner) runClientMode(ctx context.Context) (err error) { cb := rsocket.Connect() - if p.Resume { + if r.Resume { cb = cb.Resume() } - setupData, err := p.readData(p.Setup) + setupData, err := r.readData(r.Setup) if err != nil { return } setupPayload := payload.New(setupData, nil) - sendingPayloads := p.createPayload() + sendingPayloads := r.createPayload() - tp, err := makeTransport(p.URI) + tp, err := r.newClientTransport() if err != nil { return } - - // TODO: - - //if ws, ok := tp.(*rsocket.wsTransporter); ok { - // ws.Header(p.wsHeaders) - //} - c, err := cb. - DataMimeType(p.DataFormat). - MetadataMimeType(p.MetadataFormat). + DataMimeType(r.DataFormat). + MetadataMimeType(r.MetadataFormat). SetupPayload(setupPayload). Transport(tp). Start(ctx) @@ -139,30 +131,30 @@ func (p *Runner) runClientMode(ctx context.Context) (err error) { _ = c.Close() }() - for i := 0; i < p.Ops; i++ { + for i := 0; i < r.Ops; i++ { if i > 0 { logger.Infof("\n") } var first payload.Payload - if !p.Channel { + if !r.Channel { first, err = sendingPayloads.BlockFirst(ctx) if err != nil { return } } - if p.Request { - err = p.execRequestResponse(ctx, c, first) - } else if p.FNF { - err = p.execFireAndForget(ctx, c, first) - } else if p.Stream { - err = p.execRequestStream(ctx, c, first) - } else if p.Channel { - err = p.execRequestChannel(ctx, c, sendingPayloads) - } else if p.MetadataPush { - err = p.execMetadataPush(ctx, c, first) + if r.Request { + err = r.execRequestResponse(ctx, c, first) + } else if r.FNF { + err = r.execFireAndForget(ctx, c, first) + } else if r.Stream { + err = r.execRequestStream(ctx, c, first) + } else if r.Channel { + err = r.execRequestChannel(ctx, c, sendingPayloads) + } else if r.MetadataPush { + err = r.execMetadataPush(ctx, c, first) } else { - err = p.execRequestResponse(ctx, c, first) + err = r.execRequestResponse(ctx, c, first) } if err != nil { break @@ -171,38 +163,38 @@ func (p *Runner) runClientMode(ctx context.Context) (err error) { return } -func (p *Runner) runServerMode(ctx context.Context) error { +func (r *Runner) runServerMode(ctx context.Context) error { var sb rsocket.ServerBuilder - if p.Resume { + if r.Resume { sb = rsocket.Receive().Resume() } else { sb = rsocket.Receive() } ch := make(chan error) - tp, err := makeTransport(p.URI) + tp, err := r.newServerTransport() if err != nil { return err } go func() { - sendingPayloads := p.createPayload() + sendingPayloads := r.createPayload() ch <- sb. Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { var options []rsocket.OptAbstractSocket options = append(options, rsocket.RequestStream(func(message payload.Payload) flux.Flux { - p.showPayload(message) + r.showPayload(message) return sendingPayloads })) options = append(options, rsocket.RequestChannel(func(messages rx.Publisher) flux.Flux { messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) error { - p.showPayload(input) + r.showPayload(input) return nil })) return sendingPayloads })) options = append(options, rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - p.showPayload(msg) + r.showPayload(msg) return mono.Create(func(i context.Context, sink mono.Sink) { first, err := sendingPayloads.BlockFirst(i) if err != nil { @@ -213,7 +205,7 @@ func (p *Runner) runServerMode(ctx context.Context) error { }) })) options = append(options, rsocket.FireAndForget(func(msg payload.Payload) { - p.showPayload(msg) + r.showPayload(msg) })) options = append(options, rsocket.MetadataPush(func(msg payload.Payload) { metadata, _ := msg.MetadataUTF8() @@ -228,86 +220,86 @@ func (p *Runner) runServerMode(ctx context.Context) error { return <-ch } -func (p *Runner) execMetadataPush(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execMetadataPush(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { c.MetadataPush(send) m, _ := send.MetadataUTF8() logger.Infof("%s\n", m) return } -func (p *Runner) execFireAndForget(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execFireAndForget(_ context.Context, c rsocket.Client, send payload.Payload) (err error) { c.FireAndForget(send) return } -func (p *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send payload.Payload) (err error) { +func (r *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send payload.Payload) (err error) { res, err := c.RequestResponse(send).Block(ctx) if err != nil { return } - p.showPayload(res) + r.showPayload(res) return } -func (p *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error { +func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error { var f flux.Flux - if p.N < rx.RequestMax { - f = c.RequestChannel(send).Take(p.N) + if r.N < rx.RequestMax { + f = c.RequestChannel(send).Take(r.N) } else { f = c.RequestChannel(send) } - return p.printFlux(ctx, f) + return r.printFlux(ctx, f) } -func (p *Runner) execRequestStream(ctx context.Context, c rsocket.Client, send payload.Payload) error { +func (r *Runner) execRequestStream(ctx context.Context, c rsocket.Client, send payload.Payload) error { var f flux.Flux - if p.N < rx.RequestMax { - f = c.RequestStream(send).Take(p.N) + if r.N < rx.RequestMax { + f = c.RequestStream(send).Take(r.N) } else { f = c.RequestStream(send) } - return p.printFlux(ctx, f) + return r.printFlux(ctx, f) } -func (p *Runner) printFlux(ctx context.Context, f flux.Flux) (err error) { +func (r *Runner) printFlux(ctx context.Context, f flux.Flux) (err error) { _, err = f. DoOnNext(func(input payload.Payload) error { - p.showPayload(input) + r.showPayload(input) return nil }). BlockLast(ctx) return } -func (p *Runner) showPayload(pa payload.Payload) { +func (r *Runner) showPayload(pa payload.Payload) { logger.Infof("%s\n", pa.DataUTF8()) } -func (p *Runner) createPayload() flux.Flux { +func (r *Runner) createPayload() flux.Flux { var md []byte - if strings.HasPrefix(p.Metadata, "@") { + if strings.HasPrefix(r.Metadata, "@") { var err error - md, err = ioutil.ReadFile(p.Metadata[1:]) + md, err = ioutil.ReadFile(r.Metadata[1:]) if err != nil { return flux.Error(err) } } else { - md = []byte(p.Metadata) + md = []byte(r.Metadata) } - if p.Input == "-" { + if r.Input == "-" { fmt.Println("Type commands to send to the server......") reader := bufio.NewReader(os.Stdin) text, _ := reader.ReadString('\n') return flux.Just(payload.New([]byte(strings.Trim(text, "\n")), md)) } - if !strings.HasPrefix(p.Input, "@") { - return flux.Just(payload.New([]byte(p.Input), md)) + if !strings.HasPrefix(r.Input, "@") { + return flux.Just(payload.New([]byte(r.Input), md)) } return flux.Create(func(ctx context.Context, s flux.Sink) { - f, err := os.Open(p.Input[1:]) + f, err := os.Open(r.Input[1:]) if err != nil { fmt.Println("error:", err) s.Error(err) @@ -331,7 +323,7 @@ func (p *Runner) createPayload() flux.Flux { }) } -func (p *Runner) readData(input string) (data []byte, err error) { +func (r *Runner) readData(input string) (data []byte, err error) { switch { case strings.HasPrefix(input, "@"): data, err = ioutil.ReadFile(input[1:]) @@ -341,8 +333,8 @@ func (p *Runner) readData(input string) (data []byte, err error) { return } -func makeTransport(s string) (rsocket.Transporter, error) { - u, err := url.Parse(s) +func (r *Runner) newClientTransport() (transport.ClientTransportFunc, error) { + u, err := url.Parse(r.URI) if err != nil { return nil, err } @@ -352,13 +344,40 @@ func makeTransport(s string) (rsocket.Transporter, error) { if err != nil { return nil, err } - return rsocket.Tcp().HostAndPort(u.Hostname(), port).Build(), nil + return rsocket.TcpClient().SetHostAndPort(u.Hostname(), port).Build(), nil case "unix": - return rsocket.Unix().Path(u.Hostname()).Build(), nil + return rsocket.UnixClient().SetPath(u.Hostname()).Build(), nil case "ws", "wss": - return rsocket.Websocket().Url(s).Build(), nil + return rsocket.WebsocketClient().SetUrl(r.URI).SetHeader(r.wsHeaders).Build(), nil default: return nil, fmt.Errorf("invalid transport %s", u.Scheme) } +} +func (r *Runner) newServerTransport() (t transport.ServerTransportFunc, err error) { + u, err := url.Parse(r.URI) + if err != nil { + return + } + switch u.Scheme { + case "tcp": + port, err := strconv.Atoi(u.Port()) + if err != nil { + return + } + t = rsocket.TcpServer().SetHostAndPort(u.Hostname(), port).Build() + case "unix": + t = rsocket.UnixServer().SetPath(u.Hostname()).Build() + case "ws", "wss": + var addr string + if port := u.Port(); port != "" { + addr = fmt.Sprintf("%s:%s", u.Hostname(), port) + } else { + addr = fmt.Sprintf("%s:%d", u.Hostname(), rsocket.DefaultPort) + } + t = rsocket.WebsocketServer().SetAddr(addr).SetPath(u.EscapedPath()).Build() + default: + err = errors.Errorf("invalid transport %s", u.Scheme) + } + return } diff --git a/cmd/rsocket-cli/uri.go b/cmd/rsocket-cli/uri.go deleted file mode 100644 index 7d63410..0000000 --- a/cmd/rsocket-cli/uri.go +++ /dev/null @@ -1,96 +0,0 @@ -package main - -import ( - "crypto/tls" - "net/url" - "strings" - - "github.com/pkg/errors" - "github.com/rsocket/rsocket-go/core/transport" -) - -const ( - schemaUNIX = "unix" - schemaTCP = "tcp" - schemaWebsocket = "ws" - schemaWebsocketSecure = "wss" -) - -// URI represents a URI of RSocket transport. -type URI url.URL - -var tlsInsecure = &tls.Config{ - InsecureSkipVerify: true, -} - -// IsWebsocket returns true if current uri is websocket. -func (p *URI) IsWebsocket() bool { - switch strings.ToLower(p.Scheme) { - case schemaWebsocket, schemaWebsocketSecure: - return true - default: - return false - } -} - -// MakeClientTransport creates a new client-side transport. -func (p *URI) MakeClientTransport(tc *tls.Config, headers map[string][]string) (*transport.Transport, error) { - switch strings.ToLower(p.Scheme) { - case schemaTCP: - return transport.NewTcpClientTransportWithAddr(schemaTCP, p.Host, tc) - case schemaWebsocket: - if tc == nil { - return transport.NewWebsocketClientTransport(p.pp().String(), nil, headers) - } - var clone = (url.URL)(*p) - clone.Scheme = "wss" - return transport.NewWebsocketClientTransport(clone.String(), tc, headers) - case schemaWebsocketSecure: - if tc == nil { - tc = tlsInsecure - } - return transport.NewWebsocketClientTransport(p.pp().String(), tc, headers) - case schemaUNIX: - return transport.NewTcpClientTransportWithAddr(schemaUNIX, p.Path, tc) - default: - return nil, errors.Errorf("unsupported transport url: %s", p.pp().String()) - } -} - -// MakeServerTransport creates a new server-side transport. -func (p *URI) MakeServerTransport(c *tls.Config) (tp transport.ServerTransport, err error) { - switch strings.ToLower(p.Scheme) { - case schemaTCP: - tp = transport.NewTcpServerTransportWithAddr(schemaTCP, p.Host, c) - case schemaWebsocket: - tp = transport.NewWebsocketServerTransportWithAddr(p.Host, p.Path, c) - case schemaWebsocketSecure: - if c == nil { - err = errors.Errorf("missing TLS Config for proto %s", schemaWebsocketSecure) - return - } - tp = transport.NewWebsocketServerTransportWithAddr(p.Host, p.Path, c) - case schemaUNIX: - tp = transport.NewTcpServerTransportWithAddr(schemaUNIX, p.Path, c) - default: - err = errors.Errorf("unsupported transport url: %s", p.pp().String()) - } - return -} - -func (p *URI) String() string { - return p.pp().String() -} - -func (p *URI) pp() *url.URL { - return (*url.URL)(p) -} - -// ParseURI parse URI string and returns a URI. -func ParseURI(rawUrl string) (*URI, error) { - u, err := url.Parse(rawUrl) - if err != nil { - return nil, errors.Wrapf(err, "parse url failed: %s", rawUrl) - } - return (*URI)(u), nil -} diff --git a/cmd/rsocket-cli/uri_test.go b/cmd/rsocket-cli/uri_test.go deleted file mode 100644 index 3fa9312..0000000 --- a/cmd/rsocket-cli/uri_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package main - -import ( - "log" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseUTI(t *testing.T) { - uri, err := ParseURI("tcp://127.0.0.1:8080") - assert.NoError(t, err, "bad URI") - log.Println(uri) -} - -func TestName(t *testing.T) { - //u, err := url.Parse("unix:///tmp/rsocket.sock") - u, err := url.Parse("tcp://127.0.0.1:8080") - require.NoError(t, err, "bad parse") - log.Println("schema:", u.Scheme) - log.Println("host:", u.Host) - log.Println("path:", u.Path) -} diff --git a/justfile b/justfile index e0efec4..e742371 100644 --- a/justfile +++ b/justfile @@ -4,7 +4,6 @@ lint: golangci-lint run ./... test: go test -count=1 -coverprofile=coverage.out \ - ./balancer/... \ ./core/... \ ./extension/... \ ./internal/... \ diff --git a/rsocket_example_test.go b/rsocket_example_test.go index e926924..4b3bbae 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -16,10 +16,7 @@ import ( func Example() { // Serve a server - tp := rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() err := rsocket.Receive(). - Resume(). // Enable RESUME - //Lease(). Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { return rsocket.NewAbstractSocket( rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { @@ -28,7 +25,7 @@ func Example() { }), ), nil }). - Transport(tp). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). Serve(context.Background()) if err != nil { panic(err) @@ -37,14 +34,12 @@ func Example() { // Connect to a server. cli, err := rsocket.Connect(). SetupPayload(payload.NewString("Hello World", "From Golang")). - Transport(tp). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()). Start(context.Background()) if err != nil { panic(err) } - defer func() { - _ = cli.Close() - }() + defer cli.Close() cli.RequestResponse(payload.NewString("Ping", time.Now().String())). DoOnSuccess(func(elem payload.Payload) error { log.Println("incoming response:", elem) @@ -95,7 +90,7 @@ func ExampleReceive() { }), ), nil }). - Transport(rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build()). + Transport(rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).Build()). Serve(context.Background()) panic(err) } @@ -113,7 +108,7 @@ func ExampleConnect() { }), ) }). - Transport(rsocket.Tcp().Addr("127.0.0.1:7878").Build()). + Transport(rsocket.TcpClient().SetAddr("127.0.0.1:7878").Build()). Start(context.Background()) if err != nil { panic(err) diff --git a/rsocket_test.go b/rsocket_test.go index ab46bce..a8cb397 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -7,6 +7,7 @@ import ( "testing" . "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" . "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" @@ -23,16 +24,26 @@ const ( var testData = "Hello World!" func TestSuite(t *testing.T) { - transports := map[string]Transporter{ - "tcp": Tcp().Addr("127.0.0.1:7878").Build(), - "websocket": Websocket().Url("ws://127.0.0.1:8080/test").Build(), + m := []string{ + "tcp", + "websocket", } - for k, v := range transports { - testAll(t, k, v) + c := []transport.ClientTransportFunc{ + TcpClient().SetHostAndPort("127.0.0.1", 7878).Build(), + WebsocketClient().SetUrl("ws://127.0.0.1:8080/test").Build(), } + s := []transport.ServerTransportFunc{ + TcpServer().SetAddr(":7878").Build(), + WebsocketServer().SetAddr("127.0.0.1:8080").SetPath("/test").Build(), + } + + for i := 0; i < len(m); i++ { + testAll(t, m[i], c[i], s[i]) + } + } -func testAll(t *testing.T, proto string, tp Transporter) { +func testAll(t *testing.T, proto string, clientTp transport.ClientTransportFunc, serverTp transport.ServerTransportFunc) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -95,7 +106,7 @@ func testAll(t *testing.T, proto string, tp Transporter) { }), ), nil }). - Transport(tp). + Transport(serverTp). Serve(ctx) fmt.Println("SERVER STOPPED!!!!!") if err != nil { @@ -109,7 +120,7 @@ func testAll(t *testing.T, proto string, tp Transporter) { cli, err := Connect(). Fragment(192). SetupPayload(NewString(setupData, setupMetadata)). - Transport(tp). + Transport(clientTp). Start(context.Background()) assert.NoError(t, err, "connect failed") defer func() { diff --git a/rx/flux/flux.go b/rx/flux/flux.go index 2749c1a..9f5a10f 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -51,9 +51,7 @@ type Flux interface { SubscribeOn(scheduler.Scheduler) Flux // SubscribeWithChan subscribe to this Flux and puts items/error into a chan. SubscribeWithChan(ctx context.Context, values chan<- payload.Payload, err chan<- error) - // BlockToSlice subscribe Flux and save values into slice. - BlockToSlice(ctx context.Context, results *[]payload.Payload) error - // Raw returns Native Flux in reactor-go. + // Raw returns low-level reactor.Flux which defined in reactor-go library. Raw() flux.Flux // BlockFirst subscribe to this Flux and block indefinitely until the upstream signals its first value or completes. // Returns that value, error if Flux completes error, or nil if the Flux completes empty. diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index bf7e6a8..f9f8b55 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -379,11 +379,11 @@ loop: break loop } count++ - case err := <-errChan: - if err != nil { - assert.NoError(t, err) + case err, ok := <-errChan: + if !ok { + break loop } - break loop + assert.NoError(t, err) } } @@ -430,14 +430,6 @@ func TestFlux_BlockSlice(t *testing.T) { assert.Len(t, arr, n) } -func TestFlux_BlockToSlice(t *testing.T) { - results := make([]payload.Payload, 0) - const n = 10 - err := genRandomFlux(n).BlockToSlice(context.Background(), &results) - assert.NoError(t, err) - assert.Len(t, results, n) -} - func TestFlux_SubscribeWithChan(t *testing.T) { ch := make(chan payload.Payload) err := make(chan error) diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index 19a1c57..b05ed09 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -69,8 +69,10 @@ func (p proxy) ToChan(ctx context.Context, cap int) (<-chan payload.Payload, <-c err := make(chan error, 1) p.Flux. DoFinally(func(s reactor.SignalType) { - defer close(ch) - defer close(err) + defer func() { + close(ch) + close(err) + }() if s == reactor.SignalTypeCancel { err <- reactor.ErrSubscribeCancelled } @@ -95,9 +97,10 @@ func (p proxy) BlockLast(ctx context.Context) (last payload.Payload, err error) if err != nil { return } - if v != nil { - last = v.(payload.Payload) + if v == nil { + return } + last = v.(payload.Payload) return } @@ -105,18 +108,14 @@ func (p proxy) SubscribeWithChan(ctx context.Context, payloads chan<- payload.Pa p.Flux.SubscribeWithChan(ctx, payloads, err) } -func (p proxy) BlockToSlice(ctx context.Context, results *[]payload.Payload) error { - return p.Flux.BlockToSlice(ctx, results) -} - func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err error) { done := make(chan struct{}) p.Flux. DoFinally(func(s reactor.SignalType) { - close(done) - }). - DoOnCancel(func() { - err = reactor.ErrSubscribeCancelled + defer close(done) + if s == reactor.SignalTypeCancel { + err = reactor.ErrSubscribeCancelled + } }). Subscribe( ctx, diff --git a/rx/mono/mono.go b/rx/mono/mono.go index e1052b0..a6b9abf 100644 --- a/rx/mono/mono.go +++ b/rx/mono/mono.go @@ -33,7 +33,7 @@ type Mono interface { Block(context.Context) (payload.Payload, error) //SwitchIfEmpty switch to an alternative Publisher if this Mono is completed without any data. SwitchIfEmpty(alternative Mono) Mono - // Raw returns low-level Mono which defined in upstream reactor library. + // Raw returns low-level reactor.Mono which defined in reactor-go library. Raw() mono.Mono // ToChan subscribe Mono and puts items into a chan. // It also puts errors into another chan. diff --git a/server.go b/server.go index 473c9b7..fc0e65b 100644 --- a/server.go +++ b/server.go @@ -44,8 +44,8 @@ type ( // ToServerStarter is used to build a RSocket server with custom Transport string. ToServerStarter interface { - // Transport specify transport string. - Transport(t Transporter) Start + // Transport specify transport generator func. + Transport(t transport.ServerTransportFunc) Start } // Start start a RSocket server. @@ -113,8 +113,8 @@ func (p *server) Acceptor(acceptor ServerAcceptor) ToServerStarter { return p } -func (p *server) Transport(t Transporter) Start { - p.tp = t.Server() +func (p *server) Transport(t transport.ServerTransportFunc) Start { + p.tp = t return p } diff --git a/transporter.go b/transporter.go index e8632f15..9d1510b 100644 --- a/transporter.go +++ b/transporter.go @@ -5,177 +5,186 @@ import ( "crypto/tls" "fmt" "net/http" - "net/url" "os" - "github.com/pkg/errors" "github.com/rsocket/rsocket-go/core/transport" ) -type Transporter interface { - Client() transport.ClientTransportFunc - Server() transport.ServerTransportFunc +const DefaultUnixSockPath = "/var/run/rsocket.sock" +const DefaultPort = 7878 + +type TcpClientBuilder struct { + addr string + tlsCfg *tls.Config +} + +type TcpServerBuilder struct { + addr string + tlsCfg *tls.Config +} + +type WebsocketClientBuilder struct { + url string + tlsCfg *tls.Config + header http.Header +} + +type WebsocketServerBuilder struct { + addr string + path string + tlsConfig *tls.Config +} + +type UnixClientBuilder struct { + path string } -type tcpTransporter struct { - addr string - tls *tls.Config +type UnixServerBuilder struct { + path string } -type TcpTransporterBuilder struct { - opts []func(*tcpTransporter) +func (us *UnixServerBuilder) SetPath(path string) *UnixServerBuilder { + us.path = path + return us } -func (t *tcpTransporter) Server() transport.ServerTransportFunc { +func (us *UnixServerBuilder) Build() transport.ServerTransportFunc { return func(ctx context.Context) (transport.ServerTransport, error) { - return transport.NewTcpServerTransportWithAddr("tcp", t.addr, t.tls), nil + if _, err := os.Stat(us.path); !os.IsNotExist(err) { + return nil, err + } + return transport.NewTcpServerTransportWithAddr("unix", us.path, nil), nil } } -func (t *tcpTransporter) Client() transport.ClientTransportFunc { +func (uc *UnixClientBuilder) SetPath(path string) *UnixClientBuilder { + uc.path = path + return uc +} + +func (uc UnixClientBuilder) Build() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTcpClientTransportWithAddr("tcp", t.addr, t.tls) + return transport.NewTcpClientTransportWithAddr("unix", uc.path, nil) } } -func (t *TcpTransporterBuilder) Addr(addr string) *TcpTransporterBuilder { - t.opts = append(t.opts, func(transporter *tcpTransporter) { - transporter.addr = addr - }) - return t +func (ws *WebsocketServerBuilder) SetAddr(addr string) *WebsocketServerBuilder { + ws.addr = addr + return ws } -func (t *TcpTransporterBuilder) HostAndPort(host string, port int) *TcpTransporterBuilder { - return t.Addr(fmt.Sprintf("%s:%d", host, port)) +func (ws *WebsocketServerBuilder) SetPath(path string) *WebsocketServerBuilder { + ws.path = path + return ws } -func (t *TcpTransporterBuilder) TLS(config *tls.Config) *TcpTransporterBuilder { - t.opts = append(t.opts, func(transporter *tcpTransporter) { - transporter.tls = config - }) - return t +func (ws *WebsocketServerBuilder) SetTlsConfig(c *tls.Config) *WebsocketServerBuilder { + ws.tlsConfig = c + return ws } -func (t *TcpTransporterBuilder) Build() Transporter { - tp := &tcpTransporter{ - addr: ":7878", - tls: nil, - } - for _, opt := range t.opts { - opt(tp) +func (ws *WebsocketServerBuilder) Build() transport.ServerTransportFunc { + return func(ctx context.Context) (transport.ServerTransport, error) { + return transport.NewWebsocketServerTransportWithAddr(ws.addr, ws.path, ws.tlsConfig), nil } - return tp } -type wsTransporter struct { - url string - tls *tls.Config - header http.Header +func (wc *WebsocketClientBuilder) SetTlsConfig(c *tls.Config) *WebsocketClientBuilder { + wc.tlsCfg = c + return wc } -type WebsocketTransporterBuilder struct { - opts []func(*wsTransporter) +func (wc *WebsocketClientBuilder) SetUrl(url string) *WebsocketClientBuilder { + wc.url = url + return wc } -func (w *WebsocketTransporterBuilder) Header(header http.Header) *WebsocketTransporterBuilder { - w.opts = append(w.opts, func(transporter *wsTransporter) { - transporter.header = header - }) - return w +func (wc *WebsocketClientBuilder) SetHeader(h http.Header) *WebsocketClientBuilder { + wc.header = h + return wc } -func (w *WebsocketTransporterBuilder) Url(url string) *WebsocketTransporterBuilder { - w.opts = append(w.opts, func(transporter *wsTransporter) { - transporter.url = url - }) - return w +func (wc *WebsocketClientBuilder) Build() transport.ClientTransportFunc { + return func(ctx context.Context) (*transport.Transport, error) { + return transport.NewWebsocketClientTransport(wc.url, wc.tlsCfg, wc.header) + } } -func (w *WebsocketTransporterBuilder) TLS(config *tls.Config) *WebsocketTransporterBuilder { - w.opts = append(w.opts, func(transporter *wsTransporter) { - transporter.tls = config - }) - return w +func (ts *TcpServerBuilder) SetHostAndPort(host string, port int) *TcpServerBuilder { + ts.addr = fmt.Sprintf("%s:%d", host, port) + return ts } -func (w *WebsocketTransporterBuilder) Build() Transporter { - ws := &wsTransporter{ - url: "", - } - for _, opt := range w.opts { - opt(ws) - } - return ws +func (ts *TcpServerBuilder) SetAddr(addr string) *TcpServerBuilder { + ts.addr = addr + return ts } -func (w *wsTransporter) Server() transport.ServerTransportFunc { - return func(ctx context.Context) (transport.ServerTransport, error) { - u, err := url.Parse(w.url) - if err != nil { - return nil, err - } - port := u.Port() - if len(port) < 1 { - return nil, errors.New("missing websocket port") - } - return transport.NewWebsocketServerTransportWithAddr(fmt.Sprintf("%s:%s", u.Hostname(), port), u.Path, w.tls), nil - } +func (ts *TcpServerBuilder) SetTlsConfig(c *tls.Config) *TcpServerBuilder { + ts.tlsCfg = c + return ts } -func (w *wsTransporter) Client() transport.ClientTransportFunc { - return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewWebsocketClientTransport(w.url, w.tls, w.header) +func (ts *TcpServerBuilder) Build() transport.ServerTransportFunc { + return func(ctx context.Context) (transport.ServerTransport, error) { + return transport.NewTcpServerTransportWithAddr("tcp", ts.addr, ts.tlsCfg), nil } } -type UnixTransporter struct { - path string +func (tc *TcpClientBuilder) SetHostAndPort(host string, port int) *TcpClientBuilder { + tc.addr = fmt.Sprintf("%s:%d", host, port) + return tc } -type UnixTransporterBuilder struct { - opts []func(*UnixTransporter) +func (tc *TcpClientBuilder) SetAddr(addr string) *TcpClientBuilder { + tc.addr = addr + return tc } -func (u *UnixTransporter) Server() transport.ServerTransportFunc { - return func(ctx context.Context) (transport.ServerTransport, error) { - if _, err := os.Stat(u.path); !os.IsNotExist(err) { - return nil, err - } - return transport.NewTcpServerTransportWithAddr("unix", u.path, nil), nil - } +func (tc *TcpClientBuilder) SetTlsConfig(c *tls.Config) *TcpClientBuilder { + tc.tlsCfg = c + return tc } -func (u *UnixTransporter) Client() transport.ClientTransportFunc { +func (tc *TcpClientBuilder) Build() transport.ClientTransportFunc { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTcpClientTransportWithAddr("unix", u.path, nil) + return transport.NewTcpClientTransportWithAddr("tcp", tc.addr, tc.tlsCfg) } } -func (u *UnixTransporterBuilder) Path(path string) *UnixTransporterBuilder { - u.opts = append(u.opts, func(transporter *UnixTransporter) { - transporter.path = path - }) - return u +func TcpClient() *TcpClientBuilder { + return &TcpClientBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), + } } -func (u *UnixTransporterBuilder) Build() Transporter { - tp := &UnixTransporter{ - path: "/var/run/rsocket.sock", +func TcpServer() *TcpServerBuilder { + return &TcpServerBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), } - for _, opt := range u.opts { - opt(tp) +} + +func WebsocketClient() *WebsocketClientBuilder { + return &WebsocketClientBuilder{ + url: fmt.Sprintf("ws://127.0.0.1:%d", DefaultPort), } - return tp } -func Tcp() *TcpTransporterBuilder { - return &TcpTransporterBuilder{} +func WebsocketServer() *WebsocketServerBuilder { + return &WebsocketServerBuilder{ + addr: fmt.Sprintf(":%d", DefaultPort), + path: "/", + } } -func Websocket() *WebsocketTransporterBuilder { - return &WebsocketTransporterBuilder{} +func UnixClient() *UnixClientBuilder { + return &UnixClientBuilder{ + path: DefaultUnixSockPath, + } } -func Unix() *UnixTransporterBuilder { - return &UnixTransporterBuilder{} +func UnixServer() *UnixServerBuilder { + return &UnixServerBuilder{ + path: DefaultUnixSockPath, + } } diff --git a/transporter_test.go b/transporter_test.go index a485b21..77d9a7d 100644 --- a/transporter_test.go +++ b/transporter_test.go @@ -3,6 +3,7 @@ package rsocket_test import ( "context" "fmt" + "net/http" "os" "strings" "testing" @@ -12,19 +13,59 @@ import ( "github.com/stretchr/testify/assert" ) -func TestUnix(t *testing.T) { - sockFile := fmt.Sprintf("%s/test-rsocket-%s.sock", strings.TrimRight(os.TempDir(), "/"), uuid.New().String()) - defer os.Remove(sockFile) - u := rsocket.Unix().Path(sockFile).Build() +var fakeSockFile string + +func init() { + fmt.Println(os.TempDir()) + fakeSockFile = fmt.Sprintf("%s/test-rsocket-%s.sock", strings.TrimRight(os.TempDir(), "/"), uuid.New().String()) +} + +func TestUnixServer(t *testing.T) { + defer os.Remove(fakeSockFile) + u := rsocket.UnixServer().SetPath(fakeSockFile).Build() assert.NotNil(t, u) - _, err := u.Server()(context.Background()) + _, err := u(context.Background()) assert.NoError(t, err) } -func TestTcp(t *testing.T) { - rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() +func TestUnixClient(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.UnixClient().SetPath(fakeSockFile).Build() + }) +} + +func TestTcpClient(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.TcpClient(). + SetAddr(":7878"). + SetHostAndPort("127.0.0.1", 7878). + Build() + }) +} + +func TestTcpServerBuilder(t *testing.T) { + assert.NotPanics(t, func() { + rsocket.TcpServer().SetAddr(":7878").Build() + }) +} + +func TestWebsocketClient(t *testing.T) { + assert.NotPanics(t, func() { + h := make(http.Header, 0) + h.Set("x-foo-bar", "qux") + rsocket.WebsocketClient(). + SetUrl("ws://127.0.0.1:8080/fake/path"). + SetHeader(h). + Build() + }) } -func TestWebsocket(t *testing.T) { - rsocket.Websocket() +func TestWebsocketServer(t *testing.T) { + assert.NotPanics(t, func() { + tp := rsocket.WebsocketServer(). + SetAddr(":7878"). + SetPath("/fake"). + Build() + assert.NotNil(t, tp) + }) } From 650adc2172a0bb5a53523d01e4d125c56272a945 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Thu, 6 Aug 2020 23:16:59 +0800 Subject: [PATCH 21/26] Fix. --- cmd/rsocket-cli/runner.go | 15 +++++++-------- core/transport/types.go | 4 ++-- examples/echo/echo.go | 7 ++++--- examples/echo_bench/echo_bench.go | 6 ++++-- examples/fibonacci/main.go | 9 ++------- examples/word_counter/main.go | 18 ++++++------------ transporter_test.go | 2 +- 7 files changed, 26 insertions(+), 35 deletions(-) diff --git a/cmd/rsocket-cli/runner.go b/cmd/rsocket-cli/runner.go index e7f97fb..ab1d276 100644 --- a/cmd/rsocket-cli/runner.go +++ b/cmd/rsocket-cli/runner.go @@ -354,20 +354,20 @@ func (r *Runner) newClientTransport() (transport.ClientTransportFunc, error) { } } -func (r *Runner) newServerTransport() (t transport.ServerTransportFunc, err error) { +func (r *Runner) newServerTransport() (transport.ServerTransportFunc, error) { u, err := url.Parse(r.URI) if err != nil { - return + return nil, err } switch u.Scheme { case "tcp": port, err := strconv.Atoi(u.Port()) if err != nil { - return + return nil, err } - t = rsocket.TcpServer().SetHostAndPort(u.Hostname(), port).Build() + return rsocket.TcpServer().SetHostAndPort(u.Hostname(), port).Build(), nil case "unix": - t = rsocket.UnixServer().SetPath(u.Hostname()).Build() + return rsocket.UnixServer().SetPath(u.Hostname()).Build(), nil case "ws", "wss": var addr string if port := u.Port(); port != "" { @@ -375,9 +375,8 @@ func (r *Runner) newServerTransport() (t transport.ServerTransportFunc, err erro } else { addr = fmt.Sprintf("%s:%d", u.Hostname(), rsocket.DefaultPort) } - t = rsocket.WebsocketServer().SetAddr(addr).SetPath(u.EscapedPath()).Build() + return rsocket.WebsocketServer().SetAddr(addr).SetPath(u.EscapedPath()).Build(), nil default: - err = errors.Errorf("invalid transport %s", u.Scheme) + return nil, errors.Errorf("invalid transport %s", u.Scheme) } - return } diff --git a/core/transport/types.go b/core/transport/types.go index 5a60a9c..1bd4e40 100644 --- a/core/transport/types.go +++ b/core/transport/types.go @@ -9,8 +9,8 @@ import ( ) type ( - ClientTransportFunc = func(context.Context) (*Transport, error) - ServerTransportFunc = func(context.Context) (ServerTransport, error) + ClientTransportFunc func(context.Context) (*Transport, error) + ServerTransportFunc func(context.Context) (ServerTransport, error) ) // Conn is connection for RSocket. diff --git a/examples/echo/echo.go b/examples/echo/echo.go index 89bfe53..823af98 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -13,16 +13,17 @@ import ( "github.com/jjeffcaii/reactor-go/scheduler" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" ) -var MyTransporter rsocket.Transporter +var tp transport.ServerTransportFunc func init() { - MyTransporter = rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() + tp = rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).Build() } func main() { @@ -64,7 +65,7 @@ func main() { } return responder(), nil }). - Transport(MyTransporter). + Transport(tp). Serve(context.Background()) if err != nil { panic(err) diff --git a/examples/echo_bench/echo_bench.go b/examples/echo_bench/echo_bench.go index 8c9041b..647f7a1 100644 --- a/examples/echo_bench/echo_bench.go +++ b/examples/echo_bench/echo_bench.go @@ -11,17 +11,18 @@ import ( "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/mono" ) -var tp rsocket.Transporter +var tp transport.ClientTransportFunc func init() { flag.Parse() - tp = rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build() rand.Seed(time.Now().UnixNano()) + tp = rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() } func main() { @@ -67,6 +68,7 @@ func main() { } func createClient(mtu int) (rsocket.Client, error) { + return rsocket.Connect(). Fragment(mtu). SetupPayload(payload.NewString("你好", "世界")). diff --git a/examples/fibonacci/main.go b/examples/fibonacci/main.go index d3ac482..5b32465 100644 --- a/examples/fibonacci/main.go +++ b/examples/fibonacci/main.go @@ -14,12 +14,6 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -var tp rsocket.Transporter - -func init() { - tp = rsocket.Tcp().Addr("127.0.0.1:7878").Build() -} - const number = 13 func main() { @@ -82,7 +76,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestStreamHandler), nil }). // specify transport - Transport(tp). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). // serve will block execution unless an error occurred Serve(context.Background()) @@ -91,6 +85,7 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection + tp := rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() client, err := rsocket.Connect().Transport(tp).Start(context.Background()) if err != nil { panic(err) diff --git a/examples/word_counter/main.go b/examples/word_counter/main.go index 55b6003..f3c421c 100644 --- a/examples/word_counter/main.go +++ b/examples/word_counter/main.go @@ -13,12 +13,6 @@ import ( "github.com/rsocket/rsocket-go/rx/flux" ) -var tp rsocket.Transporter - -func init() { - tp = rsocket.Tcp().Addr("127.0.0.1:7878").Build() -} - const number = 13 func main() { @@ -60,7 +54,7 @@ func server(readyCh chan struct{}) { return rsocket.NewAbstractSocket(requestChannelHandler), nil }). // specify transport - Transport(tp). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). // serve will block execution unless an error occurred Serve(context.Background()) @@ -69,21 +63,21 @@ func server(readyCh chan struct{}) { func client() { // Start a client connection - client, err := rsocket.Connect().Transport(tp).Start(context.Background()) + client, err := rsocket.Connect().Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()).Start(context.Background()) if err != nil { panic(err) } defer client.Close() // strings to count the words - strings := []payload.Payload{ + sentences := []payload.Payload{ payload.NewString("", extension.TextPlain.String()), payload.NewString("qux", extension.TextPlain.String()), payload.NewString("The quick brown fox jumps over the lazy dog", extension.TextPlain.String()), payload.NewString("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", extension.TextPlain.String()), } - f := flux.FromSlice(strings) + f := flux.FromSlice(sentences) // create a wait group so that the function does not return until the stream completes wg := sync.WaitGroup{} @@ -94,11 +88,11 @@ func client() { // register handler for RequestChannel client.RequestChannel(f).DoOnNext(func(input payload.Payload) error { // print word count - fmt.Println(strings[counter].DataUTF8(), ":", input.DataUTF8()) + fmt.Println(sentences[counter].DataUTF8(), ":", input.DataUTF8()) counter = counter + 1 return nil }).DoOnComplete(func() { - // will be called on successfull completion of the stream + // will be called on successful completion of the stream fmt.Println("Word counter ended.") }).DoOnError(func(err error) { // will be called if a error occurs diff --git a/transporter_test.go b/transporter_test.go index 77d9a7d..8e39299 100644 --- a/transporter_test.go +++ b/transporter_test.go @@ -51,7 +51,7 @@ func TestTcpServerBuilder(t *testing.T) { func TestWebsocketClient(t *testing.T) { assert.NotPanics(t, func() { - h := make(http.Header, 0) + h := make(http.Header) h.Set("x-foo-bar", "qux") rsocket.WebsocketClient(). SetUrl("ws://127.0.0.1:8080/fake/path"). From 084617fc3eb67c7f8213fa53843528c13379891f Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 11 Aug 2020 11:21:41 +0800 Subject: [PATCH 22/26] Fix race. --- .travis.yml | 2 +- balancer/balancer.go | 6 +- balancer/group.go | 28 +- balancer/group_example_test.go | 46 ++ balancer/group_test.go | 109 +--- balancer/round_robin.go | 139 ++-- balancer/round_robin_test.go | 16 +- core/counter.go | 36 -- core/traffic_counter.go | 36 ++ ...ounter_test.go => traffic_counter_test.go} | 4 +- core/transport/mock_conn_test.go | 2 +- core/transport/tcp_conn.go | 4 +- core/transport/tcp_conn_test.go | 6 +- core/transport/transport_test.go | 2 +- core/transport/types.go | 2 +- core/transport/websocket_conn.go | 4 +- core/transport/websocket_conn_test.go | 6 +- go.mod | 2 +- go.sum | 4 +- internal/common/cond.go | 47 ++ internal/common/cond_test.go | 67 ++ internal/socket/base_socket_test.go | 2 +- internal/socket/duplex.go | 600 +++++++++--------- internal/socket/mock_conn_test.go | 2 +- justfile | 14 +- 25 files changed, 618 insertions(+), 568 deletions(-) create mode 100644 balancer/group_example_test.go delete mode 100644 core/counter.go create mode 100644 core/traffic_counter.go rename core/{counter_test.go => traffic_counter_test.go} (87%) create mode 100644 internal/common/cond.go create mode 100644 internal/common/cond_test.go diff --git a/.travis.yml b/.travis.yml index 5ad217c..0d5592b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,5 @@ install: script: - golangci-lint run ./... - - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 ./logger/... ./lease/... ./core/... ./balancer/... ./rx/... ./internal/... ./extension/... ./payload/... . + - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 -race ./... - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/balancer/balancer.go b/balancer/balancer.go index fcc7692..8c9b83b 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -12,13 +12,11 @@ import ( type Balancer interface { io.Closer // Put puts a new client. - Put(client rsocket.Client) + Put(client rsocket.Client) error // PutLabel puts a new client with a label. - PutLabel(label string, client rsocket.Client) + PutLabel(label string, client rsocket.Client) error // Next returns next balanced RSocket client. Next(context.Context) (rsocket.Client, bool) - // MustNext returns next balanced RSocket client. - MustNext(context.Context) rsocket.Client // OnLeave handle events when a client exit. OnLeave(fn func(label string)) } diff --git a/balancer/group.go b/balancer/group.go index dcb5ef8..f447dbb 100644 --- a/balancer/group.go +++ b/balancer/group.go @@ -13,7 +13,8 @@ var errGroupClosed = errors.New("balancer group has been closed") // Group can be used to create a simple RSocket Broker. type Group struct { g func() Balancer - m *sync.Map + l sync.Mutex + m map[string]Balancer } // Close close current RSocket group. @@ -33,10 +34,12 @@ func (p *Group) Close() (err error) { } } }(all, done) - p.m.Range(func(key, value interface{}) bool { - all <- value.(Balancer) - return true - }) + + p.l.Lock() + defer p.l.Unlock() + for _, b := range p.m { + all <- b + } p.m = nil close(all) <-done @@ -45,24 +48,23 @@ func (p *Group) Close() (err error) { // Get returns a Balancer with custom id. func (p *Group) Get(id string) Balancer { + p.l.Lock() + defer p.l.Unlock() if p.m == nil { panic(errGroupClosed) } - if actual, ok := p.m.Load(id); ok { - return actual.(Balancer) + if actual, ok := p.m[id]; ok { + return actual } newborn := p.g() - actual, loaded := p.m.LoadOrStore(id, newborn) - if loaded { - _ = newborn.Close() - } - return actual.(Balancer) + p.m[id] = newborn + return newborn } // NewGroup returns a new Group. func NewGroup(gen func() Balancer) *Group { return &Group{ g: gen, - m: &sync.Map{}, + m: make(map[string]Balancer), } } diff --git a/balancer/group_example_test.go b/balancer/group_example_test.go new file mode 100644 index 0000000..4c2dab2 --- /dev/null +++ b/balancer/group_example_test.go @@ -0,0 +1,46 @@ +package balancer + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx/mono" +) + +func ExampleNewGroup() { + group := NewGroup(func() Balancer { + return NewRoundRobinBalancer() + }) + defer func() { + _ = group.Close() + }() + // Create a broker with resume. + err := rsocket.Receive(). + Resume(rsocket.WithServerResumeSessionDuration(10 * time.Second)). + Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { + // Register service using Setup Metadata as service ID. + if serviceID, ok := setup.MetadataUTF8(); ok { + group.Get(serviceID).Put(sendingSocket) + } + // Proxy requests by group. + return rsocket.NewAbstractSocket(rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { + requestServiceID, ok := msg.MetadataUTF8() + if !ok { + panic(errors.New("missing service ID in metadata")) + } + fmt.Println("[broker] redirect request to service", requestServiceID) + upstream, _ := group.Get(requestServiceID).Next(context.Background()) + fmt.Println("[broker] choose upstream:", upstream) + return upstream.RequestResponse(msg) + })), nil + }). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). + Serve(context.Background()) + if err != nil { + panic(err) + } +} diff --git a/balancer/group_test.go b/balancer/group_test.go index 207aff3..6b45f98 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -1,106 +1,23 @@ package balancer_test import ( - "context" - "crypto/md5" - "errors" - "fmt" "testing" - "time" - "github.com/rsocket/rsocket-go" - . "github.com/rsocket/rsocket-go/balancer" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx/mono" - "github.com/stretchr/testify/require" + "github.com/rsocket/rsocket-go/balancer" + "github.com/stretchr/testify/assert" ) -func ExampleNewGroup() { - group := NewGroup(func() Balancer { - return NewRoundRobinBalancer() +var fakeGroupId = "fakeGroupId" + +func TestGroup_Get(t *testing.T) { + called := 0 + g := balancer.NewGroup(func() balancer.Balancer { + called++ + return balancer.NewRoundRobinBalancer() }) - defer func() { - _ = group.Close() - }() - // Create a broker with resume. - err := rsocket.Receive(). - Resume(rsocket.WithServerResumeSessionDuration(10 * time.Second)). - Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { - // Register service using Setup Metadata as service ID. - if serviceID, ok := setup.MetadataUTF8(); ok { - group.Get(serviceID).Put(sendingSocket) - } - // Proxy requests by group. - return rsocket.NewAbstractSocket(rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - requestServiceID, ok := msg.MetadataUTF8() - if !ok { - panic(errors.New("missing service ID in metadata")) - } - fmt.Println("[broker] redirect request to service", requestServiceID) - upstream := group.Get(requestServiceID).MustNext(context.Background()) - fmt.Println("[broker] choose upstream:", upstream) - return upstream.RequestResponse(msg) - })), nil - }). - Transport(rsocket.TcpServer().SetAddr(":7878").Build()). - Serve(context.Background()) - if err != nil { - panic(err) + for range [2]struct{}{} { + b := g.Get(fakeGroupId) + assert.NotNil(t, b) + assert.Equal(t, 1, called) } } - -func TestServiceSubscribe(t *testing.T) { - // Init broker and service. - go ExampleNewGroup() - - // Waiting broker up by sleeping 200 ms. - time.Sleep(200 * time.Millisecond) - - tp := rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build() - - // Deploy MD5 service. - go func() { - done := make(chan struct{}) - cli, err := rsocket.Connect(). - OnClose(func(err error) { - close(done) - }). - SetupPayload(payload.NewString("This is a Service Publisher!", "md5")). - Acceptor(func(socket rsocket.RSocket) rsocket.RSocket { - return rsocket.NewAbstractSocket( - rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { - result := payload.NewString(fmt.Sprintf("%02x", md5.Sum(msg.Data())), "MD5 RESULT") - fmt.Println("[publisher] accept MD5 request:", msg.DataUTF8()) - return mono.Just(result) - })) - }). - Transport(tp). - Start(context.Background()) - if err != nil { - panic(err) - } - defer func() { - _ = cli.Close() - }() - <-done - }() - - // Create a client and request md5 service. - cli, err := rsocket.Connect(). - SetupPayload(payload.NewString("This is a Subscriber", "")). - Transport(tp). - Start(context.Background()) - require.NoError(t, err, "create client failed") - defer func() { - _ = cli.Close() - time.Sleep(200 * time.Millisecond) - }() - _, err = cli.RequestResponse(payload.NewString("Hello World!", "md5")). - DoOnSuccess(func(elem payload.Payload) error { - fmt.Println("[subscriber] receive MD5 response:", elem.DataUTF8()) - require.Equal(t, "ed076287532e86365e841e92bfc50d8c", elem.DataUTF8(), "bad md5") - return nil - }). - Block(context.Background()) - require.NoError(t, err, "request failed") -} diff --git a/balancer/round_robin.go b/balancer/round_robin.go index f098755..8c9bf55 100644 --- a/balancer/round_robin.go +++ b/balancer/round_robin.go @@ -2,139 +2,120 @@ package balancer import ( "context" + "runtime" "sync" "github.com/google/uuid" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go" + "github.com/rsocket/rsocket-go/internal/common" "github.com/rsocket/rsocket-go/logger" "go.uber.org/atomic" ) -type labelClient struct { - l string - c rsocket.Client -} +var errConflictSocket = errors.New("socket exists already") type balancerRoundRobin struct { seq *atomic.Uint32 - mutex sync.RWMutex - clients []*labelClient + keys []string + sockets []rsocket.Client done chan struct{} once sync.Once onLeave []func(string) - cond *sync.Cond + c *common.Cond } -func (p *balancerRoundRobin) OnLeave(fn func(label string)) { +func (b *balancerRoundRobin) OnLeave(fn func(label string)) { if fn != nil { - p.onLeave = append(p.onLeave, fn) + b.onLeave = append(b.onLeave, fn) } } -func (p *balancerRoundRobin) Put(client rsocket.Client) { - label := uuid.New().String() - p.PutLabel(label, client) +func (b *balancerRoundRobin) Put(client rsocket.Client) error { + return b.PutLabel(uuid.New().String(), client) } -func (p *balancerRoundRobin) PutLabel(label string, client rsocket.Client) { - p.mutex.Lock() - defer p.mutex.Unlock() - p.clients = append(p.clients, &labelClient{ - l: label, - c: client, - }) - client.OnClose(func(error) { - p.remove(client) - }) - if len(p.clients) == 1 { - p.cond.Broadcast() +func (b *balancerRoundRobin) PutLabel(label string, client rsocket.Client) error { + b.c.L.Lock() + defer b.c.L.Unlock() + for _, k := range b.keys { + if k == label { + return errConflictSocket + } } -} - -func (p *balancerRoundRobin) MustNext(ctx context.Context) rsocket.Client { - c, ok := p.Next(ctx) - if !ok { - panic("cannot get next client from current balancer") + b.keys = append(b.keys, label) + b.sockets = append(b.sockets, client) + if n := len(b.sockets); n == 1 { + b.c.Broadcast() } - return c + client.OnClose(func(err error) { + b.remove(client) + }) + return nil } -func (p *balancerRoundRobin) Next(ctx context.Context) (rsocket.Client, bool) { - p.mutex.RLock() - defer p.mutex.RUnlock() - if n := len(p.clients); n > 0 { - idx := int(p.seq.Inc() % uint32(n)) - return p.clients[idx].c, true - } - - ch := make(chan rsocket.Client, 1) - closed := atomic.NewBool(false) - - go func() { - p.cond.L.Lock() - for len(p.clients) < 1 && !closed.Load() { - p.cond.Wait() - } - p.cond.L.Unlock() - if n := len(p.clients); n > 0 { - idx := int(p.seq.Inc() % uint32(n)) - ch <- p.clients[idx].c +func (b *balancerRoundRobin) Next(ctx context.Context) (client rsocket.Client, ok bool) { + b.c.L.Lock() + for { + n := len(b.keys) + if n > 0 { + idx := int(b.seq.Inc() % uint32(n)) + client = b.sockets[idx] + ok = true + break } - }() - - select { - case <-ctx.Done(): - closed.Store(true) - p.cond.Broadcast() - return nil, false - case c, ok := <-ch: - if !ok { - return nil, false + if b.c.Wait(ctx) { + break } - return c, true + b.c.L.Unlock() + runtime.Gosched() + b.c.L.Lock() } + b.c.L.Unlock() + return } -func (p *balancerRoundRobin) Close() (err error) { - p.once.Do(func() { - if len(p.clients) < 1 { +func (b *balancerRoundRobin) Close() (err error) { + b.once.Do(func() { + if len(b.sockets) < 1 { return } - clone := append([]*labelClient(nil), p.clients...) - close(p.done) + clone := append([]rsocket.Client(nil), b.sockets...) + close(b.done) wg := &sync.WaitGroup{} wg.Add(len(clone)) - for _, value := range clone { + for i := 0; i < len(clone); i++ { go func(c rsocket.Client, wg *sync.WaitGroup) { defer wg.Done() if err := c.Close(); err != nil { logger.Warnf("close client failed: %s\n", err) } - }(value.c, wg) + }(clone[i], wg) } wg.Wait() }) return } -func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok bool) { - p.mutex.Lock() +func (b *balancerRoundRobin) remove(client rsocket.Client) (label string, ok bool) { + b.c.L.Lock() j := -1 - for i, l := 0, len(p.clients); i < l; i++ { - if p.clients[i].c == client { + for i, l := 0, len(b.sockets); i < l; i++ { + if b.sockets[i] == client { j = i break } } ok = j > -1 if ok { - label = p.clients[j].l - p.clients = append(p.clients[:j], p.clients[j+1:]...) + label = b.keys[j] + b.keys = append(b.keys[:j], b.keys[j+1:]...) + b.sockets = append(b.sockets[:j], b.sockets[j+1:]...) } - p.mutex.Unlock() - if ok && len(p.onLeave) > 0 { + b.c.L.Unlock() + if ok && len(b.onLeave) > 0 { go func(label string) { - for _, fn := range p.onLeave { + for _, fn := range b.onLeave { fn(label) } }(label) @@ -145,8 +126,8 @@ func (p *balancerRoundRobin) remove(client rsocket.Client) (label string, ok boo // NewRoundRobinBalancer returns a new Round-Robin Balancer. func NewRoundRobinBalancer() Balancer { return &balancerRoundRobin{ - cond: sync.NewCond(new(sync.Mutex)), seq: atomic.NewUint32(0), done: make(chan struct{}), + c: common.NewCond(&sync.Mutex{}), } } diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index f00b229..fcccc15 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -72,7 +72,9 @@ func TestRoundRobin(t *testing.T) { wg := sync.WaitGroup{} wg.Add(n * len(ports)) for i := 0; i < n*len(ports); i++ { - b.MustNext(context.Background()).RequestResponse(req). + c, ok := b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + c.RequestResponse(req). DoFinally(func(s rx.SignalType) { wg.Done() }). @@ -100,7 +102,9 @@ func TestRoundRobin(t *testing.T) { time.Sleep(100 * time.Millisecond) // then send a request - _, err := b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + c, ok := b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err := c.RequestResponse(req).Block(context.Background()) assert.NoError(t, err) var total int @@ -111,7 +115,9 @@ func TestRoundRobin(t *testing.T) { assert.Equal(t, n*len(ports)+1, total) assert.Equal(t, int32(0), ac0.(*atomic.Int32).Load()-amount0) - _, err = b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + c, ok = b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err = c.RequestResponse(req).Block(context.Background()) assert.NoError(t, err) total++ @@ -122,7 +128,9 @@ func TestRoundRobin(t *testing.T) { const extra = 10 for i := 0; i < extra; i++ { - _, err = b.MustNext(context.Background()).RequestResponse(req).Block(context.Background()) + c, ok = b.Next(context.Background()) + assert.True(t, ok, "get next client failed") + _, err = c.RequestResponse(req).Block(context.Background()) assert.NoError(t, err) } total += 10 diff --git a/core/counter.go b/core/counter.go deleted file mode 100644 index 4d54e38..0000000 --- a/core/counter.go +++ /dev/null @@ -1,36 +0,0 @@ -package core - -import ( - "go.uber.org/atomic" -) - -// Counter represents a counter of read/write bytes. -type Counter struct { - r, w *atomic.Uint64 -} - -// ReadBytes returns the number of bytes that have been read. -func (p Counter) ReadBytes() uint64 { - return p.r.Load() -} - -// WriteBytes returns the number of bytes that have been written. -func (p Counter) WriteBytes() uint64 { - return p.w.Load() -} - -func (p Counter) IncWriteBytes(n int) { - p.w.Add(uint64(n)) -} - -func (p Counter) IncReadBytes(n int) { - p.r.Add(uint64(n)) -} - -// NewCounter returns a new counter. -func NewCounter() *Counter { - return &Counter{ - r: atomic.NewUint64(0), - w: atomic.NewUint64(0), - } -} diff --git a/core/traffic_counter.go b/core/traffic_counter.go new file mode 100644 index 0000000..89e7791 --- /dev/null +++ b/core/traffic_counter.go @@ -0,0 +1,36 @@ +package core + +import ( + "go.uber.org/atomic" +) + +// TrafficCounter represents a counter of read/write bytes. +type TrafficCounter struct { + r, w *atomic.Uint64 +} + +// ReadBytes returns the number of bytes that have been read. +func (p TrafficCounter) ReadBytes() uint64 { + return p.r.Load() +} + +// WriteBytes returns the number of bytes that have been written. +func (p TrafficCounter) WriteBytes() uint64 { + return p.w.Load() +} + +func (p TrafficCounter) IncWriteBytes(n int) { + p.w.Add(uint64(n)) +} + +func (p TrafficCounter) IncReadBytes(n int) { + p.r.Add(uint64(n)) +} + +// NewTrafficCounter returns a new counter. +func NewTrafficCounter() *TrafficCounter { + return &TrafficCounter{ + r: atomic.NewUint64(0), + w: atomic.NewUint64(0), + } +} diff --git a/core/counter_test.go b/core/traffic_counter_test.go similarity index 87% rename from core/counter_test.go rename to core/traffic_counter_test.go index bdefc3e..2cd27b8 100644 --- a/core/counter_test.go +++ b/core/traffic_counter_test.go @@ -8,12 +8,12 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCounter(t *testing.T) { +func TestTrafficCounter(t *testing.T) { const cycle = 1000 const amount = 1000 wg := sync.WaitGroup{} wg.Add(amount) - c := core.NewCounter() + c := core.NewTrafficCounter() for range [amount]struct{}{} { go func() { for range [cycle]struct{}{} { diff --git a/core/transport/mock_conn_test.go b/core/transport/mock_conn_test.go index d1ea299..14198f9 100644 --- a/core/transport/mock_conn_test.go +++ b/core/transport/mock_conn_test.go @@ -64,7 +64,7 @@ func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { } // SetCounter mocks base method -func (m *MockConn) SetCounter(c *core.Counter) { +func (m *MockConn) SetCounter(c *core.TrafficCounter) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetCounter", c) } diff --git a/core/transport/tcp_conn.go b/core/transport/tcp_conn.go index bce1f23..ac4cfda 100644 --- a/core/transport/tcp_conn.go +++ b/core/transport/tcp_conn.go @@ -17,10 +17,10 @@ type TcpConn struct { conn net.Conn writer *bufio.Writer decoder *LengthBasedFrameDecoder - counter *core.Counter + counter *core.TrafficCounter } -func (p *TcpConn) SetCounter(c *core.Counter) { +func (p *TcpConn) SetCounter(c *core.TrafficCounter) { p.counter = c } diff --git a/core/transport/tcp_conn_test.go b/core/transport/tcp_conn_test.go index daeed1f..71eff6a 100644 --- a/core/transport/tcp_conn_test.go +++ b/core/transport/tcp_conn_test.go @@ -36,7 +36,7 @@ func TestTcpConn_Read(t *testing.T) { defer ctrl.Finish() bf := &bytes.Buffer{} - c := core.NewCounter() + c := core.NewTrafficCounter() tc.SetCounter(c) toBeWritten := []core.WriteableFrame{ @@ -91,7 +91,7 @@ func TestTcpConn_Flush_Nothing(t *testing.T) { ctrl, nc, tc := InitMockTcpConn(t) defer ctrl.Finish() - c := core.NewCounter() + c := core.NewTrafficCounter() tc.SetCounter(c) nc.EXPECT().Write(gomock.Any()).Times(0) @@ -119,7 +119,7 @@ func TestTcpConn_WriteAndFlush(t *testing.T) { ctrl, nc, tc := InitMockTcpConn(t) defer ctrl.Finish() - c := core.NewCounter() + c := core.NewTrafficCounter() tc.SetCounter(c) nc.EXPECT(). diff --git a/core/transport/transport_test.go b/core/transport/transport_test.go index d491e95..a34e77a 100644 --- a/core/transport/transport_test.go +++ b/core/transport/transport_test.go @@ -198,7 +198,7 @@ func TestTransport_Flush(t *testing.T) { err := tp.Flush() assert.NoError(t, err, "flush failed") - conn.SetCounter(core.NewCounter()) + conn.SetCounter(core.NewTrafficCounter()) } func TestTransport_Close(t *testing.T) { diff --git a/core/transport/types.go b/core/transport/types.go index 1bd4e40..4d76c43 100644 --- a/core/transport/types.go +++ b/core/transport/types.go @@ -20,7 +20,7 @@ type Conn interface { // After this deadline, connection will be closed. SetDeadline(deadline time.Time) error // SetCounter bind a counter which can count r/w bytes. - SetCounter(c *core.Counter) + SetCounter(c *core.TrafficCounter) // Read reads next frame from Conn. Read() (core.Frame, error) // Write writes a frame to Conn. diff --git a/core/transport/websocket_conn.go b/core/transport/websocket_conn.go index 23f32a7..e0d28db 100644 --- a/core/transport/websocket_conn.go +++ b/core/transport/websocket_conn.go @@ -26,10 +26,10 @@ type RawWsConn interface { type WsConn struct { c RawWsConn - counter *core.Counter + counter *core.TrafficCounter } -func (p *WsConn) SetCounter(c *core.Counter) { +func (p *WsConn) SetCounter(c *core.TrafficCounter) { p.counter = c } diff --git a/core/transport/websocket_conn_test.go b/core/transport/websocket_conn_test.go index b80c2fb..703aa6b 100644 --- a/core/transport/websocket_conn_test.go +++ b/core/transport/websocket_conn_test.go @@ -36,7 +36,7 @@ func TestWsConn_Read(t *testing.T) { ctrl, rc, wc := InitMockWsConn(t) defer ctrl.Finish() - c := core.NewCounter() + c := core.NewTrafficCounter() wc.SetCounter(c) toBeWritten := []core.WriteableFrame{ @@ -105,7 +105,7 @@ func TestWsConn_Flush_Nothing(t *testing.T) { ctrl, mc, wc := InitMockWsConn(t) defer ctrl.Finish() - c := core.NewCounter() + c := core.NewTrafficCounter() wc.SetCounter(c) mc.EXPECT().WriteMessage(websocket.BinaryMessage, gomock.Any()).Times(0) @@ -132,7 +132,7 @@ func TestWsConn_Write(t *testing.T) { ctrl, mc, wc := InitMockWsConn(t) defer ctrl.Finish() - c := core.NewCounter() + c := core.NewTrafficCounter() wc.SetCounter(c) toBeWritten := []core.WriteableFrame{ diff --git a/go.mod b/go.mod index 73bece1..6446eae 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang/mock v1.4.3 github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 - github.com/jjeffcaii/reactor-go v0.2.0 + github.com/jjeffcaii/reactor-go v0.2.1 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.7.1 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index 913b615..8538a77 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,8 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jjeffcaii/reactor-go v0.2.0 h1:sIiEfclB65HH4Ne+Cz1Q8EaKn88/7le5hYHzjHhrRvA= -github.com/jjeffcaii/reactor-go v0.2.0/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/jjeffcaii/reactor-go v0.2.1 h1:Wb33QstbwZdgFFS3BqWeBweGp2KQJigfts7NQLHTxqE= +github.com/jjeffcaii/reactor-go v0.2.1/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= diff --git a/internal/common/cond.go b/internal/common/cond.go new file mode 100644 index 0000000..9c8bc73 --- /dev/null +++ b/internal/common/cond.go @@ -0,0 +1,47 @@ +package common + +import ( + "context" + "sync" + "sync/atomic" + "unsafe" +) + +// Cond +// see https://gist.github.com/zviadm/c234426882bfc8acba88f3503edaaa36#file-cond2-go +type Cond struct { + L sync.Locker + n unsafe.Pointer +} + +func NewCond(l sync.Locker) *Cond { + c := &Cond{ + L: l, + } + n := make(chan struct{}) + c.n = unsafe.Pointer(&n) + return c +} + +func (c *Cond) NotifyChan() <-chan struct{} { + ptr := atomic.LoadPointer(&c.n) + return *((*chan struct{})(ptr)) +} + +func (c *Cond) Broadcast() { + n := make(chan struct{}) + ptrOld := atomic.SwapPointer(&c.n, unsafe.Pointer(&n)) + close(*(*chan struct{})(ptrOld)) +} + +func (c *Cond) Wait(ctx context.Context) (isCtx bool) { + n := c.NotifyChan() + c.L.Unlock() + select { + case <-n: + case <-ctx.Done(): + isCtx = true + } + c.L.Lock() + return +} diff --git a/internal/common/cond_test.go b/internal/common/cond_test.go new file mode 100644 index 0000000..84422a3 --- /dev/null +++ b/internal/common/cond_test.go @@ -0,0 +1,67 @@ +package common_test + +import ( + "context" + "log" + "runtime" + "sync" + "testing" + + "github.com/rsocket/rsocket-go/internal/common" +) + +func TestNewCond(t *testing.T) { + x := 0 + c := common.NewCond(&sync.Mutex{}) + done := make(chan bool) + go func() { + c.L.Lock() + x = 1 + c.Wait(context.Background()) + if x != 2 { + log.Fatal("want 2") + } + x = 3 + c.Broadcast() + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 1 { + x = 2 + c.Broadcast() + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + go func() { + c.L.Lock() + for { + if x == 2 { + c.Wait(context.Background()) + if x != 3 { + log.Fatal("want 3") + } + break + } + if x == 3 { + break + } + c.L.Unlock() + runtime.Gosched() + c.L.Lock() + } + c.L.Unlock() + done <- true + }() + <-done + <-done + <-done +} diff --git a/internal/socket/base_socket_test.go b/internal/socket/base_socket_test.go index 520674f..d427495 100644 --- a/internal/socket/base_socket_test.go +++ b/internal/socket/base_socket_test.go @@ -44,6 +44,7 @@ func TestBaseSocket(t *testing.T) { defer close(done) _ = tp.Start(context.Background()) }() + assert.NotPanics(t, func() { s.MetadataPush(fakeRequest) s.FireAndForget(fakeRequest) @@ -55,6 +56,5 @@ func TestBaseSocket(t *testing.T) { <-done _ = s.Close() - assert.Equal(t, true, onClosedCalled.Load()) } diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 410f449..a704ca2 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -2,13 +2,12 @@ package socket import ( "context" - "errors" "fmt" - "io" "sync" "time" "github.com/jjeffcaii/reactor-go/scheduler" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/core/transport" @@ -19,13 +18,13 @@ import ( "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" - "go.uber.org/atomic" ) const _outChanSize = 64 +var errSocketClosed = errors.New("socket closed already") + var ( - errSocketClosed = errors.New("socket closed already") unsupportedRequestStream = []byte("Request-Stream not implemented.") unsupportedRequestResponse = []byte("Request-Response not implemented.") unsupportedRequestChannel = []byte("Request-Channel not implemented.") @@ -38,7 +37,8 @@ func IsSocketClosedError(err error) bool { // DuplexConnection represents a socket of RSocket which can be a requester or a responder. type DuplexConnection struct { - counter *core.Counter + l *sync.RWMutex + counter *core.TrafficCounter tp *transport.Transport outs chan core.WriteableFrame outsPriority []core.WriteableFrame @@ -47,34 +47,38 @@ type DuplexConnection struct { sids StreamID mtu int fragments *sync.Map // common.U32Map // key=streamID, value=Joiner - closed *atomic.Bool - done chan struct{} + writeDone chan struct{} keepaliver *Keepaliver cond *sync.Cond singleScheduler scheduler.Scheduler e error leases lease.Leases + closeOnce sync.Once } // SetError sets error for current socket. -func (p *DuplexConnection) SetError(e error) { - p.e = e +func (dc *DuplexConnection) SetError(err error) { + dc.l.Lock() + defer dc.l.Unlock() + dc.e = err } // GetError get the error set. -func (p *DuplexConnection) GetError() error { - return p.e +func (dc *DuplexConnection) GetError() error { + dc.l.RLock() + defer dc.l.RUnlock() + return dc.e } -func (p *DuplexConnection) nextStreamID() (sid uint32) { - var lap1st bool +func (dc *DuplexConnection) nextStreamID() (sid uint32) { + var firstLap bool for { // There's no required to check StreamID conflicts. - sid, lap1st = p.sids.Next() - if lap1st { + sid, firstLap = dc.sids.Next() + if firstLap { return } - _, ok := p.messages.Load(sid) + _, ok := dc.messages.Load(sid) if !ok { return } @@ -82,127 +86,127 @@ func (p *DuplexConnection) nextStreamID() (sid uint32) { } // Close close current socket. -func (p *DuplexConnection) Close() error { - if !p.closed.CAS(false, true) { - return nil - } - if p.keepaliver != nil { - p.keepaliver.Stop() +func (dc *DuplexConnection) Close() (err error) { + dc.closeOnce.Do(func() { + err = dc.innerClose() + }) + return +} + +func (dc *DuplexConnection) innerClose() error { + if dc.keepaliver != nil { + dc.keepaliver.Stop() } - _ = p.singleScheduler.(io.Closer).Close() - close(p.outs) - p.cond.L.Lock() - p.cond.Broadcast() - p.cond.L.Unlock() + _ = dc.singleScheduler.Close() + close(dc.outs) + dc.cond.L.Lock() + dc.cond.Broadcast() + dc.cond.L.Unlock() - <-p.done + <-dc.writeDone - if p.tp != nil { - if p.e == nil { - p.e = p.tp.Close() + if dc.tp != nil { + if dc.e == nil { + dc.e = dc.tp.Close() } else { - _ = p.tp.Close() + _ = dc.tp.Close() } } - p.messages.Range(func(key, value interface{}) bool { - if cc, ok := value.(callback); ok { - if p.e == nil { - go func() { - cc.Close(errSocketClosed) - }() - } else { - go func(e error) { - cc.Close(e) - }(p.e) + dc.messages.Range(func(_, v interface{}) bool { + if cb, ok := v.(callback); ok { + err := dc.e + if err == nil { + err = errSocketClosed } + go cb.Close(err) } return true }) - return p.e + return dc.e } // FireAndForget start a request of FireAndForget. -func (p *DuplexConnection) FireAndForget(sending payload.Payload) { +func (dc *DuplexConnection) FireAndForget(sending payload.Payload) { data := sending.Data() size := core.FrameHeaderLen + len(sending.Data()) m, ok := sending.Metadata() if ok { size += 3 + len(m) } - sid := p.nextStreamID() - if !p.shouldSplit(size) { - p.sendFrame(framing.NewWriteableFireAndForgetFrame(sid, data, m, 0)) + sid := dc.nextStreamID() + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableFireAndForgetFrame(sid, data, m, 0)) return } - p.doSplit(data, m, func(index int, result fragmentation.SplitResult) { + dc.doSplit(data, m, func(index int, result fragmentation.SplitResult) { var f core.WriteableFrame if index == 0 { f = framing.NewWriteableFireAndForgetFrame(sid, result.Data, result.Metadata, result.Flag) } else { f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) } // MetadataPush start a request of MetadataPush. -func (p *DuplexConnection) MetadataPush(payload payload.Payload) { +func (dc *DuplexConnection) MetadataPush(payload payload.Payload) { metadata, _ := payload.Metadata() - p.sendFrame(framing.NewWriteableMetadataPushFrame(metadata)) + dc.sendFrame(framing.NewWriteableMetadataPushFrame(metadata)) } // RequestResponse start a request of RequestResponse. -func (p *DuplexConnection) RequestResponse(pl payload.Payload) (mo mono.Mono) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestResponse(pl payload.Payload) (mo mono.Mono) { + sid := dc.nextStreamID() resp := mono.CreateProcessor() - p.register(sid, requestResponseCallback{pc: resp}) + dc.register(sid, requestResponseCallback{pc: resp}) data := pl.Data() metadata, _ := pl.Metadata() + mo = resp. DoFinally(func(s rx.SignalType) { if s == rx.SignalCancel { - p.sendFrame(framing.NewWriteableCancelFrame(sid)) + dc.sendFrame(framing.NewWriteableCancelFrame(sid)) } - p.unregister(sid) + dc.unregister(sid) }) - p.singleScheduler.Worker().Do(func() { - // sending... - size := framing.CalcPayloadFrameSize(data, metadata) - if !p.shouldSplit(size) { - p.sendFrame(framing.NewWriteableRequestResponseFrame(sid, data, metadata, 0)) - return + // sending... + size := framing.CalcPayloadFrameSize(data, metadata) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableRequestResponseFrame(sid, data, metadata, 0)) + return + } + dc.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { + var f core.WriteableFrame + if index == 0 { + f = framing.NewWriteableRequestResponseFrame(sid, result.Data, result.Metadata, result.Flag) + } else { + f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.doSplit(data, metadata, func(index int, result fragmentation.SplitResult) { - var f core.WriteableFrame - if index == 0 { - f = framing.NewWriteableRequestResponseFrame(sid, result.Data, result.Metadata, result.Flag) - } else { - f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) - } - p.sendFrame(f) - }) + dc.sendFrame(f) }) + return } // RequestStream start a request of RequestStream. -func (p *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux) { + sid := dc.nextStreamID() pc := flux.CreateProcessor() - p.register(sid, requestStreamCallback{pc: pc}) + dc.register(sid, requestStreamCallback{pc: pc}) requested := make(chan struct{}) ret = pc. DoFinally(func(sig rx.SignalType) { if sig == rx.SignalCancel { - p.sendFrame(framing.NewWriteableCancelFrame(sid)) + dc.sendFrame(framing.NewWriteableCancelFrame(sid)) } - p.unregister(sid) + dc.unregister(sid) }). DoOnRequest(func(n int) { n32 := ToUint32RequestN(n) @@ -217,7 +221,7 @@ func (p *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux if !newborn { frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) - p.sendFrame(frameN) + dc.sendFrame(frameN) <-frameN.DoneNotify() return } @@ -226,27 +230,27 @@ func (p *DuplexConnection) RequestStream(sending payload.Payload) (ret flux.Flux metadata, _ := sending.Metadata() size := framing.CalcPayloadFrameSize(data, metadata) + 4 - if !p.shouldSplit(size) { - p.sendFrame(framing.NewWriteableRequestStreamFrame(sid, n32, data, metadata, 0)) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteableRequestStreamFrame(sid, n32, data, metadata, 0)) return } - p.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { + dc.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) { var f core.WriteableFrame if index == 0 { f = framing.NewWriteableRequestStreamFrame(sid, n32, result.Data, result.Metadata, result.Flag) } else { f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) }) return } // RequestChannel start a request of RequestChannel. -func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { - sid := p.nextStreamID() +func (dc *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux) { + sid := dc.nextStreamID() sending := publisher.(flux.Flux) receiving := flux.CreateProcessor() @@ -255,7 +259,7 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux ret = receiving. DoFinally(func(sig rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). DoOnRequest(func(n int) { n32 := ToUint32RequestN(n) @@ -268,7 +272,7 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux } if !newborn { frameN := framing.NewWriteableRequestNFrame(sid, n32, 0) - p.sendFrame(frameN) + dc.sendFrame(frameN) <-frameN.DoneNotify() return } @@ -284,32 +288,32 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux close(sndRequested) } if !newborn { - p.sendPayload(sid, item, core.FlagNext) + dc.sendPayload(sid, item, core.FlagNext) return } d := item.Data() m, _ := item.Metadata() size := framing.CalcPayloadFrameSize(d, m) + 4 - if !p.shouldSplit(size) { + if !dc.shouldSplit(size) { metadata, _ := item.Metadata() - p.sendFrame(framing.NewWriteableRequestChannelFrame(sid, n32, item.Data(), metadata, core.FlagNext)) + dc.sendFrame(framing.NewWriteableRequestChannelFrame(sid, n32, item.Data(), metadata, core.FlagNext)) return } - p.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { + dc.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) { var f core.WriteableFrame if index == 0 { f = framing.NewWriteableRequestChannelFrame(sid, n32, result.Data, result.Metadata, result.Flag|core.FlagNext) } else { f = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext) } - p.sendFrame(f) + dc.sendFrame(f) }) return }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, requestChannelCallback{rcv: receiving, snd: s}) + dc.register(sid, requestChannelCallback{rcv: receiving, snd: s}) s.Request(1) }), ) @@ -319,7 +323,7 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux switch sig { case rx.SignalComplete: complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) - p.sendFrame(complete) + dc.sendFrame(complete) <-complete.DoneNotify() default: panic(fmt.Errorf("unsupported sending channel signal: %s", sig)) @@ -331,16 +335,16 @@ func (p *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flux return ret } -func (p *DuplexConnection) onFrameRequestResponse(frame core.Frame) error { +func (dc *DuplexConnection) onFrameRequestResponse(frame core.Frame) error { // fragment - receiving, ok := p.doFragment(frame.(*framing.RequestResponseFrame)) + receiving, ok := dc.doFragment(frame.(*framing.RequestResponseFrame)) if !ok { return nil } - return p.respondRequestResponse(receiving) + return dc.respondRequestResponse(receiving) } -func (p *DuplexConnection) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestResponse(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // 1. execute socket handler @@ -348,52 +352,52 @@ func (p *DuplexConnection) respondRequestResponse(receiving fragmentation.Header defer func() { err = tryRecover(recover()) }() - mono = p.responder.RequestResponse(receiving) + mono = dc.responder.RequestResponse(receiving) return }() // 2. sending error with panic if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } // 3. sending error with unsupported handler if sending == nil { - p.writeError(sid, framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) + dc.writeError(sid, framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestResponse)) return nil } // 4. async subscribe publisher sub := rx.NewSubscriber( rx.OnNext(func(input payload.Payload) error { - p.sendPayload(sid, input, core.FlagNext|core.FlagComplete) + dc.sendPayload(sid, input, core.FlagNext|core.FlagComplete) return nil }), rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, requestResponseCallbackReverse{su: s}) + dc.register(sid, requestResponseCallbackReverse{su: s}) s.Request(rx.RequestMax) }), ) sending. DoFinally(func(sig rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } -func (p *DuplexConnection) onFrameRequestChannel(input core.Frame) error { - receiving, ok := p.doFragment(input.(*framing.RequestChannelFrame)) +func (dc *DuplexConnection) onFrameRequestChannel(input core.Frame) error { + receiving, ok := dc.doFragment(input.(*framing.RequestChannelFrame)) if !ok { return nil } - return p.respondRequestChannel(receiving) + return dc.respondRequestChannel(receiving) } -func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayload) error { // seek initRequestN var initRequestN int switch v := pl.(type) { @@ -418,18 +422,18 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo case _, ok := <-ch: if ok { close(ch) - p.unregister(sid) + dc.unregister(sid) } default: } }). DoOnRequest(func(n int) { frameN := framing.NewWriteableRequestNFrame(sid, ToUint32RequestN(n), 0) - p.sendFrame(frameN) + dc.sendFrame(frameN) <-frameN.DoneNotify() }) - p.singleScheduler.Worker().Do(func() { + _ = dc.singleScheduler.Worker().Do(func() { receivingProcessor.Next(pl) }) @@ -438,7 +442,7 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo defer func() { err = tryRecover(recover()) }() - flux = p.responder.RequestChannel(receiving) + flux = dc.responder.RequestChannel(receiving) if flux == nil { err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel) } @@ -446,7 +450,7 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo }() if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } @@ -455,20 +459,20 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo sub := rx.NewSubscriber( rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnComplete(func() { complete := framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete) - p.sendFrame(complete) + dc.sendFrame(complete) <-complete.DoneNotify() }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, requestChannelCallbackReverse{rcv: receivingProcessor, snd: s}) + dc.register(sid, requestChannelCallbackReverse{rcv: receivingProcessor, snd: s}) close(mustSub) s.Request(initRequestN) }), rx.OnNext(func(elem payload.Payload) error { - p.sendPayload(sid, elem, core.FlagNext) + dc.sendPayload(sid, elem, core.FlagNext) return nil }), ) @@ -481,7 +485,7 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo case _, ok := <-ch: if ok { close(ch) - p.unregister(sid) + dc.unregister(sid) } default: } @@ -493,43 +497,43 @@ func (p *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPaylo return nil } -func (p *DuplexConnection) respondMetadataPush(input core.Frame) (err error) { +func (dc *DuplexConnection) respondMetadataPush(input core.Frame) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond METADATA_PUSH failed: %s\n", e) } }() - p.responder.MetadataPush(input.(*framing.MetadataPushFrame)) + dc.responder.MetadataPush(input.(*framing.MetadataPushFrame)) return } -func (p *DuplexConnection) onFrameFNF(frame core.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.FireAndForgetFrame)) +func (dc *DuplexConnection) onFrameFNF(frame core.Frame) error { + receiving, ok := dc.doFragment(frame.(*framing.FireAndForgetFrame)) if !ok { return nil } - return p.respondFNF(receiving) + return dc.respondFNF(receiving) } -func (p *DuplexConnection) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { +func (dc *DuplexConnection) respondFNF(receiving fragmentation.HeaderAndPayload) (err error) { defer func() { if e := recover(); e != nil { logger.Errorf("respond FireAndForget failed: %s\n", e) } }() - p.responder.FireAndForget(receiving) + dc.responder.FireAndForget(receiving) return } -func (p *DuplexConnection) onFrameRequestStream(frame core.Frame) error { - receiving, ok := p.doFragment(frame.(*framing.RequestStreamFrame)) +func (dc *DuplexConnection) onFrameRequestStream(frame core.Frame) error { + receiving, ok := dc.doFragment(frame.(*framing.RequestStreamFrame)) if !ok { return nil } - return p.respondRequestStream(receiving) + return dc.respondRequestStream(receiving) } -func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { +func (dc *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAndPayload) error { sid := receiving.Header().StreamID() // execute request stream handler @@ -537,7 +541,7 @@ func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAn defer func() { err = tryRecover(recover()) }() - resp = p.responder.RequestStream(receiving) + resp = dc.responder.RequestStream(receiving) if resp == nil { err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestStream) } @@ -546,7 +550,7 @@ func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAn // send error with panic if err != nil { - p.writeError(sid, err) + dc.writeError(sid, err) return nil } @@ -563,65 +567,65 @@ func (p *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderAn sub := rx.NewSubscriber( rx.OnNext(func(elem payload.Payload) error { - p.sendPayload(sid, elem, core.FlagNext) + dc.sendPayload(sid, elem, core.FlagNext) return nil }), rx.OnSubscribe(func(s rx.Subscription) { - p.register(sid, requestStreamCallbackReverse{su: s}) + dc.register(sid, requestStreamCallbackReverse{su: s}) s.Request(n32) }), rx.OnError(func(e error) { - p.writeError(sid, e) + dc.writeError(sid, e) }), rx.OnComplete(func() { - p.sendFrame(framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete)) + dc.sendFrame(framing.NewPayloadFrame(sid, nil, nil, core.FlagComplete)) }), ) // async subscribe publisher sending. DoFinally(func(s rx.SignalType) { - p.unregister(sid) + dc.unregister(sid) }). SubscribeOn(scheduler.Parallel()). SubscribeWith(context.Background(), sub) return nil } -func (p *DuplexConnection) writeError(sid uint32, e error) { +func (dc *DuplexConnection) writeError(sid uint32, e error) { // ignore sending error because current socket has been closed. if IsSocketClosedError(e) { return } switch err := e.(type) { case *framing.ErrorFrame: - p.sendFrame(err) + dc.sendFrame(err) case core.CustomError: - p.sendFrame(framing.NewWriteableErrorFrame(sid, err.ErrorCode(), err.ErrorData())) + dc.sendFrame(framing.NewWriteableErrorFrame(sid, err.ErrorCode(), err.ErrorData())) default: - p.sendFrame(framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) + dc.sendFrame(framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, []byte(e.Error()))) } } // SetResponder sets a responder for current socket. -func (p *DuplexConnection) SetResponder(responder Responder) { - p.responder = responder +func (dc *DuplexConnection) SetResponder(responder Responder) { + dc.responder = responder } -func (p *DuplexConnection) onFrameKeepalive(frame core.Frame) (err error) { +func (dc *DuplexConnection) onFrameKeepalive(frame core.Frame) (err error) { f := frame.(*framing.KeepaliveFrame) if f.Header().Flag().Check(core.FlagRespond) { k := framing.NewKeepaliveFrame(f.LastReceivedPosition(), f.Data(), false) //f.SetHeader(framing.NewFrameHeader(0, framing.FrameTypeKeepalive)) - p.sendFrame(k) + dc.sendFrame(k) } return } -func (p *DuplexConnection) onFrameCancel(frame core.Frame) (err error) { +func (dc *DuplexConnection) onFrameCancel(frame core.Frame) (err error) { sid := frame.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { logger.Warnf("nothing cancelled: sid=%d\n", sid) return @@ -636,18 +640,18 @@ func (p *DuplexConnection) onFrameCancel(frame core.Frame) (err error) { panic(fmt.Errorf("illegal cancel target: %v", vv)) } - if _, ok := p.fragments.Load(sid); ok { - p.fragments.Delete(sid) + if _, ok := dc.fragments.Load(sid); ok { + dc.fragments.Delete(sid) } return } -func (p *DuplexConnection) onFrameError(input core.Frame) (err error) { +func (dc *DuplexConnection) onFrameError(input core.Frame) (err error) { f := input.(*framing.ErrorFrame) logger.Errorf("handle error frame: %s\n", f) sid := f.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { err = fmt.Errorf("invalid stream id: %d", sid) return @@ -666,10 +670,10 @@ func (p *DuplexConnection) onFrameError(input core.Frame) (err error) { return } -func (p *DuplexConnection) onFrameRequestN(input core.Frame) (err error) { +func (dc *DuplexConnection) onFrameRequestN(input core.Frame) (err error) { f := input.(*framing.RequestNFrame) sid := f.Header().StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { if logger.IsDebugEnabled() { logger.Debugf("ignore non-exists RequestN: id=%d\n", sid) @@ -690,15 +694,15 @@ func (p *DuplexConnection) onFrameRequestN(input core.Frame) (err error) { return } -func (p *DuplexConnection) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { +func (dc *DuplexConnection) doFragment(input fragmentation.HeaderAndPayload) (out fragmentation.HeaderAndPayload, ok bool) { h := input.Header() sid := h.StreamID() - v, exist := p.fragments.Load(sid) + v, exist := dc.fragments.Load(sid) if exist { joiner := v.(fragmentation.Joiner) ok = joiner.Push(input) if ok { - p.fragments.Delete(sid) + dc.fragments.Delete(sid) out = joiner } return @@ -708,34 +712,34 @@ func (p *DuplexConnection) doFragment(input fragmentation.HeaderAndPayload) (out out = input return } - p.fragments.Store(sid, fragmentation.NewJoiner(input)) + dc.fragments.Store(sid, fragmentation.NewJoiner(input)) return } -func (p *DuplexConnection) onFramePayload(frame core.Frame) error { - pl, ok := p.doFragment(frame.(*framing.PayloadFrame)) +func (dc *DuplexConnection) onFramePayload(frame core.Frame) error { + pl, ok := dc.doFragment(frame.(*framing.PayloadFrame)) if !ok { return nil } h := pl.Header() t := h.Type() if t == core.FrameTypeRequestFNF { - return p.respondFNF(pl) + return dc.respondFNF(pl) } if t == core.FrameTypeRequestResponse { - return p.respondRequestResponse(pl) + return dc.respondRequestResponse(pl) } if t == core.FrameTypeRequestStream { - return p.respondRequestStream(pl) + return dc.respondRequestStream(pl) } if t == core.FrameTypeRequestChannel { - return p.respondRequestChannel(pl) + return dc.respondRequestChannel(pl) } sid := h.StreamID() - v, ok := p.messages.Load(sid) + v, ok := dc.messages.Load(sid) if !ok { - logger.Warnf("unoccupied Payload(id=%d), maybe it has been canceled(server=%T)\n", sid, p.sids) + logger.Warnf("unoccupied Payload(id=%d), maybe it has been canceled(server=%T)\n", sid, dc.sids) return nil } @@ -776,44 +780,44 @@ func (p *DuplexConnection) onFramePayload(frame core.Frame) error { return nil } -func (p *DuplexConnection) clearTransport() { - p.cond.L.Lock() - p.tp = nil - p.cond.L.Unlock() +func (dc *DuplexConnection) clearTransport() { + dc.cond.L.Lock() + dc.tp = nil + dc.cond.L.Unlock() } // SetTransport sets a transport for current socket. -func (p *DuplexConnection) SetTransport(tp *transport.Transport) { - tp.RegisterHandler(transport.OnCancel, p.onFrameCancel) - tp.RegisterHandler(transport.OnError, p.onFrameError) - tp.RegisterHandler(transport.OnRequestN, p.onFrameRequestN) - tp.RegisterHandler(transport.OnPayload, p.onFramePayload) - tp.RegisterHandler(transport.OnKeepalive, p.onFrameKeepalive) - - if p.responder != nil { - tp.RegisterHandler(transport.OnRequestResponse, p.onFrameRequestResponse) - tp.RegisterHandler(transport.OnMetadataPush, p.respondMetadataPush) - tp.RegisterHandler(transport.OnFireAndForget, p.onFrameFNF) - tp.RegisterHandler(transport.OnRequestStream, p.onFrameRequestStream) - tp.RegisterHandler(transport.OnRequestChannel, p.onFrameRequestChannel) +func (dc *DuplexConnection) SetTransport(tp *transport.Transport) { + tp.RegisterHandler(transport.OnCancel, dc.onFrameCancel) + tp.RegisterHandler(transport.OnError, dc.onFrameError) + tp.RegisterHandler(transport.OnRequestN, dc.onFrameRequestN) + tp.RegisterHandler(transport.OnPayload, dc.onFramePayload) + tp.RegisterHandler(transport.OnKeepalive, dc.onFrameKeepalive) + + if dc.responder != nil { + tp.RegisterHandler(transport.OnRequestResponse, dc.onFrameRequestResponse) + tp.RegisterHandler(transport.OnMetadataPush, dc.respondMetadataPush) + tp.RegisterHandler(transport.OnFireAndForget, dc.onFrameFNF) + tp.RegisterHandler(transport.OnRequestStream, dc.onFrameRequestStream) + tp.RegisterHandler(transport.OnRequestChannel, dc.onFrameRequestChannel) } - p.cond.L.Lock() - p.tp = tp - p.cond.Signal() - p.cond.L.Unlock() + dc.cond.L.Lock() + dc.tp = tp + dc.cond.Signal() + dc.cond.L.Unlock() } -func (p *DuplexConnection) sendFrame(f core.WriteableFrame) { +func (dc *DuplexConnection) sendFrame(f core.WriteableFrame) { defer func() { if e := recover(); e != nil { logger.Warnf("send frame failed: %s\n", e) } }() - p.outs <- f + dc.outs <- f } -func (p *DuplexConnection) sendPayload( +func (dc *DuplexConnection) sendPayload( sid uint32, sending payload.Payload, frameFlag core.FrameFlag, @@ -822,32 +826,33 @@ func (p *DuplexConnection) sendPayload( m, _ := sending.Metadata() size := framing.CalcPayloadFrameSize(d, m) - if !p.shouldSplit(size) { - p.sendFrame(framing.NewWriteablePayloadFrame(sid, d, m, frameFlag)) + if !dc.shouldSplit(size) { + dc.sendFrame(framing.NewWriteablePayloadFrame(sid, d, m, frameFlag)) return } - p.doSplit(d, m, func(index int, result fragmentation.SplitResult) { + dc.doSplit(d, m, func(index int, result fragmentation.SplitResult) { flag := result.Flag if index == 0 { flag |= frameFlag } else { flag |= core.FlagNext } - p.sendFrame(framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, flag)) + dc.sendFrame(framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, flag)) }) + return } -func (p *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { - if len(p.outs) > 0 { - p.drain(nil) +func (dc *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { + if len(dc.outs) > 0 { + dc.drain(nil) } var out core.WriteableFrame select { - case <-p.keepaliver.C(): + case <-dc.keepaliver.C(): ok = true - out = framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(out, true) + out = framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(out, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } @@ -858,59 +863,59 @@ func (p *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lea return } out = framing.NewWriteableLeaseFrame(ls.TimeToLive, ls.NumberOfRequests, ls.Metadata) - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } - case out, ok = <-p.outs: + case out, ok = <-dc.outs: if !ok { return } - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } } return } -func (p *DuplexConnection) drainWithKeepalive() (ok bool) { - if len(p.outs) > 0 { - p.drain(nil) +func (dc *DuplexConnection) drainWithKeepalive() (ok bool) { + if len(dc.outs) > 0 { + dc.drain(nil) } var out core.WriteableFrame select { - case <-p.keepaliver.C(): + case <-dc.keepaliver.C(): ok = true - out = framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(out, true) + out = framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(out, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } } - case out, ok = <-p.outs: + case out, ok = <-dc.outs: if !ok { return } - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) - } else if err := p.tp.Send(out, true); err != nil { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) + } else if err := dc.tp.Send(out, true); err != nil { logger.Errorf("send frame failed: %s\n", err.Error()) - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) } } return } -func (p *DuplexConnection) drain(leaseChan <-chan lease.Lease) bool { +func (dc *DuplexConnection) drain(leaseChan <-chan lease.Lease) bool { var flush bool - cycle := len(p.outs) + cycle := len(dc.outs) if cycle < 1 { cycle = 1 } @@ -920,34 +925,34 @@ func (p *DuplexConnection) drain(leaseChan <-chan lease.Lease) bool { if !ok { return false } - if p.drainOne(framing.NewWriteableLeaseFrame(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { + if dc.drainOne(framing.NewWriteableLeaseFrame(next.TimeToLive, next.NumberOfRequests, next.Metadata)) { flush = true } - case out, ok := <-p.outs: + case out, ok := <-dc.outs: if !ok { return false } - if p.drainOne(out) { + if dc.drainOne(out) { flush = true } } } if flush { - if err := p.tp.Flush(); err != nil { + if err := dc.tp.Flush(); err != nil { logger.Errorf("flush failed: %v\n", err) } } return true } -func (p *DuplexConnection) drainOne(out core.WriteableFrame) (wrote bool) { - if p.tp == nil { - p.outsPriority = append(p.outsPriority, out) +func (dc *DuplexConnection) drainOne(out core.WriteableFrame) (wrote bool) { + if dc.tp == nil { + dc.outsPriority = append(dc.outsPriority, out) return } - err := p.tp.Send(out, false) + err := dc.tp.Send(out, false) if err != nil { - p.outsPriority = append(p.outsPriority, out) + dc.outsPriority = append(dc.outsPriority, out) logger.Errorf("send frame failed: %s\n", err.Error()) return } @@ -955,50 +960,50 @@ func (p *DuplexConnection) drainOne(out core.WriteableFrame) (wrote bool) { return } -func (p *DuplexConnection) drainOutBack() { - if len(p.outsPriority) < 1 { +func (dc *DuplexConnection) drainOutBack() { + if len(dc.outsPriority) < 1 { return } defer func() { - p.outsPriority = p.outsPriority[:0] + dc.outsPriority = dc.outsPriority[:0] }() - if p.tp == nil { + if dc.tp == nil { return } var out core.WriteableFrame - for i := range p.outsPriority { - out = p.outsPriority[i] - if err := p.tp.Send(out, false); err != nil { + for i := range dc.outsPriority { + out = dc.outsPriority[i] + if err := dc.tp.Send(out, false); err != nil { out.Done() logger.Errorf("send frame failed: %v\n", err) } } - if err := p.tp.Flush(); err != nil { + if err := dc.tp.Flush(); err != nil { logger.Errorf("flush failed: %v\n", err) } } -func (p *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { +func (dc *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { for { - if p.tp == nil { - p.cond.L.Lock() - p.cond.Wait() - p.cond.L.Unlock() + if dc.tp == nil { + dc.cond.L.Lock() + dc.cond.Wait() + dc.cond.L.Unlock() } select { case <-ctx.Done(): - p.cleanOuts() + dc.cleanOuts() return ctx.Err() default: // ignore } select { - case <-p.keepaliver.C(): - kf := framing.NewKeepaliveFrame(p.counter.ReadBytes(), nil, true) - if p.tp != nil { - err := p.tp.Send(kf, true) + case <-dc.keepaliver.C(): + kf := framing.NewKeepaliveFrame(dc.counter.ReadBytes(), nil, true) + if dc.tp != nil { + err := dc.tp.Send(kf, true) if err != nil { logger.Errorf("send keepalive frame failed: %s\n", err.Error()) } @@ -1006,117 +1011,106 @@ func (p *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseCha default: } - p.drainOutBack() - if leaseChan == nil && !p.drainWithKeepalive() { + dc.drainOutBack() + if leaseChan == nil && !dc.drainWithKeepalive() { break } - if leaseChan != nil && !p.drainWithKeepaliveAndLease(leaseChan) { + if leaseChan != nil && !dc.drainWithKeepaliveAndLease(leaseChan) { break } } return nil } -func (p *DuplexConnection) cleanOuts() { - p.outsPriority = nil +func (dc *DuplexConnection) cleanOuts() { + dc.outsPriority = nil } -func (p *DuplexConnection) LoopWrite(ctx context.Context) error { - defer close(p.done) +func (dc *DuplexConnection) LoopWrite(ctx context.Context) error { + defer close(dc.writeDone) var leaseChan chan lease.Lease - if p.leases != nil { + if dc.leases != nil { leaseCtx, cancel := context.WithCancel(ctx) defer func() { cancel() }() - if c, ok := p.leases.Next(leaseCtx); ok { + if c, ok := dc.leases.Next(leaseCtx); ok { leaseChan = c } } - if p.keepaliver != nil { - defer p.keepaliver.Stop() - return p.loopWriteWithKeepaliver(ctx, leaseChan) + if dc.keepaliver != nil { + defer dc.keepaliver.Stop() + return dc.loopWriteWithKeepaliver(ctx, leaseChan) } for { - if p.tp == nil { - p.cond.L.Lock() - p.cond.Wait() - p.cond.L.Unlock() + if dc.tp == nil { + dc.cond.L.Lock() + dc.cond.Wait() + dc.cond.L.Unlock() } select { case <-ctx.Done(): - p.cleanOuts() + dc.cleanOuts() return ctx.Err() default: } - p.drainOutBack() - if !p.drain(leaseChan) { + dc.drainOutBack() + if !dc.drain(leaseChan) { break } } return nil } -func (p *DuplexConnection) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { - fragmentation.Split(p.mtu, data, metadata, handler) +func (dc *DuplexConnection) doSplit(data, metadata []byte, handler fragmentation.HandleSplitResult) { + fragmentation.Split(dc.mtu, data, metadata, handler) } -func (p *DuplexConnection) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { - fragmentation.SplitSkip(p.mtu, skip, data, metadata, handler) +func (dc *DuplexConnection) doSplitSkip(skip int, data, metadata []byte, handler fragmentation.HandleSplitResult) { + fragmentation.SplitSkip(dc.mtu, skip, data, metadata, handler) } -func (p *DuplexConnection) shouldSplit(size int) bool { - return size > p.mtu +func (dc *DuplexConnection) shouldSplit(size int) bool { + return size > dc.mtu } -func (p *DuplexConnection) register(sid uint32, msg interface{}) { - p.messages.Store(sid, msg) +func (dc *DuplexConnection) register(sid uint32, msg interface{}) { + dc.messages.Store(sid, msg) } -func (p *DuplexConnection) unregister(sid uint32) { - p.messages.Delete(sid) - p.fragments.Delete(sid) +func (dc *DuplexConnection) unregister(sid uint32) { + dc.messages.Delete(sid) + dc.fragments.Delete(sid) } // NewServerDuplexConnection creates a new server-side DuplexConnection. func NewServerDuplexConnection(mtu int, leases lease.Leases) *DuplexConnection { - return &DuplexConnection{ - closed: atomic.NewBool(false), - leases: leases, - outs: make(chan core.WriteableFrame, _outChanSize), - mtu: mtu, - messages: &sync.Map{}, - sids: &serverStreamIDs{}, - fragments: &sync.Map{}, - done: make(chan struct{}), - cond: sync.NewCond(&sync.Mutex{}), - counter: core.NewCounter(), - singleScheduler: scheduler.NewSingle(64), - } + return newDuplexConnection(mtu, nil, &serverStreamIDs{}, leases) } // NewClientDuplexConnection creates a new client-side DuplexConnection. -func NewClientDuplexConnection( - mtu int, - keepaliveInterval time.Duration, -) (s *DuplexConnection) { - ka := NewKeepaliver(keepaliveInterval) - s = &DuplexConnection{ - closed: atomic.NewBool(false), +func NewClientDuplexConnection(mtu int, keepaliveInterval time.Duration) *DuplexConnection { + return newDuplexConnection(mtu, NewKeepaliver(keepaliveInterval), &clientStreamIDs{}, nil) +} + +func newDuplexConnection(mtu int, ka *Keepaliver, sids StreamID, leases lease.Leases) *DuplexConnection { + l := &sync.RWMutex{} + return &DuplexConnection{ + l: l, + leases: leases, outs: make(chan core.WriteableFrame, _outChanSize), mtu: mtu, messages: &sync.Map{}, - sids: &clientStreamIDs{}, + sids: sids, fragments: &sync.Map{}, - done: make(chan struct{}), - cond: sync.NewCond(&sync.Mutex{}), - counter: core.NewCounter(), + writeDone: make(chan struct{}), + cond: sync.NewCond(l), + counter: core.NewTrafficCounter(), keepaliver: ka, singleScheduler: scheduler.NewSingle(64), } - return } diff --git a/internal/socket/mock_conn_test.go b/internal/socket/mock_conn_test.go index 204109d..8442e68 100644 --- a/internal/socket/mock_conn_test.go +++ b/internal/socket/mock_conn_test.go @@ -64,7 +64,7 @@ func (mr *MockConnMockRecorder) SetDeadline(deadline interface{}) *gomock.Call { } // SetCounter mocks base method -func (m *MockConn) SetCounter(c *core.Counter) { +func (m *MockConn) SetCounter(c *core.TrafficCounter) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetCounter", c) } diff --git a/justfile b/justfile index e742371..52d01f0 100644 --- a/justfile +++ b/justfile @@ -3,19 +3,9 @@ default: lint: golangci-lint run ./... test: - go test -count=1 -coverprofile=coverage.out \ - ./core/... \ - ./extension/... \ - ./internal/... \ - ./lease/... \ - ./logger/... \ - ./payload/... \ - ./rx/... \ - . + go test -count=1 -race -coverprofile=coverage.out ./... test-no-cover: - go test -count=1 ./... -v -test-race: - go test -race -count=1 ./... -v + go test -count=1 -race ./... fmt: @go fmt ./... cover: From 40f48690b5f080fab361b447e09512766dc06b44 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 11 Aug 2020 11:25:52 +0800 Subject: [PATCH 23/26] Fix lint. --- balancer/group_example_test.go | 2 +- balancer/round_robin_test.go | 2 +- internal/socket/duplex.go | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/balancer/group_example_test.go b/balancer/group_example_test.go index 4c2dab2..01b4950 100644 --- a/balancer/group_example_test.go +++ b/balancer/group_example_test.go @@ -24,7 +24,7 @@ func ExampleNewGroup() { Acceptor(func(setup payload.SetupPayload, sendingSocket rsocket.CloseableRSocket) (rsocket.RSocket, error) { // Register service using Setup Metadata as service ID. if serviceID, ok := setup.MetadataUTF8(); ok { - group.Get(serviceID).Put(sendingSocket) + _ = group.Get(serviceID).Put(sendingSocket) } // Proxy requests by group. return rsocket.NewAbstractSocket(rsocket.RequestResponse(func(msg payload.Payload) mono.Mono { diff --git a/balancer/round_robin_test.go b/balancer/round_robin_test.go index fcccc15..f4a2d03 100644 --- a/balancer/round_robin_test.go +++ b/balancer/round_robin_test.go @@ -63,7 +63,7 @@ func TestRoundRobin(t *testing.T) { Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", ports[i]).Build()). Start(context.Background()) assert.NoError(t, err) - b.PutLabel(fmt.Sprintf("test-client-%d", ports[i]), client) + _ = b.PutLabel(fmt.Sprintf("test-client-%d", ports[i]), client) } req := payload.NewString("foo", "bar") diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index a704ca2..46a4ebf 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -839,7 +839,6 @@ func (dc *DuplexConnection) sendPayload( } dc.sendFrame(framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, flag)) }) - return } func (dc *DuplexConnection) drainWithKeepaliveAndLease(leaseChan <-chan lease.Lease) (ok bool) { From 73788e596f34f75c83f222fd76b871a0376b26d5 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 11 Aug 2020 22:30:14 +0800 Subject: [PATCH 24/26] Fix --- .travis.yml | 2 +- README.md | 85 ++++++++++++------------------------------- examples/main.go | 42 +++++++++++++++++++++ logger/logger_test.go | 2 +- 4 files changed, 68 insertions(+), 63 deletions(-) create mode 100644 examples/main.go diff --git a/.travis.yml b/.travis.yml index 0d5592b..87e0d0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,5 +11,5 @@ install: script: - golangci-lint run ./... - - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 -race ./... + - go test -v -covermode=atomic -coverprofile=coverage.out -count=1 ./... - goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff --git a/README.md b/README.md index b8247b5..16db495 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ package main import ( "context" + "log" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/payload" @@ -47,10 +48,11 @@ func main() { }), ), nil }). - Transport(rsocket.Tcp().Addr(":7878").Build()). + Transport(rsocket.TcpServer().SetAddr(":7878").Build()). Serve(context.Background()) - panic(err) + log.Fatalln(err) } + ``` > Connect to echo server @@ -72,7 +74,7 @@ func main() { Resume(). Fragment(1024). SetupPayload(payload.NewString("Hello", "World")). - Transport(rsocket.Tcp().HostAndPort("127.0.0.1", 7878).Build()). + Transport(rsocket.TcpClient().SetHostAndPort("127.0.0.1", 7878).Build()). Start(context.Background()) if err != nil { panic(err) @@ -132,10 +134,11 @@ func main() { DoFinally(func(s rx.SignalType) { close(done) }). - DoOnSuccess(func(input payload.Payload) { + DoOnSuccess(func(input payload.Payload) error { // Handle and consume payload. // Do something here... fmt.Println("bingo:", input) + return nil }). SubscribeOn(scheduler.Parallel()). Subscribe(context.Background()) @@ -156,9 +159,9 @@ import ( "context" "fmt" - flxx "github.com/jjeffcaii/reactor-go/flux" "github.com/rsocket/rsocket-go/extension" "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" ) @@ -178,24 +181,27 @@ func main() { // payload.NewString("qux", extension.TextPlain.String()), //) - f. - DoOnNext(func(elem payload.Payload) { + // Block + _, _ = f. + DoOnNext(func(elem payload.Payload) error { // Handle and consume elements // Do something here... fmt.Println("bingo:", elem) + return nil }). - Subscribe(context.Background()) + BlockLast(context.Background()) - // Or you can use Raw reactor-go API. :-D - f2 := flux.Raw(flxx.Range(0, 10).Map(func(i interface{}) interface{} { - return payload.NewString(fmt.Sprintf("Hello@%d", i.(int)), extension.TextPlain.String()) + // Subscribe + f.Subscribe(context.Background(), rx.OnNext(func(input payload.Payload) error { + fmt.Println("bingo:", input) + return nil })) - f2. - DoOnNext(func(input payload.Payload) { - fmt.Println("bingo:", input) - }). - BlockLast(context.Background()) + + // Or implement your own subscriber + var s rx.Subscriber + f.SubscribeWith(context.Background(), s) } + ``` #### Backpressure & RequestN @@ -238,61 +244,18 @@ func main() { su = s su.Request(1) }), - rx.OnNext(func(elem payload.Payload) { + rx.OnNext(func(elem payload.Payload) error { // Consume element, do something... fmt.Println("bingo:", elem) // Request for next one manually. su.Request(1) + return nil }), ) } - -``` - -#### Logging - -We do not use a specific log implementation. You can register your own log implementation. For example: - -```go -package main - -import ( - "log" - - "github.com/rsocket/rsocket-go/logger" -) - -func init() { - logger.SetLevel(logger.LevelDebug) -} - ``` #### Dependencies - [reactor-go](https://github.com/jjeffcaii/reactor-go) - [testify](https://github.com/stretchr/testify) - [websocket](https://github.com/gorilla/websocket) - -### TODO - -#### Transport - - [x] TCP - - [x] Websocket - -#### Duplex Socket - - [x] MetadataPush - - [x] RequestFNF - - [x] RequestResponse - - [x] RequestStream - - [x] RequestChannel - -##### Others - - [x] Resume - - [x] Keepalive - - [x] Fragmentation - - [x] Thin Reactor - - [x] Cancel - - [x] Error - - [x] Flow Control: RequestN - - [x] Flow Control: Lease - - [x] Load Balance diff --git a/examples/main.go b/examples/main.go new file mode 100644 index 0000000..df8da1c --- /dev/null +++ b/examples/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "context" + "fmt" + + "github.com/rsocket/rsocket-go/extension" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" + "github.com/rsocket/rsocket-go/rx/flux" +) + +func main() { + // Here is an example which consume Payload one by one. + f := flux.Create(func(ctx context.Context, s flux.Sink) { + for i := 0; i < 5; i++ { + s.Next(payload.NewString(fmt.Sprintf("Hello@%d", i), extension.TextPlain.String())) + } + s.Complete() + }) + + var su rx.Subscription + f. + DoOnRequest(func(n int) { + fmt.Printf("requesting next %d element......\n", n) + }). + Subscribe( + context.Background(), + rx.OnSubscribe(func(s rx.Subscription) { + // Init Request 1 element. + su = s + su.Request(1) + }), + rx.OnNext(func(elem payload.Payload) error { + // Consume element, do something... + fmt.Println("bingo:", elem) + // Request for next one manually. + su.Request(1) + return nil + }), + ) +} diff --git a/logger/logger_test.go b/logger/logger_test.go index ac26835..154ee96 100644 --- a/logger/logger_test.go +++ b/logger/logger_test.go @@ -9,7 +9,7 @@ import ( ) var ( - fakeFormat = "fake format" + fakeFormat = "fake format: %v" fakeArgs = []interface{}{"fake args"} ) From d6dce50c840569205ba73c135556f678a80b8b6b Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 11 Aug 2020 23:04:38 +0800 Subject: [PATCH 25/26] Fix --- rx/flux/flux_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index f9f8b55..c3eb818 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -379,14 +379,10 @@ loop: break loop } count++ - case err, ok := <-errChan: - if !ok { - break loop - } + case err := <-errChan: assert.NoError(t, err) } } - assert.Equal(t, 10, count) } From 039522371709bdb5ef6f0e4df13b0273c850410d Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Tue, 11 Aug 2020 23:06:25 +0800 Subject: [PATCH 26/26] remove unused files. --- examples/main.go | 42 ------------------------------------------ 1 file changed, 42 deletions(-) delete mode 100644 examples/main.go diff --git a/examples/main.go b/examples/main.go deleted file mode 100644 index df8da1c..0000000 --- a/examples/main.go +++ /dev/null @@ -1,42 +0,0 @@ -package main - -import ( - "context" - "fmt" - - "github.com/rsocket/rsocket-go/extension" - "github.com/rsocket/rsocket-go/payload" - "github.com/rsocket/rsocket-go/rx" - "github.com/rsocket/rsocket-go/rx/flux" -) - -func main() { - // Here is an example which consume Payload one by one. - f := flux.Create(func(ctx context.Context, s flux.Sink) { - for i := 0; i < 5; i++ { - s.Next(payload.NewString(fmt.Sprintf("Hello@%d", i), extension.TextPlain.String())) - } - s.Complete() - }) - - var su rx.Subscription - f. - DoOnRequest(func(n int) { - fmt.Printf("requesting next %d element......\n", n) - }). - Subscribe( - context.Background(), - rx.OnSubscribe(func(s rx.Subscription) { - // Init Request 1 element. - su = s - su.Request(1) - }), - rx.OnNext(func(elem payload.Payload) error { - // Consume element, do something... - fmt.Println("bingo:", elem) - // Request for next one manually. - su.Request(1) - return nil - }), - ) -}