diff --git a/balancer/group_test.go b/balancer/group_test.go index 6b45f98..78426c2 100644 --- a/balancer/group_test.go +++ b/balancer/group_test.go @@ -15,6 +15,7 @@ func TestGroup_Get(t *testing.T) { called++ return balancer.NewRoundRobinBalancer() }) + defer g.Close() for range [2]struct{}{} { b := g.Get(fakeGroupId) assert.NotNil(t, b) diff --git a/core/transport/tcp_transport.go b/core/transport/tcp_transport.go index 7cbe1b1..febffad 100644 --- a/core/transport/tcp_transport.go +++ b/core/transport/tcp_transport.go @@ -151,16 +151,20 @@ func NewTCPClientTransport(c net.Conn) *Transport { } // NewTCPClientTransportWithAddr creates a new transport. -func NewTCPClientTransportWithAddr(network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { - var rawConn net.Conn +func NewTCPClientTransportWithAddr(ctx context.Context, network, addr string, tlsConfig *tls.Config) (tp *Transport, err error) { + var conn net.Conn if tlsConfig == nil { - rawConn, err = net.Dial(network, addr) + var dial net.Dialer + conn, err = dial.DialContext(ctx, network, addr) } else { - rawConn, err = tls.Dial(network, addr, tlsConfig) + dial := tls.Dialer{ + Config: tlsConfig, + } + conn, err = dial.DialContext(ctx, network, addr) } if err != nil { return } - tp = NewTCPClientTransport(rawConn) + tp = NewTCPClientTransport(conn) return } diff --git a/core/transport/websocket_transport.go b/core/transport/websocket_transport.go index 18ac1c3..86eb917 100644 --- a/core/transport/websocket_transport.go +++ b/core/transport/websocket_transport.go @@ -177,20 +177,20 @@ func NewWebsocketServerTransportWithAddr(addr string, path string, upgrader *web } // NewWebsocketClientTransport creates a new client-side transport. -func NewWebsocketClientTransport(url string, config *tls.Config, header http.Header) (*Transport, error) { - var d *websocket.Dialer +func NewWebsocketClientTransport(ctx context.Context, url string, config *tls.Config, header http.Header) (*Transport, error) { + var dial *websocket.Dialer if config == nil { - d = websocket.DefaultDialer + dial = websocket.DefaultDialer } else { - d = &websocket.Dialer{ + dial = &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: 45 * time.Second, TLSClientConfig: config, } } - wsConn, _, err := d.Dial(url, header) + conn, _, err := dial.DialContext(ctx, url, header) if err != nil { return nil, errors.Wrap(err, "dial websocket failed") } - return NewTransport(NewWebsocketConnection(wsConn)), nil + return NewTransport(NewWebsocketConnection(conn)), nil } diff --git a/examples/echo/echo b/examples/echo/echo deleted file mode 100755 index ec43c13..0000000 Binary files a/examples/echo/echo and /dev/null differ diff --git a/examples/echo_bench/echo_bench.go b/examples/echo_bench/echo_bench.go index efdf182..7e200bb 100644 --- a/examples/echo_bench/echo_bench.go +++ b/examples/echo_bench/echo_bench.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/jjeffcaii/reactor-go/scheduler" "github.com/rsocket/rsocket-go" "github.com/rsocket/rsocket-go/core/transport" "github.com/rsocket/rsocket-go/payload" @@ -47,7 +46,6 @@ func main() { rand.Read(data) now := time.Now() - ctx := context.Background() sub := rx.NewSubscriber( rx.OnNext(func(input payload.Payload) error { @@ -57,9 +55,8 @@ func main() { return nil }), ) - for i := 0; i < n; i++ { - client.RequestResponse(payload.New(data, nil)).SubscribeOn(scheduler.Parallel()).SubscribeWith(ctx, sub) + client.RequestResponse(payload.New(data, nil)).SubscribeWith(context.Background(), sub) } wg.Wait() cost := time.Since(now) diff --git a/go.mod b/go.mod index 7e75be5..8c0173f 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/golang/mock v1.4.3 github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 - github.com/jjeffcaii/reactor-go v0.2.3 + github.com/jjeffcaii/reactor-go v0.2.4 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.4.0 github.com/urfave/cli/v2 v2.1.1 diff --git a/go.sum b/go.sum index a4ffe23..83f9b8e 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jjeffcaii/reactor-go v0.2.3 h1:qkcnnNyJ241hq15EMjGXP/mtqxRQlITy7eed2qzZdYQ= -github.com/jjeffcaii/reactor-go v0.2.3/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= +github.com/jjeffcaii/reactor-go v0.2.4 h1:Q3N/0Ngt1Ywi7ezye2LQ+mU1vNdHxyG5ZRk3W2EWmYA= +github.com/jjeffcaii/reactor-go v0.2.4/go.mod h1:I4qZrpZcsqjzo3pjq0XWGBTpdFXB95XeYinrPYETNL4= github.com/panjf2000/ants/v2 v2.4.1 h1:7RtUqj5lGOw0WnZhSKDZ2zzJhaX5490ZW1sUolRXCxY= github.com/panjf2000/ants/v2 v2.4.1/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -31,7 +31,6 @@ github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2 go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -41,10 +40,8 @@ golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fq golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c h1:IGkKhmfzcztjm6gYkykvu/NiS8kaqbCWAEWWAyf8J5U= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= diff --git a/internal/socket/duplex.go b/internal/socket/duplex.go index edbc932..5c9fd99 100644 --- a/internal/socket/duplex.go +++ b/internal/socket/duplex.go @@ -355,7 +355,7 @@ func (dc *DuplexConnection) RequestChannel(publisher rx.Publisher) (ret flux.Flu }) return }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { dc.register(sid, requestChannelCallback{rcv: receiving, snd: s}) s.Request(1) }), @@ -419,7 +419,7 @@ func (dc *DuplexConnection) respondRequestResponse(receiving fragmentation.Heade rx.OnError(func(e error) { dc.writeError(sid, e) }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { dc.register(sid, requestResponseCallbackReverse{su: s}) s.Request(rx.RequestMax) }), @@ -505,7 +505,7 @@ func (dc *DuplexConnection) respondRequestChannel(pl fragmentation.HeaderAndPayl <-complete.DoneNotify() } }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { dc.register(sid, requestChannelCallbackReverse{rcv: receivingProcessor, snd: s}) close(mustSub) s.Request(initRequestN) @@ -602,7 +602,7 @@ func (dc *DuplexConnection) respondRequestStream(receiving fragmentation.HeaderA dc.sendPayload(sid, elem, core.FlagNext) return nil }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { dc.register(sid, requestStreamCallbackReverse{su: s}) s.Request(n32) }), diff --git a/internal/socket/resumable_client_socket_test.go b/internal/socket/resumable_client_socket_test.go index 9cf02d1..c1b20c9 100644 --- a/internal/socket/resumable_client_socket_test.go +++ b/internal/socket/resumable_client_socket_test.go @@ -76,7 +76,7 @@ func TestNewResumableClientSocket(t *testing.T) { } result, err := rcs.RequestResponse(payload.New(fakeData, fakeMetadata)). - DoOnSubscribe(func(s rx.Subscription) { + DoOnSubscribe(func(ctx context.Context, s rx.Subscription) { readChan <- framing.NewPayloadFrame(nextRequestId(), fakeData, fakeMetadata, core.FlagComplete) }). Block(context.Background()) @@ -90,7 +90,7 @@ func TestNewResumableClientSocket(t *testing.T) { stream = append(stream, input) return nil }). - DoOnSubscribe(func(s rx.Subscription) { + DoOnSubscribe(func(ctx context.Context, s rx.Subscription) { nextId := nextRequestId() readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) diff --git a/internal/socket/simple_client_socket_test.go b/internal/socket/simple_client_socket_test.go index 43fb872..2f33884 100644 --- a/internal/socket/simple_client_socket_test.go +++ b/internal/socket/simple_client_socket_test.go @@ -68,7 +68,7 @@ func TestNewClient(t *testing.T) { } result, err := cli.RequestResponse(payload.New(fakeData, fakeMetadata)). - DoOnSubscribe(func(s rx.Subscription) { + DoOnSubscribe(func(ctx context.Context, s rx.Subscription) { readChan <- framing.NewPayloadFrame(nextRequestId(), fakeData, fakeMetadata, core.FlagComplete) }). Block(context.Background()) @@ -82,7 +82,7 @@ func TestNewClient(t *testing.T) { stream = append(stream, input) return nil }). - DoOnSubscribe(func(s rx.Subscription) { + DoOnSubscribe(func(ctx context.Context, s rx.Subscription) { nextId := nextRequestId() readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) readChan <- framing.NewPayloadFrame(nextId, fakeData, fakeMetadata, core.FlagNext) diff --git a/rsocket_example_test.go b/rsocket_example_test.go index 93052bf..43c8184 100644 --- a/rsocket_example_test.go +++ b/rsocket_example_test.go @@ -133,7 +133,7 @@ func ExampleConnect() { s.Request(1) return nil }). - Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { + Subscribe(context.Background(), rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { s.Request(1) })) // Simple RequestChannel. diff --git a/rsocket_test.go b/rsocket_test.go index 8108886..6bb0f33 100644 --- a/rsocket_test.go +++ b/rsocket_test.go @@ -52,6 +52,13 @@ func TestResume(t *testing.T) { }() go func(ctx context.Context) { + defer func() { + select { + case <-started: + default: + close(started) + } + }() _ = Receive(). OnStart(func() { close(started) @@ -158,6 +165,13 @@ func TestConnectBroken(t *testing.T) { port := 8787 go func(ctx context.Context) { + defer func() { + select { + case <-started: + default: + close(started) + } + }() _ = Receive(). OnStart(func() { close(started) @@ -206,6 +220,15 @@ func TestBiDirection(t *testing.T) { defer cancel() go func(ctx context.Context) { + + defer func() { + select { + case <-started: + default: + close(started) + } + }() + l, _ := lease.NewSimpleFactory(3*time.Second, 1*time.Second, 1*time.Second, 10) _ = Receive(). Lease(l). @@ -313,6 +336,15 @@ func testAll(t *testing.T, proto string, clientTp transport.ClientTransporter, s serving := make(chan struct{}) go func(ctx context.Context) { + + defer func() { + select { + case <-serving: + default: + close(serving) + } + }() + err := Receive(). Fragment(128). OnStart(func() { @@ -453,7 +485,7 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) { su.Request(1) return nil }). - Subscribe(ctx, rx.OnSubscribe(func(s rx.Subscription) { + Subscribe(ctx, rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { su = s su.Request(1) })) @@ -515,7 +547,7 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) { Subscribe(ctx, rx.OnNext(func(elem payload.Payload) error { su.Request(1) return nil - }), rx.OnSubscribe(func(s rx.Subscription) { + }), rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { su = s su.Request(1) })) @@ -567,3 +599,74 @@ func startProxy(addr string, ch chan net.Listener, upstreamAddr string) { } } + +type delayedRSocket struct { +} + +func (d delayedRSocket) FireAndForget(message payload.Payload) { + panic("implement me") +} + +func (d delayedRSocket) MetadataPush(message payload.Payload) { + panic("implement me") +} + +func (d delayedRSocket) RequestResponse(message payload.Payload) mono.Mono { + return mono.Create(func(ctx context.Context, sink mono.Sink) { + time.AfterFunc(300*time.Millisecond, func() { + sink.Success(message) + }) + }) +} + +func (d delayedRSocket) RequestStream(message payload.Payload) flux.Flux { + panic("implement me") +} + +func (d delayedRSocket) RequestChannel(messages rx.Publisher) flux.Flux { + panic("implement me") +} + +func TestContextTimeout(t *testing.T) { + var responder delayedRSocket + started := make(chan struct{}) + go func() { + defer func() { + select { + case <-started: + default: + close(started) + } + }() + + _ = Receive(). + OnStart(func() { + close(started) + }). + Acceptor(func(setup payload.SetupPayload, sendingSocket CloseableRSocket) (RSocket, error) { + return responder, nil + }). + Transport(TCPServer().SetAddr(":8088").Build()). + Serve(context.Background()) + }() + + <-started + + tp := TCPClient().SetAddr("127.0.0.1:8088").Build() + + // simulate timeout + ctxMustTimeout, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + _, err := Connect().Transport(tp).Start(ctxMustTimeout) + assert.Error(t, err, "should connect timeout") + + cli, err := Connect().Transport(tp).Start(context.Background()) + assert.NoError(t, err, "should connect success") + defer cli.Close() + + ctx, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel2() + + _, err = cli.RequestResponse(fakeRequest).Block(ctx) + assert.Error(t, err, "should return error") +} diff --git a/rx/flux/flux.go b/rx/flux/flux.go index 9f5a10f..9236bc0 100644 --- a/rx/flux/flux.go +++ b/rx/flux/flux.go @@ -44,7 +44,7 @@ type Flux interface { // DoOnSubscribe add behavior triggered when the Flux is done being subscribed. DoOnSubscribe(rx.FnOnSubscribe) Flux // Map transform the items emitted by this Flux by applying a synchronous function to each item. - Map(func(payload.Payload) (payload.Payload, error)) Flux + Map(rx.FnTransform) Flux // SwitchOnFirst transform the current Flux once it emits its first element, making a conditional transformation possible. SwitchOnFirst(FnSwitchOnFirst) Flux // SubscribeOn run subscribe, onSubscribe and request on a specified scheduler. diff --git a/rx/flux/flux_test.go b/rx/flux/flux_test.go index c3eb818..3e817c7 100644 --- a/rx/flux/flux_test.go +++ b/rx/flux/flux_test.go @@ -149,7 +149,7 @@ func TestCreate(t *testing.T) { DoOnComplete(func() { fmt.Println("complete") }). - Subscribe(context.Background(), rx.OnSubscribe(func(s rx.Subscription) { + Subscribe(context.Background(), rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { su = s su.Request(1) })) @@ -234,7 +234,7 @@ func TestFluxRequest(t *testing.T) { rx.OnComplete(func() { fmt.Println("complete") }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { su = s su.Request(1) fmt.Println("request:", 1) @@ -271,7 +271,7 @@ func TestFluxProcessorWithRequest(t *testing.T) { su.Request(1) return nil }), - rx.OnSubscribe(func(s rx.Subscription) { + rx.OnSubscribe(func(ctx context.Context, s rx.Subscription) { su = s su.Request(1) }), diff --git a/rx/flux/proxy.go b/rx/flux/proxy.go index b05ed09..5dff4f5 100644 --- a/rx/flux/proxy.go +++ b/rx/flux/proxy.go @@ -132,8 +132,8 @@ func (p proxy) BlockSlice(ctx context.Context) (results []payload.Payload, err e } func (p proxy) DoOnSubscribe(fn rx.FnOnSubscribe) Flux { - return newProxy(p.Flux.DoOnSubscribe(func(su reactor.Subscription) { - fn(su) + return newProxy(p.Flux.DoOnSubscribe(func(ctx context.Context, su reactor.Subscription) { + fn(ctx, su) })) } @@ -176,8 +176,8 @@ func (p proxy) SubscribeWith(ctx context.Context, s rx.Subscriber) { reactor.OnComplete(func() { s.OnComplete() }), - reactor.OnSubscribe(func(su reactor.Subscription) { - s.OnSubscribe(su) + reactor.OnSubscribe(func(ctx context.Context, su reactor.Subscription) { + s.OnSubscribe(ctx, su) }), ) } diff --git a/rx/mono/mono.go b/rx/mono/mono.go index a6b9abf..cf9feab 100644 --- a/rx/mono/mono.go +++ b/rx/mono/mono.go @@ -15,6 +15,10 @@ type Mono interface { // Filter evaluate each source value against the given Predicate. // If the predicate test succeeds, the value is emitted. Filter(rx.FnPredicate) Mono + // Map transform the item emitted by this Mono by applying a synchronous function to another. + Map(rx.FnTransform) Mono + // FlatMap Transform the item emitted by this Mono asynchronously, returning the value emitted by another Mono. + FlatMap(func(payload.Payload) Mono) Mono // DoFinally adds behavior (side-effect) triggered after the Mono terminates for any reason, including cancellation. DoFinally(rx.FnFinally) Mono // DoOnError adds behavior (side-effect) triggered when the Mono completes with an error. diff --git a/rx/mono/mono_test.go b/rx/mono/mono_test.go index bd910a0..1c90bcd 100644 --- a/rx/mono/mono_test.go +++ b/rx/mono/mono_test.go @@ -113,6 +113,47 @@ func TestProxy_Filter(t *testing.T) { Subscribe(context.Background()) } +func TestMap(t *testing.T) { + fakeMono := Just(payload.NewString("hello", "world")) + value, err := fakeMono. + Map(func(p payload.Payload) (payload.Payload, error) { + data := strings.ToUpper(p.DataUTF8()) + metadata, _ := p.MetadataUTF8() + return payload.NewString(data, metadata), nil + }). + Map(func(p payload.Payload) (payload.Payload, error) { + metadata, _ := p.MetadataUTF8() + return payload.NewString(p.DataUTF8(), strings.ToUpper(metadata)), nil + }). + Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "HELLO", value.DataUTF8()) + metadata, _ := value.MetadataUTF8() + assert.Equal(t, "WORLD", metadata) + + _, err = fakeMono.Map(func(p payload.Payload) (payload.Payload, error) { + return nil, errors.New("fake error") + }).Block(context.Background()) + assert.Error(t, err) +} + +func TestFlatMap(t *testing.T) { + res, err := Just(payload.NewString("foo", "")). + FlatMap(func(p payload.Payload) Mono { + return Create(func(ctx context.Context, sink Sink) { + select { + case <-ctx.Done(): + sink.Error(errors.New("cancelled")) + case <-time.After(100 * time.Millisecond): + sink.Success(payload.NewString("bar", "")) + } + }).SubscribeOn(scheduler.Parallel()) + }). + Block(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "bar", res.DataUTF8()) +} + func TestCreate(t *testing.T) { Create(func(i context.Context, sink Sink) { sink.Success(payload.NewString("hello", "world")) diff --git a/rx/mono/proxy.go b/rx/mono/proxy.go index 3cdafd2..048f965 100644 --- a/rx/mono/proxy.go +++ b/rx/mono/proxy.go @@ -85,11 +85,23 @@ func (p proxy) Block(ctx context.Context) (pa payload.Payload, err error) { } func (p proxy) Filter(fn rx.FnPredicate) Mono { - return newProxy(p.Mono.Filter(func(i interface{}) bool { + return newProxy(p.Mono.Filter(func(i reactor.Any) bool { return fn(i.(payload.Payload)) })) } +func (p proxy) Map(transform rx.FnTransform) Mono { + return newProxy(p.Mono.Map(func(any reactor.Any) (reactor.Any, error) { + return transform(any.(payload.Payload)) + })) +} + +func (p proxy) FlatMap(transform func(payload.Payload) Mono) Mono { + return newProxy(p.Mono.FlatMap(func(any reactor.Any) mono.Mono { + return transform(any.(payload.Payload)).Raw() + })) +} + func (p proxy) DoFinally(fn rx.FnFinally) Mono { return newProxy(p.Mono.DoFinally(func(signal reactor.SignalType) { fn(rx.SignalType(signal)) @@ -108,8 +120,8 @@ func (p proxy) DoOnSuccess(next rx.FnOnNext) Mono { } func (p proxy) DoOnSubscribe(fn rx.FnOnSubscribe) Mono { - return newProxy(p.Mono.DoOnSubscribe(func(su reactor.Subscription) { - fn(su) + return newProxy(p.Mono.DoOnSubscribe(func(ctx context.Context, su reactor.Subscription) { + fn(ctx, su) })) } @@ -137,8 +149,8 @@ func (p proxy) SubscribeWith(ctx context.Context, actual rx.Subscriber) { reactor.OnComplete(func() { actual.OnComplete() }), - reactor.OnSubscribe(func(su reactor.Subscription) { - actual.OnSubscribe(su) + reactor.OnSubscribe(func(ctx context.Context, su reactor.Subscription) { + actual.OnSubscribe(ctx, su) }), reactor.OnError(func(e error) { actual.OnError(e) diff --git a/rx/rx.go b/rx/rx.go index f06f490..0485eae 100644 --- a/rx/rx.go +++ b/rx/rx.go @@ -25,7 +25,7 @@ type ( // FnOnNext is alias of function for signal when next element arrived. FnOnNext = func(input payload.Payload) error // FnOnSubscribe is alias of function for signal when subscribe begin. - FnOnSubscribe = func(s Subscription) + FnOnSubscribe = func(ctx context.Context, s Subscription) // FnOnError is alias of function for signal when an error occurred. FnOnError = func(e error) // FnOnCancel is alias of function for signal when subscription canceled. @@ -36,6 +36,8 @@ type ( FnPredicate = func(input payload.Payload) bool // FnOnRequest is alias of function for signal when requesting next element. FnOnRequest = func(n int) + // FnTransform is alias of function to transform a payload to another. + FnTransform = func(payload.Payload) (payload.Payload, error) ) // RawPublisher represents a basic Publisher which can be subscribed by a Subscriber. diff --git a/rx/subscriber.go b/rx/subscriber.go index 4a2ec85..d75ab05 100644 --- a/rx/subscriber.go +++ b/rx/subscriber.go @@ -1,6 +1,8 @@ package rx import ( + "context" + "github.com/jjeffcaii/reactor-go" "github.com/rsocket/rsocket-go/payload" ) @@ -27,7 +29,7 @@ type Subscriber interface { OnComplete() // OnSubscribe invoked after Publisher subscribed. // No data will start flowing until Subscription#Request is invoked. - OnSubscribe(Subscription) + OnSubscribe(context.Context, Subscription) } type subscriber struct { @@ -56,9 +58,9 @@ func (s *subscriber) OnComplete() { } } -func (s *subscriber) OnSubscribe(su Subscription) { +func (s *subscriber) OnSubscribe(ctx context.Context, su Subscription) { if s != nil && s.fnOnSubscribe != nil { - s.fnOnSubscribe(su) + s.fnOnSubscribe(ctx, su) } else { su.Request(RequestMax) } diff --git a/transporter_builder.go b/transporter_builder.go index 3f0860b..7104fc8 100644 --- a/transporter_builder.go +++ b/transporter_builder.go @@ -80,7 +80,7 @@ func (uc *UnixClientBuilder) SetPath(path string) *UnixClientBuilder { // Build builds and returns a new ClientTransporter. func (uc UnixClientBuilder) Build() transport.ClientTransporter { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTCPClientTransportWithAddr("unix", uc.path, nil) + return transport.NewTCPClientTransportWithAddr(ctx, "unix", uc.path, nil) } } @@ -176,7 +176,7 @@ func (wc *WebsocketClientBuilder) SetHeader(header http.Header) *WebsocketClient // Build builds and returns a new websocket ClientTransporter func (wc *WebsocketClientBuilder) Build() transport.ClientTransporter { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewWebsocketClientTransport(wc.url, wc.tlsCfg, wc.header) + return transport.NewWebsocketClientTransport(ctx, wc.url, wc.tlsCfg, wc.header) } } @@ -255,7 +255,7 @@ func (tc *TCPClientBuilder) SetTLSConfig(c *tls.Config) *TCPClientBuilder { // Build builds and returns a new TCP ClientTransporter. func (tc *TCPClientBuilder) Build() transport.ClientTransporter { return func(ctx context.Context) (*transport.Transport, error) { - return transport.NewTCPClientTransportWithAddr("tcp", tc.addr, tc.tlsCfg) + return transport.NewTCPClientTransportWithAddr(ctx, "tcp", tc.addr, tc.tlsCfg) } }