From 7e396c44a1665aeace8cae627e8fd3ca2d09010e Mon Sep 17 00:00:00 2001 From: Caio Northfleet Date: Wed, 25 Jun 2025 16:44:58 -0700 Subject: [PATCH] mitigated specific client race issues --- client.go | 12 +- connection.go | 177 +++++++++++++++----- connection_race_test.go | 353 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 493 insertions(+), 49 deletions(-) create mode 100644 connection_race_test.go diff --git a/client.go b/client.go index 2460e50..971e348 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,9 @@ import ( "os/signal" "strconv" "sync" + "sync/atomic" "syscall" + "unsafe" "github.com/scmhub/ibapi/protobuf" "google.golang.org/protobuf/proto" @@ -340,9 +342,9 @@ func (c *EClient) reset() { } func (c *EClient) setConnState(state ConnState) { - cs := c.connState - c.connState = state - log.Debug().Stringer("from", cs).Stringer("to", c.connState).Msg("connection state changed") + cs := ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState)))) + atomic.StoreInt32((*int32)(unsafe.Pointer(&c.connState)), int32(state)) + log.Debug().Stringer("from", cs).Stringer("to", state).Msg("connection state changed") } // request is a goroutine that will get the req from reqChan and send it to TWS. @@ -611,7 +613,7 @@ func (c *EClient) Ctx() context.Context { // IsConnected checks connection to TWS or GateWay. func (c *EClient) IsConnected() bool { - return c.conn.IsConnected() && c.connState == CONNECTED + return c.conn.IsConnected() && ConnState(atomic.LoadInt32((*int32)(unsafe.Pointer(&c.connState)))) == CONNECTED } // OptionalCapabilities returns the Optional Capabilities. @@ -1117,7 +1119,7 @@ func (c *EClient) CancelCalculateOptionPrice(reqID int64) { // // exerciseQuantity is the quantity you want to exercise. // account is the destination account. -// overrideĀ specifies whether your setting will override the system's natural action. +// override specifies whether your setting will override the system's natural action. // For example, if your action is "exercise" and the option is not in-the-money, by natural action the option would not exercise. // If you have override set to "yes" the natural action would be overridden and the out-of-the money option would be exercised. // Values: 0 = no, 1 = yes. diff --git a/connection.go b/connection.go index 8e71cc3..d938e9f 100644 --- a/connection.go +++ b/connection.go @@ -3,6 +3,8 @@ package ibapi import ( "fmt" "net" + "sync" + "sync/atomic" "time" ) @@ -11,70 +13,112 @@ const ( reconnectDelay = 500 * time.Millisecond ) -// Connection is a TCPConn wrapper. +// Connection is a TCPConn wrapper with lock-free statistics and minimal contention. type Connection struct { - *net.TCPConn - wrapper EWrapper - host string - port int - isConnected bool - numBytesSent int - numMsgSent int - numBytesRecv int - numMsgRecv int + // Connection state - protected by mutex for host/port coordination only + mu sync.RWMutex + tcpConn atomic.Pointer[net.TCPConn] // Lock-free pointer for maximum performance + wrapper EWrapper + host string + port int + isConnected int32 // atomic: 0=disconnected, 1=connected + + // Statistics - lock-free atomic counters for maximum performance + numBytesSent int64 // atomic + numMsgSent int64 // atomic + numBytesRecv int64 // atomic + numMsgRecv int64 // atomic + + // Reconnection control - prevents multiple concurrent reconnections + reconnecting int32 // atomic: 0=not reconnecting, 1=reconnecting } func (c *Connection) Write(bs []byte) (int, error) { - // first attempt - n, err := c.TCPConn.Write(bs) - if err == nil { - c.numBytesSent += n - c.numMsgSent++ - log.Trace().Int("nBytes", n).Msg("conn write") - return n, nil + // Fast path: try write with current connection + conn := c.getConn() + if conn != nil { + n, err := conn.Write(bs) + if err == nil { + // Lock-free atomic statistics update + atomic.AddInt64(&c.numBytesSent, int64(n)) + atomic.AddInt64(&c.numMsgSent, 1) + log.Trace().Int("nBytes", n).Msg("conn write") + return n, nil + } + + // Write failed, try to reconnect + log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...") } - // write failed, try to reconnect - log.Warn().Err(err).Msg("Write error detected, attempting to reconnect...") + + // Slow path: reconnect and retry if err := c.reconnect(); err != nil { return 0, fmt.Errorf("write failed and reconnection failed: %w", err) } - // second attempt - n, err = c.TCPConn.Write(bs) + // Retry write after reconnection + conn = c.getConn() + if conn == nil { + return 0, fmt.Errorf("connection still not available after reconnect") + } + + n, err := conn.Write(bs) if err != nil { return 0, fmt.Errorf("write retry after reconnect failed: %w", err) } - c.numBytesSent += n - c.numMsgSent++ + // Lock-free atomic statistics update + atomic.AddInt64(&c.numBytesSent, int64(n)) + atomic.AddInt64(&c.numMsgSent, 1) log.Trace().Int("nBytes", n).Msg("conn write (after reconnect)") return n, nil } func (c *Connection) Read(bs []byte) (int, error) { - n, err := c.TCPConn.Read(bs) + conn := c.getConn() + if conn == nil { + return 0, fmt.Errorf("connection not available") + } - c.numBytesRecv += n - c.numMsgRecv++ + n, err := conn.Read(bs) + + // Lock-free atomic statistics update + atomic.AddInt64(&c.numBytesRecv, int64(n)) + atomic.AddInt64(&c.numMsgRecv, 1) log.Trace().Int("nBytes", n).Msg("conn read") return n, err } +// getConn returns the current TCP connection in a lock-free way +func (c *Connection) getConn() *net.TCPConn { + return c.tcpConn.Load() +} + +// setConn sets the TCP connection in a lock-free way +func (c *Connection) setConn(conn *net.TCPConn) { + c.tcpConn.Store(conn) +} + func (c *Connection) reset() { - c.numBytesSent = 0 - c.numBytesRecv = 0 - c.numMsgSent = 0 - c.numMsgRecv = 0 + // Lock-free atomic reset of statistics + atomic.StoreInt64(&c.numBytesSent, 0) + atomic.StoreInt64(&c.numBytesRecv, 0) + atomic.StoreInt64(&c.numMsgSent, 0) + atomic.StoreInt64(&c.numMsgRecv, 0) } func (c *Connection) connect(host string, port int) error { + // Protect host/port assignment with mutex to prevent races + c.mu.Lock() c.host = host c.port = port + c.mu.Unlock() + c.reset() - address := fmt.Sprintf("%v:%v", c.host, c.port) + // Use the parameters directly instead of reading from struct to avoid races + address := fmt.Sprintf("%v:%v", host, port) addr, err := net.ResolveTCPAddr("tcp4", address) if err != nil { log.Error().Err(err).Str("host", address).Msg("failed to resove tcp address") @@ -82,20 +126,38 @@ func (c *Connection) connect(host string, port int) error { return err } - c.TCPConn, err = net.DialTCP("tcp4", nil, addr) + newConn, err := net.DialTCP("tcp4", nil, addr) if err != nil { log.Error().Err(err).Any("address", addr).Msg("failed to dial tcp") c.wrapper.Error(NO_VALID_ID, currentTimeMillis(), FAIL_CREATE_SOCK.Code, FAIL_CREATE_SOCK.Msg, "") return err } - log.Debug().Any("address", c.TCPConn.RemoteAddr()).Msg("tcp socket connected") - c.isConnected = true + // Atomically update connection state + c.setConn(newConn) + atomic.StoreInt32(&c.isConnected, 1) + log.Debug().Any("address", newConn.RemoteAddr()).Msg("tcp socket connected") return nil } func (c *Connection) reconnect() error { + // Use atomic CAS to prevent multiple concurrent reconnections + if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) { + // Another goroutine is already reconnecting, wait for it + for atomic.LoadInt32(&c.reconnecting) == 1 { + time.Sleep(10 * time.Millisecond) + } + // Check if the other goroutine succeeded + if atomic.LoadInt32(&c.isConnected) == 1 { + return nil + } + return fmt.Errorf("concurrent reconnection failed") + } + + // Ensure we clear the reconnecting flag when done + defer atomic.StoreInt32(&c.reconnecting, 0) + var err error backoff := reconnectDelay // Start with base delay @@ -106,14 +168,19 @@ func (c *Connection) reconnect() error { Int("maxAttempts", maxReconnectAttempts). Msg("Attempting to reconnect") - err = c.connect(c.host, c.port) + // Read host/port atomically to avoid race + c.mu.RLock() + host, port := c.host, c.port + c.mu.RUnlock() + + err = c.connect(host, port) if err == nil { log.Info().Msg("Reconnection successful") - c.isConnected = true + atomic.StoreInt32(&c.isConnected, 1) return nil } - // if this isn’t our last try, wait and then loop again + // if this isn't our last try, wait and then loop again if attempt < maxReconnectAttempts { time.Sleep(backoff) backoff *= 2 @@ -121,20 +188,42 @@ func (c *Connection) reconnect() error { } // if we get here, all attempts failed - c.isConnected = false + atomic.StoreInt32(&c.isConnected, 0) return fmt.Errorf("failed to reconnect after %d attempts: %w", maxReconnectAttempts, err) - } func (c *Connection) disconnect() error { + // Load statistics atomically for logging + msgSent := atomic.LoadInt64(&c.numMsgSent) + bytesSent := atomic.LoadInt64(&c.numBytesSent) + msgRecv := atomic.LoadInt64(&c.numMsgRecv) + bytesRecv := atomic.LoadInt64(&c.numBytesRecv) + log.Trace(). - Int("nMsgSent", c.numMsgSent).Int("nBytesSent", c.numBytesSent). - Int("nMsgRecv", c.numMsgRecv).Int("nBytesRecv", c.numBytesRecv). + Int64("nMsgSent", msgSent).Int64("nBytesSent", bytesSent). + Int64("nMsgRecv", msgRecv).Int64("nBytesRecv", bytesRecv). Msg("conn disconnect") - c.isConnected = false - return c.Close() + + // Atomically mark as disconnected + atomic.StoreInt32(&c.isConnected, 0) + + // Close the connection + conn := c.getConn() + if conn != nil { + c.setConn(nil) + return conn.Close() + } + return nil } func (c *Connection) IsConnected() bool { - return c.isConnected + return atomic.LoadInt32(&c.isConnected) == 1 +} + +// GetStatistics returns current connection statistics atomically +func (c *Connection) GetStatistics() (bytesSent, msgSent, bytesRecv, msgRecv int64) { + return atomic.LoadInt64(&c.numBytesSent), + atomic.LoadInt64(&c.numMsgSent), + atomic.LoadInt64(&c.numBytesRecv), + atomic.LoadInt64(&c.numMsgRecv) } diff --git a/connection_race_test.go b/connection_race_test.go new file mode 100644 index 0000000..2623a82 --- /dev/null +++ b/connection_race_test.go @@ -0,0 +1,353 @@ +package ibapi + +import ( + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +// DummyServer creates a simple TCP echo server for testing +type DummyServer struct { + listener net.Listener + addr string + port int +} + +func NewDummyServer() (*DummyServer, error) { + // Listen on a random available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + + addr := listener.Addr().(*net.TCPAddr) + server := &DummyServer{ + listener: listener, + addr: addr.IP.String(), + port: addr.Port, + } + + return server, nil +} + +func (s *DummyServer) Start() { + go func() { + for { + conn, err := s.listener.Accept() + if err != nil { + return // Server stopped + } + go s.handleConnection(conn) + } + }() +} + +func (s *DummyServer) handleConnection(conn net.Conn) { + defer conn.Close() + + // Simple echo server - read and write back data + buffer := make([]byte, 4096) + for { + n, err := conn.Read(buffer) + if err != nil { + if err != io.EOF { + fmt.Printf("Server read error: %v\n", err) + } + return + } + + // Echo the data back + _, writeErr := conn.Write(buffer[:n]) + if writeErr != nil { + fmt.Printf("Server write error: %v\n", writeErr) + return + } + } +} + +func (s *DummyServer) Stop() { + if s.listener != nil { + s.listener.Close() + } +} + +func (s *DummyServer) Address() (string, int) { + return s.addr, s.port +} + +// TestConnectionRaceConditions demonstrates race conditions in Connection +func TestConnectionRaceConditions(t *testing.T) { + // Start dummy server + server, err := NewDummyServer() + if err != nil { + t.Fatalf("Failed to create dummy server: %v", err) + } + defer server.Stop() + + server.Start() + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + host, port := server.Address() + + // Create connection with a simple wrapper + wrapper := &Wrapper{} + conn := &Connection{ + wrapper: wrapper, + } + + // Connect to dummy server + err = conn.connect(host, port) + if err != nil { + t.Fatalf("Failed to connect to dummy server: %v", err) + } + + // Test data + testData := []byte("Hello, Race Condition Test!") + numOperations := 100 + numGoroutines := 10 + + var wg sync.WaitGroup + + // Start multiple writer goroutines + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + for j := 0; j < numOperations; j++ { + // This will race on numBytesSent and numMsgSent + _, err := conn.Write(testData) + if err != nil { + // Expected during disconnect + return + } + + // Small delay to increase chance of race + time.Sleep(time.Microsecond) + } + }(i) + } + + // Start multiple reader goroutines + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + buffer := make([]byte, 1024) + for j := 0; j < numOperations; j++ { + // This will race on numBytesRecv and numMsgRecv + _, err := conn.Read(buffer) + if err != nil { + // Expected during disconnect + return + } + + // Small delay to increase chance of race + time.Sleep(time.Microsecond) + } + }(i) + } + + // Start disconnect goroutines to trigger races + for i := 0; i < 3; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + // Wait a bit then disconnect + time.Sleep(time.Duration(goroutineID*10) * time.Millisecond) + + // This will race with the statistics updates and isConnected flag + err := conn.disconnect() + if err != nil { + // Multiple disconnects expected to fail + } + }(i) + } + + // Start statistics reader goroutines + for i := 0; i < 5; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + for j := 0; j < numOperations*2; j++ { + // Using atomic-safe methods to read statistics + bytesSent, msgSent, bytesRecv, msgRecv := conn.GetStatistics() + _ = bytesSent + msgSent + bytesRecv + msgRecv + _ = conn.IsConnected() + + time.Sleep(time.Microsecond) + } + }(i) + } + + // Start reset goroutines + for i := 0; i < 2; i++ { + wg.Add(1) + go func(goroutineID int) { + defer wg.Done() + + time.Sleep(time.Duration(goroutineID*20) * time.Millisecond) + + // This will race with ongoing statistics updates + conn.reset() + }(i) + } + + // Wait for all goroutines to complete + wg.Wait() + + // Final disconnect to clean up + conn.disconnect() + + t.Logf("Test completed - check with 'go test -race' to detect race conditions") + bytesSent, msgSent, bytesRecv, msgRecv := conn.GetStatistics() + t.Logf("Final stats - Sent: %d msgs, %d bytes | Recv: %d msgs, %d bytes", + msgSent, bytesSent, msgRecv, bytesRecv) +} + +// TestConnectionConcurrentReconnect tests the reconnection logic under concurrent access +func TestConnectionConcurrentReconnect(t *testing.T) { + // Start dummy server + server, err := NewDummyServer() + if err != nil { + t.Fatalf("Failed to create dummy server: %v", err) + } + defer server.Stop() + + server.Start() + time.Sleep(100 * time.Millisecond) + + host, port := server.Address() + + wrapper := &Wrapper{} + conn := &Connection{ + wrapper: wrapper, + } + + var wg sync.WaitGroup + numGoroutines := 5 + + // Multiple goroutines trying to write (which triggers reconnect on failure) + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // First establish connection + if err := conn.connect(host, port); err != nil { + t.Errorf("Goroutine %d failed to connect: %v", id, err) + return + } + + // Try multiple writes - some may trigger reconnection + for j := 0; j < 50; j++ { + data := []byte(fmt.Sprintf("Message from goroutine %d, iteration %d", id, j)) + + // This can race with other goroutines doing connect/disconnect/reconnect + _, err := conn.Write(data) + if err != nil { + // Expected during concurrent access + } + + time.Sleep(time.Millisecond) + } + }(i) + } + + // Goroutine that disconnects periodically + wg.Add(1) + go func() { + defer wg.Done() + + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + conn.disconnect() // Race with Write operations + } + }() + + wg.Wait() + + t.Logf("Concurrent reconnect test completed") +} + +// TestConnectionStatisticsRace focuses specifically on the statistics counter races +func TestConnectionStatisticsRace(t *testing.T) { + wrapper := &Wrapper{} + conn := &Connection{ + wrapper: wrapper, + } + + // Don't actually connect - just test the statistics + // Simulate concurrent access to the counters + + var wg sync.WaitGroup + iterations := 1000 + + // Goroutines incrementing send stats using atomic operations + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + atomic.AddInt64(&conn.numBytesSent, 100) // Now atomic! + atomic.AddInt64(&conn.numMsgSent, 1) // Now atomic! + } + }() + } + + // Goroutines incrementing recv stats using atomic operations + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + atomic.AddInt64(&conn.numBytesRecv, 50) // Now atomic! + atomic.AddInt64(&conn.numMsgRecv, 1) // Now atomic! + } + }() + } + + // Goroutines reading stats using atomic operations + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations*2; j++ { + // Reading using atomic-safe methods - no more race condition! + bytesSent, msgSent, bytesRecv, msgRecv := conn.GetStatistics() + _ = bytesSent + bytesRecv + msgSent + msgRecv + } + }() + } + + // Goroutines resetting stats + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + time.Sleep(time.Millisecond) + conn.reset() // Race condition with increments! + } + }() + } + + wg.Wait() + + bytesSent, msgSent, bytesRecv, msgRecv := conn.GetStatistics() + t.Logf("Final statistics after race: Sent=%d/%d, Recv=%d/%d", + msgSent, bytesSent, msgRecv, bytesRecv) + + // Note: The final values will be unpredictable due to race conditions + // Expected: 10 goroutines * 1000 iterations = 10,000 messages + // Actual: Will be less due to lost updates from race conditions +}