From 1adfbd9955cc9c164e940eaa828679ffa77ab068 Mon Sep 17 00:00:00 2001 From: Charles Muchogo <48381664+Muchogoc@users.noreply.github.com> Date: Fri, 30 Jul 2021 14:13:42 +0300 Subject: [PATCH] feat: assigning and revoking roles for users (#31) --- .../infrastructure/database/fb/firebase.go | 64 ++- .../database/fb/firebase_integration_test.go | 312 +++++++++-- .../database/fb/firebase_test.go | 114 +++- .../presentation/graph/generated/generated.go | 225 +++++++- .../presentation/graph/inputs.graphql | 2 +- .../presentation/graph/profile.graphql | 4 + .../presentation/graph/profile.resolvers.go | 18 + .../presentation/graph/types.graphql | 3 +- pkg/onboarding/repository/mock/onboarding.go | 6 + pkg/onboarding/repository/onboarding.go | 1 + pkg/onboarding/usecases/roles.go | 112 ++++ pkg/onboarding/usecases/roles_unit_test.go | 500 ++++++++++++++++++ 12 files changed, 1293 insertions(+), 68 deletions(-) diff --git a/pkg/onboarding/infrastructure/database/fb/firebase.go b/pkg/onboarding/infrastructure/database/fb/firebase.go index 3c47c12f..2fbaddb3 100644 --- a/pkg/onboarding/infrastructure/database/fb/firebase.go +++ b/pkg/onboarding/infrastructure/database/fb/firebase.go @@ -1014,6 +1014,54 @@ func (fr *Repository) UpdatePrimaryPhoneNumber( return nil } +// UpdateUserRoleIDs updates the roles for a user +func (fr Repository) UpdateUserRoleIDs(ctx context.Context, id string, roleIDs []string) error { + ctx, span := tracer.Start(ctx, "UpdateUserRoleIDs") + defer span.End() + + profile, err := fr.GetUserProfileByID(ctx, id, false) + if err != nil { + utils.RecordSpanError(span, err) + return err + } + + // Add the roles + profile.Roles = roleIDs + + query := &GetAllQuery{ + CollectionName: fr.GetUserProfileCollectionName(), + FieldName: "id", + Value: profile.ID, + Operator: "==", + } + + docs, err := fr.FirestoreClient.GetAll(ctx, query) + if err != nil { + utils.RecordSpanError(span, err) + return exceptions.InternalServerError(err) + } + + if len(docs) == 0 { + return exceptions.InternalServerError(fmt.Errorf("user profile not found")) + } + + updateCommand := &UpdateCommand{ + CollectionName: fr.GetUserProfileCollectionName(), + ID: docs[0].Ref.ID, + Data: profile, + } + + err = fr.FirestoreClient.Update(ctx, updateCommand) + if err != nil { + utils.RecordSpanError(span, err) + return exceptions.InternalServerError( + fmt.Errorf("unable to update user profile primary email address: %v", err), + ) + } + + return nil +} + // UpdatePrimaryEmailAddress the primary email addresses of the profile that matches the id // this method should be called after asserting the emailAddress is unique and not associated with another userProfile func (fr *Repository) UpdatePrimaryEmailAddress( @@ -3379,7 +3427,10 @@ func (fr *Repository) SaveCoverAutolinkingEvents( } // AddAITSessionDetails saves diallers session details in the database -func (fr *Repository) AddAITSessionDetails(ctx context.Context, input *dto.SessionDetails) (*domain.USSDLeadDetails, error) { +func (fr *Repository) AddAITSessionDetails( + ctx context.Context, + input *dto.SessionDetails, +) (*domain.USSDLeadDetails, error) { ctx, span := tracer.Start(ctx, "AddAITSessionDetails") defer span.End() @@ -3585,7 +3636,10 @@ func (fr *Repository) UpdateSessionPIN( } // GetAITDetails retrieves session details from the database -func (fr *Repository) GetAITDetails(ctx context.Context, phoneNumber string) (*domain.USSDLeadDetails, error) { +func (fr *Repository) GetAITDetails( + ctx context.Context, + phoneNumber string, +) (*domain.USSDLeadDetails, error) { ctx, span := tracer.Start(ctx, "GetAITDetails") defer span.End() @@ -3624,7 +3678,11 @@ func (fr *Repository) GetAITDetails(ctx context.Context, phoneNumber string) (*d } // UpdateAITSessionDetails updates session details using phone number -func (fr *Repository) UpdateAITSessionDetails(ctx context.Context, phoneNumber string, contactLead *domain.USSDLeadDetails) error { +func (fr *Repository) UpdateAITSessionDetails( + ctx context.Context, + phoneNumber string, + contactLead *domain.USSDLeadDetails, +) error { ctx, span := tracer.Start(ctx, "UpdateAITSessionDetails") defer span.End() diff --git a/pkg/onboarding/infrastructure/database/fb/firebase_integration_test.go b/pkg/onboarding/infrastructure/database/fb/firebase_integration_test.go index b779d785..875d1ca3 100644 --- a/pkg/onboarding/infrastructure/database/fb/firebase_integration_test.go +++ b/pkg/onboarding/infrastructure/database/fb/firebase_integration_test.go @@ -93,6 +93,7 @@ func TestMain(m *testing.M) { r.GetProfileNudgesCollectionName(), r.GetSMSCollectionName(), r.GetUSSDDataCollectionName(), + r.GetRolesCollectionName(), } for _, collection := range collections { ref := fsc.Collection(collection) @@ -304,7 +305,10 @@ func TestRemoveKYCProcessingRequest(t *testing.T) { assert.Nil(t, err) // clean up - _ = s.Signup.RemoveUserByPhoneNumber(context.Background(), interserviceclient.TestUserPhoneNumber) + _ = s.Signup.RemoveUserByPhoneNumber( + context.Background(), + interserviceclient.TestUserPhoneNumber, + ) ctx, auth, err := GetTestAuthenticatedContext(t) assert.Nil(t, err) @@ -390,7 +394,10 @@ func TestPurgeUserByPhoneNumber(t *testing.T) { s, err := InitializeTestService(context.Background()) assert.Nil(t, err) // clean up - _ = s.Signup.RemoveUserByPhoneNumber(context.Background(), interserviceclient.TestUserPhoneNumber) + _ = s.Signup.RemoveUserByPhoneNumber( + context.Background(), + interserviceclient.TestUserPhoneNumber, + ) ctx, auth, err := GetTestAuthenticatedContext(t) assert.Nil(t, err) assert.NotNil(t, auth) @@ -409,7 +416,11 @@ func TestPurgeUserByPhoneNumber(t *testing.T) { assert.Equal(t, interserviceclient.TestUserPhoneNumber, *profile.PrimaryPhone) // fetch the same profile but now using the primary phone number - profile, err = fr.GetUserProfileByPrimaryPhoneNumber(ctx, interserviceclient.TestUserPhoneNumber, false) + profile, err = fr.GetUserProfileByPrimaryPhoneNumber( + ctx, + interserviceclient.TestUserPhoneNumber, + false, + ) assert.Nil(t, err) assert.NotNil(t, profile) assert.Equal(t, interserviceclient.TestUserPhoneNumber, *profile.PrimaryPhone) @@ -424,7 +435,11 @@ func TestPurgeUserByPhoneNumber(t *testing.T) { // create an invalid user profile fakeUID := uuid.New().String() - invalidpr1, err := fr.CreateUserProfile(context.Background(), interserviceclient.TestUserPhoneNumber, fakeUID) + invalidpr1, err := fr.CreateUserProfile( + context.Background(), + interserviceclient.TestUserPhoneNumber, + fakeUID, + ) assert.Nil(t, err) assert.NotNil(t, invalidpr1) @@ -666,7 +681,8 @@ func TestRepository_GetCustomerOrSupplierProfileByProfileID(t *testing.T) { tt.args.profileID, ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetCustomerOrSupplierProfileByProfileID() error = %v, wantErr %v", + t.Errorf( + "Repository.GetCustomerOrSupplierProfileByProfileID() error = %v, wantErr %v", err, tt.wantErr, ) @@ -729,7 +745,11 @@ func TestRepository_GetCustomerProfileByID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { customerProfile, err := fr.GetCustomerProfileByID(tt.args.ctx, tt.args.id) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetCustomerProfileByID() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetCustomerProfileByID() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if serverutils.IsDebug() { @@ -762,7 +782,11 @@ func TestRepository_ExchangeRefreshTokenForIDToken(t *testing.T) { return } - user, err := fr.GenerateAuthCredentials(ctx, interserviceclient.TestUserPhoneNumber, userProfile) + user, err := fr.GenerateAuthCredentials( + ctx, + interserviceclient.TestUserPhoneNumber, + userProfile, + ) if err != nil { t.Errorf("failed to generate auth credentials: %v", err) return @@ -808,7 +832,11 @@ func TestRepository_ExchangeRefreshTokenForIDToken(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.ExchangeRefreshTokenForIDToken(tt.args.ctx, tt.args.refreshToken) if (err != nil) != tt.wantErr { - t.Errorf("Repository.ExchangeRefreshTokenForIDToken() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.ExchangeRefreshTokenForIDToken() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } @@ -820,7 +848,11 @@ func TestRepository_ExchangeRefreshTokenForIDToken(t *testing.T) { return } if auth.UID != tt.want.UID { - t.Errorf("Repository.ExchangeRefreshTokenForIDToken() = %v, want %v", got.UID, tt.want.UID) + t.Errorf( + "Repository.ExchangeRefreshTokenForIDToken() = %v, want %v", + got.UID, + tt.want.UID, + ) } } }) @@ -874,7 +906,11 @@ func TestRepository_GetUserProfileByPhoneNumber(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.GetUserProfileByPhoneNumber(tt.args.ctx, tt.args.phoneNumber, false) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetUserProfileByPhoneNumber() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetUserProfileByPhoneNumber() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && got == nil { @@ -938,9 +974,17 @@ func TestRepository_GetUserProfileByPrimaryPhoneNumber(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fr.GetUserProfileByPrimaryPhoneNumber(tt.args.ctx, tt.args.phoneNumber, false) + got, err := fr.GetUserProfileByPrimaryPhoneNumber( + tt.args.ctx, + tt.args.phoneNumber, + false, + ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetUserProfileByPrimaryPhoneNumber() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetUserProfileByPrimaryPhoneNumber() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && got == nil { @@ -1003,7 +1047,11 @@ func TestRepository_GetSupplierProfileByProfileID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.GetSupplierProfileByProfileID(tt.args.ctx, tt.args.profileID) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetSupplierProfileByProfileID() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetSupplierProfileByProfileID() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -1076,7 +1124,11 @@ func TestRepository_GetSupplierProfileByID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.GetSupplierProfileByID(tt.args.ctx, tt.args.id) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetSupplierProfileByID() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetSupplierProfileByID() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -1282,7 +1334,11 @@ func TestRepository_CheckIfPhoneNumberExists(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.CheckIfPhoneNumberExists(tt.args.ctx, tt.args.phoneNumber) if (err != nil) != tt.wantErr { - t.Errorf("Repository.CheckIfPhoneNumberExists() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.CheckIfPhoneNumberExists() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if got != tt.want { @@ -1362,7 +1418,11 @@ func TestRepository_CheckIfUsernameExists(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.CheckIfUsernameExists(tt.args.ctx, tt.args.userName) if (err != nil) != tt.wantErr { - t.Errorf("Repository.CheckIfUsernameExists() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.CheckIfUsernameExists() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if got != tt.want { @@ -1698,9 +1758,17 @@ func TestRepository_ActivateSupplierProfile(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - supp, err := fr.ActivateSupplierProfile(tt.args.ctx, tt.args.profileID, tt.args.supplier) + supp, err := fr.ActivateSupplierProfile( + tt.args.ctx, + tt.args.profileID, + tt.args.supplier, + ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.ActivateSupplierProfile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.ActivateSupplierProfile() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if supp != nil { @@ -1799,7 +1867,12 @@ func TestRepository_AddPartnerType(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := fr.AddPartnerType(tt.args.ctx, tt.args.profileID, tt.args.name, tt.args.partnerType) + got, err := fr.AddPartnerType( + tt.args.ctx, + tt.args.profileID, + tt.args.name, + tt.args.partnerType, + ) if (err != nil) != tt.wantErr { t.Errorf("Repository.AddPartnerType() error = %v, wantErr %v", err, tt.wantErr) return @@ -1882,7 +1955,11 @@ func TestRepository_RecordPostVisitSurvey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.RecordPostVisitSurvey(tt.args.ctx, tt.args.input, tt.args.UID); (err != nil) != tt.wantErr { - t.Errorf("Repository.RecordPostVisitSurvey() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.RecordPostVisitSurvey() error = %v, wantErr %v", + err, + tt.wantErr, + ) } }) } @@ -1994,18 +2071,24 @@ func TestRepository_UpdateVerifiedUIDS(t *testing.T) { { name: "Happy Case - Successfully update profile UIDs", args: args{ - ctx: ctx, - id: user.ID, - uids: []string{"f4f39af7-5b64-4c2f-91bd-42b3af315a4e", "5d46d3bd-a482-4787-9b87-3c94510c8b53"}, + ctx: ctx, + id: user.ID, + uids: []string{ + "f4f39af7-5b64-4c2f-91bd-42b3af315a4e", + "5d46d3bd-a482-4787-9b87-3c94510c8b53", + }, }, wantErr: false, }, { name: "Sad Case - Invalid ID", args: args{ - ctx: ctx, - id: "invalidid", - uids: []string{"f4f39af7-5b64-4c2f-91bd-42b3af315a4e", "5d46d3bd-a482-4787-9b87-3c94510c8b53"}, + ctx: ctx, + id: "invalidid", + uids: []string{ + "f4f39af7-5b64-4c2f-91bd-42b3af315a4e", + "5d46d3bd-a482-4787-9b87-3c94510c8b53", + }, }, wantErr: true, }, @@ -2114,7 +2197,11 @@ func TestRepository_UpdateVerifiedIdentifiers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.UpdateVerifiedIdentifiers(tt.args.ctx, tt.args.id, tt.args.identifiers); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateVerifiedIdentifiers() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateVerifiedIdentifiers() error = %v, wantErr %v", + err, + tt.wantErr, + ) } }) } @@ -2253,7 +2340,11 @@ func TestRepository_UpdateSecondaryEmailAddresses(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := fr.UpdateSecondaryEmailAddresses(tt.args.ctx, tt.args.id, tt.args.emailAddresses) if (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateSecondaryEmailAddresses() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateSecondaryEmailAddresses() error = %v, wantErr %v", + err, + tt.wantErr, + ) } }) } @@ -2321,7 +2412,11 @@ func TestRepository_UpdateSupplierProfile(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := fr.UpdateSupplierProfile(tt.args.ctx, *tt.args.data.ProfileID, tt.args.data) if (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateSupplierProfile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateSupplierProfile() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } }) @@ -2382,7 +2477,11 @@ func TestRepositoryFetchKYCProcessingRequests(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.FetchKYCProcessingRequests(tt.args.ctx) if (err != nil) != tt.wantErr { - t.Errorf("Repository.FetchKYCProcessingRequests() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.FetchKYCProcessingRequests() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -2449,7 +2548,11 @@ func TestRepository_UpdatePrimaryEmailAddress(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.UpdatePrimaryEmailAddress(tt.args.ctx, tt.args.id, tt.args.emailAddress); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdatePrimaryEmailAddress() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdatePrimaryEmailAddress() error = %v, wantErr %v", + err, + tt.wantErr, + ) } if !tt.wantErr { @@ -2518,7 +2621,11 @@ func TestRepository_UpdatePrimaryPhoneNumber(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.UpdatePrimaryPhoneNumber(tt.args.ctx, tt.args.id, tt.args.phoneNumber); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdatePrimaryPhoneNumber() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdatePrimaryPhoneNumber() error = %v, wantErr %v", + err, + tt.wantErr, + ) } if !tt.wantErr { @@ -2588,7 +2695,11 @@ func TestRepository_UpdateSecondaryPhoneNumbers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.UpdateSecondaryPhoneNumbers(tt.args.ctx, tt.args.id, tt.args.phoneNumbers); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateSecondaryPhoneNumbers() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateSecondaryPhoneNumbers() error = %v, wantErr %v", + err, + tt.wantErr, + ) } if !tt.wantErr { @@ -2789,7 +2900,11 @@ func TestRepositoryFetchKYCProcessingRequestByID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := fr.FetchKYCProcessingRequestByID(tt.args.ctx, tt.args.id) if (err != nil) != tt.wantErr { - t.Errorf("Repository.FetchKYCProcessingRequestByID() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.FetchKYCProcessingRequestByID() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if tt.wantErr { @@ -2878,7 +2993,11 @@ func TestRepositoryUpdateKYCProcessingRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := fr.UpdateKYCProcessingRequest(tt.args.ctx, tt.args.kycRequest); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateKYCProcessingRequest() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateKYCProcessingRequest() error = %v, wantErr %v", + err, + tt.wantErr, + ) } }) if tt.wantErr { @@ -2951,7 +3070,11 @@ func TestRepositoryGenerateAuthCredentialsForAnonymousUser(t *testing.T) { t.Run(tt.name, func(t *testing.T) { authResponse, err := fr.GenerateAuthCredentialsForAnonymousUser(tt.args.ctx) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GenerateAuthCredentialsForAnonymousUser() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GenerateAuthCredentialsForAnonymousUser() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } @@ -3063,9 +3186,17 @@ func TestRepositoryGenerateAuthCredentials(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - authResponse, err := fr.GenerateAuthCredentials(tt.args.ctx, tt.args.phone, tt.args.profile) + authResponse, err := fr.GenerateAuthCredentials( + tt.args.ctx, + tt.args.phone, + tt.args.profile, + ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GenerateAuthCredentials() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GenerateAuthCredentials() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr { @@ -3391,7 +3522,11 @@ func TestGetNHIFDetailsByProfileID(t *testing.T) { t.Run(tt.name, func(t *testing.T) { nhif, err := fr.GetNHIFDetailsByProfileID(tt.args.ctx, tt.args.profileID) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetNHIFDetailsByProfileID() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetNHIFDetailsByProfileID() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if tt.wantErr && nhif != nil { @@ -3463,7 +3598,11 @@ func TestUpdateCustomerProfile(t *testing.T) { t.Run(tt.name, func(t *testing.T) { customer, err := fr.UpdateCustomerProfile(tt.args.ctx, tt.args.profileID, tt.args.cus) if (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateCustomerProfile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateCustomerProfile() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if customer != nil { @@ -3544,7 +3683,11 @@ func TestRepository_PersistIncomingSMSData(t *testing.T) { t.Run(tt.name, func(t *testing.T) { err := firestoreDB.PersistIncomingSMSData(tt.args.ctx, &tt.args.input) if (err != nil) != tt.wantErr { - t.Errorf("Repository.PersistIncomingSMSData() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.PersistIncomingSMSData() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && err != nil { @@ -3634,7 +3777,11 @@ func TestRepository_AddAITSessionDetails(t *testing.T) { got, err := firestoreDB.AddAITSessionDetails(tt.args.ctx, tt.args.input) if (err != nil) != tt.wantErr { - t.Errorf("Repository.AddAITSessionDetails() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.AddAITSessionDetails() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if tt.wantErr && got != nil { @@ -3691,7 +3838,11 @@ func TestRepository_GetAITSessionDetailss(t *testing.T) { t.Run(tt.name, func(t *testing.T) { _, err := firestoreDB.GetAITSessionDetails(tt.args.ctx, tt.args.sessionID) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetAITSessionDetails() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetAITSessionDetails() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && err != nil { @@ -3856,7 +4007,11 @@ func TestRepository_UpdateSessionLevel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := firestoreDB.UpdateSessionLevel(tt.args.ctx, tt.args.sessionID, tt.args.level) + got, err := firestoreDB.UpdateSessionLevel( + tt.args.ctx, + tt.args.sessionID, + tt.args.level, + ) if (err != nil) != tt.wantErr { t.Errorf("Repository.UpdateSessionLevel() error = %v, wantErr %v", err, tt.wantErr) return @@ -3999,11 +4154,19 @@ func TestRepository_SaveCoverAutolinkingEvents_Integration_Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := firestoreDB.SaveCoverAutolinkingEvents(tt.args.ctx, tt.args.input) if (err != nil) != tt.wantErr { - t.Errorf("Repository.SaveCoverAutolinkingEvents() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.SaveCoverAutolinkingEvents() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && got == nil { - t.Errorf("Repository.SaveCoverAutolinkingEvents() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.SaveCoverAutolinkingEvents() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } }) @@ -4154,7 +4317,62 @@ func TestRepository_UpdateAITSessionDetails_Integration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := firestoreDB.UpdateAITSessionDetails(tt.args.ctx, tt.args.phoneNumber, tt.args.contactLead); (err != nil) != tt.wantErr { - t.Errorf("Repository.UpdateAITSessionDetails() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.UpdateAITSessionDetails() error = %v, wantErr %v", + err, + tt.wantErr, + ) + } + }) + } +} + +func TestRepository_UpdateUserRoleIDs_Integration(t *testing.T) { + ctx, token, err := GetTestAuthenticatedContext(t) + if err != nil { + t.Errorf("failed to get test authenticated context: %v", err) + return + } + + fsc, fbc := InitializeTestFirebaseClient(ctx) + if fsc == nil { + log.Panicf("failed to initialize test FireStore client") + } + if fbc == nil { + log.Panicf("failed to initialize test FireBase client") + } + firestoreExtension := fb.NewFirestoreClientExtension(fsc) + fr := fb.NewFirebaseRepository(firestoreExtension, fbc) + + userProfile, err := fr.GetUserProfileByUID(ctx, token.UID, false) + if err != nil { + t.Errorf("failed to get a user profile") + return + } + type args struct { + ctx context.Context + id string + roleIDs []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "pass:success update user profile IDs", + args: args{ + ctx: ctx, + id: userProfile.ID, + roleIDs: []string{uuid.NewString()}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := fr.UpdateUserRoleIDs(tt.args.ctx, tt.args.id, tt.args.roleIDs); (err != nil) != tt.wantErr { + t.Errorf("Repository.UpdateUserRoleIDs() error = %v, wantErr %v", err, tt.wantErr) } }) } diff --git a/pkg/onboarding/infrastructure/database/fb/firebase_test.go b/pkg/onboarding/infrastructure/database/fb/firebase_test.go index c93a7dee..4c380942 100644 --- a/pkg/onboarding/infrastructure/database/fb/firebase_test.go +++ b/pkg/onboarding/infrastructure/database/fb/firebase_test.go @@ -279,7 +279,9 @@ func TestRepository_AddUserAsExperimentParticipant(t *testing.T) { if tt.name == "invalid:throws_internal_server_error_while_checking_existence" { fakeFireStoreClientExt.GetAllFn = func(ctx context.Context, query *fb.GetAllQuery) ([]*firestore.DocumentSnapshot, error) { - return nil, exceptions.InternalServerError(fmt.Errorf("unable to parse user profile as firebase snapshot")) + return nil, exceptions.InternalServerError( + fmt.Errorf("unable to parse user profile as firebase snapshot"), + ) } } @@ -290,7 +292,9 @@ func TestRepository_AddUserAsExperimentParticipant(t *testing.T) { } fakeFireStoreClientExt.CreateFn = func(ctx context.Context, command *fb.CreateCommand) (*firestore.DocumentRef, error) { - return nil, exceptions.InternalServerError(fmt.Errorf("unable to add user profile of ID in experiment_participant")) + return nil, exceptions.InternalServerError( + fmt.Errorf("unable to add user profile of ID in experiment_participant"), + ) } } @@ -364,7 +368,11 @@ func TestRepository_RemoveUserAsExperimentParticipant(t *testing.T) { } if tt.name == "invalid:throws_internal_server_error_while_removing" { fakeFireStoreClientExt.DeleteFn = func(ctx context.Context, command *fb.DeleteCommand) error { - return exceptions.InternalServerError(fmt.Errorf("unable to remove user profile of ID from experiment_participant")) + return exceptions.InternalServerError( + fmt.Errorf( + "unable to remove user profile of ID from experiment_participant", + ), + ) } } @@ -686,7 +694,9 @@ func TestRepository_UpdateFavNavActions(t *testing.T) { fakeFireStoreClientExt.GetAllFn = func(ctx context.Context, query *fb.GetAllQuery) ([]*firestore.DocumentSnapshot, error) { docs := []*firestore.DocumentSnapshot{ { - Ref: &firestore.DocumentRef{ID: "c9d62c7e-93e5-44a6-b503-6fc159c1782f"}, + Ref: &firestore.DocumentRef{ + ID: "c9d62c7e-93e5-44a6-b503-6fc159c1782f", + }, CreateTime: time.Time{}, UpdateTime: time.Time{}, ReadTime: time.Time{}, @@ -808,9 +818,17 @@ func TestRepository_CreateDetailedSupplierProfile(t *testing.T) { } } - got, err := repo.CreateDetailedSupplierProfile(tt.args.ctx, tt.args.profileID, tt.args.supplier) + got, err := repo.CreateDetailedSupplierProfile( + tt.args.ctx, + tt.args.profileID, + tt.args.supplier, + ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.CreateDetailedSupplierProfile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.CreateDetailedSupplierProfile() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && got == nil { @@ -1051,9 +1069,17 @@ func TestRepository_CreateDetailedUserProfile(t *testing.T) { } } - got, err := repo.CreateDetailedUserProfile(tt.args.ctx, tt.args.phoneNumber, tt.args.profile) + got, err := repo.CreateDetailedUserProfile( + tt.args.ctx, + tt.args.phoneNumber, + tt.args.profile, + ) if (err != nil) != tt.wantErr { - t.Errorf("Repository.CreateDetailedUserProfile() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.CreateDetailedUserProfile() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !tt.wantErr && got == nil { @@ -1115,7 +1141,11 @@ func TestRepository_ListAgentUserProfiles(t *testing.T) { got, err := repo.ListUserProfiles(tt.args.ctx, tt.args.role) if (err != nil) != tt.wantErr { - t.Errorf("Repository.ListAgentUserProfiles() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.ListAgentUserProfiles() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -1200,7 +1230,11 @@ func TestRepository_AddAITSessionDetails_Unittest(t *testing.T) { got, err := repo.AddAITSessionDetails(tt.args.ctx, tt.args.input) if (err != nil) != tt.wantErr { - t.Errorf("Repository.AddAITSessionDetails() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.AddAITSessionDetails() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -1263,7 +1297,11 @@ func TestRepository_GetAITSessionDetails_Unittests(t *testing.T) { got, err := repo.GetAITSessionDetails(tt.args.ctx, tt.args.sessionID) if (err != nil) != tt.wantErr { - t.Errorf("Repository.GetAITSessionDetails() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.GetAITSessionDetails() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if !reflect.DeepEqual(got, tt.want) { @@ -1713,7 +1751,11 @@ func TestRepository_CheckIfRoleNameExists(t *testing.T) { got, err := repo.CheckIfRoleNameExists(tt.args.ctx, tt.args.name) if (err != nil) != tt.wantErr { - t.Errorf("Repository.CheckIfRoleNameExists() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf( + "Repository.CheckIfRoleNameExists() error = %v, wantErr %v", + err, + tt.wantErr, + ) return } if got != tt.want { @@ -1775,6 +1817,54 @@ func TestRepository_GetRolesByIDs(t *testing.T) { } } +func TestRepository_UpdateUserRoleIDs(t *testing.T) { + ctx := context.Background() + var fireStoreClientExt fb.FirestoreClientExtension = &fakeFireStoreClientExt + repo := fb.NewFirebaseRepository(fireStoreClientExt, fireBaseClientExt) + + type args struct { + ctx context.Context + id string + roleIDs []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "fail:cannot retrieve user profile", + args: args{ + ctx: ctx, + id: uuid.NewString(), + roleIDs: []string{uuid.NewString()}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.name == "fail:cannot retrieve user profile" { + + fakeFireStoreClientExt.GetAllFn = func(ctx context.Context, query *fb.GetAllQuery) ([]*firestore.DocumentSnapshot, error) { + ref := firestore.DocumentRef{ID: "123"} + docs := []*firestore.DocumentSnapshot{{Ref: &ref}} + return docs, nil + } + + fakeFireStoreClientExt.UpdateFn = func(ctx context.Context, command *fb.UpdateCommand) error { + return nil + } + } + + if err := repo.UpdateUserRoleIDs(tt.args.ctx, tt.args.id, tt.args.roleIDs); (err != nil) != tt.wantErr { + t.Errorf("Repository.UpdateUserRoleIDs() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestRepository_GetAllRoles(t *testing.T) { ctx := context.Background() var fireStoreClientExt fb.FirestoreClientExtension = &fakeFireStoreClientExt diff --git a/pkg/onboarding/presentation/graph/generated/generated.go b/pkg/onboarding/presentation/graph/generated/generated.go index f757ec29..4be37f0d 100644 --- a/pkg/onboarding/presentation/graph/generated/generated.go +++ b/pkg/onboarding/presentation/graph/generated/generated.go @@ -269,6 +269,7 @@ type ComplexityRoot struct { AddPermissionsToRole func(childComplexity int, input dto.RolePermissionInput) int AddSecondaryEmailAddress func(childComplexity int, email []string) int AddSecondaryPhoneNumber func(childComplexity int, phone []string) int + AssignRole func(childComplexity int, userID string, roleID string) int CompleteSignup func(childComplexity int, flavour feedlib.Flavour) int CreateRole func(childComplexity int, input dto.RoleInput) int DeactivateAgent func(childComplexity int, agentID string) int @@ -284,6 +285,7 @@ type ComplexityRoot struct { RetireKYCProcessingRequest func(childComplexity int) int RetireSecondaryEmailAddresses func(childComplexity int, emails []string) int RetireSecondaryPhoneNumbers func(childComplexity int, phones []string) int + RevokeRole func(childComplexity int, userID string, roleID string) int SaveFavoriteNavAction func(childComplexity int, title string) int SetPrimaryEmailAddress func(childComplexity int, email string, otp string) int SetPrimaryPhoneNumber func(childComplexity int, phone string, otp string) int @@ -539,6 +541,7 @@ type ComplexityRoot struct { PrimaryEmailAddress func(childComplexity int) int PrimaryPhone func(childComplexity int) int PushTokens func(childComplexity int) int + Roles func(childComplexity int) int SecondaryEmailAddresses func(childComplexity int) int SecondaryPhoneNumbers func(childComplexity int) int Suspended func(childComplexity int) int @@ -611,6 +614,8 @@ type MutationResolver interface { DeregisterAllMicroservices(ctx context.Context) (bool, error) CreateRole(ctx context.Context, input dto.RoleInput) (*dto.RoleOutput, error) AddPermissionsToRole(ctx context.Context, input dto.RolePermissionInput) (*dto.RoleOutput, error) + AssignRole(ctx context.Context, userID string, roleID string) (bool, error) + RevokeRole(ctx context.Context, userID string, roleID string) (bool, error) } type QueryResolver interface { DummyQuery(ctx context.Context) (*bool, error) @@ -1744,6 +1749,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.AddSecondaryPhoneNumber(childComplexity, args["phone"].([]string)), true + case "Mutation.assignRole": + if e.complexity.Mutation.AssignRole == nil { + break + } + + args, err := ec.field_Mutation_assignRole_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.AssignRole(childComplexity, args["userID"].(string), args["roleID"].(string)), true + case "Mutation.completeSignup": if e.complexity.Mutation.CompleteSignup == nil { break @@ -1914,6 +1931,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.RetireSecondaryPhoneNumbers(childComplexity, args["phones"].([]string)), true + case "Mutation.revokeRole": + if e.complexity.Mutation.RevokeRole == nil { + break + } + + args, err := ec.field_Mutation_revokeRole_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.RevokeRole(childComplexity, args["userID"].(string), args["roleID"].(string)), true + case "Mutation.saveFavoriteNavAction": if e.complexity.Mutation.SaveFavoriteNavAction == nil { break @@ -3280,6 +3309,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.UserProfile.PushTokens(childComplexity), true + case "UserProfile.roles": + if e.complexity.UserProfile.Roles == nil { + break + } + + return e.complexity.UserProfile.Roles(childComplexity), true + case "UserProfile.secondaryEmailAddresses": if e.complexity.UserProfile.SecondaryEmailAddresses == nil { break @@ -4053,7 +4089,7 @@ input RoleInput { } input RolePermissionInput { - roleID: String! + roleID: ID! scopes: [String!]! } `, BuiltIn: false}, @@ -4221,6 +4257,10 @@ extend type Mutation { createRole(input: RoleInput!): RoleOutput! addPermissionsToRole(input: RolePermissionInput!): RoleOutput! + + assignRole(userID: ID!, roleID: ID!): Boolean! + + revokeRole(userID: ID!, roleID: ID!): Boolean! } `, BuiltIn: false}, {Name: "pkg/onboarding/presentation/graph/types.graphql", Input: `scalar Date @@ -4267,6 +4307,7 @@ type UserProfile @key(fields: "id") { userBioData: BioData homeAddress: Address workAddress: Address + roles: [String] } type Customer { @@ -4708,7 +4749,7 @@ type Microservice @key(fields: "id") { } type RoleOutput { - id: String! + id: ID! name: String! description: String! active: Boolean! @@ -5094,6 +5135,30 @@ func (ec *executionContext) field_Mutation_addSecondaryPhoneNumber_args(ctx cont return args, nil } +func (ec *executionContext) field_Mutation_assignRole_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["userID"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("userID")) + arg0, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["userID"] = arg0 + var arg1 string + if tmp, ok := rawArgs["roleID"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roleID")) + arg1, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["roleID"] = arg1 + return args, nil +} + func (ec *executionContext) field_Mutation_completeSignup_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -5307,6 +5372,30 @@ func (ec *executionContext) field_Mutation_retireSecondaryPhoneNumbers_args(ctx return args, nil } +func (ec *executionContext) field_Mutation_revokeRole_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["userID"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("userID")) + arg0, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["userID"] = arg0 + var arg1 string + if tmp, ok := rawArgs["roleID"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roleID")) + arg1, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["roleID"] = arg1 + return args, nil +} + func (ec *executionContext) field_Mutation_saveFavoriteNavAction_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -11760,6 +11849,90 @@ func (ec *executionContext) _Mutation_addPermissionsToRole(ctx context.Context, return ec.marshalNRoleOutput2ᚖgithubᚗcomᚋsavannahghiᚋonboardingᚋpkgᚋonboardingᚋapplicationᚋdtoᚐRoleOutput(ctx, field.Selections, res) } +func (ec *executionContext) _Mutation_assignRole(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Mutation", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Mutation_assignRole_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().AssignRole(rctx, args["userID"].(string), args["roleID"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + +func (ec *executionContext) _Mutation_revokeRole(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Mutation", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Mutation_revokeRole_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + fc.Args = args + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().RevokeRole(rctx, args["userID"].(string), args["roleID"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + func (ec *executionContext) _NHIFDetails_id(ctx context.Context, field graphql.CollectedField, obj *domain.NHIFDetails) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -16056,7 +16229,7 @@ func (ec *executionContext) _RoleOutput_id(ctx context.Context, field graphql.Co } res := resTmp.(string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalNID2string(ctx, field.Selections, res) } func (ec *executionContext) _RoleOutput_name(ctx context.Context, field graphql.CollectedField, obj *dto.RoleOutput) (ret graphql.Marshaler) { @@ -17921,6 +18094,38 @@ func (ec *executionContext) _UserProfile_workAddress(ctx context.Context, field return ec.marshalOAddress2ᚖgithubᚗcomᚋsavannahghiᚋprofileutilsᚐAddress(ctx, field.Selections, res) } +func (ec *executionContext) _UserProfile_roles(ctx context.Context, field graphql.CollectedField, obj *profileutils.UserProfile) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "UserProfile", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Roles, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]string) + fc.Result = res + return ec.marshalOString2ᚕstring(ctx, field.Selections, res) +} + func (ec *executionContext) _VerifiedIdentifier_uid(ctx context.Context, field graphql.CollectedField, obj *profileutils.VerifiedIdentifier) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -20807,7 +21012,7 @@ func (ec *executionContext) unmarshalInputRolePermissionInput(ctx context.Contex var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roleID")) - it.RoleID, err = ec.unmarshalNString2string(ctx, v) + it.RoleID, err = ec.unmarshalNID2string(ctx, v) if err != nil { return it, err } @@ -22306,6 +22511,16 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) if out.Values[i] == graphql.Null { invalids++ } + case "assignRole": + out.Values[i] = ec._Mutation_assignRole(ctx, field) + if out.Values[i] == graphql.Null { + invalids++ + } + case "revokeRole": + out.Values[i] = ec._Mutation_revokeRole(ctx, field) + if out.Values[i] == graphql.Null { + invalids++ + } default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -23716,6 +23931,8 @@ func (ec *executionContext) _UserProfile(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._UserProfile_homeAddress(ctx, field, obj) case "workAddress": out.Values[i] = ec._UserProfile_workAddress(ctx, field, obj) + case "roles": + out.Values[i] = ec._UserProfile_roles(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } diff --git a/pkg/onboarding/presentation/graph/inputs.graphql b/pkg/onboarding/presentation/graph/inputs.graphql index 59e74416..8a504e04 100644 --- a/pkg/onboarding/presentation/graph/inputs.graphql +++ b/pkg/onboarding/presentation/graph/inputs.graphql @@ -332,6 +332,6 @@ input RoleInput { } input RolePermissionInput { - roleID: String! + roleID: ID! scopes: [String!]! } diff --git a/pkg/onboarding/presentation/graph/profile.graphql b/pkg/onboarding/presentation/graph/profile.graphql index 42d122b7..03cf6a50 100644 --- a/pkg/onboarding/presentation/graph/profile.graphql +++ b/pkg/onboarding/presentation/graph/profile.graphql @@ -162,4 +162,8 @@ extend type Mutation { createRole(input: RoleInput!): RoleOutput! addPermissionsToRole(input: RolePermissionInput!): RoleOutput! + + assignRole(userID: ID!, roleID: ID!): Boolean! + + revokeRole(userID: ID!, roleID: ID!): Boolean! } diff --git a/pkg/onboarding/presentation/graph/profile.resolvers.go b/pkg/onboarding/presentation/graph/profile.resolvers.go index 671df7a2..d331cd81 100644 --- a/pkg/onboarding/presentation/graph/profile.resolvers.go +++ b/pkg/onboarding/presentation/graph/profile.resolvers.go @@ -595,6 +595,24 @@ func (r *mutationResolver) AddPermissionsToRole(ctx context.Context, input dto.R return role, err } +func (r *mutationResolver) AssignRole(ctx context.Context, userID string, roleID string) (bool, error) { + startTime := time.Now() + + status, err := r.interactor.Role.AssignRole(ctx, userID, roleID) + defer serverutils.RecordGraphqlResolverMetrics(ctx, startTime, "assignRole", err) + + return status, err +} + +func (r *mutationResolver) RevokeRole(ctx context.Context, userID string, roleID string) (bool, error) { + startTime := time.Now() + + status, err := r.interactor.Role.RevokeRole(ctx, userID, roleID) + defer serverutils.RecordGraphqlResolverMetrics(ctx, startTime, "revokeRole", err) + + return status, err +} + func (r *queryResolver) DummyQuery(ctx context.Context) (*bool, error) { dummy := true return &dummy, nil diff --git a/pkg/onboarding/presentation/graph/types.graphql b/pkg/onboarding/presentation/graph/types.graphql index d503f66b..19af30b8 100644 --- a/pkg/onboarding/presentation/graph/types.graphql +++ b/pkg/onboarding/presentation/graph/types.graphql @@ -42,6 +42,7 @@ type UserProfile @key(fields: "id") { userBioData: BioData homeAddress: Address workAddress: Address + roles: [String] } type Customer { @@ -483,7 +484,7 @@ type Microservice @key(fields: "id") { } type RoleOutput { - id: String! + id: ID! name: String! description: String! active: Boolean! diff --git a/pkg/onboarding/repository/mock/onboarding.go b/pkg/onboarding/repository/mock/onboarding.go index 779786cc..40f5f228 100644 --- a/pkg/onboarding/repository/mock/onboarding.go +++ b/pkg/onboarding/repository/mock/onboarding.go @@ -162,6 +162,7 @@ type FakeOnboardingRepository struct { UpdatePrimaryEmailAddressFn func(ctx context.Context, id string, emailAddress string) error UpdateSecondaryPhoneNumbersFn func(ctx context.Context, id string, phoneNumbers []string) error UpdateSecondaryEmailAddressesFn func(ctx context.Context, id string, emailAddresses []string) error + UpdateUserRoleIDsFn func(ctx context.Context, id string, roleIDs []string) error UpdateSuspendedFn func(ctx context.Context, id string, status bool) error UpdatePhotoUploadIDFn func(ctx context.Context, id string, uploadID string) error UpdateCoversFn func(ctx context.Context, id string, covers []profileutils.Cover) error @@ -934,3 +935,8 @@ func (f *FakeOnboardingRepository) CheckIfRoleNameExists(ctx context.Context, na func (f *FakeOnboardingRepository) CheckIfUserHasPermission(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { return f.CheckIfUserHasPermissionFn(ctx, UID, requiredPermission) } + +// UpdateUserRoleIDs ... +func (f *FakeOnboardingRepository) UpdateUserRoleIDs(ctx context.Context, id string, roleIDs []string) error { + return f.UpdateUserRoleIDsFn(ctx, id, roleIDs) +} diff --git a/pkg/onboarding/repository/onboarding.go b/pkg/onboarding/repository/onboarding.go index 0cbc7835..6ebbe1a7 100644 --- a/pkg/onboarding/repository/onboarding.go +++ b/pkg/onboarding/repository/onboarding.go @@ -290,6 +290,7 @@ type UserProfileRepository interface { UpdatePushTokens(ctx context.Context, id string, pushToken []string) error UpdatePermissions(ctx context.Context, id string, perms []profileutils.PermissionType) error UpdateRole(ctx context.Context, id string, role profileutils.RoleType) error + UpdateUserRoleIDs(ctx context.Context, id string, roleIDs []string) error UpdateBioData(ctx context.Context, id string, data profileutils.BioData) error UpdateAddresses( ctx context.Context, diff --git a/pkg/onboarding/usecases/roles.go b/pkg/onboarding/usecases/roles.go index 462360da..9db56d12 100644 --- a/pkg/onboarding/usecases/roles.go +++ b/pkg/onboarding/usecases/roles.go @@ -28,6 +28,10 @@ type RoleUseCase interface { GetRole(ctx context.Context, ID string) (*dto.RoleOutput, error) GetUserPermissions(ctx context.Context, UID string) ([]profileutils.Permission, error) + + AssignRole(ctx context.Context, userID string, roleID string) (bool, error) + + RevokeRole(ctx context.Context, userID string, roleID string) (bool, error) } // RoleUseCaseImpl represents usecase implementation object @@ -353,3 +357,111 @@ func (r *RoleUseCaseImpl) GetUserPermissions( return permissions, nil } + +// AssignRole assigns a user a particular role +func (r *RoleUseCaseImpl) AssignRole( + ctx context.Context, + userID string, + roleID string, +) (bool, error) { + ctx, span := tracer.Start(ctx, "AssignRole") + defer span.End() + + role, err := r.repo.GetRoleByID(ctx, roleID) + if err != nil { + utils.RecordSpanError(span, err) + return false, err + } + + profile, err := r.repo.GetUserProfileByID(ctx, userID, false) + if err != nil { + return false, err + } + + for _, r := range profile.Roles { + // check if role exists first + if r == role.ID { + err := fmt.Errorf("role already exists: %v", role.Name) + return false, err + } + } + + updated := append(profile.Roles, roleID) + + err = r.repo.UpdateUserRoleIDs(ctx, profile.ID, updated) + if err != nil { + return false, err + } + + return true, nil +} + +// RevokeRole removes a role from the user +func (r *RoleUseCaseImpl) RevokeRole( + ctx context.Context, + userID string, + roleID string, +) (bool, error) { + ctx, span := tracer.Start(ctx, "RevokeRole") + defer span.End() + + // Check logged in user has permissions/role of employee + user, err := r.baseExt.GetLoggedInUser(ctx) + if err != nil { + utils.RecordSpanError(span, err) + return false, err + } + + // Check logged in user has the right permissions + allowed, err := r.repo.CheckIfUserHasPermission(ctx, user.UID, profileutils.CanAssignRole) + if err != nil { + utils.RecordSpanError(span, err) + return false, err + } + if !allowed { + return false, exceptions.RoleNotValid( + fmt.Errorf("error: logged in user does not have permissions to view roles"), + ) + } + + role, err := r.repo.GetRoleByID(ctx, roleID) + if err != nil { + utils.RecordSpanError(span, err) + return false, err + } + + profile, err := r.repo.GetUserProfileByID(ctx, userID, false) + if err != nil { + return false, err + } + + var exist bool + for _, r := range profile.Roles { + // check if role exists first + if r == role.ID { + exist = true + break + } + } + + if !exist { + err := fmt.Errorf("user doesn't have role: %v", role.Name) + utils.RecordSpanError(span, err) + return false, err + } + + // roles copy + updated := []string{} + for _, r := range profile.Roles { + if r != role.ID { + updated = append(updated, r) + } + } + + err = r.repo.UpdateUserRoleIDs(ctx, profile.ID, updated) + if err != nil { + return false, err + } + + return true, nil +} diff --git a/pkg/onboarding/usecases/roles_unit_test.go b/pkg/onboarding/usecases/roles_unit_test.go index 1bbcd3ff..0e146747 100644 --- a/pkg/onboarding/usecases/roles_unit_test.go +++ b/pkg/onboarding/usecases/roles_unit_test.go @@ -6,6 +6,7 @@ import ( "reflect" "testing" + "github.com/google/uuid" "github.com/savannahghi/onboarding/pkg/onboarding/application/dto" "github.com/savannahghi/profileutils" ) @@ -738,3 +739,502 @@ func TestRoleUseCaseImpl_GetRole(t *testing.T) { }) } } + +func TestRoleUseCaseImpl_AssignRole(t *testing.T) { + ctx := context.Background() + i, err := InitializeFakeOnboardingInteractor() + + if err != nil { + t.Errorf("failed to fake initialize onboarding interactor: %v", + err, + ) + return + } + + type args struct { + ctx context.Context + userID string + roleID string + } + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + name: "fail: cannot get logged in user", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: user doesn't have the permission", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: role ID doesn't exist", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "invalid id", + }, + want: false, + wantErr: true, + }, + { + name: "fail: cannot retrieve user profile", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: role already exists", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "0637333d-74b0-473d-95bd-0a03b1ae5e06", + }, + want: false, + wantErr: true, + }, + { + name: "fail: error updating user profile role", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + }, + want: false, + wantErr: true, + }, + { + name: "success: add a new role to user", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + }, + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.name == "fail: cannot get logged in user" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return nil, fmt.Errorf("cannot get logged in user") + } + + //remove + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return nil, fmt.Errorf("cannot get role ny id") + } + } + + if tt.name == "fail: user doesn't have the permission" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return false, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanRegisterAgent.Scope}, + }, nil + } + + //remove + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return nil, fmt.Errorf("cannot get role ny id") + } + } + + if tt.name == "fail: role ID doesn't exist" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return nil, fmt.Errorf("cannot get role ny id") + } + } + + if tt.name == "fail: cannot retrieve user profile" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return nil, fmt.Errorf("no user profile") + } + } + + if tt.name == "fail: role already exists" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "0637333d-74b0-473d-95bd-0a03b1ae5e06", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ + ID: "", + Roles: []string{"0637333d-74b0-473d-95bd-0a03b1ae5e06"}, + }, nil + } + } + + if tt.name == "fail: error updating user profile role" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ID: ""}, nil + } + + fakeRepo.UpdateUserRoleIDsFn = func(ctx context.Context, id string, roleIDs []string) error { + return fmt.Errorf("cannot update role ids") + } + } + + if tt.name == "success: add a new role to user" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ID: ""}, nil + } + + fakeRepo.UpdateUserRoleIDsFn = func(ctx context.Context, id string, roleIDs []string) error { + return nil + } + } + + got, err := i.Role.AssignRole(tt.args.ctx, tt.args.userID, tt.args.roleID) + if (err != nil) != tt.wantErr { + t.Errorf("RoleUseCaseImpl.AssignRole() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("RoleUseCaseImpl.AssignRole() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRoleUseCaseImpl_RevokeRole(t *testing.T) { + ctx := context.Background() + i, err := InitializeFakeOnboardingInteractor() + + if err != nil { + t.Errorf("failed to fake initialize onboarding interactor: %v", + err, + ) + return + } + + type args struct { + ctx context.Context + userID string + roleID string + } + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + name: "fail: cannot get logged in user", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: user doesn't have the permission", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: role ID doesn't exist", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "invalid id", + }, + want: false, + wantErr: true, + }, + { + name: "fail: cannot retrieve user profile", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: uuid.NewString(), + }, + want: false, + wantErr: true, + }, + { + name: "fail: user does not have the role", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "missing", + }, + want: false, + wantErr: true, + }, + { + name: "fail: error updating user profile roles", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + }, + want: false, + wantErr: true, + }, + { + name: "success: remove a role from a user", + args: args{ + ctx: ctx, + userID: uuid.NewString(), + roleID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + }, + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.name == "fail: cannot get logged in user" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return nil, fmt.Errorf("cannot get logged in user") + } + } + + if tt.name == "fail: user doesn't have the permission" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return false, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanRegisterAgent.Scope}, + }, nil + } + } + + if tt.name == "fail: role ID doesn't exist" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return nil, fmt.Errorf("cannot get role ny id") + } + } + + if tt.name == "fail: cannot retrieve user profile" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return nil, fmt.Errorf("no user profile") + } + } + + if tt.name == "fail: user does not have the role" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "", + Scopes: []string{"duplicate"}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ID: "", Roles: []string{"duplicate"}}, nil + } + } + + if tt.name == "fail: error updating user profile roles" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ + ID: "", + Roles: []string{ + "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + "56e5e987-2f02-4455-9dde-ae15162d8bce", + }, + }, nil + } + + fakeRepo.UpdateUserRoleIDsFn = func(ctx context.Context, id string, roleIDs []string) error { + return fmt.Errorf("cannot update user profile roles") + } + } + + if tt.name == "success: remove a role from a user" { + fakeBaseExt.GetLoggedInUserFn = func(ctx context.Context) (*dto.UserInfo, error) { + return &dto.UserInfo{UID: ""}, nil + } + + fakeRepo.CheckIfUserHasPermissionFn = func(ctx context.Context, UID string, requiredPermission profileutils.Permission) (bool, error) { + return true, nil + } + + fakeRepo.GetRoleByIDFn = func(ctx context.Context, roleID string) (*profileutils.Role, error) { + return &profileutils.Role{ + ID: "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + Scopes: []string{profileutils.CanAssignRole.Scope}, + }, nil + } + + fakeRepo.GetUserProfileByIDFn = func(ctx context.Context, id string, suspended bool) (*profileutils.UserProfile, error) { + return &profileutils.UserProfile{ + ID: "", + Roles: []string{ + "17e6ea18-7147-4bdb-ad0b-d9ce03a8c0ac", + "56e5e987-2f02-4455-9dde-ae15162d8bce", + }, + }, nil + } + + fakeRepo.UpdateUserRoleIDsFn = func(ctx context.Context, id string, roleIDs []string) error { + return nil + } + } + + got, err := i.Role.RevokeRole(tt.args.ctx, tt.args.userID, tt.args.roleID) + if (err != nil) != tt.wantErr { + t.Errorf("RoleUseCaseImpl.RevokeRole() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("RoleUseCaseImpl.RevokeRole() = %v, want %v", got, tt.want) + } + }) + } +}