diff --git a/core/project/service.go b/core/project/service.go index deac0f329..c48c86c97 100644 --- a/core/project/service.go +++ b/core/project/service.go @@ -133,8 +133,11 @@ func (s Service) Create(ctx context.Context, prj Project) (Project, error) { func (s Service) List(ctx context.Context, f Filter) ([]Project, error) { if f.Principal != nil { - if f.Principal.ID == "" || f.Principal.Type == "" { - return nil, fmt.Errorf("project: invalid principal filter") + if !utils.IsValidUUID(f.Principal.ID) { + return nil, ErrInvalidUUID + } + if f.Principal.Type == "" { + return nil, ErrInvalidPrincipalType } if s.membershipService == nil { return nil, fmt.Errorf("project: membership service not wired") diff --git a/core/project/service_test.go b/core/project/service_test.go index 7289ff1b8..dfdad3c41 100644 --- a/core/project/service_test.go +++ b/core/project/service_test.go @@ -288,14 +288,15 @@ func TestService_List(t *testing.T) { func TestService_List_WithPrincipal(t *testing.T) { ctx := context.Background() - userPrincipal := authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal} + userPrincipal := authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c", Type: schema.UserPrincipal} tests := []struct { - name string - setup func(*testing.T) *project.Service - filter project.Filter - want []project.Project - wantErr bool + name string + setup func(*testing.T) *project.Service + filter project.Filter + want []project.Project + wantErr bool + wantErrIs error }{ { name: "errors when membership service is not wired", @@ -309,24 +310,37 @@ func TestService_List_WithPrincipal(t *testing.T) { wantErr: true, }, { - name: "errors when Principal has empty ID", + name: "returns ErrInvalidUUID when Principal has empty ID", filter: project.Filter{Principal: &authenticate.Principal{Type: schema.UserPrincipal}}, setup: func(t *testing.T) *project.Service { t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, - wantErr: true, + wantErr: true, + wantErrIs: project.ErrInvalidUUID, }, { - name: "errors when Principal has empty Type", - filter: project.Filter{Principal: &authenticate.Principal{ID: "user-id"}}, + name: "returns ErrInvalidUUID when Principal ID is not a valid UUID", + filter: project.Filter{Principal: &authenticate.Principal{ID: "not-a-uuid", Type: schema.UserPrincipal}}, setup: func(t *testing.T) *project.Service { t.Helper() repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) }, - wantErr: true, + wantErr: true, + wantErrIs: project.ErrInvalidUUID, + }, + { + name: "returns ErrInvalidPrincipalType when Principal has empty Type", + filter: project.Filter{Principal: &authenticate.Principal{ID: "68f86fec-eb87-49f0-9be0-8d99b00a4a9c"}}, + setup: func(t *testing.T) *project.Service { + t.Helper() + repo, userService, suserService, relationService, policyService, authnService, groupService, roleService := mockService(t) + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService, roleService) + }, + wantErr: true, + wantErrIs: project.ErrInvalidPrincipalType, }, { name: "returns projects from the membership shim", @@ -462,6 +476,9 @@ func TestService_List_WithPrincipal(t *testing.T) { t.Errorf("List() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Errorf("List() error = %v, want errors.Is(%v)", err, tt.wantErrIs) + } if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("List() mismatch (-want +got):\n%s", diff) } diff --git a/core/serviceuser/service.go b/core/serviceuser/service.go index 0ae9212ee..bebeb7e08 100644 --- a/core/serviceuser/service.go +++ b/core/serviceuser/service.go @@ -112,6 +112,9 @@ func (s Service) Create(ctx context.Context, serviceUser ServiceUser) (ServiceUs } func (s Service) Get(ctx context.Context, id string) (ServiceUser, error) { + if !utils.IsValidUUID(id) { + return ServiceUser{}, ErrInvalidID + } return s.repo.GetByID(ctx, id) } diff --git a/core/serviceuser/service_test.go b/core/serviceuser/service_test.go index 86d2cc289..d65f7f91a 100644 --- a/core/serviceuser/service_test.go +++ b/core/serviceuser/service_test.go @@ -113,3 +113,53 @@ func TestService_Delete(t *testing.T) { }) } } + +func TestService_Get(t *testing.T) { + ctx := context.Background() + const validID = "68f86fec-eb87-49f0-9be0-8d99b00a4a9c" + + tests := []struct { + name string + id string + setup func(*mocks.Repository) + wantErrIs error + }{ + { + name: "empty id returns ErrInvalidID without hitting the repo", + id: "", + setup: func(repo *mocks.Repository) {}, + wantErrIs: serviceuser.ErrInvalidID, + }, + { + name: "non-uuid id returns ErrInvalidID without hitting the repo", + id: "not-a-uuid", + setup: func(repo *mocks.Repository) {}, + wantErrIs: serviceuser.ErrInvalidID, + }, + { + name: "valid uuid delegates to the repo", + id: validID, + setup: func(repo *mocks.Repository) { + repo.On("GetByID", ctx, validID).Return(serviceuser.ServiceUser{ID: validID}, nil) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc, repo, _, _, _ := newTestService(t) + tt.setup(repo) + + _, err := svc.Get(ctx, tt.id) + if tt.wantErrIs != nil { + if !errors.Is(err, tt.wantErrIs) { + t.Errorf("Get() error = %v, want errors.Is(%v)", err, tt.wantErrIs) + } + return + } + if err != nil { + t.Errorf("Get() unexpected error = %v", err) + } + }) + } +} diff --git a/internal/api/v1beta1connect/permission_check.go b/internal/api/v1beta1connect/permission_check.go index c6ebe9eff..3af7e923b 100644 --- a/internal/api/v1beta1connect/permission_check.go +++ b/internal/api/v1beta1connect/permission_check.go @@ -28,24 +28,36 @@ func logAuditForCheck(ctx context.Context, result bool, objectID string, objectN } func (h *ConnectHandler) getPermissionName(ctx context.Context, ns, name string) (string, error) { + resolved, ok, err := h.resolvePermissionName(ctx, ns, name) + if err != nil { + return "", connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + if !ok { + return "", connect.NewError(connect.CodeNotFound, ErrNotFound) + } + return resolved, nil +} + +// resolvePermissionName looks up the canonical permission name for (namespace, name). +// ok=false means the permission is not defined; err is reserved for genuine lookup +// failures. Callers that want to treat an unknown permission as "no result" +// should use this helper; callers that want to reject the request should use +// getPermissionName which maps unknown permissions to CodeNotFound. +func (h *ConnectHandler) resolvePermissionName(ctx context.Context, ns, name string) (string, bool, error) { if ns == schema.PlatformNamespace && schema.IsPlatformPermission(name) { - return name, nil + return name, true, nil } perm, err := h.permissionService.Get(ctx, permission.AddNamespaceIfRequired(ns, name)) if err != nil { - switch { - case errors.Is(err, permission.ErrNotExist): - return "", connect.NewError(connect.CodeNotFound, ErrNotFound) - default: - return "", connect.NewError(connect.CodeInternal, ErrInternalServerError) + if errors.Is(err, permission.ErrNotExist) { + return "", false, nil } + return "", false, err } - // if the permission is on the same namespace as the object, use the name if perm.NamespaceID == ns { - return perm.Name, nil + return perm.Name, true, nil } - // else use fully qualified name(slug) - return perm.Slug, nil + return perm.Slug, true, nil } func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, req *connect.Request[frontierv1beta1.CheckFederatedResourcePermissionRequest]) (*connect.Response[frontierv1beta1.CheckFederatedResourcePermissionResponse], error) { @@ -94,19 +106,38 @@ func (h *ConnectHandler) CheckFederatedResourcePermission(ctx context.Context, r } func (h *ConnectHandler) fetchAccessPairsOnResource(ctx context.Context, objectNamespace string, ids, permissions []string) ([]relation.CheckPair, error) { - checks := make([]resource.Check, 0, len(ids)*len(permissions)) + // Resolve each requested permission once, dropping unknown names and + // duplicate inputs. Unknown names produce an empty result rather than + // 4xx/5xx — see the contract on resolvePermissionName. + resolvedPerms := make([]string, 0, len(permissions)) + seen := make(map[string]struct{}, len(permissions)) + for _, p := range permissions { + resolved, ok, err := h.resolvePermissionName(ctx, objectNamespace, p) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + if !ok { + continue + } + if _, dup := seen[resolved]; dup { + continue + } + seen[resolved] = struct{}{} + resolvedPerms = append(resolvedPerms, resolved) + } + if len(resolvedPerms) == 0 || len(ids) == 0 { + return []relation.CheckPair{}, nil + } + + checks := make([]resource.Check, 0, len(ids)*len(resolvedPerms)) for _, id := range ids { - for _, permission := range permissions { - permissionName, err := h.getPermissionName(ctx, objectNamespace, permission) - if err != nil { - return nil, err - } + for _, p := range resolvedPerms { checks = append(checks, resource.Check{ Object: relation.Object{ ID: id, Namespace: objectNamespace, }, - Permission: permissionName, + Permission: p, }) } } diff --git a/internal/api/v1beta1connect/serviceuser.go b/internal/api/v1beta1connect/serviceuser.go index 76daf229d..30bd3a5b3 100644 --- a/internal/api/v1beta1connect/serviceuser.go +++ b/internal/api/v1beta1connect/serviceuser.go @@ -10,7 +10,6 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" - "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/errors" @@ -99,7 +98,9 @@ func (h *ConnectHandler) GetServiceUser(ctx context.Context, request *connect.Re "service_user_id", serviceUserID) switch { - case err == serviceuser.ErrNotExist: + case errors.Is(err, serviceuser.ErrInvalidID): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + case errors.Is(err, serviceuser.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrServiceUserNotFound) default: errorLogger.LogUnexpectedError(ctx, request, "GetServiceUser", err, @@ -475,7 +476,17 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c errorLogger.LogServiceError(ctx, request, "ListServiceUserProjects", err, "service_user_id", serviceUserID, "org_id", orgID) - return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + + switch { + case errors.Is(err, project.ErrInvalidUUID), + errors.Is(err, project.ErrInvalidPrincipalType): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) + default: + errorLogger.LogUnexpectedError(ctx, request, "ListServiceUserProjects", err, + "service_user_id", serviceUserID, + "org_id", orgID) + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } } var projects []*frontierv1beta1.Project @@ -501,20 +512,19 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c "with_permissions", request.Msg.GetWithPermissions()) return nil, err } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + permsByProject := map[string][]string{} + projectOrder := make([]string, 0, len(projList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByProject[resID]; !seen { + projectOrder = append(projectOrder, resID) + } + permsByProject[resID] = append(permsByProject[resID], p.Relation.RelationName) + } + for _, resID := range projectOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ ProjectId: resID, - Permissions: permissions, + Permissions: permsByProject[resID], }) } } diff --git a/internal/api/v1beta1connect/serviceuser_test.go b/internal/api/v1beta1connect/serviceuser_test.go index 1f15a2059..88b1337d7 100644 --- a/internal/api/v1beta1connect/serviceuser_test.go +++ b/internal/api/v1beta1connect/serviceuser_test.go @@ -1296,6 +1296,56 @@ func TestHandler_DeleteServiceUserToken(t *testing.T) { } } +func TestHandler_GetServiceUser(t *testing.T) { + tests := []struct { + name string + setup func(*mocks.ServiceUserService) + request *connect.Request[frontierv1beta1.GetServiceUserRequest] + errCode connect.Code + wantErr error + }{ + { + name: "maps ErrInvalidID to InvalidArgument", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, "not-a-uuid").Return(serviceuser.ServiceUser{}, serviceuser.ErrInvalidID) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: "not-a-uuid"}), + errCode: connect.CodeInvalidArgument, + wantErr: ErrBadRequest, + }, + { + name: "maps ErrNotExist to NotFound", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, testServiceUserID).Return(serviceuser.ServiceUser{}, serviceuser.ErrNotExist) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: testServiceUserID}), + errCode: connect.CodeNotFound, + wantErr: ErrServiceUserNotFound, + }, + { + name: "maps unexpected error to Internal", + setup: func(svc *mocks.ServiceUserService) { + svc.EXPECT().Get(mock.Anything, testServiceUserID).Return(serviceuser.ServiceUser{}, errors.New("boom")) + }, + request: connect.NewRequest(&frontierv1beta1.GetServiceUserRequest{Id: testServiceUserID}), + errCode: connect.CodeInternal, + wantErr: ErrInternalServerError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &mocks.ServiceUserService{} + tt.setup(svc) + h := &ConnectHandler{serviceUserService: svc} + resp, err := h.GetServiceUser(context.Background(), tt.request) + assert.Nil(t, resp) + assert.Error(t, err) + assert.Equal(t, tt.errCode, connect.CodeOf(err)) + assert.Contains(t, err.Error(), tt.wantErr.Error()) + }) + } +} + func TestHandler_ListServiceUserProjects(t *testing.T) { testProjectMap := map[string]project.Project{ "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71": { @@ -1398,6 +1448,20 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { wantErr: ErrBadRequest, errCode: connect.CodeInvalidArgument, }, + { + name: "should return invalid argument when project service returns ErrInvalidUUID", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "not-a-uuid", + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + projSvc.EXPECT().List(mock.Anything, project.Filter{ + Principal: &authenticate.Principal{ID: "not-a-uuid", Type: schema.ServiceUserPrincipal}, + }).Return(nil, project.ErrInvalidUUID) + }, + want: nil, + wantErr: ErrBadRequest, + errCode: connect.CodeInvalidArgument, + }, { name: "should forward org_id to project.Filter when set", request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ @@ -1517,6 +1581,110 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { wantErr: nil, errCode: connect.Code(0), }, + { + name: "emits one access pair per project when multiple permissions succeed", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "1", + WithPermissions: []string{"update", "delete"}, + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + var projects []project.Project + for _, projectID := range testProjectIDList { + projects = append(projects, testProjectMap[projectID]) + } + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) + + permSvc.EXPECT().Get(mock.Anything, "app/project:update").Return( + permission.Permission{Name: "update", NamespaceID: "app/project"}, nil) + permSvc.EXPECT().Get(mock.Anything, "app/project:delete").Return( + permission.Permission{Name: "delete", NamespaceID: "app/project"}, nil) + + resourceSvc.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "update"}, + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "delete"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "update"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "delete"}, Status: true}, + }, nil) + }, + want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ + Projects: []*frontierv1beta1.Project{{ + Id: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", + Name: "prj-1", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org1.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }, { + Id: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", + Name: "prj-2", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org2.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }}, + AccessPairs: []*frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ + {ProjectId: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Permissions: []string{"update", "delete"}}, + {ProjectId: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Permissions: []string{"update", "delete"}}, + }, + }), + wantErr: nil, + errCode: connect.Code(0), + }, + { + name: "drops unknown permissions from withPermissions", + request: connect.NewRequest(&frontierv1beta1.ListServiceUserProjectsRequest{ + Id: "1", + WithPermissions: []string{"get", "bogus"}, + }), + setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { + var projects []project.Project + for _, projectID := range testProjectIDList { + projects = append(projects, testProjectMap[projectID]) + } + projSvc.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}}).Return(projects, nil) + + permSvc.EXPECT().Get(mock.Anything, "app/project:get").Return( + permission.Permission{Name: "get", NamespaceID: "app/project"}, nil) + permSvc.EXPECT().Get(mock.Anything, "app/project:bogus").Return( + permission.Permission{}, permission.ErrNotExist) + + resourceSvc.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, Permission: "get"}, + {Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, Permission: "get"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Namespace: "app/project"}, RelationName: "get"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Namespace: "app/project"}, RelationName: "get"}, Status: true}, + }, nil) + }, + want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ + Projects: []*frontierv1beta1.Project{{ + Id: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", + Name: "prj-1", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org1.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }, { + Id: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", + Name: "prj-2", + Metadata: &structpb.Struct{Fields: map[string]*structpb.Value{"email": structpb.NewStringValue("org1@org2.com")}}, + OrgId: "9f256f86-31a3-11ec-8d3d-0242ac130003", + CreatedAt: timestamppb.New(time.Time{}), + UpdatedAt: timestamppb.New(time.Time{}), + }}, + AccessPairs: []*frontierv1beta1.ListServiceUserProjectsResponse_AccessPair{ + {ProjectId: "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71", Permissions: []string{"get"}}, + {ProjectId: "c7772c63-fca4-4c7c-bf93-c8f85115de4b", Permissions: []string{"get"}}, + }, + }), + wantErr: nil, + errCode: connect.Code(0), + }, } for _, tt := range tests { diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index 45ff7e43b..5a1298e58 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -3,13 +3,12 @@ package v1beta1connect import ( "context" "fmt" + "log/slog" "net/mail" "strings" "connectrpc.com/connect" - "log/slog" - "github.com/pkg/errors" "github.com/raystack/frontier/core/audit" "github.com/raystack/frontier/core/authenticate" @@ -17,7 +16,6 @@ import ( "github.com/raystack/frontier/core/membership" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/project" - "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/internal/store/postgres" @@ -553,20 +551,19 @@ func (h *ConnectHandler) ListCurrentUserGroups(ctx context.Context, request *con "org_id", request.Msg.GetOrgId()) return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + permsByGroup := map[string][]string{} + groupOrder := make([]string, 0, len(groupsList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByGroup[resID]; !seen { + groupOrder = append(groupOrder, resID) + } + permsByGroup[resID] = append(permsByGroup[resID], p.Relation.RelationName) + } + for _, resID := range groupOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ GroupId: resID, - Permissions: permissions, + Permissions: permsByGroup[resID], }) } } @@ -876,10 +873,11 @@ func (h *ConnectHandler) ListProjectsByUser(ctx context.Context, request *connec "user_id", userID) switch { + case errors.Is(err, project.ErrInvalidUUID), + errors.Is(err, project.ErrInvalidPrincipalType): + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) case errors.Is(err, user.ErrNotExist): return nil, connect.NewError(connect.CodeNotFound, ErrNotFound) - case errors.Is(err, user.ErrInvalidUUID): - return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) default: errorLogger.LogUnexpectedError(ctx, request, "ListProjectsByUser", err, "user_id", userID) @@ -949,20 +947,22 @@ func (h *ConnectHandler) ListProjectsByCurrentUser(ctx context.Context, request "org_id", request.Msg.GetOrgId()) return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } - for _, successCheck := range successCheckPairs { - resID := successCheck.Relation.Object.ID - - // find all permission checks on same resource - pairsForCurrentGroup := utils.Filter(successCheckPairs, func(pair relation.CheckPair) bool { - return pair.Relation.Object.ID == resID - }) - // fetch permissions - permissions := utils.Map(pairsForCurrentGroup, func(pair relation.CheckPair) string { - return pair.Relation.RelationName - }) + // Group permissions by project id, emit one access pair per project in + // first-seen order. successCheckPairs is unique by (resID, permName) so + // no per-permission dedup is needed here. + permsByProject := map[string][]string{} + projectOrder := make([]string, 0, len(projList)) + for _, p := range successCheckPairs { + resID := p.Relation.Object.ID + if _, seen := permsByProject[resID]; !seen { + projectOrder = append(projectOrder, resID) + } + permsByProject[resID] = append(permsByProject[resID], p.Relation.RelationName) + } + for _, resID := range projectOrder { accessPairsPb = append(accessPairsPb, &frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ ProjectId: resID, - Permissions: permissions, + Permissions: permsByProject[resID], }) } } diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index 73e475170..e5801b888 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -10,7 +10,10 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/group" "github.com/raystack/frontier/core/organization" + "github.com/raystack/frontier/core/permission" "github.com/raystack/frontier/core/project" + "github.com/raystack/frontier/core/relation" + "github.com/raystack/frontier/core/resource" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/internal/api/v1beta1connect/mocks" @@ -1209,6 +1212,110 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { } } +func TestConnectHandler_ListCurrentUserGroups_AccessPairs(t *testing.T) { + const ( + groupA = "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71" + groupB = "c7772c63-fca4-4c7c-bf93-c8f85115de4b" + ) + principal := authenticate.Principal{ + ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", + Type: schema.UserPrincipal, + User: &user.User{ID: "9f256f86-31a3-11ec-8d3d-0242ac130003"}, + } + + resolvedPermission := func(name string) permission.Permission { + return permission.Permission{Name: name, NamespaceID: schema.GroupNamespace} + } + + tests := []struct { + title string + withPermissions []string + setup func(*mocks.PermissionService, *mocks.ResourceService) + wantAccessPairs []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair + }{ + { + title: "emits one access pair per group when multiple permissions succeed", + withPermissions: []string{"update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/group:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/group:delete").Return(resolvedPermission("delete"), nil) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ + {GroupId: groupA, Permissions: []string{"update", "delete"}}, + {GroupId: groupB, Permissions: []string{"update", "delete"}}, + }, + }, + { + title: "drops unknown permissions and returns access pairs only for the known ones", + withPermissions: []string{"update", "bogus"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/group:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/group:bogus").Return(permission.Permission{}, permission.ErrNotExist) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, Permission: "update"}, + {Object: relation.Object{ID: groupB, Namespace: schema.GroupNamespace}, Permission: "update"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: groupA, Namespace: schema.GroupNamespace}, RelationName: "update"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair{ + {GroupId: groupA, Permissions: []string{"update"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockGroupSrv := new(mocks.GroupService) + mockAuthnSrv := new(mocks.AuthnService) + mockPermissionSrv := new(mocks.PermissionService) + mockResourceSrv := new(mocks.ResourceService) + + mockAuthnSrv.EXPECT().GetPrincipal(mock.Anything).Return(principal, nil) + mockGroupSrv.EXPECT().List(mock.Anything, mock.MatchedBy(func(f group.Filter) bool { + return f.Principal != nil && *f.Principal == principal + })).Return([]group.Group{ + {ID: groupA, OrganizationID: "org-1"}, + {ID: groupB, OrganizationID: "org-1"}, + }, nil) + if tt.setup != nil { + tt.setup(mockPermissionSrv, mockResourceSrv) + } + + handler := &ConnectHandler{ + groupService: mockGroupSrv, + authnService: mockAuthnSrv, + permissionService: mockPermissionSrv, + resourceService: mockResourceSrv, + } + + req := connect.NewRequest(&frontierv1beta1.ListCurrentUserGroupsRequest{ + WithPermissions: tt.withPermissions, + }) + resp, err := handler.ListCurrentUserGroups(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.wantAccessPairs, resp.Msg.GetAccessPairs()) + + mockGroupSrv.AssertExpectations(t) + mockAuthnSrv.AssertExpectations(t) + mockPermissionSrv.AssertExpectations(t) + mockResourceSrv.AssertExpectations(t) + }) + } +} + func TestConnectHandler_ListOrganizationsByUser(t *testing.T) { userID := uuid.New().String() @@ -1634,14 +1741,23 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { err: connect.CodeNotFound, }, { - title: "should return bad request error for invalid user ID", + title: "should return bad request error when project service returns ErrInvalidUUID", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}}).Return(nil, user.ErrInvalidUUID) + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}}).Return(nil, project.ErrInvalidUUID) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "invalid-id"}, want: nil, err: connect.CodeInvalidArgument, }, + { + title: "should return bad request error when project service returns ErrInvalidPrincipalType", + setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { + ps.EXPECT().List(mock.Anything, project.Filter{Principal: &authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}}).Return(nil, project.ErrInvalidPrincipalType) + }, + req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, + want: nil, + err: connect.CodeInvalidArgument, + }, { title: "should return internal error for project service failure", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { @@ -1914,3 +2030,147 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { }) } } + +func TestConnectHandler_ListProjectsByCurrentUser_AccessPairs(t *testing.T) { + const ( + projA = "ab657ae7-8c9e-45eb-9862-dd9ceb6d5c71" + projB = "c7772c63-fca4-4c7c-bf93-c8f85115de4b" + ) + principal := authenticate.Principal{ + ID: "9f256f86-31a3-11ec-8d3d-0242ac130003", + Type: schema.UserPrincipal, + User: &user.User{ID: "9f256f86-31a3-11ec-8d3d-0242ac130003"}, + } + + resolvedPermission := func(name string) permission.Permission { + return permission.Permission{Name: name, NamespaceID: schema.ProjectNamespace} + } + + tests := []struct { + title string + withPermissions []string + setup func(*mocks.PermissionService, *mocks.ResourceService) + wantAccessPairs []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair + wantErr connect.Code + }{ + { + title: "emits one access pair per project when multiple permissions succeed", + withPermissions: []string{"update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/project:delete").Return(resolvedPermission("delete"), nil) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update", "delete"}}, + {ProjectId: projB, Permissions: []string{"update", "delete"}}, + }, + }, + { + title: "drops unknown permissions and returns access pairs only for the known ones", + withPermissions: []string{"update", "bogus"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil) + perm.EXPECT().Get(mock.Anything, "app/project:bogus").Return(permission.Permission{}, permission.ErrNotExist) + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update"}}, + }, + }, + { + title: "returns empty access pairs when every requested permission is unknown", + withPermissions: []string{"bogus1", "bogus2"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:bogus1").Return(permission.Permission{}, permission.ErrNotExist) + perm.EXPECT().Get(mock.Anything, "app/project:bogus2").Return(permission.Permission{}, permission.ErrNotExist) + // resourceService.BatchCheck must NOT be called. + }, + wantAccessPairs: nil, + }, + { + title: "deduplicates repeated permission inputs", + withPermissions: []string{"update", "update", "delete"}, + setup: func(perm *mocks.PermissionService, res *mocks.ResourceService) { + perm.EXPECT().Get(mock.Anything, "app/project:update").Return(resolvedPermission("update"), nil).Times(2) + perm.EXPECT().Get(mock.Anything, "app/project:delete").Return(resolvedPermission("delete"), nil) + // Each (project, permission) appears exactly once even though "update" was requested twice. + res.EXPECT().BatchCheck(mock.Anything, []resource.Check{ + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "update"}, + {Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, Permission: "delete"}, + }).Return([]relation.CheckPair{ + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projA, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "update"}, Status: true}, + {Relation: relation.Relation{Object: relation.Object{ID: projB, Namespace: schema.ProjectNamespace}, RelationName: "delete"}, Status: true}, + }, nil) + }, + wantAccessPairs: []*frontierv1beta1.ListProjectsByCurrentUserResponse_AccessPair{ + {ProjectId: projA, Permissions: []string{"update", "delete"}}, + {ProjectId: projB, Permissions: []string{"update", "delete"}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + mockProjectSrv := new(mocks.ProjectService) + mockAuthnSrv := new(mocks.AuthnService) + mockPermissionSrv := new(mocks.PermissionService) + mockResourceSrv := new(mocks.ResourceService) + + mockAuthnSrv.EXPECT().GetPrincipal(mock.Anything).Return(principal, nil) + mockProjectSrv.EXPECT().List(mock.Anything, mock.MatchedBy(func(f project.Filter) bool { + return f.Principal != nil && *f.Principal == principal + })).Return([]project.Project{ + {ID: projA, Organization: organization.Organization{ID: "org-1"}}, + {ID: projB, Organization: organization.Organization{ID: "org-1"}}, + }, nil) + if tt.setup != nil { + tt.setup(mockPermissionSrv, mockResourceSrv) + } + + handler := &ConnectHandler{ + projectService: mockProjectSrv, + authnService: mockAuthnSrv, + permissionService: mockPermissionSrv, + resourceService: mockResourceSrv, + } + + req := connect.NewRequest(&frontierv1beta1.ListProjectsByCurrentUserRequest{ + WithPermissions: tt.withPermissions, + }) + resp, err := handler.ListProjectsByCurrentUser(context.Background(), req) + if tt.wantErr != connect.Code(0) { + assert.Nil(t, resp) + assert.Equal(t, tt.wantErr, connect.CodeOf(err)) + return + } + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tt.wantAccessPairs, resp.Msg.GetAccessPairs()) + + mockProjectSrv.AssertExpectations(t) + mockAuthnSrv.AssertExpectations(t) + mockPermissionSrv.AssertExpectations(t) + mockResourceSrv.AssertExpectations(t) + }) + } +}