Skip to content

Commit

Permalink
save client's ip version on session
Browse files Browse the repository at this point in the history
  • Loading branch information
henrod committed Oct 5, 2018
1 parent 33c6262 commit bf89664
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 2 deletions.
9 changes: 9 additions & 0 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion acceptor/tcp_acceptor.go
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/topfreegames/pitaya/constants"
"github.com/topfreegames/pitaya/logger"
"github.com/topfreegames/viaproxy"
)

// TCPAcceptor struct
Expand Down Expand Up @@ -120,6 +121,13 @@ func (a *TCPAcceptor) serve() {
logger.Log.Error(err.Error())
continue
}
a.connChan <- conn

pcn, err := viaproxy.Wrap(conn)
if err != nil {
logger.Log.Errorf("conn wrap error: %q\n", err)
continue
}

a.connChan <- pcn
}
}
18 changes: 18 additions & 0 deletions agent/agent.go
Expand Up @@ -26,6 +26,7 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -268,6 +269,23 @@ func (a *Agent) Handle() {
}
}

// IPVersion returns the remote address ip version.
// net.TCPAddr and net.UDPAddr implementations of String()
// always construct result as <ip>:<port> on both
// ipv4 and ipv6. Also, to see if the ip is ipv6 they both
// check if there is a colon on the string.
// So checking if there are more than one colon here is safe.
func (a *Agent) IPVersion() string {
version := constants.IPv4

ipPort := a.RemoteAddr().String()
if strings.Count(ipPort, ":") > 1 {
version = constants.IPv6
}

return version
}

func (a *Agent) heartbeat() {
ticker := time.NewTicker(a.heartbeatTimeout)

Expand Down
31 changes: 31 additions & 0 deletions agent/agent_test.go
Expand Up @@ -920,3 +920,34 @@ func TestNatsRPCServerReportMetrics(t *testing.T) {
mockMetricsReporter.EXPECT().ReportGauge(metrics.ChannelCapacity, gomock.Any(), float64(-1)) // because buffersize is 0 and chan sz is 1
ag.reportChannelSize()
}

type customMockAddr struct{ network, str string }

func (m *customMockAddr) Network() string { return m.network }
func (m *customMockAddr) String() string { return m.str }

func TestIPVersion(t *testing.T) {
tables := []struct {
addr string
ipVersion string
}{
{"127.0.0.1:80", constants.IPv4},
{"1.2.3.4:3333", constants.IPv4},
{"::1:3333", constants.IPv6},
{"2001:db8:0000:1:1:1:1:1:3333", constants.IPv6},
}

for _, table := range tables {
t.Run("test_"+table.addr, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockConn := mocks.NewMockConn(ctrl)
mockAddr := &customMockAddr{str: table.addr}

mockConn.EXPECT().RemoteAddr().Return(mockAddr)
a := &Agent{conn: mockConn}

assert.Equal(t, table.ipVersion, a.IPVersion())
})
}
}
7 changes: 7 additions & 0 deletions constants/const.go
Expand Up @@ -89,3 +89,10 @@ var GRPCExternalPortKey = "grpc-external-port"

// RegionKey is the key to save the region server is on
var RegionKey = "region"

// IP constants
const (
IPVersionKey = "ipversion"
IPv4 = "ipv4"
IPv6 = "ipv6"
)
5 changes: 4 additions & 1 deletion service/handler.go
Expand Up @@ -219,14 +219,17 @@ func (h *HandlerService) processPacket(a *agent.Agent, p *packet.Packet) error {
// Parse the json sent with the handshake by the client
handshakeData := &session.HandshakeData{}
err := json.Unmarshal(p.Data, handshakeData)

if err != nil {
a.SetStatus(constants.StatusClosed)
return fmt.Errorf("Invalid handshake data. Id=%d", a.Session.ID())
}

a.Session.SetHandshakeData(handshakeData)
a.SetStatus(constants.StatusHandshake)
err = a.Session.Set(constants.IPVersionKey, a.IPVersion())
if err != nil {
logger.Log.Warnf("failed to save ip version on session: %q\n", err)
}

case packet.HandshakeAck:
a.SetStatus(constants.StatusWorking)
Expand Down
3 changes: 3 additions & 0 deletions service/handler_test.go
Expand Up @@ -270,6 +270,9 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) {
mockConn.EXPECT().Write(gomock.Any()).Do(func(d []byte) {
assert.Contains(t, string(d), "heartbeat")
})
if table.errStr == "" {
mockConn.EXPECT().RemoteAddr().Return(&mockAddr{})
}

ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil)

Expand Down

0 comments on commit bf89664

Please sign in to comment.