Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions core/project/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ func (s Service) Create(ctx context.Context, prj Project) (Project, error) {

func (s Service) List(ctx context.Context, f Filter) ([]Project, error) {
if f.Principal != nil {
if f.Principal.ID == "" || f.Principal.Type == "" {
return nil, fmt.Errorf("project: invalid principal filter")
if !utils.IsValidUUID(f.Principal.ID) {
return nil, ErrInvalidUUID
}
if f.Principal.Type == "" {
return nil, ErrInvalidPrincipalType
}
if s.membershipService == nil {
return nil, fmt.Errorf("project: membership service not wired")
Expand Down
39 changes: 28 additions & 11 deletions core/project/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,15 @@ func TestService_List(t *testing.T) {

func TestService_List_WithPrincipal(t *testing.T) {
ctx := context.Background()
userPrincipal := authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}
userPrincipal := authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c", Type: schema.UserPrincipal}

tests := []struct {
name string
setup func(*testing.T) *project.Service
filter project.Filter
want []project.Project
wantErr bool
name string
setup func(*testing.T) *project.Service
filter project.Filter
want []project.Project
wantErr bool
wantErrIs error
}{
{
name: "errors when membership service is not wired",
Expand All @@ -309,24 +310,37 @@ func TestService_List_WithPrincipal(t *testing.T) {
wantErr: true,
},
{
name: "errors when Principal has empty ID",
name: "returns ErrInvalidUUID when Principal has empty ID",
filter: project.Filter{Principal: &authenticate.Principal{Type: schema.UserPrincipal}},
setup: func(t *testing.T) *project.Service {
t.Helper()
repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t)
return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService)
},
wantErr: true,
wantErr: true,
wantErrIs: project.ErrInvalidUUID,
},
{
name: "errors when Principal has empty Type",
filter: project.Filter{Principal: &authenticate.Principal{ID: "user-id"}},
name: "returns ErrInvalidUUID when Principal ID is not a valid UUID",
filter: project.Filter{Principal: &authenticate.Principal{ID: "not-a-uuid", Type: schema.UserPrincipal}},
setup: func(t *testing.T) *project.Service {
t.Helper()
repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t)
return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService)
},
wantErr: true,
wantErr: true,
wantErrIs: project.ErrInvalidUUID,
},
{
name: "returns ErrInvalidPrincipalType when Principal has empty Type",
filter: project.Filter{Principal: &authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c"}},
setup: func(t *testing.T) *project.Service {
t.Helper()
repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t)
return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService)
},
wantErr: true,
wantErrIs: project.ErrInvalidPrincipalType,
},
{
name: "returns projects from the membership shim",
Expand Down Expand Up @@ -462,6 +476,9 @@ func TestService_List_WithPrincipal(t *testing.T) {
t.Errorf("List() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) {
t.Errorf("List() error = %v, want errors.Is(%v)", err, tt.wantErrIs)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("List() mismatch (-want +got):\n%s", diff)
}
Expand Down
3 changes: 3 additions & 0 deletions core/serviceuser/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ func (s Service) Create(ctx context.Context, serviceUser ServiceUser) (ServiceUs
}

func (s Service) Get(ctx context.Context, id string) (ServiceUser, error) {
if !utils.IsValidUUID(id) {
return ServiceUser{}, ErrInvalidID
}
return s.repo.GetByID(ctx, id)
}

Expand Down
50 changes: 50 additions & 0 deletions core/serviceuser/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,53 @@ func TestService_Delete(t *testing.T) {
})
}
}

func TestService_Get(t *testing.T) {
ctx := context.Background()
const validID = "68f86fec-eb87-49f0-9be0-8d99b00a4a9c"

tests := []struct {
name string
id string
setup func(*mocks.Repository)
wantErrIs error
}{
{
name: "empty id returns ErrInvalidID without hitting the repo",
id: "",
setup: func(repo *mocks.Repository) {},
wantErrIs: serviceuser.ErrInvalidID,
},
{
name: "non-uuid id returns ErrInvalidID without hitting the repo",
id: "not-a-uuid",
setup: func(repo *mocks.Repository) {},
wantErrIs: serviceuser.ErrInvalidID,
},
{
name: "valid uuid delegates to the repo",
id: validID,
setup: func(repo *mocks.Repository) {
repo.On("GetByID", ctx, validID).Return(serviceuser.ServiceUser{ID: validID}, nil)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc, repo, _, _, _ := newTestService(t)
tt.setup(repo)

_, err := svc.Get(ctx, tt.id)
if tt.wantErrIs != nil {
if !errors.Is(err, tt.wantErrIs) {
t.Errorf("Get() error = %v, want errors.Is(%v)", err, tt.wantErrIs)
}
return
}
if err != nil {
t.Errorf("Get() unexpected error = %v", err)
}
})
}
}
65 changes: 48 additions & 17 deletions internal/api/v1beta1connect/permission_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,36 @@ func logAuditForCheck(ctx context.Context, result bool, objectID string, objectN
}

func (h *ConnectHandler) getPermissionName(ctx context.Context, ns, name string) (string, error) {
resolved, ok, err := h.resolvePermissionName(ctx, ns, name)
if err != nil {
return "", connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
if !ok {
return "", connect.NewError(connect.CodeNotFound, ErrNotFound)
}
return resolved, nil
}

// resolvePermissionName looks up the canonical permission name for (namespace, name).
// ok=false means the permission is not defined; err is reserved for genuine lookup
// failures. Callers that want to treat an unknown permission as "no result"
// should use this helper; callers that want to reject the request should use
// getPermissionName which maps unknown permissions to CodeNotFound.
func (h *ConnectHandler) resolvePermissionName(ctx context.Context, ns, name string) (string, bool, error) {
if ns == schema.PlatformNamespace && schema.IsPlatformPermission(name) {
return name, nil
return name, true, nil
}
perm, err := h.permissionService.Get(ctx, permission.AddNamespaceIfRequired(ns, name))
if err != nil {
switch {
case errors.Is(err, permission.ErrNotExist):
return "", connect.NewError(connect.CodeNotFound, ErrNotFound)
default:
return "", connect.NewError(connect.CodeInternal, ErrInternalServerError)
if errors.Is(err, permission.ErrNotExist) {
return "", false, nil
}
return "", false, err
}
// if the permission is on the same namespace as the object, use the name
if perm.NamespaceID == ns {
return perm.Name, nil
return perm.Name, true, nil
}
// else use fully qualified name(slug)
return perm.Slug, nil
return perm.Slug, true, nil
}

func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, req *connect.Request[frontierv1beta1.CheckFederatedResourcePermissionRequest]) (*connect.Response[frontierv1beta1.CheckFederatedResourcePermissionResponse], error) {
Expand Down Expand Up @@ -94,19 +106,38 @@ func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, r
}

func (h *ConnectHandler) fetchAccessPairsOnResource(ctx context.Context, objectNamespace string, ids, permissions []string) ([]relation.CheckPair, error) {
checks := make([]resource.Check, 0, len(ids)*len(permissions))
// Resolve each requested permission once, dropping unknown names and
// duplicate inputs. Unknown names produce an empty result rather than
// 4xx/5xx — see the contract on resolvePermissionName.
resolvedPerms := make([]string, 0, len(permissions))
seen := make(map[string]struct{}, len(permissions))
for _, p := range permissions {
resolved, ok, err := h.resolvePermissionName(ctx, objectNamespace, p)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
if !ok {
continue
}
if _, dup := seen[resolved]; dup {
continue
}
seen[resolved] = struct{}{}
resolvedPerms = append(resolvedPerms, resolved)
}
if len(resolvedPerms) == 0 || len(ids) == 0 {
return []relation.CheckPair{}, nil
}

checks := make([]resource.Check, 0, len(ids)*len(resolvedPerms))
for _, id := range ids {
for _, permission := range permissions {
permissionName, err := h.getPermissionName(ctx, objectNamespace, permission)
if err != nil {
return nil, err
}
for _, p := range resolvedPerms {
checks = append(checks, resource.Check{
Object: relation.Object{
ID: id,
Namespace: objectNamespace,
},
Permission: permissionName,
Permission: p,
})
}
}
Expand Down
40 changes: 25 additions & 15 deletions internal/api/v1beta1connect/serviceuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/raystack/frontier/core/authenticate"
"github.com/raystack/frontier/core/organization"
"github.com/raystack/frontier/core/project"
"github.com/raystack/frontier/core/relation"
"github.com/raystack/frontier/core/serviceuser"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/raystack/frontier/pkg/errors"
Expand Down Expand Up @@ -99,7 +98,9 @@ func (h *ConnectHandler) GetServiceUser(ctx context.Context, request *connect.Re
"service_user_id", serviceUserID)

switch {
case err == serviceuser.ErrNotExist:
case errors.Is(err, serviceuser.ErrInvalidID):
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
case errors.Is(err, serviceuser.ErrNotExist):
return nil, connect.NewError(connect.CodeNotFound, ErrServiceUserNotFound)
default:
errorLogger.LogUnexpectedError(ctx, request, "GetServiceUser", err,
Expand Down Expand Up @@ -475,7 +476,17 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c
errorLogger.LogServiceError(ctx, request, "ListServiceUserProjects", err,
"service_user_id", serviceUserID,
"org_id", orgID)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)

switch {
case errors.Is(err, project.ErrInvalidUUID),
errors.Is(err, project.ErrInvalidPrincipalType):
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest)
default:
errorLogger.LogUnexpectedError(ctx, request, "ListServiceUserProjects", err,
"service_user_id", serviceUserID,
"org_id", orgID)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
}

var projects []*frontierv1beta1.Project
Expand All @@ -501,20 +512,19 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c
"with_permissions", request.Msg.GetWithPermissions())
return nil, err
}
for _, successCheck := range successCheckPairs {
resID := successCheck.Relation.Object.ID

// find all permission checks on same resource
pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool {
return pair.Relation.Object.ID == resID
})
// fetch permissions
permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string {
return pair.Relation.RelationName
})
permsByProject := map[string][]string{}
projectOrder := make([]string, 0, len(projList))
for _, p := range successCheckPairs {
resID := p.Relation.Object.ID
if _, seen := permsByProject[resID]; !seen {
projectOrder = append(projectOrder, resID)
}
permsByProject[resID] = append(permsByProject[resID], p.Relation.RelationName)
}
for _, resID := range projectOrder {
accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{
ProjectId: resID,
Permissions: permissions,
Permissions: permsByProject[resID],
})
}
}
Expand Down
Loading
Loading