diff --git a/.travis.yml b/.travis.yml index 87e0d0e..bc2dca0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,7 @@ go: install: - go get golang.org/x/lint/golint - - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.27.0 + - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.30.0 - go get golang.org/x/tools/cmd/cover - go get github.com/mattn/goveralls diff --git a/client.go b/client.go index d08866d..c6ffe3c 100644 --- a/client.go +++ b/client.go @@ -87,110 +87,110 @@ type clientBuilder struct { onCloses []func(error) } -func (p *clientBuilder) Lease() ClientBuilder { - p.setup.Lease = true - return p +func (cb *clientBuilder) Lease() ClientBuilder { + cb.setup.Lease = true + return cb } -func (p *clientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { - if p.resume == nil { - p.resume = newResumeOpts() +func (cb *clientBuilder) Resume(opts ...ClientResumeOptions) ClientBuilder { + if cb.resume == nil { + cb.resume = newResumeOpts() } for _, it := range opts { - it(p.resume) + it(cb.resume) } - return p + return cb } -func (p *clientBuilder) Fragment(mtu int) ClientBuilder { +func (cb *clientBuilder) Fragment(mtu int) ClientBuilder { if mtu == 0 { - p.fragment = fragmentation.MaxFragment + cb.fragment = fragmentation.MaxFragment } else { - p.fragment = mtu + cb.fragment = mtu } - return p + return cb } -func (p *clientBuilder) OnClose(fn func(error)) ClientBuilder { - p.onCloses = append(p.onCloses, fn) - return p +func (cb *clientBuilder) OnClose(fn func(error)) ClientBuilder { + cb.onCloses = append(cb.onCloses, fn) + return cb } -func (p *clientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { - p.setup.KeepaliveInterval = tickPeriod - p.setup.KeepaliveLifetime = time.Duration(missedAcks) * ackTimeout - return p +func (cb *clientBuilder) KeepAlive(tickPeriod, ackTimeout time.Duration, missedAcks int) ClientBuilder { + cb.setup.KeepaliveInterval = tickPeriod + cb.setup.KeepaliveLifetime = time.Duration(missedAcks) * ackTimeout + return cb } -func (p *clientBuilder) DataMimeType(mime string) ClientBuilder { - p.setup.DataMimeType = []byte(mime) - return p +func (cb *clientBuilder) DataMimeType(mime string) ClientBuilder { + cb.setup.DataMimeType = []byte(mime) + return cb } -func (p *clientBuilder) MetadataMimeType(mime string) ClientBuilder { - p.setup.MetadataMimeType = []byte(mime) - return p +func (cb *clientBuilder) MetadataMimeType(mime string) ClientBuilder { + cb.setup.MetadataMimeType = []byte(mime) + return cb } -func (p *clientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { - p.setup.Data = nil - p.setup.Metadata = nil +func (cb *clientBuilder) SetupPayload(setup payload.Payload) ClientBuilder { + cb.setup.Data = nil + cb.setup.Metadata = nil if data := setup.Data(); len(data) > 0 { - p.setup.Data = make([]byte, len(data)) - copy(p.setup.Data, data) + cb.setup.Data = make([]byte, len(data)) + copy(cb.setup.Data, data) } if metadata, ok := setup.Metadata(); ok { - p.setup.Metadata = make([]byte, len(metadata)) - copy(p.setup.Metadata, metadata) + cb.setup.Metadata = make([]byte, len(metadata)) + copy(cb.setup.Metadata, metadata) } - return p + return cb } -func (p *clientBuilder) Acceptor(acceptor ClientSocketAcceptor) ToClientStarter { - p.acceptor = acceptor - return p +func (cb *clientBuilder) Acceptor(acceptor ClientSocketAcceptor) ToClientStarter { + cb.acceptor = acceptor + return cb } -func (p *clientBuilder) Transport(t transport.ClientTransportFunc) ClientStarter { - p.tpGen = t - return p +func (cb *clientBuilder) Transport(t transport.ClientTransportFunc) ClientStarter { + cb.tpGen = t + return cb } -func (p *clientBuilder) Start(ctx context.Context) (client Client, err error) { +func (cb *clientBuilder) Start(ctx context.Context) (client Client, err error) { // create a blank socket. - err = fragmentation.IsValidFragment(p.fragment) + err = fragmentation.IsValidFragment(cb.fragment) if err != nil { return nil, err } - sk := socket.NewClientDuplexConnection( - p.fragment, - p.setup.KeepaliveInterval, + conn := socket.NewClientDuplexConnection( + cb.fragment, + cb.setup.KeepaliveInterval, ) // create a client. var cs setupClientSocket - if p.resume != nil { - p.setup.Token = p.resume.tokenGen() - cs = socket.NewResumableClientSocket(p.tpGen, sk) + if cb.resume != nil { + cb.setup.Token = cb.resume.tokenGen() + cs = socket.NewResumableClientSocket(cb.tpGen, conn) } else { - cs = socket.NewClient(p.tpGen, sk) + cs = socket.NewClient(cb.tpGen, conn) } - if p.acceptor != nil { - sk.SetResponder(p.acceptor(cs)) + if cb.acceptor != nil { + conn.SetResponder(cb.acceptor(cs)) } else { - sk.SetResponder(_noopSocket) + conn.SetResponder(_noopSocket) } // bind closers. - if len(p.onCloses) > 0 { - for _, closer := range p.onCloses { + if len(cb.onCloses) > 0 { + for _, closer := range cb.onCloses { cs.OnClose(closer) } } // setup client. - err = cs.Setup(ctx, p.setup) + err = cs.Setup(ctx, cb.setup) if err == nil { client = cs } diff --git a/core/framing/frame.go b/core/framing/frame.go index 32b7214..6b66f18 100644 --- a/core/framing/frame.go +++ b/core/framing/frame.go @@ -2,7 +2,6 @@ package framing import ( "errors" - "fmt" "io" "github.com/rsocket/rsocket-go/core" @@ -139,7 +138,3 @@ func FromBytes(b []byte) (core.Frame, error) { return FromRawFrame(raw) } -func PrintFrame(f core.WriteableFrame) string { - // TODO: print frame - return fmt.Sprintf("%+v", f) -} diff --git a/core/framing/frame_test.go b/core/framing/frame_test.go index a9320d2..7cc2597 100644 --- a/core/framing/frame_test.go +++ b/core/framing/frame_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" ) -const _sid uint32 = 1 +const _sid uint32 = 1234 func TestFromBytes(t *testing.T) { // empty diff --git a/core/framing/misc.go b/core/framing/misc.go index 81b930f..dcdd3f7 100644 --- a/core/framing/misc.go +++ b/core/framing/misc.go @@ -1,7 +1,11 @@ package framing import ( + "encoding/binary" + "fmt" "io" + "strconv" + "strings" "github.com/rsocket/rsocket-go/core" "github.com/rsocket/rsocket-go/internal/common" @@ -53,6 +57,88 @@ func FromRawFrame(f *RawFrame) (frame core.Frame, err error) { return } +// PrintFrame prints frame in bytes dump. +func PrintFrame(f core.WriteableFrame) string { + var initN, reqN uint32 + var metadata, data []byte + + switch it := f.(type) { + case *PayloadFrame: + metadata, _ = it.Metadata() + data = it.Data() + case *WriteablePayloadFrame: + metadata, data = it.metadata, it.data + case *MetadataPushFrame: + metadata, _ = it.Metadata() + case *FireAndForgetFrame: + metadata, _ = it.Metadata() + data = it.Data() + case *RequestResponseFrame: + metadata, _ = it.Metadata() + data = it.Data() + case *RequestStreamFrame: + metadata, _ = it.Metadata() + data = it.Data() + initN = it.InitialRequestN() + case *RequestChannelFrame: + metadata, _ = it.Metadata() + data = it.Data() + initN = it.InitialRequestN() + case *SetupFrame: + metadata, _ = it.Metadata() + data = it.Data() + case *RequestNFrame: + reqN = it.N() + case *WriteableMetadataPushFrame: + metadata = it.metadata + case *WriteableFireAndForgetFrame: + metadata, data = it.metadata, it.data + case *WriteableRequestResponseFrame: + metadata, data = it.metadata, it.data + case *WriteableRequestStreamFrame: + metadata, data = it.metadata, it.data + reqN = binary.BigEndian.Uint32(it.n[:]) + case *WriteableRequestChannelFrame: + metadata, data = it.metadata, it.data + reqN = binary.BigEndian.Uint32(it.n[:]) + case *WriteableSetupFrame: + metadata, data = it.metadata, it.data + case *WriteableRequestNFrame: + reqN = binary.BigEndian.Uint32(it.n[:]) + } + + b := &strings.Builder{} + b.WriteString("\nFrame => Stream ID: ") + h := f.Header() + b.WriteString(strconv.Itoa(int(h.StreamID()))) + b.WriteString(" Type: ") + b.WriteString(h.Type().String()) + b.WriteString(" Flags: 0b") + _, _ = fmt.Fprintf(b, "%010b", h.Flag()) + b.WriteString(" Length: ") + b.WriteString(strconv.Itoa(f.Len())) + if initN > 0 { + b.WriteString(" InitialRequestN: ") + _, _ = fmt.Fprintf(b, "%d", initN) + } + + if reqN > 0 { + b.WriteString(" RequestN: ") + _, _ = fmt.Fprintf(b, "%d", reqN) + } + + if metadata != nil { + b.WriteString("\nMetadata:\n") + common.AppendPrettyHexDump(b, metadata) + } + + if data != nil { + b.WriteString("\nData:\n") + common.AppendPrettyHexDump(b, data) + } + return b.String() +} + func writePayload(w io.Writer, data []byte, metadata []byte) (n int64, err error) { if l := len(metadata); l > 0 { var wrote int64 diff --git a/core/framing/misc_test.go b/core/framing/misc_test.go new file mode 100644 index 0000000..03adadc --- /dev/null +++ b/core/framing/misc_test.go @@ -0,0 +1,44 @@ +package framing_test + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/stretchr/testify/assert" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func TestPrintFrame(t *testing.T) { + mime := []byte("fake mime") + data := make([]byte, 100) + metadata := make([]byte, 50) + rand.Read(data) + rand.Read(metadata) + data = append([]byte("fake data"), data...) + metadata = append([]byte("fake metadata"), metadata...) + for _, f := range []core.WriteableFrame{ + framing.NewCancelFrame(_sid), + framing.NewPayloadFrame(_sid, data, metadata, core.FlagComplete|core.FlagNext|core.FlagFollow), + framing.NewRequestResponseFrame(_sid, data, metadata, core.FlagComplete|core.FlagNext|core.FlagFollow), + framing.NewMetadataPushFrame(metadata), + framing.NewFireAndForgetFrame(_sid, data, metadata, core.FlagComplete|core.FlagNext|core.FlagFollow), + framing.NewRequestStreamFrame(_sid, 1, data, metadata, core.FlagComplete|core.FlagNext|core.FlagFollow), + framing.NewRequestChannelFrame(_sid, 1, data, metadata, core.FlagComplete|core.FlagNext|core.FlagFollow), + framing.NewSetupFrame(core.DefaultVersion, 30*time.Second, 90*time.Second, nil, mime, mime, data, metadata, false), + } { + tryPrintFrame(t, f) + } +} + +func tryPrintFrame(t *testing.T, f core.WriteableFrame) { + s := framing.PrintFrame(f) + assert.True(t, len(s) > 0) + fmt.Println(s) +} diff --git a/core/transport/tcp_conn.go b/core/transport/tcp_conn.go index ac4cfda..97e05bb 100644 --- a/core/transport/tcp_conn.go +++ b/core/transport/tcp_conn.go @@ -51,7 +51,7 @@ func (p *TcpConn) Read() (f core.Frame, err error) { return } if logger.IsDebugEnabled() { - logger.Debugf("<--- rcv: %s\n", f) + logger.Debugf("%s\n", framing.PrintFrame(f)) } return } @@ -84,7 +84,7 @@ func (p *TcpConn) Write(frame core.WriteableFrame) (err error) { return } if logger.IsDebugEnabled() { - logger.Debugf("---> snd: %s\n", debugStr) + logger.Debugf("%s\n", debugStr) } return } diff --git a/core/transport/websocket_conn.go b/core/transport/websocket_conn.go index e0d28db..aa3fc7a 100644 --- a/core/transport/websocket_conn.go +++ b/core/transport/websocket_conn.go @@ -75,7 +75,7 @@ func (p *WsConn) Read() (f core.Frame, err error) { return } if logger.IsDebugEnabled() { - logger.Debugf("<--- rcv: %s\n", f) + logger.Debugf("%s\n", framing.PrintFrame(f)) } return } @@ -107,7 +107,7 @@ func (p *WsConn) Write(frame core.WriteableFrame) (err error) { p.counter.IncWriteBytes(size) } if logger.IsDebugEnabled() { - logger.Debugf("---> snd: %s\n", frame) + logger.Debugf("%s\n", framing.PrintFrame(frame)) } return } diff --git a/examples/echo/echo.go b/examples/echo/echo.go index 823af98..878ec80 100644 --- a/examples/echo/echo.go +++ b/examples/echo/echo.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/jjeffcaii/reactor-go/scheduler" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" @@ -24,13 +23,12 @@ var tp transport.ServerTransportFunc func init() { tp = rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).Build() -} - -func main() { go func() { - http.Handle("/metrics", promhttp.Handler()) log.Println(http.ListenAndServe(":4444", nil)) }() +} + +func main() { //logger.SetLevel(logger.LevelDebug) err := rsocket.Receive(). //Fragment(65535). diff --git a/extension/routing_test.go b/extension/routing_test.go index 58ba1fd..e162751 100644 --- a/extension/routing_test.go +++ b/extension/routing_test.go @@ -1,8 +1,10 @@ package extension import ( + "bytes" "testing" + "github.com/rsocket/rsocket-go/internal/common" "github.com/stretchr/testify/assert" ) @@ -18,3 +20,20 @@ func TestParseRoutingTags(t *testing.T) { assert.Equal(t, "/bar", tags[1]) assert.Equal(t, "/foo/bar", tags[2]) } + +func TestEncodeRouting_TooLarge(t *testing.T) { + _, err := EncodeRouting(common.RandAlphanumeric(256)) + assert.Error(t, err, "should return error") + _, err = EncodeRouting("foobar", common.RandAlphanumeric(256)) + assert.Error(t, err, "should return error") +} + +func TestParseRoutingTags_Broken(t *testing.T) { + bf := &bytes.Buffer{} + bf.WriteByte(0xFF) + bf.WriteString("brokenTag") + + b := bf.Bytes() + _, err := ParseRoutingTags(b) + assert.Error(t, err, "should return error") +} diff --git a/go.mod b/go.mod index 6446eae..5e27d5b 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/gorilla/websocket v1.4.1 github.com/jjeffcaii/reactor-go v0.2.1 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.7.1 github.com/stretchr/testify v1.4.0 github.com/urfave/cli/v2 v2.1.1 go.uber.org/atomic v1.5.1 diff --git a/go.sum b/go.sum index 8538a77..42b322e 100644 --- a/go.sum +++ b/go.sum @@ -1,97 +1,28 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jjeffcaii/reactor-go v0.2.1 h1:Wb33QstbwZdgFFS3BqWeBweGp2KQJigfts7NQLHTxqE= github.com/jjeffcaii/reactor-go v0.2.1/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= -github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= -github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= -github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/panjf2000/ants/v2 v2.4.1 h1:7RtUqj5lGOw0WnZhSKDZ2zzJhaX5490ZW1sUolRXCxY= github.com/panjf2000/ants/v2 v2.4.1/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= -github.com/prometheus/client_golang v1.7.1 h1:NTGy1Ja9pByO+xAeH/qiWnLrKtr3hJPNjaVUwnjpdpA= -github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= -github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lNawc= -github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= -github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -99,25 +30,13 @@ github.com/urfave/cli/v2 v2.1.1 h1:Qt8FeAtxE/vfdrLmR3rxR6JRE0RoVmbXu8+6kZtYU4k= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1 h1:ogLJMz+qpzav7lGMh10LMvAkM/fAoGlaiiHYiFYdm80= -golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -125,23 +44,9 @@ golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= diff --git a/internal/common/bytebuffer_test.go b/internal/common/bytebuffer_test.go index ea0ce94..c7add40 100644 --- a/internal/common/bytebuffer_test.go +++ b/internal/common/bytebuffer_test.go @@ -42,8 +42,8 @@ func TestByteBuff_WriteTo(t *testing.T) { f, err := os.OpenFile("/dev/null", os.O_WRONLY, os.ModeAppend) assert.NoError(t, err, "open /dev/null failed") defer f.Close() - // 16MB - s := common.RandAlphanumeric(16 * 1024 * 1024) + // 1MB + s := common.RandAlphanumeric(1 * 1024 * 1024) err = b.WriteString(s) assert.NoError(t, err) n, err := b.WriteTo(f) diff --git a/internal/common/bytedump.go b/internal/common/bytedump.go new file mode 100644 index 0000000..855b9ce --- /dev/null +++ b/internal/common/bytedump.go @@ -0,0 +1,118 @@ +package common + +import ( + "fmt" + "strconv" + "strings" +) + +var ( + _newLine = "\n" + _hexPadding [16]string + _bytePadding [16]string + _hexDumpRowPrefixes [4096]string +) + +func init() { + b := &strings.Builder{} + for i := 0; i < len(_hexPadding); i++ { + padding := len(_hexPadding) - i + for j := 0; j < padding; j++ { + b.WriteString(" ") + } + _hexPadding[i] = b.String() + b.Reset() + } + for i := 0; i < len(_bytePadding); i++ { + padding := len(_bytePadding) - i + for j := 0; j < padding; j++ { + b.WriteByte(' ') + } + _bytePadding[i] = b.String() + b.Reset() + } + for i := 0; i < len(_hexDumpRowPrefixes); i++ { + b.WriteString(_newLine) + n := i<<4&0xFFFFFFFF | 0x100000000 + b.WriteByte('|') + b.WriteString(leftPad(strconv.FormatInt(int64(n), 16), "0", 8)) + b.WriteByte('|') + _hexDumpRowPrefixes[i] = b.String() + b.Reset() + } +} + +func PrettyHexDump(b []byte) string { + sb := &strings.Builder{} + AppendPrettyHexDump(sb, b) + return sb.String() +} + +func AppendPrettyHexDump(dump *strings.Builder, b []byte) { + if len(b) < 1 { + return + } + dump.WriteString(" +-------------------------------------------------+") + dump.WriteString(_newLine) + dump.WriteString(" | 0 1 2 3 4 5 6 7 8 9 a b c d e f |") + dump.WriteString(_newLine) + dump.WriteString("+--------+-------------------------------------------------+----------------+") + length := len(b) + startIndex := 0 + fullRows := length >> 4 + remainder := length & 0xF + for row := 0; row < fullRows; row++ { + rowStartIndex := row<<4 + startIndex + appendHexDumpRowPrefix(dump, row, rowStartIndex) + rowEndIndex := rowStartIndex + 16 + for j := rowStartIndex; j < rowEndIndex; j++ { + _, _ = fmt.Fprintf(dump, " %02x", b[j]) + } + dump.WriteString(" |") + for j := rowStartIndex; j < rowEndIndex; j++ { + dump.WriteByte(byte2char(b[j])) + } + dump.WriteByte('|') + } + if remainder != 0 { + rowStartIndex := fullRows<<4 + startIndex + appendHexDumpRowPrefix(dump, fullRows, rowStartIndex) + rowEndIndex := rowStartIndex + remainder + for j := rowStartIndex; j < rowEndIndex; j++ { + _, _ = fmt.Fprintf(dump, " %02x", b[j]) + } + dump.WriteString(_hexPadding[remainder]) + dump.WriteString(" |") + for j := rowStartIndex; j < rowEndIndex; j++ { + dump.WriteByte(byte2char(b[j])) + } + dump.WriteString(_bytePadding[remainder]) + dump.WriteByte('|') + } + dump.WriteString(_newLine) + dump.WriteString("+--------+-------------------------------------------------+----------------+") +} + +func appendHexDumpRowPrefix(dump *strings.Builder, row int, rowStartIndex int) { + if row < len(_hexDumpRowPrefixes) { + dump.WriteString(_hexDumpRowPrefixes[row]) + return + } + dump.WriteString(_newLine) + n := rowStartIndex&0xFFFFFFFF | 0x100000000 + dump.WriteString(strconv.FormatInt(int64(n), 16)) + dump.WriteByte('|') +} + +func byte2char(b byte) byte { + if b <= 0x1f || b >= 0x7f { + return '.' + } + return b +} + +func leftPad(s string, padStr string, length int) string { + padCountInt := 1 + ((length - len(padStr)) / len(padStr)) + retStr := strings.Repeat(padStr, padCountInt) + s + return retStr[(len(retStr) - length):] +} diff --git a/internal/common/bytedump_test.go b/internal/common/bytedump_test.go new file mode 100644 index 0000000..2624ecc --- /dev/null +++ b/internal/common/bytedump_test.go @@ -0,0 +1,37 @@ +package common + +import ( + "math/rand" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func TestAppendPrettyHexDump(t *testing.T) { + b := make([]byte, 100) + rand.Read(b) + s := PrettyHexDump(b) + assert.NotEmpty(t, s, "should not return empty string") +} + +func TestAppendPrettyHexDump_Empty(t *testing.T) { + s := PrettyHexDump(nil) + assert.Empty(t, s, "should return empty string") + s = PrettyHexDump([]byte{}) + assert.Empty(t, s, "should return empty string") +} + +func TestAppendPrettyHexDump_Big(t *testing.T) { + // 512KB + b := make([]byte, 512*1024) + rand.Read(b) + sb := &strings.Builder{} + AppendPrettyHexDump(sb, b) + assert.NotEmpty(t, sb.String(), "should not return empty string") +} diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 300471e..f31597e 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -18,7 +18,7 @@ func TestSession(t *testing.T) { deadline := time.Now().Add(time.Duration(i+1) * time.Second) token := fmt.Sprintf("token_%d", i) tokens = append(tokens, token) - manager.Push(session.NewSession(deadline, socket.NewServerResume(nil, []byte(token)))) + manager.Push(session.NewSession(deadline, socket.NewResumableServerSocket(nil, []byte(token)))) } for _, token := range tokens { diff --git a/internal/socket/callback.go b/internal/socket/callback.go index 6afea4b..efbff00 100644 --- a/internal/socket/callback.go +++ b/internal/socket/callback.go @@ -8,14 +8,14 @@ import ( ) type callback interface { - Close(error) + stopWithError(error) } type requestStreamCallback struct { pc flux.Processor } -func (s requestStreamCallback) Close(err error) { +func (s requestStreamCallback) stopWithError(err error) { s.pc.Error(err) } @@ -23,7 +23,7 @@ type requestResponseCallback struct { pc mono.Processor } -func (s requestResponseCallback) Close(err error) { +func (s requestResponseCallback) stopWithError(err error) { s.pc.Error(err) } @@ -32,7 +32,7 @@ type requestChannelCallback struct { rcv flux.Processor } -func (s requestChannelCallback) Close(err error) { +func (s requestChannelCallback) stopWithError(err error) { s.snd.Cancel() s.rcv.Error(err) } @@ -41,7 +41,7 @@ type requestResponseCallbackReverse struct { su reactor.Subscription } -func (s requestResponseCallbackReverse) Close(err error) { +func (s requestResponseCallbackReverse) stopWithError(err error) { s.su.Cancel() // TODO: fill err } @@ -50,7 +50,7 @@ type requestStreamCallbackReverse struct { su rx.Subscription } -func (s requestStreamCallbackReverse) Close(err error) { +func (s requestStreamCallbackReverse) stopWithError(err error) { s.su.Cancel() // TODO: fill error } @@ -60,7 +60,7 @@ type requestChannelCallbackReverse struct { rcv flux.Processor } -func (s requestChannelCallbackReverse) Close(err error) { +func (s requestChannelCallbackReverse) stopWithError(err error) { s.rcv.Error(err) s.snd.Cancel() } diff --git a/internal/socket/callback_test.go b/internal/socket/callback_test.go new file mode 100644 index 0000000..c807e63 --- /dev/null +++ b/internal/socket/callback_test.go @@ -0,0 +1,6 @@ +package socket + +import "testing" + +func TestRequestResponseCallback_Close(t *testing.T) { +} diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index 46a4ebf..eae4a8b 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -18,9 +18,13 @@ import ( "github.com/rsocket/rsocket-go/rx" "github.com/rsocket/rsocket-go/rx/flux" "github.com/rsocket/rsocket-go/rx/mono" + "go.uber.org/atomic" ) -const _outChanSize = 64 +const ( + _outChanSize = 64 + _schedulerSize = 64 +) var errSocketClosed = errors.New("socket closed already") @@ -30,43 +34,39 @@ var ( unsupportedRequestChannel = []byte("Request-Channel not implemented.") ) -// IsSocketClosedError returns true if input error is for socket closed. -func IsSocketClosedError(err error) bool { - return err == errSocketClosed -} - // DuplexConnection represents a socket of RSocket which can be a requester or a responder. type DuplexConnection struct { - l *sync.RWMutex - counter *core.TrafficCounter - tp *transport.Transport - outs chan core.WriteableFrame - outsPriority []core.WriteableFrame - responder Responder - messages *sync.Map - sids StreamID - mtu int - fragments *sync.Map // common.U32Map // key=streamID, value=Joiner - writeDone chan struct{} - keepaliver *Keepaliver - cond *sync.Cond - singleScheduler scheduler.Scheduler - e error - leases lease.Leases - closeOnce sync.Once + locker sync.RWMutex + counter *core.TrafficCounter + tp *transport.Transport + outs chan core.WriteableFrame + outsPriority []core.WriteableFrame + responder Responder + messages *map32 + sids StreamID + mtu int + fragments *map32 // key=streamID, value=Joiner + writeDone chan struct{} + keepaliver *Keepaliver + cond sync.Cond + sc scheduler.Scheduler + e error + leases lease.Leases + closed *atomic.Bool + ready *atomic.Bool } // SetError sets error for current socket. func (dc *DuplexConnection) SetError(err error) { - dc.l.Lock() - defer dc.l.Unlock() + dc.locker.Lock() + defer dc.locker.Unlock() dc.e = err } // GetError get the error set. func (dc *DuplexConnection) GetError() error { - dc.l.RLock() - defer dc.l.RUnlock() + dc.locker.RLock() + defer dc.locker.RUnlock() return dc.e } @@ -86,18 +86,14 @@ func (dc *DuplexConnection) nextStreamID() (sid uint32) { } // Close close current socket. -func (dc *DuplexConnection) Close() (err error) { - dc.closeOnce.Do(func() { - err = dc.innerClose() - }) - return -} - -func (dc *DuplexConnection) innerClose() error { +func (dc *DuplexConnection) Close() error { + if !dc.closed.CAS(false, true) { + return nil + } if dc.keepaliver != nil { dc.keepaliver.Stop() } - _ = dc.singleScheduler.Close() + _ = dc.sc.Close() close(dc.outs) dc.cond.L.Lock() dc.cond.Broadcast() @@ -112,16 +108,18 @@ func (dc *DuplexConnection) innerClose() error { _ = dc.tp.Close() } } - dc.messages.Range(func(_, v interface{}) bool { + dc.messages.Range(func(_ uint32, v interface{}) bool { if cb, ok := v.(callback); ok { err := dc.e if err == nil { err = errSocketClosed } - go cb.Close(err) + go cb.stopWithError(err) } return true }) + dc.messages.Destroy() + dc.fragments.Destroy() return dc.e } @@ -433,7 +431,7 @@ func (dc *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayl <-frameN.DoneNotify() }) - _ = dc.singleScheduler.Worker().Do(func() { + _ = dc.sc.Worker().Do(func() { receivingProcessor.Next(pl) }) @@ -781,19 +779,19 @@ func (dc *DuplexConnection) onFramePayload(frame core.Frame) error { } func (dc *DuplexConnection) clearTransport() { - dc.cond.L.Lock() + dc.locker.Lock() + defer dc.locker.Unlock() dc.tp = nil - dc.cond.L.Unlock() + dc.ready.Store(false) } // SetTransport sets a transport for current socket. -func (dc *DuplexConnection) SetTransport(tp *transport.Transport) { +func (dc *DuplexConnection) SetTransport(tp *transport.Transport) (ok bool) { tp.RegisterHandler(transport.OnCancel, dc.onFrameCancel) tp.RegisterHandler(transport.OnError, dc.onFrameError) tp.RegisterHandler(transport.OnRequestN, dc.onFrameRequestN) tp.RegisterHandler(transport.OnPayload, dc.onFramePayload) tp.RegisterHandler(transport.OnKeepalive, dc.onFrameKeepalive) - if dc.responder != nil { tp.RegisterHandler(transport.OnRequestResponse, dc.onFrameRequestResponse) tp.RegisterHandler(transport.OnMetadataPush, dc.respondMetadataPush) @@ -802,10 +800,16 @@ func (dc *DuplexConnection) SetTransport(tp *transport.Transport) { tp.RegisterHandler(transport.OnRequestChannel, dc.onFrameRequestChannel) } - dc.cond.L.Lock() + ok = dc.ready.CAS(false, true) + if !ok { + return + } + + dc.locker.Lock() dc.tp = tp dc.cond.Signal() - dc.cond.L.Unlock() + dc.locker.Unlock() + return } func (dc *DuplexConnection) sendFrame(f core.WriteableFrame) { @@ -984,12 +988,14 @@ func (dc *DuplexConnection) drainOutBack() { func (dc *DuplexConnection) loopWriteWithKeepaliver(ctx context.Context, leaseChan <-chan lease.Lease) error { for { - if dc.tp == nil { - dc.cond.L.Lock() + if dc.closed.Load() { + break + } + if !dc.ready.Load() { + dc.locker.Lock() dc.cond.Wait() - dc.cond.L.Unlock() + dc.locker.Unlock() } - select { case <-ctx.Done(): dc.cleanOuts() @@ -1044,12 +1050,14 @@ func (dc *DuplexConnection) LoopWrite(ctx context.Context) error { return dc.loopWriteWithKeepaliver(ctx, leaseChan) } for { - if dc.tp == nil { + if dc.closed.Load() { + break + } + if !dc.ready.Load() { dc.cond.L.Lock() dc.cond.Wait() dc.cond.L.Unlock() } - select { case <-ctx.Done(): dc.cleanOuts() @@ -1086,6 +1094,11 @@ func (dc *DuplexConnection) unregister(sid uint32) { dc.fragments.Delete(sid) } +// IsSocketClosedError returns true if input error is for socket closed. +func IsSocketClosedError(err error) bool { + return err == errSocketClosed +} + // NewServerDuplexConnection creates a new server-side DuplexConnection. func NewServerDuplexConnection(mtu int, leases lease.Leases) *DuplexConnection { return newDuplexConnection(mtu, nil, &serverStreamIDs{}, leases) @@ -1097,19 +1110,20 @@ func NewClientDuplexConnection(mtu int, keepaliveInterval time.Duration) *Duplex } func newDuplexConnection(mtu int, ka *Keepaliver, sids StreamID, leases lease.Leases) *DuplexConnection { - l := &sync.RWMutex{} - return &DuplexConnection{ - l: l, - leases: leases, - outs: make(chan core.WriteableFrame, _outChanSize), - mtu: mtu, - messages: &sync.Map{}, - sids: sids, - fragments: &sync.Map{}, - writeDone: make(chan struct{}), - cond: sync.NewCond(l), - counter: core.NewTrafficCounter(), - keepaliver: ka, - singleScheduler: scheduler.NewSingle(64), + c := &DuplexConnection{ + leases: leases, + outs: make(chan core.WriteableFrame, _outChanSize), + mtu: mtu, + messages: newMap32(), + sids: sids, + fragments: newMap32(), + writeDone: make(chan struct{}), + counter: core.NewTrafficCounter(), + keepaliver: ka, + sc: scheduler.NewSingle(_schedulerSize), + closed: atomic.NewBool(false), + ready: atomic.NewBool(false), } + c.cond.L = &c.locker + return c } diff --git a/internal/socket/map32.go b/internal/socket/map32.go new file mode 100644 index 0000000..ab2aa4d --- /dev/null +++ b/internal/socket/map32.go @@ -0,0 +1,51 @@ +package socket + +import "sync" + +type map32 struct { + locker sync.RWMutex + store map[uint32]interface{} +} + +func (p *map32) Destroy() { + p.locker.Lock() + p.store = nil + p.locker.Unlock() +} + +func (p *map32) Range(fn func(uint32, interface{}) bool) { + p.locker.RLock() + defer p.locker.RUnlock() + for key, value := range p.store { + if !fn(key, value) { + break + } + } +} + +func (p *map32) Load(key uint32) (v interface{}, ok bool) { + p.locker.RLock() + v, ok = p.store[key] + p.locker.RUnlock() + return +} + +func (p *map32) Store(key uint32, value interface{}) { + p.locker.Lock() + if p.store != nil { + p.store[key] = value + } + p.locker.Unlock() +} + +func (p *map32) Delete(key uint32) { + p.locker.Lock() + delete(p.store, key) + p.locker.Unlock() +} + +func newMap32() *map32 { + return &map32{ + store: make(map[uint32]interface{}), + } +} diff --git a/internal/socket/misc.go b/internal/socket/misc.go index d387396..3639e20 100644 --- a/internal/socket/misc.go +++ b/internal/socket/misc.go @@ -36,21 +36,6 @@ func (p *SetupInfo) toFrame() core.WriteableFrame { ) } -func tryRecover(e interface{}) (err error) { - if e == nil { - return - } - switch v := e.(type) { - case error: - err = v - case string: - err = errors.New(v) - default: - err = errors.Errorf("error: %s", v) - } - return -} - func ToIntRequestN(n uint32) int { if n > rx.RequestMax { return rx.RequestMax @@ -67,3 +52,18 @@ func ToUint32RequestN(n int) uint32 { } return uint32(n) } + +func tryRecover(e interface{}) (err error) { + if e == nil { + return + } + switch v := e.(type) { + case error: + err = v + case string: + err = errors.New(v) + default: + err = errors.Errorf("error: %s", v) + } + return +} diff --git a/internal/socket/misc_test.go b/internal/socket/misc_test.go index f448896..1b64ccd 100644 --- a/internal/socket/misc_test.go +++ b/internal/socket/misc_test.go @@ -1,23 +1,31 @@ -package socket_test +package socket import ( "math" "testing" - "github.com/rsocket/rsocket-go/internal/socket" + "github.com/pkg/errors" "github.com/rsocket/rsocket-go/rx" "github.com/stretchr/testify/assert" ) func TestToUint32RequestN(t *testing.T) { - assert.Equal(t, uint32(1), socket.ToUint32RequestN(1)) + assert.Equal(t, uint32(1), ToUint32RequestN(1)) assert.Panics(t, func() { - socket.ToUint32RequestN(-1) + ToUint32RequestN(-1) }, "should panic") - assert.Equal(t, uint32(rx.RequestMax), socket.ToUint32RequestN(math.MaxInt64)) + assert.Equal(t, uint32(rx.RequestMax), ToUint32RequestN(math.MaxInt64)) } func TestToIntRequestN(t *testing.T) { - assert.Equal(t, 1, socket.ToIntRequestN(1)) - assert.Equal(t, rx.RequestMax, socket.ToIntRequestN(math.MaxUint32)) + assert.Equal(t, 1, ToIntRequestN(1)) + assert.Equal(t, rx.RequestMax, ToIntRequestN(math.MaxUint32)) +} + +func TestTryRecover(t *testing.T) { + assert.NoError(t, tryRecover(nil)) + e := errors.New("fake error") + assert.Equal(t, e, tryRecover(e)) + assert.Error(t, e, tryRecover("fake error")) + assert.Error(t, e, tryRecover(struct{}{})) } diff --git a/internal/socket/resumable_client_socket.go b/internal/socket/resumable_client_socket.go index 60d08b7..7745602 100644 --- a/internal/socket/resumable_client_socket.go +++ b/internal/socket/resumable_client_socket.go @@ -13,7 +13,8 @@ import ( "go.uber.org/atomic" ) -const reconnectDelay = 1 * time.Second +const _resumeReconnectDelay = 1 * time.Second +const _resumeTimeout = 10 * time.Second type resumeClientSocket struct { *BaseSocket @@ -22,52 +23,52 @@ type resumeClientSocket struct { tp transport.ClientTransportFunc } -func (p *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { - p.setup = setup +func (r *resumeClientSocket) Setup(ctx context.Context, setup *SetupInfo) error { + r.setup = setup go func(ctx context.Context) { - _ = p.socket.LoopWrite(ctx) + _ = r.socket.LoopWrite(ctx) }(ctx) - return p.connect(ctx) + return r.connect(ctx) } -func (p *resumeClientSocket) Close() (err error) { - p.once.Do(func() { - p.markClosing() - err = p.socket.Close() - for i, l := 0, len(p.closers); i < l; i++ { - p.closers[l-i-1](err) +func (r *resumeClientSocket) Close() (err error) { + r.once.Do(func() { + r.markAsClosing() + err = r.socket.Close() + for i, l := 0, len(r.closers); i < l; i++ { + r.closers[l-i-1](err) } }) return } -func (p *resumeClientSocket) connect(ctx context.Context) (err error) { - connects := p.connects.Inc() +func (r *resumeClientSocket) connect(ctx context.Context) (err error) { + connects := r.connects.Inc() if connects < 0 { - _ = p.Close() + _ = r.Close() return } - tp, err := p.tp(ctx) + tp, err := r.tp(ctx) if err != nil { if connects == 1 { return } - time.Sleep(reconnectDelay) - _ = p.connect(ctx) + time.Sleep(_resumeReconnectDelay) + _ = r.connect(ctx) return } - tp.Connection().SetCounter(p.socket.counter) - tp.SetLifetime(p.setup.KeepaliveLifetime) + tp.Connection().SetCounter(r.socket.counter) + tp.SetLifetime(r.setup.KeepaliveLifetime) go func(ctx context.Context, tp *transport.Transport) { defer func() { - p.socket.clearTransport() - if p.isClosed() { - _ = p.Close() + r.socket.clearTransport() + if r.isClosed() { + _ = r.Close() return } - time.Sleep(reconnectDelay) - _ = p.connect(ctx) + time.Sleep(_resumeReconnectDelay) + _ = r.connect(ctx) }() err := tp.Start(ctx) if err != nil { @@ -78,23 +79,23 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { var f core.WriteableFrame // connect first time. - if len(p.setup.Token) < 1 || connects == 1 { + if len(r.setup.Token) < 1 || connects == 1 { tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { - p.socket.SetError(frame.(*framing.ErrorFrame)) - p.markClosing() + r.socket.SetError(frame.(*framing.ErrorFrame)) + r.markAsClosing() return }) - f = p.setup.toFrame() + f = r.setup.toFrame() err = tp.Send(f, true) - p.socket.SetTransport(tp) + r.socket.SetTransport(tp) return } f = framing.NewWriteableResumeFrame( core.DefaultVersion, - p.setup.Token, - p.socket.counter.WriteBytes(), - p.socket.counter.ReadBytes(), + r.setup.Token, + r.socket.counter.WriteBytes(), + r.socket.counter.ReadBytes(), ) resumeErr := make(chan string) @@ -104,14 +105,19 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { return }) - tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) (err error) { + tp.RegisterHandler(transport.OnErrorWithZeroStreamID, func(frame core.Frame) error { // TODO: process other error with zero StreamID f := frame.(*framing.ErrorFrame) if f.ErrorCode() == core.ErrorCodeRejectedResume { + defer func() { + if err := recover(); err != nil { + logger.Warnf("handle reject resume failed: %s\n", err) + } + }() resumeErr <- f.Error() close(resumeErr) } - return + return nil }) err = tp.Send(f, true) @@ -119,29 +125,29 @@ func (p *resumeClientSocket) connect(ctx context.Context) (err error) { return err } - ctx2, cancel := context.WithTimeout(ctx, 10*time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, _resumeTimeout) defer cancel() select { - case <-ctx2.Done(): - err = ctx2.Err() + case <-timeoutCtx.Done(): + err = timeoutCtx.Err() case reject, ok := <-resumeErr: if ok { err = errors.New(reject) - p.markClosing() + r.markAsClosing() } else { - p.socket.SetTransport(tp) + r.socket.SetTransport(tp) } } return } -func (p *resumeClientSocket) markClosing() { - p.connects.Store(math.MinInt32) +func (r *resumeClientSocket) markAsClosing() { + r.connects.Store(math.MinInt32) } -func (p *resumeClientSocket) isClosed() bool { - return p.connects.Load() < 0 +func (r *resumeClientSocket) isClosed() bool { + return r.connects.Load() < 0 } // NewResumableClientSocket creates a client-side socket with resume support. diff --git a/internal/socket/resumable_client_socket_test.go b/internal/socket/resumable_client_socket_test.go index 565230a..9cf02d1 100644 --- a/internal/socket/resumable_client_socket_test.go +++ b/internal/socket/resumable_client_socket_test.go @@ -8,12 +8,27 @@ import ( "github.com/golang/mock/gomock" "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/internal/fragmentation" "github.com/rsocket/rsocket-go/internal/socket" + "github.com/rsocket/rsocket-go/payload" + "github.com/rsocket/rsocket-go/rx" "github.com/stretchr/testify/assert" + "go.uber.org/atomic" ) +var fakeResumableSetup = &socket.SetupInfo{ + Version: core.DefaultVersion, + MetadataMimeType: fakeMimeType, + DataMimeType: fakeMimeType, + Metadata: fakeMetadata, + Data: fakeData, + KeepaliveLifetime: 90 * time.Second, + KeepaliveInterval: 30 * time.Second, + Token: fakeToken, +} + func TestNewResumableClientSocket(t *testing.T) { ctrl, conn, tp := InitTransport(t) defer ctrl.Finish() @@ -40,8 +55,132 @@ func TestNewResumableClientSocket(t *testing.T) { return tp, nil }, ds) - defer rcs.Close() + onCloseCalled := atomic.NewBool(false) + rcs.OnClose(func(err error) { + onCloseCalled.CAS(false, true) + }) + + defer func() { + err := rcs.Close() + assert.NoError(t, err) + time.Sleep(500 * time.Millisecond) + assert.True(t, onCloseCalled.Load()) + }() - err := rcs.Setup(context.Background(), fakeSetup) + err := rcs.Setup(context.Background(), fakeResumableSetup) assert.NoError(t, err) + + requestId := atomic.NewUint32(1) + nextRequestId := func() uint32 { + return requestId.Add(2) - 2 + } + + result, err := rcs.RequestResponse(payload.New(fakeData, fakeMetadata)). + DoOnSubscribe(func(s rx.Subscription) { + readChan <- framing.NewPayloadFrame(nextRequestId(), fakeData, fakeMetadata, core.FlagComplete) + }). + Block(context.Background()) + assert.NoError(t, err, "request response failed") + assert.Equal(t, fakeData, result.Data(), "response data doesn't match") + assert.Equal(t, fakeMetadata, extractMetadata(result), "response metadata doesn't match") + + var stream []payload.Payload + _, err = rcs.RequestStream(payload.New(fakeData, fakeMetadata)). + DoOnNext(func(input payload.Payload) error { + stream = append(stream, input) + return nil + }). + DoOnSubscribe(func(s rx.Subscription) { + nextId := nextRequestId() + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) + readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext|core.FlagComplete) + }). + BlockLast(context.Background()) + assert.NoError(t, err, "request stream failed") + + // When a fatal error occurred, client should be stopped immediately. + fatalErr := []byte("fatal error") + readChan <- framing.NewErrorFrame(0, core.ErrorCodeRejected, fatalErr) + time.Sleep(100 * time.Millisecond) + err = ds.GetError() + assert.Error(t, err, "should get error") + assert.Equal(t, fatalErr, err.(core.CustomError).ErrorData()) +} + +func TestResumeClientSocket_Setup_Broken(t *testing.T) { + c := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + s := socket.NewResumableClientSocket(func(ctx context.Context) (*transport.Transport, error) { + return nil, fakeErr + }, c) + defer s.Close() + err := s.Setup(context.Background(), fakeResumableSetup) + assert.Error(t, err) +} + +func TestResumeClientSocket_Setup(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ds := socket.NewClientDuplexConnection(fragmentation.MaxFragment, 90*time.Second) + + readChanChan := make(chan chan core.Frame, 64) + + createTimes := atomic.NewInt32(0) + rcs := socket.NewResumableClientSocket(func(ctx context.Context) (*transport.Transport, error) { + if createTimes.Inc() >= 3 { + return nil, fakeErr + } + + conn, tp := InitTransportWithController(ctrl) + + // For test + readChan := make(chan core.Frame, 64) + + conn.EXPECT().Close().AnyTimes() + conn.EXPECT().SetCounter(gomock.Any()).Times(1) + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + readChanChan <- readChan + + return tp, nil + }, ds) + + onCloseCalled := atomic.NewBool(false) + rcs.OnClose(func(err error) { + onCloseCalled.CAS(false, true) + }) + + defer func() { + err := rcs.Close() + assert.NoError(t, err) + time.Sleep(500 * time.Millisecond) + assert.True(t, onCloseCalled.Load()) + }() + + err := rcs.Setup(context.Background(), fakeResumableSetup) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + readChan := <-readChanChan + close(readChan) + + time.Sleep(100 * time.Millisecond) + + readChan = <-readChanChan + readChan <- framing.NewResumeOKFrame(0) + time.Sleep(100 * time.Millisecond) + readChan <- framing.NewErrorFrame(0, core.ErrorCodeRejectedResume, []byte("fake reject error")) + close(readChan) + time.Sleep(100 * time.Millisecond) } diff --git a/internal/socket/resumable_server_socket.go b/internal/socket/resumable_server_socket.go index e9a580c..65c03c7 100644 --- a/internal/socket/resumable_server_socket.go +++ b/internal/socket/resumable_server_socket.go @@ -36,8 +36,8 @@ func (p *resumeServerSocket) Start(ctx context.Context) error { return p.socket.LoopWrite(ctx) } -// NewServerResume creates a new server-side socket with resume support. -func NewServerResume(socket *DuplexConnection, token []byte) ServerSocket { +// NewResumableServerSocket creates a new server-side socket with resume support. +func NewResumableServerSocket(socket *DuplexConnection, token []byte) ServerSocket { return &resumeServerSocket{ BaseSocket: NewBaseSocket(socket), token: token, diff --git a/internal/socket/resumable_server_socket_test.go b/internal/socket/resumable_server_socket_test.go new file mode 100644 index 0000000..b41fa48 --- /dev/null +++ b/internal/socket/resumable_server_socket_test.go @@ -0,0 +1,85 @@ +package socket_test + +import ( + "context" + "io" + "log" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/rsocket/rsocket-go/core" + "github.com/rsocket/rsocket-go/core/framing" + "github.com/rsocket/rsocket-go/internal/fragmentation" + "github.com/rsocket/rsocket-go/internal/socket" + "github.com/stretchr/testify/assert" +) + +var fakeToken = []byte("fakeToken") + +func TestResumableServerSocket_Start(t *testing.T) { + ctrl, conn, tp := InitTransport(t) + defer ctrl.Finish() + + // For test + readChan := make(chan core.Frame, 64) + setupFrame := framing.NewSetupFrame( + core.DefaultVersion, + 30*time.Second, + 90*time.Second, + nil, + fakeMimeType, + fakeMimeType, + fakeData, + fakeMetadata, + false, + ) + readChan <- setupFrame + + conn.EXPECT().Close().Times(1) + conn.EXPECT().SetCounter(gomock.Any()).AnyTimes() + conn.EXPECT().Write(gomock.Any()).Return(nil).AnyTimes() + conn.EXPECT().Flush().AnyTimes() + conn.EXPECT().Read().DoAndReturn(func() (core.Frame, error) { + next, ok := <-readChan + if !ok { + return nil, io.EOF + } + return next, nil + }).AnyTimes() + conn.EXPECT().SetDeadline(gomock.Any()).AnyTimes() + + firstFrame, err := tp.ReadFirst(context.Background()) + assert.NoError(t, err, "read first frame failed") + assert.Equal(t, setupFrame, firstFrame, "first should be setup frame") + + close(readChan) + + c := socket.NewServerDuplexConnection(fragmentation.MaxFragment, nil) + ss := socket.NewResumableServerSocket(c, fakeToken) + + ss.SetResponder(fakeResponder) + ss.SetTransport(tp) + + token, ok := ss.Token() + assert.True(t, ok) + assert.Equal(t, fakeToken, token, "token doesn't match") + + done := make(chan struct{}) + go func() { + defer close(done) + err := ss.Start(context.Background()) + assert.NoError(t, err, "start server socket failed") + }() + + err = tp.Start(context.Background()) + assert.NoError(t, err, "start transport failed") + + log.Println("closing") + _ = c.Close() + log.Println("closed") + + assert.Equal(t, true, ss.Pause(), "should return true") + + <-done +} diff --git a/internal/socket/simple_server_socket_test.go b/internal/socket/simple_server_socket_test.go index 74551e2..36d10e7 100644 --- a/internal/socket/simple_server_socket_test.go +++ b/internal/socket/simple_server_socket_test.go @@ -55,8 +55,8 @@ func TestSimpleServerSocket_Start(t *testing.T) { close(readChan) - ds := socket.NewServerDuplexConnection(fragmentation.MaxFragment, nil) - ss := socket.NewSimpleServerSocket(ds) + c := socket.NewServerDuplexConnection(fragmentation.MaxFragment, nil) + ss := socket.NewSimpleServerSocket(c) ss.SetResponder(fakeResponder) ss.SetTransport(tp) @@ -75,7 +75,7 @@ func TestSimpleServerSocket_Start(t *testing.T) { err = tp.Start(context.Background()) assert.NoError(t, err, "start transport failed") - _ = ds.Close() + _ = c.Close() <-done } diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go index aee5693..8401864 100644 --- a/internal/socket/socket_test.go +++ b/internal/socket/socket_test.go @@ -30,10 +30,15 @@ var ( } ) -func InitTransport(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { - ctrl := gomock.NewController(t) +func InitTransportWithController(ctrl *gomock.Controller) (*MockConn, *transport.Transport) { conn := NewMockConn(ctrl) tp := transport.NewTransport(conn) + return conn, tp +} + +func InitTransport(t *testing.T) (*gomock.Controller, *MockConn, *transport.Transport) { + ctrl := gomock.NewController(t) + conn, tp := InitTransportWithController(ctrl) return ctrl, conn, tp } diff --git a/lease/lease_test.go b/lease/lease_test.go index fd2ccf7..1470a94 100644 --- a/lease/lease_test.go +++ b/lease/lease_test.go @@ -11,7 +11,7 @@ import ( ) func TestSimpleLease_Next(t *testing.T) { - l, err := lease.NewSimpleLease(3*time.Second, 1*time.Second, 1*time.Second, 1) + l, err := lease.NewSimpleLease(300*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond, 1) assert.NoError(t, err, "create simple lease failed") lease, ok := l.Next(context.Background()) assert.True(t, ok, "get next lease chan failed") diff --git a/server.go b/server.go index fc0e65b..a711fe3 100644 --- a/server.go +++ b/server.go @@ -245,7 +245,7 @@ func (p *server) doSetup(frame *framing.SetupFrame, tp *transport.Transport, soc // 4. resume success copy(token, frame.Token()) - sendingSocket = socket.NewServerResume(rawSocket, token) + sendingSocket = socket.NewResumableServerSocket(rawSocket, token) if responder, e := p.acc(frame, sendingSocket); e != nil { switch vv := e.(type) { case *framing.ErrorFrame: diff --git a/transporter_test.go b/transporter_test.go index 8e39299..d810549 100644 --- a/transporter_test.go +++ b/transporter_test.go @@ -2,6 +2,7 @@ package rsocket_test import ( "context" + "crypto/tls" "fmt" "net/http" "os" @@ -14,9 +15,11 @@ import ( ) var fakeSockFile string +var fakeTlsConfig = &tls.Config{ + InsecureSkipVerify: true, +} func init() { - fmt.Println(os.TempDir()) fakeSockFile = fmt.Sprintf("%s/test-rsocket-%s.sock", strings.TrimRight(os.TempDir(), "/"), uuid.New().String()) } @@ -39,6 +42,7 @@ func TestTcpClient(t *testing.T) { rsocket.TcpClient(). SetAddr(":7878"). SetHostAndPort("127.0.0.1", 7878). + SetTlsConfig(fakeTlsConfig). Build() }) } @@ -46,6 +50,7 @@ func TestTcpClient(t *testing.T) { func TestTcpServerBuilder(t *testing.T) { assert.NotPanics(t, func() { rsocket.TcpServer().SetAddr(":7878").Build() + rsocket.TcpServer().SetHostAndPort("127.0.0.1", 7878).SetTlsConfig(fakeTlsConfig).Build() }) } @@ -56,6 +61,7 @@ func TestWebsocketClient(t *testing.T) { rsocket.WebsocketClient(). SetUrl("ws://127.0.0.1:8080/fake/path"). SetHeader(h). + SetTlsConfig(fakeTlsConfig). Build() }) } @@ -65,6 +71,7 @@ func TestWebsocketServer(t *testing.T) { tp := rsocket.WebsocketServer(). SetAddr(":7878"). SetPath("/fake"). + SetTlsConfig(fakeTlsConfig). Build() assert.NotNil(t, tp) })