diff --git a/README.md b/README.md
index eacd4f3..f14af33 100644
--- a/README.md
+++ b/README.md
@@ -94,18 +94,13 @@ Currently supported query arguments are:
| Query Argument | Description | Values |
|----------------|-------------|--------|
-| use_prepared_statements | whether to use client-side query interpolation or server-side argument binding | 1 = (default) use server-side bindings |
-| | | 0 = user client side interpolation **(LESS SECURE)** |
-| connection_load_balance | whether to enable connection load balancing on the client side | 0 = (default) disable load balancing |
-| | | 1 = enable load balancing |
-| tlsmode | the ssl/tls policy for this connection | 'none' (default) = don't use SSL/TLS for this connection |
-| | | 'server' = server must support SSL/TLS, but skip verification **(INSECURE!)** |
-| | | 'server-strict' = server must support SSL/TLS |
-| | | {customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below) |
-| backup_server_node | a list of backup hosts for the client to try to connect if the primary host is unreachable | a comma-seperated list of backup host-port pairs. E.g.
'host1:port1,host2:port2,host3:port3' |
+| use_prepared_statements | Whether to use client-side query interpolation or server-side argument binding. | 1 = (default) use server-side bindings
0 = user client side interpolation **(LESS SECURE)** |
+| connection_load_balance | Whether to enable connection load balancing on the client side. | 0 = (default) disable load balancing
1 = enable load balancing |
+| tlsmode | The ssl/tls policy for this connection. |
'none' (default) = don't use SSL/TLS for this connection'server' = server must support SSL/TLS, but skip verification **(INSECURE!)**'server-strict' = server must support SSL/TLS{customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below) |
+| backup_server_node | A list of backup hosts for the client to try to connect if the primary host is unreachable. | a comma-seperated list of backup host-port pairs. E.g.
'host1:port1,host2:port2,host3:port3' |
| client_label | Sets a label for the connection on the server. This value appears in the `client_label` column of the SESSIONS system table. | (default) vertica-sql-go-{version}-{pid}-{timestamp} |
-| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on |
-| | | 0 = off
+| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on
0 = off|
+| oauth_access_token | To authenticate via OAuth, provide an OAuth Access Token that authorizes a user to the database. | unspecified by default, if specified then *user* is optional |
To ping the server and validate a connection (as the connection isn't necessarily created at that moment), simply call the *PingContext()* method.
diff --git a/common/types.go b/common/types.go
index a2d869a..f0ddaa7 100644
--- a/common/types.go
+++ b/common/types.go
@@ -60,6 +60,7 @@ const (
AuthenticationOK int32 = 0
AuthenticationCleartextPassword int32 = 3
AuthenticationMD5Password int32 = 5
+ AuthenticationOAuth int32 = 12
AuthenticationSHA512Password int32 = 66048
)
diff --git a/connection.go b/connection.go
index db77e80..0e64725 100644
--- a/connection.go
+++ b/connection.go
@@ -109,6 +109,7 @@ type connection struct {
scratch [512]byte
sessionID string
autocommit string
+ oauthaccesstoken string
serverTZOffset string
dead bool // used if a ROLLBACK severity error is encountered
sessMutex sync.Mutex
@@ -236,6 +237,9 @@ func newConnection(connString string) (*connection, error) {
result.autocommit = "off"
}
+ // Read OAuth access token flag.
+ result.oauthaccesstoken = result.connURL.Query().Get("oauth_access_token")
+
// Read connection load balance flag.
loadBalanceFlag := result.connURL.Query().Get("connection_load_balance")
@@ -415,14 +419,14 @@ func min(a, b int) int {
func (v *connection) handshake() error {
- if v.connURL.User == nil {
- return fmt.Errorf("connection string must include a user name")
+ if v.connURL.User == nil && len(v.oauthaccesstoken) == 0 {
+ return fmt.Errorf("connection string must include a user name or oauth_access_token")
}
userName := v.connURL.User.Username()
- if len(userName) == 0 {
- return fmt.Errorf("connection string must have a non-empty user name")
+ if len(userName) == 0 && len(v.oauthaccesstoken) == 0 {
+ return fmt.Errorf("connection string must have a non-empty user name or oauth_access_token")
}
dbName := ""
@@ -431,14 +435,15 @@ func (v *connection) handshake() error {
}
msg := &msgs.FEStartupMsg{
- ProtocolVersion: protocolVersion,
- DriverName: driverName,
- DriverVersion: driverVersion,
- Username: userName,
- Database: dbName,
- SessionID: v.sessionID,
- ClientPID: v.clientPID,
- Autocommit: v.autocommit,
+ ProtocolVersion: protocolVersion,
+ DriverName: driverName,
+ DriverVersion: driverVersion,
+ Username: userName,
+ Database: dbName,
+ SessionID: v.sessionID,
+ ClientPID: v.clientPID,
+ Autocommit: v.autocommit,
+ OAuthAccessToken: v.oauthaccesstoken,
}
if err := v.sendMessage(msg); err != nil {
@@ -526,6 +531,8 @@ func (v *connection) defaultMessageHandler(bMsg msgs.BackEndMsg) (bool, error) {
err = v.authSendMD5Password(msg.ExtraAuthData)
case common.AuthenticationSHA512Password:
err = v.authSendSHA512Password(msg.ExtraAuthData)
+ case common.AuthenticationOAuth:
+ err = v.authSendOAuthAccessToken()
default:
handled = false
err = fmt.Errorf("unsupported authentication scheme: %d", msg.Response)
@@ -715,6 +722,11 @@ func (v *connection) authSendSHA512Password(extraAuthData []byte) error {
return v.sendMessage(msg)
}
+func (v *connection) authSendOAuthAccessToken() error {
+ msg := &msgs.FEPasswordMsg{PasswordData: v.oauthaccesstoken}
+ return v.sendMessage(msg)
+}
+
func (v *connection) sync() error {
err := v.sendMessage(&msgs.FESyncMsg{})
diff --git a/driver.go b/driver.go
index 36057e8..83e1544 100644
--- a/driver.go
+++ b/driver.go
@@ -47,7 +47,7 @@ type Driver struct{}
const (
driverName string = "vertica-sql-go"
driverVersion string = "1.3.0"
- protocolVersion uint32 = 0x00030009
+ protocolVersion uint32 = 0x0003000C
)
var driverLogger = logger.New("driver")
diff --git a/msgs/festartupmsg.go b/msgs/festartupmsg.go
index b6965f7..418082e 100644
--- a/msgs/festartupmsg.go
+++ b/msgs/festartupmsg.go
@@ -41,16 +41,17 @@ import (
// FEStartupMsg docs
type FEStartupMsg struct {
- ProtocolVersion uint32
- DriverName string
- DriverVersion string
- Username string
- Database string
- SessionID string
- ClientPID int
- ClientOS string
- OSUsername string
- Autocommit string
+ ProtocolVersion uint32
+ DriverName string
+ DriverVersion string
+ Username string
+ Database string
+ SessionID string
+ ClientPID int
+ ClientOS string
+ OSUsername string
+ Autocommit string
+ OAuthAccessToken string
}
// Flatten docs
@@ -78,14 +79,17 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) {
buf.appendUint32(m.ProtocolVersion)
buf.appendBytes([]byte{0})
- if len(m.Username) > 0 {
- buf.appendLabeledString("user", m.Username)
- }
+ buf.appendLabeledString("user", m.Username)
if len(m.Database) > 0 {
buf.appendLabeledString("database", m.Database)
}
+ if len(m.OAuthAccessToken) > 0 {
+ buf.appendLabeledString("oauth_access_token", m.OAuthAccessToken)
+ buf.appendLabeledString("auth_category", "OAuth")
+ }
+
buf.appendLabeledString("client_type", m.DriverName)
buf.appendLabeledString("client_version", m.DriverVersion)
buf.appendLabeledString("client_label", m.SessionID)
@@ -100,7 +104,7 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) {
func (m *FEStartupMsg) String() string {
return fmt.Sprintf(
- "Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s'",
+ "Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s', OAuthAccessToken=",
m.ProtocolVersion,
m.DriverName,
m.DriverVersion,
@@ -110,5 +114,6 @@ func (m *FEStartupMsg) String() string {
m.ClientPID,
m.ClientOS,
m.OSUsername,
- m.Autocommit)
+ m.Autocommit,
+ len(m.OAuthAccessToken))
}