Skip to content

Commit

Permalink
Implement Azure AD token provider.
Browse files Browse the repository at this point in the history
  • Loading branch information
wrosenuance committed Dec 17, 2020
1 parent a5f79d7 commit 1d891d2
Show file tree
Hide file tree
Showing 7 changed files with 910 additions and 3 deletions.
152 changes: 152 additions & 0 deletions azuread/adal_tokens.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package azuread

import (
"context"
"errors"
"fmt"
"os"

"github.com/Azure/go-autorest/autorest/adal"
)

// When the security token library is used, the token is obtained without input
// from the server, so the AD endpoint and Azure SQL resource URI are provided
// from the constants below.
var (
// activeDirectoryEndpoint is the security token service URL to use when
// the server does not provide the URL.
activeDirectoryEndpoint = "https://login.microsoftonline.com/"
)

func init() {
endpoint := os.Getenv("AZURE_AD_STS_URL")
if endpoint != "" {
activeDirectoryEndpoint = endpoint
}
}

const (
// azureSQLResource is the AD resource to use when the server does not
// provide the resource.
azureSQLResource = "https://database.windows.net/"

// driverClientID is the AD client ID to use when performing a username
// and password login.
driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db"
)

func retrieveToken(ctx context.Context, token *adal.ServicePrincipalToken) (string, error) {
err := token.RefreshWithContext(ctx)
if err != nil {
err = fmt.Errorf("Failed to refresh token: %v", err)
return "", err
}

return token.Token().AccessToken, nil
}

func (p *azureFedAuthProvider) ProvideSecurityToken(ctx context.Context) (string, error) {
switch {
case p.certificate != nil && p.privateKey != nil:
return p.securityTokenFromCertificate(ctx)
case p.clientSecret != "":
return p.securityTokenFromSecret(ctx)
}

return "", errors.New("Client certificate and key, or client secret, required for service principal login")
}

func (p *azureFedAuthProvider) securityTokenFromCertificate(ctx context.Context) (string, error) {
// The activeDirectoryEndpoint URL is used as a base against which the
// tenant ID is resolved.
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, p.tenantID)
if err != nil {
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v",
activeDirectoryEndpoint, p.tenantID, err)
return "", err
}

token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, p.clientID, p.certificate, p.privateKey, azureSQLResource)
if err != nil {
err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", p.clientID, p.tenantID, err)
return "", err
}

return retrieveToken(ctx, token)
}

func (p *azureFedAuthProvider) securityTokenFromSecret(ctx context.Context) (string, error) {
// The activeDirectoryEndpoint URL is used as a base against which the
// tenant ID is resolved.
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, p.tenantID)
if err != nil {
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v",
activeDirectoryEndpoint, p.tenantID, err)
return "", err
}

token, err := adal.NewServicePrincipalToken(*oauthConfig, p.clientID, p.clientSecret, azureSQLResource)

if err != nil {
err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", p.clientID, p.tenantID, err)
return "", err
}

return retrieveToken(ctx, token)
}

func (p *azureFedAuthProvider) ProvideActiveDirectoryToken(ctx context.Context, serverSPN, stsURL string) (string, error) {
switch p.adalWorkflow {
case fedAuthADALWorkflowPassword:
return p.activeDirectoryTokenFromPassword(ctx, serverSPN, stsURL)
case fedAuthADALWorkflowMSI:
return p.activeDirectoryTokenFromIdentity(ctx, serverSPN, stsURL)
}

return "", fmt.Errorf("ADAL workflow id %d not supported", p.adalWorkflow)
}

func (p *azureFedAuthProvider) activeDirectoryTokenFromPassword(ctx context.Context, serverSPN, stsURL string) (string, error) {
// The activeDirectoryEndpoint URL is used as a base against which the
// STS URL is resolved. However, the STS URL is normally absolute and
// the activeDirectoryEndpoint URL is completely ignored.
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, stsURL)
if err != nil {
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v",
activeDirectoryEndpoint, stsURL, err)
return "", err
}

token, err := adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, p.user, p.password, serverSPN)

if err != nil {
err = fmt.Errorf("Failed to obtain token for user %s for resource %s from service %s: %v", p.user, serverSPN, stsURL, err)
return "", err
}

return retrieveToken(ctx, token)
}

func (p *azureFedAuthProvider) activeDirectoryTokenFromIdentity(ctx context.Context, serverSPN, stsURL string) (string, error) {
msiEndpoint, err := adal.GetMSIEndpoint()
if err != nil {
return "", err
}

var token *adal.ServicePrincipalToken
var access string
if p.clientID == "" {
access = "system identity"
token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, serverSPN)
} else {
access = "user-assigned identity " + p.clientID
token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, serverSPN, p.clientID)
}

if err != nil {
err = fmt.Errorf("Failed to obtain token for %s for resource %s from service %s: %v", access, serverSPN, stsURL, err)
return "", err
}

return retrieveToken(ctx, token)
}
203 changes: 203 additions & 0 deletions azuread/adal_tokens_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package azuread

import (
"context"
"database/sql"
"net/url"
"os"
"strings"
"testing"

mssql "github.com/denisenkom/go-mssqldb"
)

type testLogger struct {
t *testing.T
}

func (l testLogger) Printf(format string, v ...interface{}) {
l.t.Logf(format, v...)
}

func (l testLogger) Println(v ...interface{}) {
l.t.Log(v...)
}

func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) {
u := &url.URL{
Scheme: "sqlserver",
Host: os.Getenv("SQL_SERVER"),
}

if u.Host == "" {
t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable")
}

database := os.Getenv("SQL_DATABASE")
if database == "" {
t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable")
}

tenantID := os.Getenv("AZURE_TENANT_ID")
if tenantID == "" {
t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable")
}

query := u.Query()

query.Add("database", database)
query.Add("encrypt", "true")
query.Add("fedauth", fedAuth)

u.RawQuery = query.Encode()

return u, tenantID
}

func checkFedAuthUserPassword(t *testing.T) *url.URL {
u, _ := checkAzureSQLEnvironment("ActiveDirectoryPassword", t)

username := os.Getenv("SQL_AD_ADMIN_USER")
password := os.Getenv("SQL_AD_ADMIN_PASSWORD")

if username == "" || password == "" {
t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables")
}

u.User = url.UserPassword(username, password)

return u
}

func checkFedAuthAppPassword(t *testing.T) *url.URL {
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryApplication", t)

appClientID := os.Getenv("APP_SP_CLIENT_ID")
appPassword := os.Getenv("APP_SP_CLIENT_SECRET")

if appClientID == "" || appPassword == "" {
t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables")
}

u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword)

return u
}

func checkFedAuthAppCertPath(t *testing.T) *url.URL {
u := checkFedAuthAppPassword(t)

appCertPath := os.Getenv("APP_SP_CLIENT_CERT")
if appCertPath == "" {
t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate")
}

query := u.Query()
query.Add("clientcertpath", appCertPath)
u.RawQuery = query.Encode()

return u
}

func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) {
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t)

vmClientID := os.Getenv("VM_CLIENT_ID")
if vmClientID == "" {
t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable")
}

return u, vmClientID + "@" + tenantID
}

func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) {
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t)

uaClientID := os.Getenv("UA_CLIENT_ID")
if uaClientID == "" {
t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable")
}

u.User = url.User(uaClientID)

return u, uaClientID + "@" + tenantID
}

func checkLoggedInUser(expected string, u *url.URL, t *testing.T) {
db, err := sql.Open(DriverName, u.String())
if err != nil {
t.Fatalf("Failed to open URL %v: %v", u, err)
}

defer db.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sql := "SELECT SUSER_NAME()"

stmt, err := db.PrepareContext(ctx, sql)
if err != nil {
t.Fatalf("Failed to prepare query %s: %v", sql, err)
}

defer stmt.Close()

rows, err := stmt.QueryContext(ctx)
if err != nil {
t.Fatalf("Failed to fetch query result for %s: %v", sql, err)
}

defer rows.Close()

var username string
if !rows.Next() {
t.Fatalf("Empty result set for query %s", sql)
}

err = rows.Scan(&username)
if err != nil {
t.Fatalf("Failed to fetch first row for %s: %v", sql, err)
}

if !strings.EqualFold(username, expected) {
t.Fatalf("Expected username %s: actual: %s", expected, username)
}

t.Logf("Logged in username %s matches expected %s", username, expected)
}

func TestFedAuthWithUserAndPassword(t *testing.T) {
mssql.SetLogger(testLogger{t})
u := checkFedAuthUserPassword(t)

checkLoggedInUser(u.User.Username(), u, t)
}

func TestFedAuthWithApplicationUsingPassword(t *testing.T) {
mssql.SetLogger(testLogger{t})
u := checkFedAuthAppPassword(t)

checkLoggedInUser(u.User.Username(), u, t)
}

func TestFedAuthWithApplicationUsingCertificate(t *testing.T) {
mssql.SetLogger(testLogger{t})
u := checkFedAuthAppCertPath(t)

checkLoggedInUser(u.User.Username(), u, t)
}

func TestFedAuthWithSystemAssignedIdentity(t *testing.T) {
u, vmName := checkFedAuthVMSystemID(t)
mssql.SetLogger(testLogger{t})

checkLoggedInUser(vmName, u, t)
}

func TestFedAuthWithUserAssignedIdentity(t *testing.T) {
mssql.SetLogger(testLogger{t})
u, uaName := checkFedAuthVMUserAssignedID(t)

checkLoggedInUser(uaName, u, t)
}

0 comments on commit 1d891d2

Please sign in to comment.