Skip to content

Commit

Permalink
use ctx on routingFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
henrod committed Sep 11, 2018
1 parent 5bcf279 commit cd5d4e9
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
6 changes: 3 additions & 3 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,19 @@ func TestAddRoute(t *testing.T) {
initApp()
Configure(true, "testtype", Cluster, map[string]string{}, viper.New())
app.router = nil
err := AddRoute("somesv", func(session *session.Session, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
err := AddRoute("somesv", func(ctx context.Context, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
return nil, nil
})
assert.EqualError(t, constants.ErrRouterNotInitialized, err.Error())

app.router = router.New()
err = AddRoute("somesv", func(session *session.Session, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
err = AddRoute("somesv", func(ctx context.Context, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
return nil, nil
})
assert.NoError(t, err)

app.running = true
err = AddRoute("somesv", func(session *session.Session, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
err = AddRoute("somesv", func(ctx context.Context, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) {
return nil, nil
})
assert.EqualError(t, constants.ErrChangeRouteWhileRunning, err.Error())
Expand Down
1 change: 1 addition & 0 deletions mocks/net.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package router

import (
"context"
"math/rand"
"time"

Expand All @@ -30,7 +31,6 @@ import (
"github.com/topfreegames/pitaya/logger"
"github.com/topfreegames/pitaya/protos"
"github.com/topfreegames/pitaya/route"
"github.com/topfreegames/pitaya/session"
)

// Router struct
Expand All @@ -41,7 +41,7 @@ type Router struct {

// RoutingFunc defines a routing function
type RoutingFunc func(
session *session.Session,
ctx context.Context,
route *route.Route,
payload []byte,
servers map[string]*cluster.Server,
Expand Down Expand Up @@ -74,9 +74,9 @@ func (r *Router) defaultRoute(

// Route gets the right server to use in the call
func (r *Router) Route(
ctx context.Context,
rpcType protos.RPCType,
svType string,
session *session.Session,
route *route.Route,
msg *message.Message,
) (*cluster.Server, error) {
Expand All @@ -97,7 +97,7 @@ func (r *Router) Route(
server := r.defaultRoute(serversOfType)
return server, nil
}
return routeFunc(session, route, msg.Data, serversOfType)
return routeFunc(ctx, route, msg.Data, serversOfType)
}

// AddRoute adds a routing function to a server type
Expand Down
8 changes: 4 additions & 4 deletions router/router_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package router

import (
"context"
"errors"
"testing"

Expand All @@ -11,7 +12,6 @@ import (
"github.com/topfreegames/pitaya/internal/message"
"github.com/topfreegames/pitaya/protos"
"github.com/topfreegames/pitaya/route"
"github.com/topfreegames/pitaya/session"
)

var (
Expand All @@ -24,7 +24,7 @@ var (
}

routingFunction = func(
session *session.Session,
ctx context.Context,
route *route.Route,
payload []byte,
servers map[string]*cluster.Server,
Expand Down Expand Up @@ -70,7 +70,7 @@ func TestDefaultRoute(t *testing.T) {
func TestRoute(t *testing.T) {
t.Parallel()

session := &session.Session{}
ctx := context.Background()
route := route.NewRoute(serverType, "service", "method")

for name, table := range routerTables {
Expand All @@ -86,7 +86,7 @@ func TestRoute(t *testing.T) {
router.AddRoute(serverType, routingFunction)
router.SetServiceDiscovery(mockServiceDiscovery)

retServer, err := router.Route(table.rpcType, table.serverType, session, route, &message.Message{
retServer, err := router.Route(ctx, table.rpcType, table.serverType, route, &message.Message{
Data: []byte{0x01},
})
assert.Equal(t, table.server, retServer)
Expand Down
2 changes: 1 addition & 1 deletion service/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func (r *RemoteService) remoteCall(
target := server

if target == nil {
target, err = r.router.Route(rpcType, svType, session, route, msg)
target, err = r.router.Route(ctx, rpcType, svType, route, msg)
if err != nil {
return nil, e.NewError(err, e.ErrInternalCode)
}
Expand Down

0 comments on commit cd5d4e9

Please sign in to comment.