Skip to content

Commit

Permalink
directory/azure: add paging support to user group members call (#2311) (
Browse files Browse the repository at this point in the history
#2312)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
  • Loading branch information
github-actions[bot] and calebdoxsey committed Jun 24, 2021
1 parent 88e1458 commit e5d4c82
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 51 deletions.
45 changes: 45 additions & 0 deletions internal/directory/azure/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package azure

import "strings"

type (
apiGetUserResponse struct {
apiUser
}
apiGetUserMembersResponse struct {
Context string `json:"@odata.context"`
NextLink string `json:"@odata.nextLink,omitempty"`
Value []apiGroup `json:"value"`
}

apiGroup struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
}
apiUser struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Mail string `json:"mail"`
UserPrincipalName string `json:"userPrincipalName"`
}
)

func (obj apiUser) getEmail() string {
if obj.Mail != "" {
return obj.Mail
}

// AD often doesn't have the email address returned, but we can parse it from the UPN

// UPN looks like:
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
email := obj.UserPrincipalName
if idx := strings.Index(email, "#EXT"); idx > 0 {
email = email[:idx]
}
// find the last _ and replace it with @
if idx := strings.LastIndex(email, "_"); idx > 0 {
email = email[:idx] + "@" + email[idx+1:]
}
return email
}
39 changes: 24 additions & 15 deletions internal/directory/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,30 +115,17 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc
Path: fmt.Sprintf("/v1.0/users/%s", userID),
}).String()

var u usersDeltaResponseUser
var u apiGetUserResponse
err := p.api(ctx, userURL, &u)
if err != nil {
return nil, err
}
du.DisplayName = u.DisplayName
du.Email = u.getEmail()

groupURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
}).String()

var res struct {
Value []usersDeltaResponseUser `json:"value"`
}
err = p.api(ctx, groupURL, &res)
du.GroupIds, err = p.transitiveMemberOf(ctx, userID)
if err != nil {
return nil, err
}
for _, g := range res.Value {
du.GroupIds = append(du.GroupIds, g.ID)
}

sort.Strings(du.GroupIds)

return du, nil
}
Expand Down Expand Up @@ -246,6 +233,28 @@ func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
return p.token, nil
}

func (p *Provider) transitiveMemberOf(ctx context.Context, userID string) (groupIDs []string, err error) {
apiURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
}).String()
for {
var res apiGetUserMembersResponse
err := p.api(ctx, apiURL, &res)
if err != nil {
return nil, err
}
for _, g := range res.Value {
groupIDs = append(groupIDs, g.ID)
}
if res.NextLink == "" {
break
}
apiURL = res.NextLink
}
sort.Strings(groupIDs)
return groupIDs, nil
}

// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
type ServiceAccount struct {
ClientID string `json:"client_id"`
Expand Down
47 changes: 41 additions & 6 deletions internal/directory/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"

"github.com/go-chi/chi"
Expand Down Expand Up @@ -86,11 +87,28 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user-1":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "admin"},
},
})
switch r.URL.Query().Get("page") {
case "":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "admin"},
},
"@odata.nextLink": getPageURL(r, 1),
})
case "1":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "group1"},
},
"@odata.nextLink": getPageURL(r, 2),
})
case "2":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "group2"},
},
})
}
default:
http.Error(w, "not found", http.StatusNotFound)
}
Expand Down Expand Up @@ -126,7 +144,7 @@ func TestProvider_User(t *testing.T) {
"id": "user-1",
"displayName": "User 1",
"email": "user1@example.com",
"groupIds": ["admin"]
"groupIds": ["admin", "group1", "group2"]
}`, du)
}

Expand Down Expand Up @@ -219,3 +237,20 @@ func mustParseURL(rawurl string) *url.URL {
}
return u
}

func getPageURL(r *http.Request, page int) string {
var u url.URL
u = *r.URL
if r.TLS == nil {
u.Scheme = "http"
} else {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = r.Host
}
q := u.Query()
q.Set("page", strconv.Itoa(page))
u.RawQuery = q.Encode()
return u.String()
}
35 changes: 5 additions & 30 deletions internal/directory/azure/delta.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"net/url"
"sort"
"strings"

"github.com/pomerium/pomerium/pkg/grpc/directory"
)
Expand Down Expand Up @@ -229,10 +228,9 @@ type (
Value []groupsDeltaResponseGroup `json:"value"`
}
groupsDeltaResponseGroup struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
apiGroup
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
}
groupsDeltaResponseGroupMember struct {
Type string `json:"@odata.type"`
Expand All @@ -247,30 +245,7 @@ type (
Value []usersDeltaResponseUser `json:"value"`
}
usersDeltaResponseUser struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Mail string `json:"mail"`
UserPrincipalName string `json:"userPrincipalName"`
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
apiUser
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
}
)

func (obj usersDeltaResponseUser) getEmail() string {
if obj.Mail != "" {
return obj.Mail
}

// AD often doesn't have the email address returned, but we can parse it from the UPN

// UPN looks like:
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
email := obj.UserPrincipalName
if idx := strings.Index(email, "#EXT"); idx > 0 {
email = email[:idx]
}
// find the last _ and replace it with @
if idx := strings.LastIndex(email, "_"); idx > 0 {
email = email[:idx] + "@" + email[idx+1:]
}
return email
}

0 comments on commit e5d4c82

Please sign in to comment.