diff --git a/README.md b/README.md index eacd4f3..f14af33 100644 --- a/README.md +++ b/README.md @@ -94,18 +94,13 @@ Currently supported query arguments are: | Query Argument | Description | Values | |----------------|-------------|--------| -| use_prepared_statements | whether to use client-side query interpolation or server-side argument binding | 1 = (default) use server-side bindings | -| | | 0 = user client side interpolation **(LESS SECURE)** | -| connection_load_balance | whether to enable connection load balancing on the client side | 0 = (default) disable load balancing | -| | | 1 = enable load balancing | -| tlsmode | the ssl/tls policy for this connection | 'none' (default) = don't use SSL/TLS for this connection | -| | | 'server' = server must support SSL/TLS, but skip verification **(INSECURE!)** | -| | | 'server-strict' = server must support SSL/TLS | -| | | {customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below) | -| backup_server_node | a list of backup hosts for the client to try to connect if the primary host is unreachable | a comma-seperated list of backup host-port pairs. E.g.
'host1:port1,host2:port2,host3:port3' | +| use_prepared_statements | Whether to use client-side query interpolation or server-side argument binding. | 1 = (default) use server-side bindings
0 = user client side interpolation **(LESS SECURE)** | +| connection_load_balance | Whether to enable connection load balancing on the client side. | 0 = (default) disable load balancing
1 = enable load balancing | +| tlsmode | The ssl/tls policy for this connection. |
  • 'none' (default) = don't use SSL/TLS for this connection
  • 'server' = server must support SSL/TLS, but skip verification **(INSECURE!)**
  • 'server-strict' = server must support SSL/TLS
  • {customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below)
  • | +| backup_server_node | A list of backup hosts for the client to try to connect if the primary host is unreachable. | a comma-seperated list of backup host-port pairs. E.g.
    'host1:port1,host2:port2,host3:port3' | | client_label | Sets a label for the connection on the server. This value appears in the `client_label` column of the SESSIONS system table. | (default) vertica-sql-go-{version}-{pid}-{timestamp} | -| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on | -| | | 0 = off +| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on
    0 = off| +| oauth_access_token | To authenticate via OAuth, provide an OAuth Access Token that authorizes a user to the database. | unspecified by default, if specified then *user* is optional | To ping the server and validate a connection (as the connection isn't necessarily created at that moment), simply call the *PingContext()* method. diff --git a/common/types.go b/common/types.go index a2d869a..f0ddaa7 100644 --- a/common/types.go +++ b/common/types.go @@ -60,6 +60,7 @@ const ( AuthenticationOK int32 = 0 AuthenticationCleartextPassword int32 = 3 AuthenticationMD5Password int32 = 5 + AuthenticationOAuth int32 = 12 AuthenticationSHA512Password int32 = 66048 ) diff --git a/connection.go b/connection.go index db77e80..0e64725 100644 --- a/connection.go +++ b/connection.go @@ -109,6 +109,7 @@ type connection struct { scratch [512]byte sessionID string autocommit string + oauthaccesstoken string serverTZOffset string dead bool // used if a ROLLBACK severity error is encountered sessMutex sync.Mutex @@ -236,6 +237,9 @@ func newConnection(connString string) (*connection, error) { result.autocommit = "off" } + // Read OAuth access token flag. + result.oauthaccesstoken = result.connURL.Query().Get("oauth_access_token") + // Read connection load balance flag. loadBalanceFlag := result.connURL.Query().Get("connection_load_balance") @@ -415,14 +419,14 @@ func min(a, b int) int { func (v *connection) handshake() error { - if v.connURL.User == nil { - return fmt.Errorf("connection string must include a user name") + if v.connURL.User == nil && len(v.oauthaccesstoken) == 0 { + return fmt.Errorf("connection string must include a user name or oauth_access_token") } userName := v.connURL.User.Username() - if len(userName) == 0 { - return fmt.Errorf("connection string must have a non-empty user name") + if len(userName) == 0 && len(v.oauthaccesstoken) == 0 { + return fmt.Errorf("connection string must have a non-empty user name or oauth_access_token") } dbName := "" @@ -431,14 +435,15 @@ func (v *connection) handshake() error { } msg := &msgs.FEStartupMsg{ - ProtocolVersion: protocolVersion, - DriverName: driverName, - DriverVersion: driverVersion, - Username: userName, - Database: dbName, - SessionID: v.sessionID, - ClientPID: v.clientPID, - Autocommit: v.autocommit, + ProtocolVersion: protocolVersion, + DriverName: driverName, + DriverVersion: driverVersion, + Username: userName, + Database: dbName, + SessionID: v.sessionID, + ClientPID: v.clientPID, + Autocommit: v.autocommit, + OAuthAccessToken: v.oauthaccesstoken, } if err := v.sendMessage(msg); err != nil { @@ -526,6 +531,8 @@ func (v *connection) defaultMessageHandler(bMsg msgs.BackEndMsg) (bool, error) { err = v.authSendMD5Password(msg.ExtraAuthData) case common.AuthenticationSHA512Password: err = v.authSendSHA512Password(msg.ExtraAuthData) + case common.AuthenticationOAuth: + err = v.authSendOAuthAccessToken() default: handled = false err = fmt.Errorf("unsupported authentication scheme: %d", msg.Response) @@ -715,6 +722,11 @@ func (v *connection) authSendSHA512Password(extraAuthData []byte) error { return v.sendMessage(msg) } +func (v *connection) authSendOAuthAccessToken() error { + msg := &msgs.FEPasswordMsg{PasswordData: v.oauthaccesstoken} + return v.sendMessage(msg) +} + func (v *connection) sync() error { err := v.sendMessage(&msgs.FESyncMsg{}) diff --git a/driver.go b/driver.go index 36057e8..83e1544 100644 --- a/driver.go +++ b/driver.go @@ -47,7 +47,7 @@ type Driver struct{} const ( driverName string = "vertica-sql-go" driverVersion string = "1.3.0" - protocolVersion uint32 = 0x00030009 + protocolVersion uint32 = 0x0003000C ) var driverLogger = logger.New("driver") diff --git a/msgs/festartupmsg.go b/msgs/festartupmsg.go index b6965f7..418082e 100644 --- a/msgs/festartupmsg.go +++ b/msgs/festartupmsg.go @@ -41,16 +41,17 @@ import ( // FEStartupMsg docs type FEStartupMsg struct { - ProtocolVersion uint32 - DriverName string - DriverVersion string - Username string - Database string - SessionID string - ClientPID int - ClientOS string - OSUsername string - Autocommit string + ProtocolVersion uint32 + DriverName string + DriverVersion string + Username string + Database string + SessionID string + ClientPID int + ClientOS string + OSUsername string + Autocommit string + OAuthAccessToken string } // Flatten docs @@ -78,14 +79,17 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) { buf.appendUint32(m.ProtocolVersion) buf.appendBytes([]byte{0}) - if len(m.Username) > 0 { - buf.appendLabeledString("user", m.Username) - } + buf.appendLabeledString("user", m.Username) if len(m.Database) > 0 { buf.appendLabeledString("database", m.Database) } + if len(m.OAuthAccessToken) > 0 { + buf.appendLabeledString("oauth_access_token", m.OAuthAccessToken) + buf.appendLabeledString("auth_category", "OAuth") + } + buf.appendLabeledString("client_type", m.DriverName) buf.appendLabeledString("client_version", m.DriverVersion) buf.appendLabeledString("client_label", m.SessionID) @@ -100,7 +104,7 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) { func (m *FEStartupMsg) String() string { return fmt.Sprintf( - "Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s'", + "Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s', OAuthAccessToken=", m.ProtocolVersion, m.DriverName, m.DriverVersion, @@ -110,5 +114,6 @@ func (m *FEStartupMsg) String() string { m.ClientPID, m.ClientOS, m.OSUsername, - m.Autocommit) + m.Autocommit, + len(m.OAuthAccessToken)) }