Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
"os/signal"
"strconv"
"sync"
"sync/atomic"
"syscall"
"unsafe"

"github.com/scmhub/ibapi/protobuf"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -340,9 +342,9 @@ func (c *EClient) reset() {
}

func (c *EClient) setConnState(state ConnState) {
cs := c.connState
c.connState = state
log.Debug().Stringer("from", cs).Stringer("to", c.connState).Msg("connection state changed")
cs := ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState))))
atomic.StoreInt32((*int32)(unsafe.Pointer(&c.connState)), int32(state))
log.Debug().Stringer("from", cs).Stringer("to", state).Msg("connection state changed")
}

// request is a goroutine that will get the req from reqChan and send it to TWS.
Expand Down Expand Up @@ -611,7 +613,7 @@ func (c *EClient) Ctx() context.Context {

// IsConnected checks connection to TWS or GateWay.
func (c *EClient) IsConnected() bool {
return c.conn.IsConnected() && c.connState == CONNECTED
return c.conn.IsConnected() && ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState)))) == CONNECTED
}

// OptionalCapabilities returns the Optional Capabilities.
Expand Down Expand Up @@ -1117,7 +1119,7 @@ func (c *EClient) CancelCalculateOptionPrice(reqID int64) {
//
// exerciseQuantity is the quantity you want to exercise.
// account is the destination account.
// override specifies whether your setting will override the system's natural action.
// override specifies whether your setting will override the system's natural action.
// For example, if your action is "exercise" and the option is not in-the-money, by natural action the option would not exercise.
// If you have override set to "yes" the natural action would be overridden and the out-of-the money option would be exercised.
// Values: 0 = no, 1 = yes.
Expand Down
177 changes: 133 additions & 44 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ibapi
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -11,91 +13,151 @@ const (
reconnectDelay = 500 * time.Millisecond
)

// Connection is a TCPConn wrapper.
// Connection is a TCPConn wrapper with lock-free statistics and minimal contention.
type Connection struct {
*net.TCPConn
wrapper EWrapper
host string
port int
isConnected bool
numBytesSent int
numMsgSent int
numBytesRecv int
numMsgRecv int
// Connection state - protected by mutex for host/port coordination only
mu sync.RWMutex
tcpConn atomic.Pointer[net.TCPConn] // Lock-free pointer for maximum performance
wrapper EWrapper
host string
port int
isConnected int32 // atomic: 0=disconnected, 1=connected

// Statistics - lock-free atomic counters for maximum performance
numBytesSent int64 // atomic
numMsgSent int64 // atomic
numBytesRecv int64 // atomic
numMsgRecv int64 // atomic

// Reconnection control - prevents multiple concurrent reconnections
reconnecting int32 // atomic: 0=not reconnecting, 1=reconnecting
}

func (c *Connection) Write(bs []byte) (int, error) {
// first attempt
n, err := c.TCPConn.Write(bs)
if err == nil {
c.numBytesSent += n
c.numMsgSent++
log.Trace().Int("nBytes", n).Msg("conn write")
return n, nil
// Fast path: try write with current connection
conn := c.getConn()
if conn != nil {
n, err := conn.Write(bs)
if err == nil {
// Lock-free atomic statistics update
atomic.AddInt64(&c.numBytesSent, int64(n))
atomic.AddInt64(&c.numMsgSent, 1)
log.Trace().Int("nBytes", n).Msg("conn write")
return n, nil
}

// Write failed, try to reconnect
log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...")
}
// write failed, try to reconnect
log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...")

// Slow path: reconnect and retry
if err := c.reconnect(); err != nil {
return 0, fmt.Errorf("write failed and reconnection failed: %w", err)
}

// second attempt
n, err = c.TCPConn.Write(bs)
// Retry write after reconnection
conn = c.getConn()
if conn == nil {
return 0, fmt.Errorf("connection still not available after reconnect")
}

n, err := conn.Write(bs)
if err != nil {
return 0, fmt.Errorf("write retry after reconnect failed: %w", err)
}

c.numBytesSent += n
c.numMsgSent++
// Lock-free atomic statistics update
atomic.AddInt64(&c.numBytesSent, int64(n))
atomic.AddInt64(&c.numMsgSent, 1)
log.Trace().Int("nBytes", n).Msg("conn write (after reconnect)")
return n, nil
}

func (c *Connection) Read(bs []byte) (int, error) {
n, err := c.TCPConn.Read(bs)
conn := c.getConn()
if conn == nil {
return 0, fmt.Errorf("connection not available")
}

c.numBytesRecv += n
c.numMsgRecv++
n, err := conn.Read(bs)

// Lock-free atomic statistics update
atomic.AddInt64(&c.numBytesRecv, int64(n))
atomic.AddInt64(&c.numMsgRecv, 1)

log.Trace().Int("nBytes", n).Msg("conn read")

return n, err
}

// getConn returns the current TCP connection in a lock-free way
func (c *Connection) getConn() *net.TCPConn {
return c.tcpConn.Load()
}

// setConn sets the TCP connection in a lock-free way
func (c *Connection) setConn(conn *net.TCPConn) {
c.tcpConn.Store(conn)
}

func (c *Connection) reset() {
c.numBytesSent = 0
c.numBytesRecv = 0
c.numMsgSent = 0
c.numMsgRecv = 0
// Lock-free atomic reset of statistics
atomic.StoreInt64(&c.numBytesSent, 0)
atomic.StoreInt64(&c.numBytesRecv, 0)
atomic.StoreInt64(&c.numMsgSent, 0)
atomic.StoreInt64(&c.numMsgRecv, 0)
}

func (c *Connection) connect(host string, port int) error {
// Protect host/port assignment with mutex to prevent races
c.mu.Lock()
c.host = host
c.port = port
c.mu.Unlock()

c.reset()

address := fmt.Sprintf("%v:%v", c.host, c.port)
// Use the parameters directly instead of reading from struct to avoid races
address := fmt.Sprintf("%v:%v", host, port)
addr, err := net.ResolveTCPAddr("tcp4", address)
if err != nil {
log.Error().Err(err).Str("host", address).Msg("failed to resove tcp address")
c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), FAIL_CREATE_SOCK.Code, FAIL_CREATE_SOCK.Msg, "")
return err
}

c.TCPConn, err = net.DialTCP("tcp4", nil, addr)
newConn, err := net.DialTCP("tcp4", nil, addr)
if err != nil {
log.Error().Err(err).Any("address", addr).Msg("failed to dial tcp")
c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), FAIL_CREATE_SOCK.Code, FAIL_CREATE_SOCK.Msg, "")
return err
}

log.Debug().Any("address", c.TCPConn.RemoteAddr()).Msg("tcp socket connected")
c.isConnected = true
// Atomically update connection state
c.setConn(newConn)
atomic.StoreInt32(&c.isConnected, 1)

log.Debug().Any("address", newConn.RemoteAddr()).Msg("tcp socket connected")
return nil
}

func (c *Connection) reconnect() error {
// Use atomic CAS to prevent multiple concurrent reconnections
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
// Another goroutine is already reconnecting, wait for it
for atomic.LoadInt32(&c.reconnecting) == 1 {
time.Sleep(10 * time.Millisecond)
}
// Check if the other goroutine succeeded
if atomic.LoadInt32(&c.isConnected) == 1 {
return nil
}
return fmt.Errorf("concurrent reconnection failed")
}

// Ensure we clear the reconnecting flag when done
defer atomic.StoreInt32(&c.reconnecting, 0)

var err error
backoff := reconnectDelay // Start with base delay

Expand All @@ -106,35 +168,62 @@ func (c *Connection) reconnect() error {
Int("maxAttempts", maxReconnectAttempts).
Msg("Attempting to reconnect")

err = c.connect(c.host, c.port)
// Read host/port atomically to avoid race
c.mu.RLock()
host, port := c.host, c.port
c.mu.RUnlock()

err = c.connect(host, port)
if err == nil {
log.Info().Msg("Reconnection successful")
c.isConnected = true
atomic.StoreInt32(&c.isConnected, 1)
return nil
}

// if this isnt our last try, wait and then loop again
// if this isn't our last try, wait and then loop again
if attempt < maxReconnectAttempts {
time.Sleep(backoff)
backoff *= 2
}
}

// if we get here, all attempts failed
c.isConnected = false
atomic.StoreInt32(&c.isConnected, 0)
return fmt.Errorf("failed to reconnect after %d attempts: %w", maxReconnectAttempts, err)

}

func (c *Connection) disconnect() error {
// Load statistics atomically for logging
msgSent := atomic.LoadInt64(&c.numMsgSent)
bytesSent := atomic.LoadInt64(&c.numBytesSent)
msgRecv := atomic.LoadInt64(&c.numMsgRecv)
bytesRecv := atomic.LoadInt64(&c.numBytesRecv)

log.Trace().
Int("nMsgSent", c.numMsgSent).Int("nBytesSent", c.numBytesSent).
Int("nMsgRecv", c.numMsgRecv).Int("nBytesRecv", c.numBytesRecv).
Int64("nMsgSent", msgSent).Int64("nBytesSent", bytesSent).
Int64("nMsgRecv", msgRecv).Int64("nBytesRecv", bytesRecv).
Msg("conn disconnect")
c.isConnected = false
return c.Close()

// Atomically mark as disconnected
atomic.StoreInt32(&c.isConnected, 0)

// Close the connection
conn := c.getConn()
if conn != nil {
c.setConn(nil)
return conn.Close()
}
return nil
}

func (c *Connection) IsConnected() bool {
return c.isConnected
return atomic.LoadInt32(&c.isConnected) == 1
}

// GetStatistics returns current connection statistics atomically
func (c *Connection) GetStatistics() (bytesSent, msgSent, bytesRecv, msgRecv int64) {
return atomic.LoadInt64(&c.numBytesSent),
atomic.LoadInt64(&c.numMsgSent),
atomic.LoadInt64(&c.numBytesRecv),
atomic.LoadInt64(&c.numMsgRecv)
}
Loading