Skip to content

Commit

Permalink
unique session module
Browse files Browse the repository at this point in the history
  • Loading branch information
felipejfc committed May 4, 2018
1 parent ef204b2 commit 3442b55
Show file tree
Hide file tree
Showing 17 changed files with 429 additions and 74 deletions.
6 changes: 5 additions & 1 deletion agent/agent_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ func (a *Remote) Push(route string, v interface{}) error {
a.Session.ID(), a.Session.UID(), route, v)
}

sv, err := a.serviceDiscovery.GetServer(a.frontendID)
if err != nil {
return err
}
return a.sendPush(
pendingMessage{typ: message.Push, route: route, payload: v},
cluster.GetUserMessagesTopic(a.Session.UID()),
cluster.GetUserMessagesTopic(a.Session.UID(), sv.Type),
)
}

Expand Down
9 changes: 7 additions & 2 deletions agent/agent_remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,17 @@ func TestAgentRemotePush(t *testing.T) {
if table.rpcClient == nil {
table.rpcClient = clustermocks.NewMockRPCClient(ctrl)
}
fSvID := "123id"
ss := &protos.Session{Uid: table.uid}
mockSerializer := serializemocks.NewMockSerializer(ctrl)
remote, err := NewRemote(ss, "", table.rpcClient, nil, mockSerializer, nil, "", nil)
mockSD := clustermocks.NewMockServiceDiscovery(ctrl)
remote, err := NewRemote(ss, "", table.rpcClient, nil, mockSerializer, mockSD, fSvID, nil)
assert.NoError(t, err)
assert.NotNil(t, remote)

if table.uid != "" {
expectedData := []byte("done")
topic := cluster.GetUserMessagesTopic(table.uid)
topic := cluster.GetUserMessagesTopic(table.uid, "connector")

if reflect.TypeOf(table.data) == reflect.TypeOf(([]byte)(nil)) {
expectedData = table.data.([]byte)
Expand All @@ -152,6 +154,9 @@ func TestAgentRemotePush(t *testing.T) {
}
}

if table.err != constants.ErrNoUIDBind {
mockSD.EXPECT().GetServer(fSvID).Return(cluster.NewServer(fSvID, "connector", true), nil)
}
err = remote.Push(route, table.data)
assert.Equal(t, table.err, err)
})
Expand Down
15 changes: 12 additions & 3 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/topfreegames/pitaya/internal/codec"
"github.com/topfreegames/pitaya/internal/message"
"github.com/topfreegames/pitaya/logger"
mods "github.com/topfreegames/pitaya/modules"
"github.com/topfreegames/pitaya/remote"
"github.com/topfreegames/pitaya/route"
"github.com/topfreegames/pitaya/router"
Expand Down Expand Up @@ -132,9 +133,9 @@ func Configure(
app.server.Frontend = isFrontend
app.server.Type = serverType
app.serverMode = serverMode
app.configured = true
app.server.Metadata = serverMetadata
app.messageEncoder = message.NewEncoder(app.config.GetBool("pitaya.dataCompression"))
app.configured = true
}

// AddAcceptor adds a new acceptor to app
Expand Down Expand Up @@ -182,9 +183,9 @@ func SetRPCServer(s cluster.RPCServer) {
if reflect.TypeOf(s) == reflect.TypeOf(&cluster.NatsRPCServer{}) {
// When using nats rpc server the server must start listening to messages
// destined to the userID that's binding
session.SetOnSessionBind(func(s *session.Session) error {
session.OnSessionBind(func(ctx context.Context, s *session.Session) error {
if app.server.Frontend && app.rpcServer != nil {
subs, err := app.rpcServer.(*cluster.NatsRPCServer).SubscribeToUserMessages(s.UID())
subs, err := app.rpcServer.(*cluster.NatsRPCServer).SubscribeToUserMessages(s.UID(), app.server.Type)
if err != nil {
return err
}
Expand Down Expand Up @@ -362,9 +363,17 @@ func listen() {

logger.Log.Infof("listening with acceptor %s on addr %s", reflect.TypeOf(a), a.GetAddr())
}
if app.serverMode == Cluster && app.server.Frontend && reflect.TypeOf(app.rpcServer) == reflect.TypeOf(&cluster.NatsRPCServer{}) {
if app.config.GetBool("pitaya.session.unique") {
unique := mods.NewUniqueSession(app.server, app.rpcServer.(*cluster.NatsRPCServer), app.rpcClient.(*cluster.NatsRPCClient))
RegisterModule(unique, "uniqueSession")
}
}

startModules()

logger.Log.Info("all modules started!")

// this handles remote messages
if app.rpcServer != nil {
for i := 0; i < app.config.GetInt("pitaya.concurrency.remote.service"); i++ {
Expand Down
31 changes: 25 additions & 6 deletions cluster/nats_rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ type NatsRPCServer struct {
config *config.Config
stopChan chan bool
subChan chan *nats.Msg // subChan is the channel used by the server to receive network messages addressed to itself
bindingsChan chan *nats.Msg // bindingsChan receives notify from other servers on every user bind to session
unhandledReqCh chan *protos.Request
userPushCh chan *protos.Push
sub *nats.Subscription
sub *nats.Subscription // TODO monitor its size?
dropped int
}

Expand Down Expand Up @@ -78,20 +79,37 @@ func (ns *NatsRPCServer) configure() error {
return constants.ErrNatsPushBufferSizeZero
}
ns.subChan = make(chan *nats.Msg, ns.messagesBufferSize)
ns.bindingsChan = make(chan *nats.Msg, ns.messagesBufferSize)
// the reason this channel is buffered is that we can achieve more performance by not
// blocking producers on a massive push
ns.userPushCh = make(chan *protos.Push, ns.pushBufferSize)
return nil
}

// GetBindingsChannel gets the channel that will receive all bindings
func (ns *NatsRPCServer) GetBindingsChannel() chan *nats.Msg {
return ns.bindingsChan
}

// GetUserMessagesTopic get the topic for user
func GetUserMessagesTopic(uid string) string {
return fmt.Sprintf("pitaya/user/%s/push", uid)
func GetUserMessagesTopic(uid string, svType string) string {
return fmt.Sprintf("pitaya/%s/user/%s/push", svType, uid)
}

// GetBindBroadcastTopic gets the topic on which bind events will be broadcasted
func GetBindBroadcastTopic(svType string) string {
return fmt.Sprintf("pitaya/%s/bindings", svType)
}

// SubscribeToBindingsChannel subscribes to the channel that will receive binding notifications from other servers
func (ns *NatsRPCServer) SubscribeToBindingsChannel() error {
_, err := ns.conn.ChanSubscribe(GetBindBroadcastTopic(ns.server.Type), ns.bindingsChan)
return err
}

// SubscribeToUserMessages subscribes to user msg channel
func (ns *NatsRPCServer) SubscribeToUserMessages(uid string) (*nats.Subscription, error) {
subs, err := ns.conn.Subscribe(GetUserMessagesTopic(uid), func(msg *nats.Msg) {
func (ns *NatsRPCServer) SubscribeToUserMessages(uid string, svType string) (*nats.Subscription, error) {
subs, err := ns.conn.Subscribe(GetUserMessagesTopic(uid, svType), func(msg *nats.Msg) {
push := &protos.Push{}
err := proto.Unmarshal(msg.Data, push)
if err != nil {
Expand All @@ -109,6 +127,7 @@ func (ns *NatsRPCServer) handleMessages() {
defer (func() {
close(ns.unhandledReqCh)
close(ns.subChan)
close(ns.bindingsChan)
})()
maxPending := float64(0)
for {
Expand Down Expand Up @@ -162,7 +181,7 @@ func (ns *NatsRPCServer) Init() error {
if ns.sub, err = ns.subscribe(getChannel(ns.server.Type, ns.server.ID)); err != nil {
return err
}
return nil
return ns.SubscribeToBindingsChannel()
}

// AfterInit runs after initialization
Expand Down
47 changes: 37 additions & 10 deletions cluster/nats_rpc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func TestNatsRPCServerConfigure(t *testing.T) {

func TestNatsRPCServerGetUserMessagesTopic(t *testing.T) {
t.Parallel()
assert.Equal(t, "pitaya/user/bla/push", GetUserMessagesTopic("bla"))
assert.Equal(t, "pitaya/user/123bla/push", GetUserMessagesTopic("123bla"))
assert.Equal(t, "pitaya/user/1/push", GetUserMessagesTopic("1"))
assert.Equal(t, "pitaya/connector/user/bla/push", GetUserMessagesTopic("bla", "connector"))
assert.Equal(t, "pitaya/game/user/123bla/push", GetUserMessagesTopic("123bla", "game"))
assert.Equal(t, "pitaya/connector/user/1/push", GetUserMessagesTopic("1", "connector"))
}

func TestNatsRPCServerGetUnhandledRequestsChannel(t *testing.T) {
Expand All @@ -87,6 +87,32 @@ func TestNatsRPCServerGetUnhandledRequestsChannel(t *testing.T) {
assert.IsType(t, make(chan *protos.Request), n.GetUnhandledRequestsChannel())
}

func TestNatsRPCServerGetBindingsChannel(t *testing.T) {
t.Parallel()
cfg := getConfig()
sv := getServer()
n, _ := NewNatsRPCServer(cfg, sv)
assert.Equal(t, n.bindingsChan, n.GetBindingsChannel())
}

func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) {
t.Parallel()
cfg := getConfig()
sv := getServer()
rpcServer, _ := NewNatsRPCServer(cfg, sv)
s := helpers.GetTestNatsServer(t)
defer s.Shutdown()
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()))
assert.NoError(t, err)
rpcServer.conn = conn
err = rpcServer.SubscribeToBindingsChannel()
assert.NoError(t, err)
dt := []byte("somedata")
conn.Publish(GetBindBroadcastTopic(sv.Type), dt)
msg := helpers.ShouldEventuallyReceive(t, rpcServer.GetBindingsChannel()).(*nats.Msg)
assert.Equal(t, msg.Data, dt)
}

func TestNatsRPCServerGetUserPushChannel(t *testing.T) {
t.Parallel()
cfg := getConfig()
Expand All @@ -106,20 +132,21 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) {
assert.NoError(t, err)
rpcServer.conn = conn
tables := []struct {
uid string
msg []byte
uid string
svType string
msg []byte
}{
{"user1", []byte("msg1")},
{"user2", []byte("")},
{"u", []byte("000")},
{"user1", "conn", []byte("msg1")},
{"user2", "game", []byte("")},
{"u", "conn", []byte("000")},
}

for _, table := range tables {
t.Run(table.uid, func(t *testing.T) {
subs, err := rpcServer.SubscribeToUserMessages(table.uid)
subs, err := rpcServer.SubscribeToUserMessages(table.uid, table.svType)
assert.NoError(t, err)
assert.Equal(t, true, subs.IsValid())
conn.Publish(GetUserMessagesTopic(table.uid), table.msg)
conn.Publish(GetUserMessagesTopic(table.uid, table.svType), table.msg)
helpers.ShouldEventuallyReceive(t, rpcServer.userPushCh)
})
}
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (c *Config) fillDefaultValues() {
"pitaya.cluster.sd.etcd.syncservers.interval": "120s",
"pitaya.dataCompression": true,
"pitaya.heartbeat.interval": "30s",
"pitaya.session.unique": true,
}

for param := range defaultsMap {
Expand Down
1 change: 1 addition & 0 deletions constants/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ var (
ErrWrongValueType = errors.New("protobuf: convert on wrong type value")
ErrInvalidCertificates = errors.New("certificates must be exactly two")
ErrTimeoutTerminatingBinaryModule = errors.New("timeout waiting to binary module to die")
ErrFrontendTypeNotSpecified = errors.New("for using SendPushToUsers from a backend server you have to specify a valid frontendType")
)
Loading

0 comments on commit 3442b55

Please sign in to comment.