diff --git a/config/configuration.go b/config/configuration.go index 1425cde7b..75408c7a8 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -534,7 +534,7 @@ const ( // This allows for alternate socket hosts for connecting to a session for failover. // (i.e.) SocketConnectHost1, SocketConnectHost2... must be consecutive and have a matching SocketConnectPort. // - // Required: Yes for initiators + // Required: Yes for initiators on socket connections // // Default: None // @@ -547,7 +547,7 @@ const ( // This allows for alternate socket ports for connecting to a session for failover. // (i.e.) SocketConnectPort1, SocketConnectPort2... must be consecutive and have a matching SocketConnectHost. // - // Required: Yes for initiators + // Required: Yes for initiators on socket connections // // Default: None // @@ -624,6 +624,27 @@ const ( // Valid Values: // - Any string ProxyPassword string = "ProxyPassword" + + // WebsocketLocation sets the websocket endpoint to attempt to connect to. + // Setting this would override any SocketConnectHost and SocketConnectPort settings and connect using websocket + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - A websocket endpoint - eg. wss://example.com/ws + WebsocketLocation string = "WebsocketLocation" + + // WebsocketOrigin sets the websocket origin to attempt to connect from. + // + // Required: No + // + // Default: N/A + // + // Valid Values: + // - url - eg. http://localhost/ + WebsocketOrigin string = "WebsocketOrigin" ) const ( diff --git a/dialer.go b/dialer.go index de4419ff8..7168e8f7a 100644 --- a/dialer.go +++ b/dialer.go @@ -16,23 +16,106 @@ package quickfix import ( + "context" + "crypto/tls" "fmt" "net" + "strings" "time" "golang.org/x/net/proxy" + "golang.org/x/net/websocket" "github.com/quickfixgo/quickfix/config" ) -func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) { +type Dialer interface { + Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (net.Conn, error) +} + +type TcpDialer struct { + ctxDialer proxy.ContextDialer +} + +func (d *TcpDialer) Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (conn net.Conn, err error) { + address := session.SocketConnectAddress[attempt%len(session.SocketConnectAddress)] + session.log.OnEventf("Connecting to: %v", address) + + conn, err = d.ctxDialer.DialContext(ctx, "tcp", address) + + if err != nil { + return + } else if tlsConfig != nil { + // Unless InsecureSkipVerify is true, server name config is required for TLS + // to verify the received certificate + if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 { + serverName := address + if c := strings.LastIndex(serverName, ":"); c > 0 { + serverName = serverName[:c] + } + tlsConfig.ServerName = serverName + } + tlsConn := tls.Client(conn, tlsConfig) + if err = tlsConn.Handshake(); err != nil { + + session.log.OnEventf("Failed handshake: %v", err) + return + } + conn = tlsConn + } + + return +} + +type WebsocketDialer struct { + wsConfig *websocket.Config +} + +func (d *WebsocketDialer) Dial(ctx context.Context, session *session, _ int, tlsConfig *tls.Config) (conn net.Conn, err error) { + session.log.OnEventf("Connecting to: %v", d.wsConfig.Location) + + d.wsConfig.TlsConfig = tlsConfig + conn, err = d.wsConfig.DialContext(ctx) + return +} + +func loadDialerConfig(settings *SessionSettings) (dialer Dialer, err error) { + + if settings.HasSetting(config.WebsocketLocation) { + var location string + location, err = settings.Setting(config.WebsocketLocation) + if err != nil { + return nil, err + } + + var origin string + origin, err = settings.Setting(config.WebsocketOrigin) + if err != nil { + return nil, err + } + + var wsConfig *websocket.Config + wsConfig, err = websocket.NewConfig(location, origin) + if err != nil { + return nil, err + } + + dialer = &WebsocketDialer{ + wsConfig: wsConfig, + } + return + } + stdDialer := &net.Dialer{} + dialer = &TcpDialer{ + ctxDialer: stdDialer, + } if settings.HasSetting(config.SocketTimeout) { timeout, err := settings.DurationSetting(config.SocketTimeout) if err != nil { timeoutInt, err := settings.IntSetting(config.SocketTimeout) if err != nil { - return stdDialer, err + return nil, err } stdDialer.Timeout = time.Duration(timeoutInt) * time.Second @@ -40,7 +123,6 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, er stdDialer.Timeout = timeout } } - dialer = stdDialer if !settings.HasSetting(config.ProxyType) { return @@ -81,7 +163,9 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, er } if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok { - dialer = contextDialer + dialer = &TcpDialer{ + ctxDialer: contextDialer, + } } else { err = fmt.Errorf("proxy does not support context dialer") return diff --git a/dialer_test.go b/dialer_test.go index 5f8599c1e..8c2d1899e 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -42,7 +42,7 @@ func (s *DialerTestSuite) TestLoadDialerNoSettings() { dialer, err := loadDialerConfig(s.settings.GlobalSettings()) s.Require().Nil(err) - stdDialer, ok := dialer.(*net.Dialer) + stdDialer, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer) s.Require().True(ok) s.Require().NotNil(stdDialer) s.Zero(stdDialer.Timeout) @@ -53,7 +53,7 @@ func (s *DialerTestSuite) TestLoadDialerWithTimeout() { dialer, err := loadDialerConfig(s.settings.GlobalSettings()) s.Require().Nil(err) - stdDialer, ok := dialer.(*net.Dialer) + stdDialer, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer) s.Require().True(ok) s.Require().NotNil(stdDialer) s.EqualValues(10*time.Second, stdDialer.Timeout) @@ -73,7 +73,7 @@ func (s *DialerTestSuite) TestLoadDialerSocksProxy() { s.Require().Nil(err) s.Require().NotNil(dialer) - _, ok := dialer.(*net.Dialer) + _, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer) s.Require().False(ok) } diff --git a/initiator.go b/initiator.go index 18451477e..358db9079 100644 --- a/initiator.go +++ b/initiator.go @@ -19,11 +19,8 @@ import ( "bufio" "context" "crypto/tls" - "strings" "sync" "time" - - "golang.org/x/net/proxy" ) // Initiator initiates connections and processes messages for all sessions. @@ -48,12 +45,12 @@ func (i *Initiator) Start() (err error) { // TODO: move into session factory. var tlsConfig *tls.Config if tlsConfig, err = loadTLSConfig(settings); err != nil { - return + return err } - var dialer proxy.ContextDialer + var dialer Dialer if dialer, err = loadDialerConfig(settings); err != nil { - return + return err } i.wg.Add(1) @@ -143,7 +140,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo return true } -func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) { +func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer Dialer) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -180,29 +177,13 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di var msgIn chan fixIn var msgOut chan []byte - address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] - session.log.OnEventf("Connecting to: %v", address) - - netConn, err := dialer.DialContext(ctx, "tcp", address) + netConn, err := dialer.Dial(ctx, session, connectionAttempt, tlsConfig) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect - } else if tlsConfig != nil { - // Unless InsecureSkipVerify is true, server name config is required for TLS - // to verify the received certificate - if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 { - serverName := address - if c := strings.LastIndex(serverName, ":"); c > 0 { - serverName = serverName[:c] - } - tlsConfig.ServerName = serverName - } - tlsConn := tls.Client(netConn, tlsConfig) - if err = tlsConn.Handshake(); err != nil { - session.log.OnEventf("Failed handshake: %v", err) - goto reconnect - } - netConn = tlsConn + } else { + address := netConn.RemoteAddr().String() + session.log.OnEventf("connected to remote address: %v", address) } msgIn = make(chan fixIn) diff --git a/session_factory.go b/session_factory.go index a2291cad7..af046edb3 100644 --- a/session_factory.go +++ b/session_factory.go @@ -523,6 +523,13 @@ func (f sessionFactory) buildInitiatorSettings(session *session, settings *Sessi func (f sessionFactory) configureSocketConnectAddress(session *session, settings *SessionSettings) (err error) { session.SocketConnectAddress = []string{} + if !settings.HasSetting(config.SocketConnectHost) { + if !settings.HasSetting(config.WebsocketLocation) { + err = errors.New("SocketConnectHost must be specified if WebsocketLocation is not specified") + } + return + } + var socketConnectHost, socketConnectPort string for i := 0; ; {