Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth 2.0 Authentication (basic) #160

Merged
merged 5 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,13 @@ Currently supported query arguments are:

| Query Argument | Description | Values |
|----------------|-------------|--------|
| use_prepared_statements | whether to use client-side query interpolation or server-side argument binding | 1 = (default) use server-side bindings |
| | | 0 = user client side interpolation **(LESS SECURE)** |
| connection_load_balance | whether to enable connection load balancing on the client side | 0 = (default) disable load balancing |
| | | 1 = enable load balancing |
| tlsmode | the ssl/tls policy for this connection | 'none' (default) = don't use SSL/TLS for this connection |
| | | 'server' = server must support SSL/TLS, but skip verification **(INSECURE!)** |
| | | 'server-strict' = server must support SSL/TLS |
| | | {customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below) |
| backup_server_node | a list of backup hosts for the client to try to connect if the primary host is unreachable | a comma-seperated list of backup host-port pairs. E.g.<br> 'host1:port1,host2:port2,host3:port3' |
| use_prepared_statements | Whether to use client-side query interpolation or server-side argument binding. | 1 = (default) use server-side bindings <br>0 = user client side interpolation **(LESS SECURE)** |
| connection_load_balance | Whether to enable connection load balancing on the client side. | 0 = (default) disable load balancing <br>1 = enable load balancing |
| tlsmode | The ssl/tls policy for this connection. | <li>'none' (default) = don't use SSL/TLS for this connection</li><li>'server' = server must support SSL/TLS, but skip verification **(INSECURE!)**</li><li>'server-strict' = server must support SSL/TLS</li><li>{customName} = use custom registered `tls.Config` (see "Using custom TLS config" section below)</li> |
| backup_server_node | A list of backup hosts for the client to try to connect if the primary host is unreachable. | a comma-seperated list of backup host-port pairs. E.g.<br> 'host1:port1,host2:port2,host3:port3' |
| client_label | Sets a label for the connection on the server. This value appears in the `client_label` column of the SESSIONS system table. | (default) vertica-sql-go-{version}-{pid}-{timestamp} |
| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on |
| | | 0 = off
| autocommit | Controls whether the connection automatically commits transactions. | 1 = (default) on <br>0 = off|
| oauth_access_token | To authenticate via OAuth, provide an OAuth Access Token that authorizes a user to the database. | unspecified by default, if specified then *user* is optional |

To ping the server and validate a connection (as the connection isn't necessarily created at that moment), simply call the *PingContext()* method.

Expand Down
1 change: 1 addition & 0 deletions common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ const (
AuthenticationOK int32 = 0
AuthenticationCleartextPassword int32 = 3
AuthenticationMD5Password int32 = 5
AuthenticationOAuth int32 = 12
AuthenticationSHA512Password int32 = 66048
)

Expand Down
36 changes: 24 additions & 12 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ type connection struct {
scratch [512]byte
sessionID string
autocommit string
oauthaccesstoken string
serverTZOffset string
dead bool // used if a ROLLBACK severity error is encountered
sessMutex sync.Mutex
Expand Down Expand Up @@ -236,6 +237,9 @@ func newConnection(connString string) (*connection, error) {
result.autocommit = "off"
}

// Read OAuth access token flag.
result.oauthaccesstoken = result.connURL.Query().Get("oauth_access_token")

// Read connection load balance flag.
loadBalanceFlag := result.connURL.Query().Get("connection_load_balance")

Expand Down Expand Up @@ -415,14 +419,14 @@ func min(a, b int) int {

func (v *connection) handshake() error {

if v.connURL.User == nil {
return fmt.Errorf("connection string must include a user name")
if v.connURL.User == nil && len(v.oauthaccesstoken) == 0 {
return fmt.Errorf("connection string must include a user name or oauth_access_token")
}

userName := v.connURL.User.Username()

if len(userName) == 0 {
return fmt.Errorf("connection string must have a non-empty user name")
if len(userName) == 0 && len(v.oauthaccesstoken) == 0 {
return fmt.Errorf("connection string must have a non-empty user name or oauth_access_token")
}

dbName := ""
Expand All @@ -431,14 +435,15 @@ func (v *connection) handshake() error {
}

msg := &msgs.FEStartupMsg{
ProtocolVersion: protocolVersion,
DriverName: driverName,
DriverVersion: driverVersion,
Username: userName,
Database: dbName,
SessionID: v.sessionID,
ClientPID: v.clientPID,
Autocommit: v.autocommit,
ProtocolVersion: protocolVersion,
DriverName: driverName,
DriverVersion: driverVersion,
Username: userName,
Database: dbName,
SessionID: v.sessionID,
ClientPID: v.clientPID,
Autocommit: v.autocommit,
OAuthAccessToken: v.oauthaccesstoken,
}

if err := v.sendMessage(msg); err != nil {
Expand Down Expand Up @@ -526,6 +531,8 @@ func (v *connection) defaultMessageHandler(bMsg msgs.BackEndMsg) (bool, error) {
err = v.authSendMD5Password(msg.ExtraAuthData)
case common.AuthenticationSHA512Password:
err = v.authSendSHA512Password(msg.ExtraAuthData)
case common.AuthenticationOAuth:
err = v.authSendOAuthAccessToken()
default:
handled = false
err = fmt.Errorf("unsupported authentication scheme: %d", msg.Response)
Expand Down Expand Up @@ -715,6 +722,11 @@ func (v *connection) authSendSHA512Password(extraAuthData []byte) error {
return v.sendMessage(msg)
}

func (v *connection) authSendOAuthAccessToken() error {
msg := &msgs.FEPasswordMsg{PasswordData: v.oauthaccesstoken}
return v.sendMessage(msg)
}

func (v *connection) sync() error {
err := v.sendMessage(&msgs.FESyncMsg{})

Expand Down
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Driver struct{}
const (
driverName string = "vertica-sql-go"
driverVersion string = "1.3.0"
protocolVersion uint32 = 0x00030009
protocolVersion uint32 = 0x0003000C
)

var driverLogger = logger.New("driver")
Expand Down
35 changes: 20 additions & 15 deletions msgs/festartupmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,17 @@ import (

// FEStartupMsg docs
type FEStartupMsg struct {
ProtocolVersion uint32
DriverName string
DriverVersion string
Username string
Database string
SessionID string
ClientPID int
ClientOS string
OSUsername string
Autocommit string
ProtocolVersion uint32
DriverName string
DriverVersion string
Username string
Database string
SessionID string
ClientPID int
ClientOS string
OSUsername string
Autocommit string
OAuthAccessToken string
}

// Flatten docs
Expand Down Expand Up @@ -78,14 +79,17 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) {
buf.appendUint32(m.ProtocolVersion)
buf.appendBytes([]byte{0})

if len(m.Username) > 0 {
buf.appendLabeledString("user", m.Username)
}
buf.appendLabeledString("user", m.Username)

if len(m.Database) > 0 {
buf.appendLabeledString("database", m.Database)
}

if len(m.OAuthAccessToken) > 0 {
buf.appendLabeledString("oauth_access_token", m.OAuthAccessToken)
buf.appendLabeledString("auth_category", "OAuth")
}

buf.appendLabeledString("client_type", m.DriverName)
buf.appendLabeledString("client_version", m.DriverVersion)
buf.appendLabeledString("client_label", m.SessionID)
Expand All @@ -100,7 +104,7 @@ func (m *FEStartupMsg) Flatten() ([]byte, byte) {

func (m *FEStartupMsg) String() string {
return fmt.Sprintf(
"Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s'",
"Startup (packet): ProtocolVersion:%08X, DriverName='%s', DriverVersion='%s', UserName='%s', Database='%s', SessionID='%s', ClientPID=%d, ClientOS='%s', ClientOSUserName='%s', Autocommit='%s', OAuthAccessToken=<length:%d>",
m.ProtocolVersion,
m.DriverName,
m.DriverVersion,
Expand All @@ -110,5 +114,6 @@ func (m *FEStartupMsg) String() string {
m.ClientPID,
m.ClientOS,
m.OSUsername,
m.Autocommit)
m.Autocommit,
len(m.OAuthAccessToken))
}