forked from denisenkom/go-mssqldb
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af4fee2
commit d5f257b
Showing
8 changed files
with
1,007 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
package azuread | ||
|
||
import ( | ||
"context" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"fmt" | ||
"os" | ||
|
||
"github.com/Azure/go-autorest/autorest/adal" | ||
) | ||
|
||
// When the security token library is used, the token is obtained without input | ||
// from the server, so the AD endpoint and Azure SQL resource URI are provided | ||
// from the constants below. | ||
var ( | ||
// activeDirectoryEndpoint is the security token service URL to use when | ||
// the server does not provide the URL. | ||
activeDirectoryEndpoint = "https://login.microsoftonline.com/" | ||
) | ||
|
||
func init() { | ||
endpoint := os.Getenv("AZURE_AD_STS_URL") | ||
if endpoint != "" { | ||
activeDirectoryEndpoint = endpoint | ||
} | ||
} | ||
|
||
const ( | ||
// azureSQLResource is the AD resource to use when the server does not | ||
// provide the resource. | ||
azureSQLResource = "https://database.windows.net/" | ||
|
||
// driverClientID is the AD client ID to use when performing a username | ||
// and password login. | ||
driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db" | ||
) | ||
|
||
func retrieveToken(ctx context.Context, token *adal.ServicePrincipalToken) (string, error) { | ||
err := token.RefreshWithContext(ctx) | ||
if err != nil { | ||
err = fmt.Errorf("Failed to refresh token: %v", err) | ||
return "", err | ||
} | ||
|
||
return token.Token().AccessToken, nil | ||
} | ||
|
||
// SecurityTokenFromCertificate obtains a security token using a certificate and RSA private key. | ||
func SecurityTokenFromCertificate(ctx context.Context, clientID, tenantID string, certificate *x509.Certificate, privateKey *rsa.PrivateKey) (string, error) { | ||
// The activeDirectoryEndpoint URL is used as a base against which the | ||
// tenant ID is resolved. | ||
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) | ||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", | ||
activeDirectoryEndpoint, tenantID, err) | ||
return "", err | ||
} | ||
|
||
token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, privateKey, azureSQLResource) | ||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) | ||
return "", err | ||
} | ||
|
||
return retrieveToken(ctx, token) | ||
} | ||
|
||
// SecurityTokenFromSecret obtains a security token using a client ID and secret. | ||
func SecurityTokenFromSecret(ctx context.Context, clientID, tenantID, clientSecret string) (string, error) { | ||
// The activeDirectoryEndpoint URL is used as a base against which the | ||
// tenant ID is resolved. | ||
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) | ||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", | ||
activeDirectoryEndpoint, tenantID, err) | ||
return "", err | ||
} | ||
|
||
token, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, azureSQLResource) | ||
|
||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) | ||
return "", err | ||
} | ||
|
||
return retrieveToken(ctx, token) | ||
} | ||
|
||
// ActiveDirectoryTokenFromPassword obtains a security token using an Active Directory username and password. | ||
func ActiveDirectoryTokenFromPassword(ctx context.Context, serverSPN, stsURL, user, password string) (string, error) { | ||
// The activeDirectoryEndpoint URL is used as a base against which the | ||
// STS URL is resolved. However, the STS URL is normally absolute and | ||
// the activeDirectoryEndpoint URL is completely ignored. | ||
oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, stsURL) | ||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", | ||
activeDirectoryEndpoint, stsURL, err) | ||
return "", err | ||
} | ||
|
||
token, err := adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, user, password, serverSPN) | ||
|
||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain token for user %s for resource %s from service %s: %v", user, serverSPN, stsURL, err) | ||
return "", err | ||
} | ||
|
||
return retrieveToken(ctx, token) | ||
} | ||
|
||
// ActiveDirectoryTokenFromIdentity obtains a security token the managed identity service. | ||
func ActiveDirectoryTokenFromIdentity(ctx context.Context, serverSPN, stsURL, clientID string) (string, error) { | ||
msiEndpoint, err := adal.GetMSIEndpoint() | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
var token *adal.ServicePrincipalToken | ||
var access string | ||
if clientID == "" { | ||
access = "system identity" | ||
token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, serverSPN) | ||
} else { | ||
access = "user-assigned identity " + clientID | ||
token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, serverSPN, clientID) | ||
} | ||
|
||
if err != nil { | ||
err = fmt.Errorf("Failed to obtain token for %s for resource %s from service %s: %v", access, serverSPN, stsURL, err) | ||
return "", err | ||
} | ||
|
||
return retrieveToken(ctx, token) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
package azuread | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"net/url" | ||
"os" | ||
"strings" | ||
"testing" | ||
|
||
mssql "github.com/denisenkom/go-mssqldb" | ||
) | ||
|
||
type testLogger struct { | ||
t *testing.T | ||
} | ||
|
||
func (l testLogger) Printf(format string, v ...interface{}) { | ||
l.t.Logf(format, v...) | ||
} | ||
|
||
func (l testLogger) Println(v ...interface{}) { | ||
l.t.Log(v...) | ||
} | ||
|
||
func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) { | ||
u := &url.URL{ | ||
Scheme: "sqlserver", | ||
Host: os.Getenv("SQL_SERVER"), | ||
} | ||
|
||
if u.Host == "" { | ||
t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable") | ||
} | ||
|
||
database := os.Getenv("SQL_DATABASE") | ||
if database == "" { | ||
t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable") | ||
} | ||
|
||
tenantID := os.Getenv("AZURE_TENANT_ID") | ||
if tenantID == "" { | ||
t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable") | ||
} | ||
|
||
query := u.Query() | ||
|
||
query.Add("database", database) | ||
query.Add("encrypt", "true") | ||
query.Add("fedauth", fedAuth) | ||
|
||
u.RawQuery = query.Encode() | ||
|
||
return u, tenantID | ||
} | ||
|
||
func checkFedAuthUserPassword(t *testing.T) *url.URL { | ||
u, _ := checkAzureSQLEnvironment("ActiveDirectoryPassword", t) | ||
|
||
username := os.Getenv("SQL_AD_ADMIN_USER") | ||
password := os.Getenv("SQL_AD_ADMIN_PASSWORD") | ||
|
||
if username == "" || password == "" { | ||
t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables") | ||
} | ||
|
||
u.User = url.UserPassword(username, password) | ||
|
||
return u | ||
} | ||
|
||
func checkFedAuthAppPassword(t *testing.T) *url.URL { | ||
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryApplication", t) | ||
|
||
appClientID := os.Getenv("APP_SP_CLIENT_ID") | ||
appPassword := os.Getenv("APP_SP_CLIENT_SECRET") | ||
|
||
if appClientID == "" || appPassword == "" { | ||
t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables") | ||
} | ||
|
||
u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword) | ||
|
||
return u | ||
} | ||
|
||
func checkFedAuthAppCertPath(t *testing.T) *url.URL { | ||
u := checkFedAuthAppPassword(t) | ||
|
||
appCertPath := os.Getenv("APP_SP_CLIENT_CERT") | ||
if appCertPath == "" { | ||
t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate") | ||
} | ||
|
||
query := u.Query() | ||
query.Add("clientcertpath", appCertPath) | ||
u.RawQuery = query.Encode() | ||
|
||
return u | ||
} | ||
|
||
func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) { | ||
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) | ||
|
||
vmClientID := os.Getenv("VM_CLIENT_ID") | ||
if vmClientID == "" { | ||
t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable") | ||
} | ||
|
||
return u, vmClientID + "@" + tenantID | ||
} | ||
|
||
func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) { | ||
u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) | ||
|
||
uaClientID := os.Getenv("UA_CLIENT_ID") | ||
if uaClientID == "" { | ||
t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable") | ||
} | ||
|
||
u.User = url.User(uaClientID) | ||
|
||
return u, uaClientID + "@" + tenantID | ||
} | ||
|
||
func checkLoggedInUser(expected string, u *url.URL, t *testing.T) { | ||
db, err := sql.Open(DriverName, u.String()) | ||
if err != nil { | ||
t.Fatalf("Failed to open URL %v: %v", u, err) | ||
} | ||
|
||
defer db.Close() | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
sql := "SELECT SUSER_NAME()" | ||
|
||
stmt, err := db.PrepareContext(ctx, sql) | ||
if err != nil { | ||
t.Fatalf("Failed to prepare query %s: %v", sql, err) | ||
} | ||
|
||
defer stmt.Close() | ||
|
||
rows, err := stmt.QueryContext(ctx) | ||
if err != nil { | ||
t.Fatalf("Failed to fetch query result for %s: %v", sql, err) | ||
} | ||
|
||
defer rows.Close() | ||
|
||
var username string | ||
if !rows.Next() { | ||
t.Fatalf("Empty result set for query %s", sql) | ||
} | ||
|
||
err = rows.Scan(&username) | ||
if err != nil { | ||
t.Fatalf("Failed to fetch first row for %s: %v", sql, err) | ||
} | ||
|
||
if !strings.EqualFold(username, expected) { | ||
t.Fatalf("Expected username %s: actual: %s", expected, username) | ||
} | ||
|
||
t.Logf("Logged in username %s matches expected %s", username, expected) | ||
} | ||
|
||
func TestFedAuthWithUserAndPassword(t *testing.T) { | ||
mssql.SetLogger(testLogger{t}) | ||
u := checkFedAuthUserPassword(t) | ||
|
||
checkLoggedInUser(u.User.Username(), u, t) | ||
} | ||
|
||
func TestFedAuthWithApplicationUsingPassword(t *testing.T) { | ||
mssql.SetLogger(testLogger{t}) | ||
u := checkFedAuthAppPassword(t) | ||
|
||
checkLoggedInUser(u.User.Username(), u, t) | ||
} | ||
|
||
func TestFedAuthWithApplicationUsingCertificate(t *testing.T) { | ||
mssql.SetLogger(testLogger{t}) | ||
u := checkFedAuthAppCertPath(t) | ||
|
||
checkLoggedInUser(u.User.Username(), u, t) | ||
} | ||
|
||
func TestFedAuthWithSystemAssignedIdentity(t *testing.T) { | ||
u, vmName := checkFedAuthVMSystemID(t) | ||
mssql.SetLogger(testLogger{t}) | ||
|
||
checkLoggedInUser(vmName, u, t) | ||
} | ||
|
||
func TestFedAuthWithUserAssignedIdentity(t *testing.T) { | ||
mssql.SetLogger(testLogger{t}) | ||
u, uaName := checkFedAuthVMUserAssignedID(t) | ||
|
||
checkLoggedInUser(uaName, u, t) | ||
} |
Oops, something went wrong.