Skip to content

Commit

Permalink
Add token (JWT) authentication support
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick authored and wendigo committed May 3, 2024
1 parent 47eee6e commit aa08ec4
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Trino, and receive the resulting data.

* Native Go implementation
* Connections over HTTP or HTTPS
* HTTP Basic, and Kerberos authentication
* HTTP Basic, Kerberos, and JSON web token (JWT) authentication
* Per-query user information for access control
* Support custom HTTP client (tunable conn pools, timeouts, TLS)
* Supports conversion from Trino to native Go data types
Expand Down Expand Up @@ -60,7 +60,7 @@ db, err := sql.Open("trino", dsn)

### Authentication

Both HTTP Basic, and Kerberos authentication are supported.
Both HTTP Basic, Kerberos, and JWT authentication are supported.

#### HTTP Basic authentication

Expand All @@ -81,6 +81,17 @@ Please refer to the [Coordinator Kerberos
Authentication](https://trino.io/docs/current/security/server.html) for
server-side configuration.

#### JSON web token authentication

This driver supports JWT authentication by setting up the `AccessToken` field
in the
[Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config)
struct.

Please refer to the [Coordinator JWT
Authentication](https://trino.io/docs/current/security/jwt.html) for
server-side configuration.

#### System access control and per-query user information

It's possible to pass user information to Trino, different from the principal
Expand Down
13 changes: 13 additions & 0 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ const (
trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare`
trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare`

authorizationHeader = "Authorization"

kerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
kerberosPrincipalConfig = "KerberosPrincipal"
Expand All @@ -137,6 +139,7 @@ const (
kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName"
sslCertPathConfig = "SSLCertPath"
sslCertConfig = "SSLCert"
accessTokenConfig = "accessToken"
)

var (
Expand Down Expand Up @@ -175,6 +178,7 @@ type Config struct {
KerberosConfigPath string // The krb5 config path (optional)
SSLCertPath string // The SSL cert path for TLS verification (optional)
SSLCert string // The SSL cert for TLS verification (optional)
AccessToken string // An access token (JWT) for authentication (optional)
}

// FormatDSN returns a DSN string from the configuration.
Expand Down Expand Up @@ -256,6 +260,7 @@ func (c *Config) FormatDSN() (string, error) {
"session_properties": strings.Join(sessionkv, ","),
"extra_credentials": strings.Join(credkv, ","),
"custom_client": c.CustomClientName,
accessTokenConfig: c.AccessToken,
} {
if v != "" {
query[k] = []string{v}
Expand Down Expand Up @@ -371,6 +376,7 @@ func newConn(dsn string) (*Conn, error) {
trinoSchemaHeader: query.Get("schema"),
trinoSessionHeader: query.Get("session_properties"),
trinoExtraCredentialHeader: query.Get("extra_credentials"),
authorizationHeader: getAuthorization(query.Get(accessTokenConfig)),
} {
if v != "" {
c.httpHeaders.Add(k, v)
Expand All @@ -380,6 +386,13 @@ func newConn(dsn string) (*Conn, error) {
return c, nil
}

func getAuthorization(token string) string {
if token == "" {
return ""
}
return fmt.Sprintf("Bearer %s", token)
}

// registry for custom http clients
var customClientRegistry = struct {
sync.RWMutex
Expand Down
34 changes: 34 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,20 @@ func TestInvalidKerberosConfig(t *testing.T) {
assert.Error(t, err, "dsn generated from invalid secure url, since kerberos enabled must has SSL enabled")
}

func TestAccessTokenConfig(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8090",
AccessToken: "token",
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8090?accessToken=token&source=trino-go-client"

assert.Equal(t, want, dsn)
}

func TestConfigWithMalformedURL(t *testing.T) {
_, err := (&Config{ServerURI: ":("}).FormatDSN()
assert.Error(t, err, "dsn generated from malformed url")
Expand Down Expand Up @@ -270,6 +284,26 @@ func TestAuthFailure(t *testing.T) {
assert.NoError(t, db.Close())
}

func TestTokenAuth(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer token" {
w.WriteHeader(http.StatusUnauthorized)
} else {
w.WriteHeader(http.StatusOK)
}
}))

t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL+"?accessToken=token")
require.NoError(t, err)

_, err = db.Query("SELECT 1")
require.Error(t, err, "trino: EOF")

assert.NoError(t, db.Close())
}

func TestQueryForUsername(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
Expand Down

0 comments on commit aa08ec4

Please sign in to comment.