Skip to content

Commit

Permalink
Use ADAL for multitenant auth or when requested. #1566
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas11 committed Aug 24, 2022
1 parent ccbe54a commit 0c8a108
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 20 deletions.
42 changes: 30 additions & 12 deletions provider/pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ type azureNativeProvider struct {
converter *resources.SdkShapeConverter
customResources map[string]*resources.CustomResource
rgLocationMap map[string]string
adalTokenClient tokenGetter
msalTokenClient tokenGetter
}

func makeProvider(host *provider.HostClient, name, version string, schemaBytes []byte, schemaString string,
Expand All @@ -103,15 +105,17 @@ func makeProvider(host *provider.HostClient, name, version string, schemaBytes [

// Return the new provider
return &azureNativeProvider{
host: host,
name: name,
version: version,
client: client,
resourceMap: resourceMap,
config: map[string]string{},
schemaBytes: schemaBytes,
converter: &converter,
rgLocationMap: map[string]string{},
host: host,
name: name,
version: version,
client: client,
resourceMap: resourceMap,
config: map[string]string{},
schemaBytes: schemaBytes,
converter: &converter,
rgLocationMap: map[string]string{},
adalTokenClient: getOAuthTokenADAL,
msalTokenClient: getOAuthTokenMSAL,
}, nil
}

Expand Down Expand Up @@ -250,7 +254,7 @@ func (k *azureNativeProvider) Invoke(ctx context.Context, req *rpc.InvokeRequest
if endpointArg := args["endpoint"]; endpointArg.HasValue() && endpointArg.IsString() {
endpoint = endpointArg.StringValue()
}
token, err := k.getOAuthTokenNew(ctx, auth, endpoint)
token, err := k.getOAuthToken(ctx, auth, endpoint)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1748,6 +1752,20 @@ func (k *azureNativeProvider) getAuthorizers(authConfig *authentication.Config)
}

func (k *azureNativeProvider) getOAuthToken(ctx context.Context, auth *authentication.Config, endpoint string) (string, error) {
if len(auth.AuxiliaryTenantIDs) > 0 {
return k.adalTokenClient(ctx, k, auth, endpoint)
}
if k.getConfig("useLegacyADALAuth", "USE_LEGACY_ADAL_AUTH") == "true" {
// TODO print warning
return k.adalTokenClient(ctx, k, auth, endpoint)
}
return k.msalTokenClient(ctx, k, auth, endpoint)
}

type tokenGetter func(context.Context, *azureNativeProvider, *authentication.Config, string) (string, error)

// Obtain a token via the deprecated ADAL method, using the hashicorp/go-azure-helpers/authentication package
func getOAuthTokenADAL(ctx context.Context, k *azureNativeProvider, auth *authentication.Config, endpoint string) (string, error) {
buildSender := sender.BuildSender("AzureNative")
oauthConfig, err := auth.BuildOAuthConfig(k.environment.ActiveDirectoryEndpoint)
authorizer, err := auth.GetAuthorizationToken(buildSender, oauthConfig, endpoint)
Expand Down Expand Up @@ -1776,7 +1794,8 @@ func (k *azureNativeProvider) getOAuthToken(ctx context.Context, auth *authentic
return token, nil
}

func (k *azureNativeProvider) getOAuthTokenNew(ctx context.Context, auth *authentication.Config, endpoint string) (string, error) {
// Obtain a token via the new MSAL method, using the azidentity package
func getOAuthTokenMSAL(ctx context.Context, k *azureNativeProvider, auth *authentication.Config, endpoint string) (string, error) {
clientOpts := azcore.ClientOptions{Cloud: k.cloud()}

// There are several ways to obtain a token. We try, in order: client cert, client secret, managed service
Expand All @@ -1803,7 +1822,6 @@ func (k *azureNativeProvider) getOAuthTokenNew(ctx context.Context, auth *authen
if err != nil {
return "", err
}
// TODO,tkappler get rid of auth here
cred, err := azidentity.NewClientCertificateCredential(auth.TenantID, auth.ClientID, certs, key,
&azidentity.ClientCertificateCredentialOptions{ClientOptions: clientOpts})
if err != nil {
Expand Down
95 changes: 87 additions & 8 deletions provider/pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,104 @@
package provider

import (
"context"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/hashicorp/go-azure-helpers/authentication"
"github.com/stretchr/testify/assert"
)

func TestMapsAutorestCloudToAzureSdkCloud(t *testing.T) {
k := azureNativeProvider{}
p := azureNativeProvider{}

// default cloud
assert.Equal(t, k.cloud(), cloud.AzurePublic)
assert.Equal(t, p.cloud(), cloud.AzurePublic)

k.environment = azure.PublicCloud
assert.Equal(t, k.cloud(), cloud.AzurePublic)
p.environment = azure.PublicCloud
assert.Equal(t, p.cloud(), cloud.AzurePublic)

k.environment = azure.ChinaCloud
assert.Equal(t, k.cloud(), cloud.AzureChina)
p.environment = azure.ChinaCloud
assert.Equal(t, p.cloud(), cloud.AzureChina)

k.environment = azure.USGovernmentCloud
assert.Equal(t, k.cloud(), cloud.AzureGovernment)
p.environment = azure.USGovernmentCloud
assert.Equal(t, p.cloud(), cloud.AzureGovernment)
}

func TestUseMSALByDefault(t *testing.T) {
p, usedADAL, usedMSAL := setUpProviderWithMockTokenGetters()

callGetToken(t, p)
assert.False(t, *usedADAL)
assert.True(t, *usedMSAL)
}

func TestUseMSALForManagedIdentity(t *testing.T) {
p, usedADAL, usedMSAL := setUpProviderWithMockTokenGetters()
p.config["useMsi"] = "true"

callGetToken(t, p)
assert.False(t, *usedADAL)
assert.True(t, *usedMSAL)
}

func TestUseMSALForServicePrincipalSecret(t *testing.T) {
p, usedADAL, usedMSAL := setUpProviderWithMockTokenGetters()
p.config["clientSecret"] = "verysecret"
p.config["subscriptionId"] = "123"
p.config["clientId"] = "456"
p.config["tenantId"] = "789"

callGetToken(t, p)
assert.False(t, *usedADAL)
assert.True(t, *usedMSAL)
}

func TestUseADALForMultitenantAuth(t *testing.T) {
p, usedADAL, usedMSAL := setUpProviderWithMockTokenGetters()

p.config["auxiliaryTenantIds"] = "[\"1\"]"
callGetToken(t, p)
assert.True(t, *usedADAL)
assert.False(t, *usedMSAL)
}

func TestUseADALWhenRequested(t *testing.T) {
p, usedADAL, usedMSAL := setUpProviderWithMockTokenGetters()

p.config["useLegacyADALAuth"] = "true"
p.config["auxiliaryTenantIds"] = "[]" // even without multitenant auth
callGetToken(t, p)
assert.True(t, *usedADAL)
assert.False(t, *usedMSAL)
}

func callGetToken(t *testing.T, p *azureNativeProvider) {
auth, err := p.getAuthConfig()
assert.NoError(t, err)

p.getOAuthToken(context.Background(), auth, "")
}

// create a new provider where the func's to obtain auth tokens instead just update the bool return values if they were called
func setUpProviderWithMockTokenGetters() (p *azureNativeProvider, usedADAL *bool, usedMSAL *bool) {
p = &azureNativeProvider{}
p.config = map[string]string{}

var adal bool
p.adalTokenClient = func(ctx context.Context, k *azureNativeProvider, auth *authentication.Config, endpoint string) (string, error) {
adal = true
return "", nil
}

var msal bool
p.msalTokenClient = func(ctx context.Context, k *azureNativeProvider, auth *authentication.Config, endpoint string) (string, error) {
msal = true
return "", nil
}

usedADAL = &adal
usedMSAL = &msal
return
}

0 comments on commit 0c8a108

Please sign in to comment.