Skip to content

Commit

Permalink
using dieChan to close application
Browse files Browse the repository at this point in the history
  • Loading branch information
henrod committed Jun 5, 2018
1 parent fdc512d commit 77a68f4
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func startDefaultSD() {

func startDefaultRPCServer() {
// initialize default rpc server
rpcServer, err := cluster.NewNatsRPCServer(app.config, app.server, app.metricsReporters)
rpcServer, err := cluster.NewNatsRPCServer(app.config, app.server, app.metricsReporters, app.dieChan)
if err != nil {
logger.Log.Fatalf("error starting cluster rpc server component: %s", err.Error())
}
Expand All @@ -256,7 +256,7 @@ func startDefaultRPCServer() {

func startDefaultRPCClient() {
// initialize default rpc client
rpcClient, err := cluster.NewNatsRPCClient(app.config, app.server, app.metricsReporters)
rpcClient, err := cluster.NewNatsRPCClient(app.config, app.server, app.metricsReporters, app.dieChan)
if err != nil {
logger.Log.Fatalf("error starting cluster rpc client component: %s", err.Error())
}
Expand Down
8 changes: 4 additions & 4 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ func setup() {
etcdSD, _ := cluster.NewEtcdServiceDiscovery(app.config, app.server)
typeOfetcdSD = reflect.TypeOf(etcdSD)

natsRPCServer, _ := cluster.NewNatsRPCServer(app.config, app.server, nil)
natsRPCServer, _ := cluster.NewNatsRPCServer(app.config, app.server, nil, app.dieChan)
typeOfNatsRPCServer = reflect.TypeOf(natsRPCServer)

natsRPCClient, _ := cluster.NewNatsRPCClient(app.config, app.server, nil)
natsRPCClient, _ := cluster.NewNatsRPCClient(app.config, app.server, nil, app.dieChan)
typeOfNatsRPCClient = reflect.TypeOf(natsRPCClient)
}

Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSetHeartbeatInterval(t *testing.T) {
func TestSetRPCServer(t *testing.T) {
initApp()
Configure(true, "testtype", Cluster, map[string]string{}, viper.New())
r, err := cluster.NewNatsRPCServer(app.config, app.server, nil)
r, err := cluster.NewNatsRPCServer(app.config, app.server, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, r)

Expand All @@ -188,7 +188,7 @@ func TestSetRPCServer(t *testing.T) {
func TestSetRPCClient(t *testing.T) {
initApp()
Configure(true, "testtype", Cluster, map[string]string{}, viper.New())
r, err := cluster.NewNatsRPCClient(app.config, app.server, nil)
r, err := cluster.NewNatsRPCClient(app.config, app.server, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, r)
SetRPCClient(r)
Expand Down
11 changes: 9 additions & 2 deletions cluster/nats_rpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,22 @@ type NatsRPCClient struct {
running bool
server *Server
metricsReporters []metrics.Reporter
appDieChan chan bool
}

// NewNatsRPCClient ctor
func NewNatsRPCClient(config *config.Config, server *Server, metricsReporters []metrics.Reporter) (*NatsRPCClient, error) {
func NewNatsRPCClient(
config *config.Config,
server *Server,
metricsReporters []metrics.Reporter,
appDieChan chan bool,
) (*NatsRPCClient, error) {
ns := &NatsRPCClient{
config: config,
server: server,
running: false,
metricsReporters: metricsReporters,
appDieChan: appDieChan,
}
if err := ns.configure(); err != nil {
return nil, err
Expand Down Expand Up @@ -222,7 +229,7 @@ func (ns *NatsRPCClient) Call(
// Init inits nats rpc client
func (ns *NatsRPCClient) Init() error {
ns.running = true
conn, err := setupNatsConn(ns.connString, nats.MaxReconnects(ns.maxReconnectionRetries))
conn, err := setupNatsConn(ns.connString, ns.appDieChan, nats.MaxReconnects(ns.maxReconnectionRetries))
if err != nil {
return err
}
Expand Down
24 changes: 12 additions & 12 deletions cluster/nats_rpc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestNewNatsRPCClient(t *testing.T) {

cfg := getConfig()
sv := getServer()
n, err := NewNatsRPCClient(cfg, sv, mockMetricsReporters)
n, err := NewNatsRPCClient(cfg, sv, mockMetricsReporters, nil)
assert.NoError(t, err)
assert.NotNil(t, n)
assert.Equal(t, sv, n.server)
Expand All @@ -80,7 +80,7 @@ func TestNatsRPCClientConfigure(t *testing.T) {
cfg.Set("pitaya.cluster.rpc.client.nats.connect", table.natsConnect)
cfg.Set("pitaya.cluster.rpc.client.nats.requesttimeout", table.reqTimeout)
conf := getConfig(cfg)
_, err := NewNatsRPCClient(conf, getServer(), nil)
_, err := NewNatsRPCClient(conf, getServer(), nil, nil)
assert.Equal(t, table.err, err)
})
}
Expand All @@ -90,15 +90,15 @@ func TestNatsRPCClientGetSubscribeChannel(t *testing.T) {
t.Parallel()
cfg := getConfig()
sv := getServer()
n, _ := NewNatsRPCClient(cfg, sv, nil)
n, _ := NewNatsRPCClient(cfg, sv, nil, nil)
assert.Equal(t, fmt.Sprintf("pitaya/servers/%s/%s", n.server.Type, n.server.ID), n.getSubscribeChannel())
}

func TestNatsRPCClientStop(t *testing.T) {
t.Parallel()
cfg := getConfig()
sv := getServer()
n, _ := NewNatsRPCClient(cfg, sv, nil)
n, _ := NewNatsRPCClient(cfg, sv, nil, nil)
// change it to true to ensure it goes to false
n.running = true
n.stop()
Expand All @@ -111,7 +111,7 @@ func TestNatsRPCClientInitShouldFailIfConnFails(t *testing.T) {
cfg := viper.New()
cfg.Set("pitaya.cluster.rpc.client.nats.connect", "nats://localhost:1")
config := getConfig(cfg)
rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
err := rpcClient.Init()
assert.Error(t, err)
}
Expand All @@ -124,7 +124,7 @@ func TestNatsRPCClientInit(t *testing.T) {
config := getConfig(cfg)
sv := getServer()

rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
err := rpcClient.Init()
assert.NoError(t, err)
assert.True(t, rpcClient.running)
Expand All @@ -137,7 +137,7 @@ func TestNatsRPCClientInit(t *testing.T) {
func TestNatsRPCClientSendShouldFailIfNotRunning(t *testing.T) {
config := getConfig()
sv := getServer()
rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
err := rpcClient.Send("topic", []byte("data"))
assert.Equal(t, constants.ErrRPCClientNotInitialized, err)
}
Expand All @@ -150,7 +150,7 @@ func TestNatsRPCClientSend(t *testing.T) {
config := getConfig(cfg)
sv := getServer()

rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
rpcClient.Init()

tables := []struct {
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestNatsRPCClientSend(t *testing.T) {
func TestNatsRPCClientBuildRequest(t *testing.T) {
config := getConfig()
sv := getServer()
rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)

rt := route.NewRoute("sv", "svc", "method")
ss := session.New(nil, true, "uid")
Expand Down Expand Up @@ -283,7 +283,7 @@ func TestNatsRPCClientBuildRequest(t *testing.T) {
func TestNatsRPCClientCallShouldFailIfNotRunning(t *testing.T) {
config := getConfig()
sv := getServer()
rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
res, err := rpcClient.Call(context.Background(), protos.RPCType_Sys, nil, nil, nil, sv)
assert.Equal(t, constants.ErrRPCClientNotInitialized, err)
assert.Nil(t, res)
Expand All @@ -297,7 +297,7 @@ func TestNatsRPCClientCall(t *testing.T) {
cfg.Set("pitaya.cluster.rpc.client.nats.connect", fmt.Sprintf("nats://%s", s.Addr()))
cfg.Set("pitaya.cluster.rpc.client.nats.requesttimeout", "300ms")
config := getConfig(cfg)
rpcClient, _ := NewNatsRPCClient(config, sv, nil)
rpcClient, _ := NewNatsRPCClient(config, sv, nil, nil)
rpcClient.Init()

rt := route.NewRoute("sv", "svc", "method")
Expand All @@ -324,7 +324,7 @@ func TestNatsRPCClientCall(t *testing.T) {

for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()))
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
assert.NoError(t, err)

sv2 := getServer()
Expand Down
7 changes: 4 additions & 3 deletions cluster/nats_rpc_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package cluster

import (
"fmt"
"syscall"

nats "github.com/nats-io/go-nats"
"github.com/topfreegames/pitaya/logger"
Expand All @@ -32,7 +31,7 @@ func getChannel(serverType, serverID string) string {
return fmt.Sprintf("pitaya/servers/%s/%s", serverType, serverID)
}

func setupNatsConn(connectString string, options ...nats.Option) (*nats.Conn, error) {
func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
natsOptions := append(
options,
nats.DisconnectHandler(func(_ *nats.Conn) {
Expand All @@ -48,8 +47,10 @@ func setupNatsConn(connectString string, options ...nats.Option) (*nats.Conn, er
return
}

syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError())
if appDieChan != nil {
close(appDieChan)
}
}),
)

Expand Down
23 changes: 21 additions & 2 deletions cluster/nats_rpc_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ package cluster
import (
"fmt"
"testing"
"time"

"github.com/nats-io/go-nats"
"github.com/stretchr/testify/assert"
"github.com/topfreegames/pitaya/helpers"
)
Expand All @@ -46,14 +48,31 @@ func TestNatsRPCCommonSetupNatsConn(t *testing.T) {
t.Parallel()
s := helpers.GetTestNatsServer(t)
defer s.Shutdown()
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()))
conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil)
assert.NoError(t, err)
assert.NotNil(t, conn)
}

func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) {
t.Parallel()
conn, err := setupNatsConn("nats://localhost:1234")
conn, err := setupNatsConn("nats://localhost:1234", nil)
assert.Error(t, err)
assert.Nil(t, conn)
}

func TestNatsRPCCommonCloseHandler(t *testing.T) {
t.Parallel()
s := helpers.GetTestNatsServer(t)

dieChan := make(chan bool)

conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1),
nats.ReconnectWait(1*time.Millisecond))
assert.NoError(t, err)
assert.NotNil(t, conn)

s.Shutdown()

_, ok := <-dieChan
assert.False(t, ok)
}
11 changes: 9 additions & 2 deletions cluster/nats_rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,24 @@ type NatsRPCServer struct {
sub *nats.Subscription
dropped int
metricsReporters []metrics.Reporter
appDieChan chan bool
}

// NewNatsRPCServer ctor
func NewNatsRPCServer(config *config.Config, server *Server, metricsReporters []metrics.Reporter) (*NatsRPCServer, error) {
func NewNatsRPCServer(
config *config.Config,
server *Server,
metricsReporters []metrics.Reporter,
appDieChan chan bool,
) (*NatsRPCServer, error) {
ns := &NatsRPCServer{
config: config,
server: server,
stopChan: make(chan bool),
unhandledReqCh: make(chan *protos.Request),
dropped: 0,
metricsReporters: metricsReporters,
appDieChan: appDieChan,
}
if err := ns.configure(); err != nil {
return nil, err
Expand Down Expand Up @@ -180,7 +187,7 @@ func (ns *NatsRPCServer) GetUserPushChannel() chan *protos.Push {
func (ns *NatsRPCServer) Init() error {
// TODO should we have concurrency here? it feels like we should
go ns.handleMessages()
conn, err := setupNatsConn(ns.connString, nats.MaxReconnects(ns.maxReconnectionRetries))
conn, err := setupNatsConn(ns.connString, ns.appDieChan, nats.MaxReconnects(ns.maxReconnectionRetries))
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 77a68f4

Please sign in to comment.