Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Also check sockets bind to tcp6 and fail on all closed sockets (#824)
Browse files Browse the repository at this point in the history
Also check sockets bind to tcp6 and fail on all closed sockets
  • Loading branch information
ajnavarro committed Sep 30, 2019
2 parents 2e82b0a + 604c376 commit 6ee998b
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 33 deletions.
4 changes: 2 additions & 2 deletions internal/sockstate/netstat_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
// (juanjux) TODO: not implemented
logrus.Info("Connection checking not implemented for Darwin")
return []sockTabEntry{}, nil
logrus.Warn("Connection checking not implemented for Darwin")
return nil, ErrSocketCheckNotImplemented.New()
}

func GetConnInode(c *net.TCPConn) (n uint64, err error) {
Expand Down
54 changes: 40 additions & 14 deletions internal/sockstate/netstat_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
)

const (
pathTCPTab = "/proc/net/tcp"
pathTCP4Tab = "/proc/net/tcp"
pathTCP6Tab = "/proc/net/tcp6"
ipv4StrLen = 8
ipv6StrLen = 32
)

type procFd struct {
Expand Down Expand Up @@ -120,6 +122,23 @@ func parseIPv4(s string) (net.IP, error) {
return ip, nil
}

func parseIPv6(s string) (net.IP, error) {
ip := make(net.IP, net.IPv6len)
const grpLen = 4
i, j := 0, 4
for len(s) != 0 {
grp := s[0:8]
u, err := strconv.ParseUint(grp, 16, 32)
binary.LittleEndian.PutUint32(ip[i:j], uint32(u))
if err != nil {
return nil, err
}
i, j = i+grpLen, j+grpLen
s = s[8:]
}
return ip, nil
}

func parseAddr(s string) (*sockAddr, error) {
fields := strings.Split(s, ":")
if len(fields) < 2 {
Expand All @@ -130,6 +149,8 @@ func parseAddr(s string) (*sockAddr, error) {
switch len(fields[0]) {
case ipv4StrLen:
ip, err = parseIPv4(fields[0])
case ipv6StrLen:
ip, err = parseIPv6(fields[0])
default:
log.Fatal("Badly formatted connection address:", s)
}
Expand Down Expand Up @@ -192,21 +213,26 @@ func parseSocktab(r io.Reader, accept AcceptFn) ([]sockTabEntry, error) {
// tcpSocks returns a slice of active TCP sockets containing only those
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
f, err := os.Open(pathTCPTab)
defer func() {
_ = f.Close()
}()
if err != nil {
return nil, err
}
paths := [2]string{pathTCP4Tab, pathTCP6Tab}
var allTabs []sockTabEntry
for _, p := range paths {
f, err := os.Open(p)
defer func() {
_ = f.Close()
}()
if err != nil {
return nil, err
}

tabs, err := parseSocktab(f, accept)
if err != nil {
return nil, err
}
t, err := parseSocktab(f, accept)
if err != nil {
return nil, err
}
allTabs = append(allTabs, t...)

extractProcInfo(tabs)
return tabs, nil
}
extractProcInfo(allTabs)
return allTabs, nil
}

// GetConnInode returns the Linux inode number of a TCP connection
Expand Down
4 changes: 2 additions & 2 deletions internal/sockstate/netstat_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
// elements that satisfy the accept function
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
// (juanjux) TODO: not implemented
logrus.Info("Connection checking not implemented for Windows")
return []sockTabEntry{}, nil
logrus.Warn("Connection checking not implemented for Windows")
return nil, ErrSocketCheckNotImplemented.New()
}

func GetConnInode(c *net.TCPConn) (n uint64, err error) {
Expand Down
21 changes: 16 additions & 5 deletions internal/sockstate/sockstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
type SockState uint8

const (
Finished = iota
Broken
Broken = iota
Other
Error
)
Expand Down Expand Up @@ -37,12 +36,24 @@ func GetInodeSockState(port int, inode uint64) (SockState, error) {

switch len(socks) {
case 0:
return Finished, nil
return Broken, nil
case 1:
if socks[0].State == CloseWait {
switch socks[0].State {
case CloseWait:
fallthrough
case TimeWait:
fallthrough
case FinWait1:
fallthrough
case FinWait2:
fallthrough
case Close:
fallthrough
case Closing:
return Broken, nil
default:
return Other, nil
}
return Other, nil
default: // more than one sock for inode, impossible?
return Error, ErrMultipleSocketsForInode.New()
}
Expand Down
8 changes: 2 additions & 6 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,12 @@ func (h *Handler) ComQuery(
for {
select {
case <-quit:
// timeout or other errors detected by the calling routine
return
default:
}

st, err := sockstate.GetInodeSockState(t.Port, inode)
switch st {
case sockstate.Finished:
// Not Linux OSs will also exit here
return
case sockstate.Broken:
errChan <- ErrConnectionWasClosed.New()
return
Expand All @@ -243,6 +239,7 @@ rowLoop:

if r.RowsAffected == rowsBatch {
if err := callback(r); err != nil {
close(quit)
return err
}

Expand Down Expand Up @@ -276,13 +273,12 @@ rowLoop:
}
timer.Reset(waitTime)
}
close(quit)

if err := rows.Close(); err != nil {
return err
}

close(quit)

// Even if r.RowsAffected = 0, the callback must be
// called to update the state in the go-vitess' listener
// and avoid returning errors when the query doesn't
Expand Down
8 changes: 4 additions & 4 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ func TestHandlerKill(t *testing.T) {
require.Len(handler.c, 2)
require.Equal(conntainer1, handler.c[1])
require.Equal(conntainer2, handler.c[2])

assertNoConnProcesses(t, e, conn2.ConnectionID)

ctx1 := handler.sm.NewContextWithQuery(conn1, "SELECT 1")
Expand Down Expand Up @@ -256,6 +255,7 @@ func TestHandlerTimeout(t *testing.T) {
})
require.NoError(err)
}

func TestOkClosedConnection(t *testing.T) {
require := require.New(t)
e := setupMemDB(require)
Expand All @@ -282,11 +282,11 @@ func TestOkClosedConnection(t *testing.T) {
0,
)
h.AddNetConnection(&conn)
c2 := newConn(2)
h.NewConnection(c2)
c := newConn(1)
h.NewConnection(c)

q := fmt.Sprintf("SELECT SLEEP(%d)", tcpCheckerSleepTime*4)
err = h.ComQuery(c2, q, func(res *sqltypes.Result) error {
err = h.ComQuery(c, q, func(res *sqltypes.Result) error {
return nil
})
require.NoError(err)
Expand Down

0 comments on commit 6ee998b

Please sign in to comment.