From 4831b72b030ec040179e23aaffd470d0e32f1eea Mon Sep 17 00:00:00 2001 From: Idan Varsano Date: Tue, 22 Aug 2023 18:06:13 -0400 Subject: [PATCH 1/4] graphql mutation for configuring code host rate limit --- cmd/frontend/graphqlbackend/BUILD.bazel | 1 + cmd/frontend/graphqlbackend/code_host.go | 47 ++++ cmd/frontend/graphqlbackend/code_host_test.go | 242 ++++++++++++++++++ cmd/frontend/graphqlbackend/schema.graphql | 37 +++ 4 files changed, 327 insertions(+) create mode 100644 cmd/frontend/graphqlbackend/code_host_test.go diff --git a/cmd/frontend/graphqlbackend/BUILD.bazel b/cmd/frontend/graphqlbackend/BUILD.bazel index 5fd770451773..27445a0b3bcd 100644 --- a/cmd/frontend/graphqlbackend/BUILD.bazel +++ b/cmd/frontend/graphqlbackend/BUILD.bazel @@ -388,6 +388,7 @@ go_test( "access_requests_test.go", "access_tokens_test.go", "client_configuration_test.go", + "code_host_test.go", "code_hosts_test.go", "event_log_test.go", "event_logs_test.go", diff --git a/cmd/frontend/graphqlbackend/code_host.go b/cmd/frontend/graphqlbackend/code_host.go index 07d3f07d805b..67afa9ea8632 100644 --- a/cmd/frontend/graphqlbackend/code_host.go +++ b/cmd/frontend/graphqlbackend/code_host.go @@ -1,9 +1,13 @@ package graphqlbackend import ( + "context" + "github.com/graph-gophers/graphql-go" + "github.com/sourcegraph/sourcegraph/internal/auth" "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" ) type codeHostResolver struct { @@ -11,6 +15,18 @@ type codeHostResolver struct { db database.DB } +type SetCodeHostRateLimitsArgs struct { + Input SetCodeHostRateLimitsInput +} + +type SetCodeHostRateLimitsInput struct { + CodeHostID graphql.ID + APIQuota int32 + APIReplenishmentIntervalSeconds int32 + GitQuota int32 + GitReplenishmentIntervalSeconds int32 +} + func (r *codeHostResolver) ID() graphql.ID { return MarshalCodeHostID(r.ch.ID) } @@ -65,3 +81,34 @@ func (r *codeHostResolver) ExternalServices(args *CodeHostExternalServicesArgs) } return &externalServiceConnectionResolver{db: r.db, opt: opt}, nil } + +func (r *schemaResolver) SetCodeHostRateLimits(ctx context.Context, args SetCodeHostRateLimitsArgs) (*EmptyResponse, error) { + // Security 🚨: Code Hosts may only be updated by site admins. + if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.db); err != nil { + return nil, err + } + + if args.Input.APIQuota < 0 || args.Input.GitQuota < 0 || args.Input.APIReplenishmentIntervalSeconds < 0 || args.Input.GitReplenishmentIntervalSeconds < 0 { + return nil, errors.New("rate limit settings must be positive integers") + } + + codeHostID, err := UnmarshalCodeHostID(args.Input.CodeHostID) + if err != nil { + return nil, errors.Wrap(err, "invalid code host id") + } + codeHostIDInt32 := int32(codeHostID) + codeHost, err := r.db.CodeHosts().GetByID(ctx, codeHostIDInt32) + if err != nil { + return nil, err + } + codeHost.APIRateLimitQuota = &args.Input.APIQuota + codeHost.APIRateLimitIntervalSeconds = &args.Input.APIReplenishmentIntervalSeconds + codeHost.GitRateLimitQuota = &args.Input.GitQuota + codeHost.GitRateLimitIntervalSeconds = &args.Input.GitReplenishmentIntervalSeconds + + err = r.db.CodeHosts().Update(ctx, codeHost) + if err != nil { + return nil, err + } + return &EmptyResponse{}, err +} diff --git a/cmd/frontend/graphqlbackend/code_host_test.go b/cmd/frontend/graphqlbackend/code_host_test.go new file mode 100644 index 000000000000..b35e77a00628 --- /dev/null +++ b/cmd/frontend/graphqlbackend/code_host_test.go @@ -0,0 +1,242 @@ +package graphqlbackend + +import ( + "context" + "testing" + + "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/sourcegraph/internal/database/dbmocks" + "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/stretchr/testify/assert" +) + +func TestSchemaResolver_SetCodeHostRateLimits_NotASiteAdmin(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: false}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{}, + }) + assert.NotNil(t, err) + assert.Equal(t, "must be site admin", err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + ctx := context.Background() + wantErr := errors.New("rate limit settings must be positive integers") + + tests := []struct { + name string + args SetCodeHostRateLimitsArgs + wantErr error + }{ + { + name: "Negative APIQuota", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: -1, + APIReplenishmentIntervalSeconds: 1, + GitQuota: 1, + GitReplenishmentIntervalSeconds: 1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative APIReplenishmentIntervalSeconds", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: 1, + APIReplenishmentIntervalSeconds: -1, + GitQuota: 1, + GitReplenishmentIntervalSeconds: 1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative GitQuota", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: 1, + APIReplenishmentIntervalSeconds: 1, + GitQuota: -1, + GitReplenishmentIntervalSeconds: 1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative GitReplenishmentIntervalSeconds", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: 1, + APIReplenishmentIntervalSeconds: 1, + GitQuota: 1, + GitReplenishmentIntervalSeconds: -1, + }, + }, + wantErr: wantErr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + _, err := r.SetCodeHostRateLimits(ctx, test.args) + assert.NotNil(t, err) + assert.Equal(t, "rate limit settings must be positive integers", err.Error()) + }) + } +} + +func TestSchemaResolver_SetCodeHostRateLimits_InvalidCodeHostID(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + return nil, wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: ""}, + }) + assert.NotNil(t, err) + assert.Equal(t, "invalid code host id: invalid graphql.ID", err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_GetCodeHostByIDError(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + return nil, wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, + }) + assert.NotNil(t, err) + assert.Equal(t, wantErr.Error(), err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + assert.Equal(t, int32(1), id) + return &types.CodeHost{ID: 1}, nil + }) + codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { + return wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, + }) + assert.NotNil(t, err) + assert.Equal(t, wantErr.Error(), err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + ctx := context.Background() + setCodeHostRateLimitsInput := SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: 1, + APIReplenishmentIntervalSeconds: 2, + GitQuota: 3, + GitReplenishmentIntervalSeconds: 4, + } + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + assert.Equal(t, int32(1), id) + return &types.CodeHost{ID: 1}, nil + }) + codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { + assert.Equal(t, setCodeHostRateLimitsInput.APIQuota, *(host.APIRateLimitQuota)) + assert.Equal(t, setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, *(host.APIRateLimitIntervalSeconds)) + assert.Equal(t, setCodeHostRateLimitsInput.GitQuota, *(host.GitRateLimitQuota)) + assert.Equal(t, setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, *(host.GitRateLimitIntervalSeconds)) + return nil + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + variables := map[string]any{ + "input": map[string]any{ + "codeHostID": "Q29kZUhvc3Q6MQ==", + "apiQuota": setCodeHostRateLimitsInput.APIQuota, + "apiReplenishmentIntervalSeconds": setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, + "gitQuota": setCodeHostRateLimitsInput.GitQuota, + "gitReplenishmentIntervalSeconds": setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, + }, + } + RunTest(t, &Test{ + Context: ctx, + Schema: mustParseGraphQLSchema(t, db), + Variables: variables, + + Query: `mutation setCodeHostRateLimits($input:SetCodeHostRateLimitsInput!) { + setCodeHostRateLimits(input:$input) { + alwaysNil + } + }`, + ExpectedResult: `{ + "setCodeHostRateLimits": { + "alwaysNil": null + } + }`, + }) + _, err := r.SetCodeHostRateLimits(ctx, SetCodeHostRateLimitsArgs{ + Input: setCodeHostRateLimitsInput, + }) + assert.Nil(t, err) +} diff --git a/cmd/frontend/graphqlbackend/schema.graphql b/cmd/frontend/graphqlbackend/schema.graphql index bf7436df7bf2..6d08bf082168 100755 --- a/cmd/frontend/graphqlbackend/schema.graphql +++ b/cmd/frontend/graphqlbackend/schema.graphql @@ -10604,6 +10604,43 @@ type PerforceChangelist { commit: GitCommit! } +extend type Mutation { + """ + Updates a code host's rate limit configurations. All rate limit values must be positive integers. + """ + setCodeHostRateLimits(input: SetCodeHostRateLimitsInput!): EmptyResponse +} + +""" +SetCodeHostRateLimitsInput represents the input for configuring rate limits for a code host. +""" +input SetCodeHostRateLimitsInput { + """ + ID of the code host for which rate limits are being set. + """ + codeHostID: ID! + + """ + The maximum number of API requests allowed per time window defined by apiReplenishmentIntervalSeconds. + """ + apiQuota: Int! + + """ + The time interval at which the apiQuota's worth of API requests are replenished. + """ + apiReplenishmentIntervalSeconds: Int! + + """ + The maximum number of Git requests allowed per time window defined by gitReplenishmentIntervalSeconds. + """ + gitQuota: Int! + + """ + The time interval at which the gitQuota's worth of Git requests are replenished. + """ + gitReplenishmentIntervalSeconds: Int! +} + extend type Query { """ List of all configured code hosts on this instance. From a3ba483d95395481705ce88c6aa06ff419867159 Mon Sep 17 00:00:00 2001 From: Idan Varsano Date: Tue, 22 Aug 2023 18:50:10 -0400 Subject: [PATCH 2/4] remove unneeded casting --- cmd/frontend/graphqlbackend/code_host.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/frontend/graphqlbackend/code_host.go b/cmd/frontend/graphqlbackend/code_host.go index 67afa9ea8632..458159912c40 100644 --- a/cmd/frontend/graphqlbackend/code_host.go +++ b/cmd/frontend/graphqlbackend/code_host.go @@ -96,8 +96,8 @@ func (r *schemaResolver) SetCodeHostRateLimits(ctx context.Context, args SetCode if err != nil { return nil, errors.Wrap(err, "invalid code host id") } - codeHostIDInt32 := int32(codeHostID) - codeHost, err := r.db.CodeHosts().GetByID(ctx, codeHostIDInt32) + + codeHost, err := r.db.CodeHosts().GetByID(ctx, codeHostID) if err != nil { return nil, err } From 82f639c4684a7694a72f8634c612981315ef6ef9 Mon Sep 17 00:00:00 2001 From: Idan Varsano Date: Tue, 22 Aug 2023 18:58:33 -0400 Subject: [PATCH 3/4] cleanups --- cmd/frontend/graphqlbackend/code_host_test.go | 29 ++++--------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/cmd/frontend/graphqlbackend/code_host_test.go b/cmd/frontend/graphqlbackend/code_host_test.go index b35e77a00628..b7263716b040 100644 --- a/cmd/frontend/graphqlbackend/code_host_test.go +++ b/cmd/frontend/graphqlbackend/code_host_test.go @@ -43,11 +43,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { name: "Negative APIQuota", args: SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{ - CodeHostID: "Q29kZUhvc3Q6MQ==", - APIQuota: -1, - APIReplenishmentIntervalSeconds: 1, - GitQuota: 1, - GitReplenishmentIntervalSeconds: 1, + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: -1, }, }, wantErr: wantErr, @@ -57,10 +54,7 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { args: SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{ CodeHostID: "Q29kZUhvc3Q6MQ==", - APIQuota: 1, APIReplenishmentIntervalSeconds: -1, - GitQuota: 1, - GitReplenishmentIntervalSeconds: 1, }, }, wantErr: wantErr, @@ -69,11 +63,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { name: "Negative GitQuota", args: SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{ - CodeHostID: "Q29kZUhvc3Q6MQ==", - APIQuota: 1, - APIReplenishmentIntervalSeconds: 1, - GitQuota: -1, - GitReplenishmentIntervalSeconds: 1, + CodeHostID: "Q29kZUhvc3Q6MQ==", + GitQuota: -1, }, }, wantErr: wantErr, @@ -83,9 +74,6 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { args: SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{ CodeHostID: "Q29kZUhvc3Q6MQ==", - APIQuota: 1, - APIReplenishmentIntervalSeconds: 1, - GitQuota: 1, GitReplenishmentIntervalSeconds: -1, }, }, @@ -180,9 +168,7 @@ func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) } func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { - logger := logtest.Scoped(t) db := dbmocks.NewMockDB() - r := &schemaResolver{logger: logger, db: db} ctx := context.Background() setCodeHostRateLimitsInput := SetCodeHostRateLimitsInput{ CodeHostID: "Q29kZUhvc3Q6MQ==", @@ -212,7 +198,7 @@ func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { variables := map[string]any{ "input": map[string]any{ - "codeHostID": "Q29kZUhvc3Q6MQ==", + "codeHostID": string(setCodeHostRateLimitsInput.CodeHostID), "apiQuota": setCodeHostRateLimitsInput.APIQuota, "apiReplenishmentIntervalSeconds": setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, "gitQuota": setCodeHostRateLimitsInput.GitQuota, @@ -223,7 +209,6 @@ func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { Context: ctx, Schema: mustParseGraphQLSchema(t, db), Variables: variables, - Query: `mutation setCodeHostRateLimits($input:SetCodeHostRateLimitsInput!) { setCodeHostRateLimits(input:$input) { alwaysNil @@ -235,8 +220,4 @@ func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { } }`, }) - _, err := r.SetCodeHostRateLimits(ctx, SetCodeHostRateLimitsArgs{ - Input: setCodeHostRateLimitsInput, - }) - assert.Nil(t, err) } From 1529df04e4f0a5226278d326dce1ece4e2502fa5 Mon Sep 17 00:00:00 2001 From: Idan Varsano Date: Wed, 23 Aug 2023 08:35:03 -0400 Subject: [PATCH 4/4] address pr comments --- cmd/frontend/graphqlbackend/code_host.go | 35 ++++++++----- cmd/frontend/graphqlbackend/code_host_test.go | 52 +++++++++++++------ 2 files changed, 56 insertions(+), 31 deletions(-) diff --git a/cmd/frontend/graphqlbackend/code_host.go b/cmd/frontend/graphqlbackend/code_host.go index 458159912c40..58ee8525c76f 100644 --- a/cmd/frontend/graphqlbackend/code_host.go +++ b/cmd/frontend/graphqlbackend/code_host.go @@ -10,6 +10,8 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) +var errCodeHostRateLimitsMustBePositiveIntegers = errors.New("rate limit settings must be positive integers") + type codeHostResolver struct { ch *types.CodeHost db database.DB @@ -88,8 +90,9 @@ func (r *schemaResolver) SetCodeHostRateLimits(ctx context.Context, args SetCode return nil, err } - if args.Input.APIQuota < 0 || args.Input.GitQuota < 0 || args.Input.APIReplenishmentIntervalSeconds < 0 || args.Input.GitReplenishmentIntervalSeconds < 0 { - return nil, errors.New("rate limit settings must be positive integers") + input := args.Input + if input.APIQuota < 0 || input.GitQuota < 0 || input.APIReplenishmentIntervalSeconds < 0 || input.GitReplenishmentIntervalSeconds < 0 { + return nil, errCodeHostRateLimitsMustBePositiveIntegers } codeHostID, err := UnmarshalCodeHostID(args.Input.CodeHostID) @@ -97,18 +100,22 @@ func (r *schemaResolver) SetCodeHostRateLimits(ctx context.Context, args SetCode return nil, errors.Wrap(err, "invalid code host id") } - codeHost, err := r.db.CodeHosts().GetByID(ctx, codeHostID) - if err != nil { - return nil, err - } - codeHost.APIRateLimitQuota = &args.Input.APIQuota - codeHost.APIRateLimitIntervalSeconds = &args.Input.APIReplenishmentIntervalSeconds - codeHost.GitRateLimitQuota = &args.Input.GitQuota - codeHost.GitRateLimitIntervalSeconds = &args.Input.GitReplenishmentIntervalSeconds + err = r.db.WithTransact(ctx, func(tx database.DB) (err error) { + codeHost, err := tx.CodeHosts().GetByID(ctx, codeHostID) + if err != nil { + return err + } + codeHost.APIRateLimitQuota = &input.APIQuota + codeHost.APIRateLimitIntervalSeconds = &input.APIReplenishmentIntervalSeconds + codeHost.GitRateLimitQuota = &input.GitQuota + codeHost.GitRateLimitIntervalSeconds = &input.GitReplenishmentIntervalSeconds + + err = tx.CodeHosts().Update(ctx, codeHost) + if err != nil { + return err + } + return nil + }) - err = r.db.CodeHosts().Update(ctx, codeHost) - if err != nil { - return nil, err - } return &EmptyResponse{}, err } diff --git a/cmd/frontend/graphqlbackend/code_host_test.go b/cmd/frontend/graphqlbackend/code_host_test.go index b7263716b040..fb356d3a0066 100644 --- a/cmd/frontend/graphqlbackend/code_host_test.go +++ b/cmd/frontend/graphqlbackend/code_host_test.go @@ -5,10 +5,12 @@ import ( "testing" "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/sourcegraph/internal/auth" + "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/database/dbmocks" "github.com/sourcegraph/sourcegraph/internal/types" "github.com/sourcegraph/sourcegraph/lib/errors" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSchemaResolver_SetCodeHostRateLimits_NotASiteAdmin(t *testing.T) { @@ -23,8 +25,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_NotASiteAdmin(t *testing.T) { _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{}, }) - assert.NotNil(t, err) - assert.Equal(t, "must be site admin", err.Error()) + require.NotNil(t, err) + require.Equal(t, auth.ErrMustBeSiteAdmin, err) } func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { @@ -88,8 +90,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { db.UsersFunc.SetDefaultReturn(usersStore) _, err := r.SetCodeHostRateLimits(ctx, test.args) - assert.NotNil(t, err) - assert.Equal(t, "rate limit settings must be positive integers", err.Error()) + require.NotNil(t, err) + require.Equal(t, errCodeHostRateLimitsMustBePositiveIntegers, err) }) } } @@ -100,6 +102,10 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidCodeHostID(t *testing.T) { r := &schemaResolver{logger: logger, db: db} wantErr := errors.New("test error") + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + usersStore := dbmocks.NewMockUserStore() usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) db.UsersFunc.SetDefaultReturn(usersStore) @@ -113,8 +119,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_InvalidCodeHostID(t *testing.T) { _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{CodeHostID: ""}, }) - assert.NotNil(t, err) - assert.Equal(t, "invalid code host id: invalid graphql.ID", err.Error()) + require.NotNil(t, err) + require.Equal(t, "invalid code host id: invalid graphql.ID", err.Error()) } func TestSchemaResolver_SetCodeHostRateLimits_GetCodeHostByIDError(t *testing.T) { @@ -123,6 +129,10 @@ func TestSchemaResolver_SetCodeHostRateLimits_GetCodeHostByIDError(t *testing.T) r := &schemaResolver{logger: logger, db: db} wantErr := errors.New("test error") + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + usersStore := dbmocks.NewMockUserStore() usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) db.UsersFunc.SetDefaultReturn(usersStore) @@ -136,8 +146,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_GetCodeHostByIDError(t *testing.T) _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, }) - assert.NotNil(t, err) - assert.Equal(t, wantErr.Error(), err.Error()) + require.NotNil(t, err) + require.Equal(t, wantErr.Error(), err.Error()) } func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) { @@ -146,13 +156,17 @@ func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) r := &schemaResolver{logger: logger, db: db} wantErr := errors.New("test error") + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + usersStore := dbmocks.NewMockUserStore() usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) db.UsersFunc.SetDefaultReturn(usersStore) codeHostStore := dbmocks.NewMockCodeHostStore() codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { - assert.Equal(t, int32(1), id) + require.Equal(t, int32(1), id) return &types.CodeHost{ID: 1}, nil }) codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { @@ -163,8 +177,8 @@ func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, }) - assert.NotNil(t, err) - assert.Equal(t, wantErr.Error(), err.Error()) + require.NotNil(t, err) + require.Equal(t, wantErr.Error(), err.Error()) } func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { @@ -178,20 +192,24 @@ func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { GitReplenishmentIntervalSeconds: 4, } + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + usersStore := dbmocks.NewMockUserStore() usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) db.UsersFunc.SetDefaultReturn(usersStore) codeHostStore := dbmocks.NewMockCodeHostStore() codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { - assert.Equal(t, int32(1), id) + require.Equal(t, int32(1), id) return &types.CodeHost{ID: 1}, nil }) codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { - assert.Equal(t, setCodeHostRateLimitsInput.APIQuota, *(host.APIRateLimitQuota)) - assert.Equal(t, setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, *(host.APIRateLimitIntervalSeconds)) - assert.Equal(t, setCodeHostRateLimitsInput.GitQuota, *(host.GitRateLimitQuota)) - assert.Equal(t, setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, *(host.GitRateLimitIntervalSeconds)) + require.Equal(t, setCodeHostRateLimitsInput.APIQuota, *(host.APIRateLimitQuota)) + require.Equal(t, setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, *(host.APIRateLimitIntervalSeconds)) + require.Equal(t, setCodeHostRateLimitsInput.GitQuota, *(host.GitRateLimitQuota)) + require.Equal(t, setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, *(host.GitRateLimitIntervalSeconds)) return nil }) db.CodeHostsFunc.SetDefaultReturn(codeHostStore)