diff --git a/cmd/frontend/graphqlbackend/BUILD.bazel b/cmd/frontend/graphqlbackend/BUILD.bazel index 5fd77045177..27445a0b3bc 100644 --- a/cmd/frontend/graphqlbackend/BUILD.bazel +++ b/cmd/frontend/graphqlbackend/BUILD.bazel @@ -388,6 +388,7 @@ go_test( "access_requests_test.go", "access_tokens_test.go", "client_configuration_test.go", + "code_host_test.go", "code_hosts_test.go", "event_log_test.go", "event_logs_test.go", diff --git a/cmd/frontend/graphqlbackend/code_host.go b/cmd/frontend/graphqlbackend/code_host.go index 07d3f07d805..58ee8525c76 100644 --- a/cmd/frontend/graphqlbackend/code_host.go +++ b/cmd/frontend/graphqlbackend/code_host.go @@ -1,16 +1,34 @@ package graphqlbackend import ( + "context" + "github.com/graph-gophers/graphql-go" + "github.com/sourcegraph/sourcegraph/internal/auth" "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" ) +var errCodeHostRateLimitsMustBePositiveIntegers = errors.New("rate limit settings must be positive integers") + type codeHostResolver struct { ch *types.CodeHost db database.DB } +type SetCodeHostRateLimitsArgs struct { + Input SetCodeHostRateLimitsInput +} + +type SetCodeHostRateLimitsInput struct { + CodeHostID graphql.ID + APIQuota int32 + APIReplenishmentIntervalSeconds int32 + GitQuota int32 + GitReplenishmentIntervalSeconds int32 +} + func (r *codeHostResolver) ID() graphql.ID { return MarshalCodeHostID(r.ch.ID) } @@ -65,3 +83,39 @@ func (r *codeHostResolver) ExternalServices(args *CodeHostExternalServicesArgs) } return &externalServiceConnectionResolver{db: r.db, opt: opt}, nil } + +func (r *schemaResolver) SetCodeHostRateLimits(ctx context.Context, args SetCodeHostRateLimitsArgs) (*EmptyResponse, error) { + // Security 🚨: Code Hosts may only be updated by site admins. + if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.db); err != nil { + return nil, err + } + + input := args.Input + if input.APIQuota < 0 || input.GitQuota < 0 || input.APIReplenishmentIntervalSeconds < 0 || input.GitReplenishmentIntervalSeconds < 0 { + return nil, errCodeHostRateLimitsMustBePositiveIntegers + } + + codeHostID, err := UnmarshalCodeHostID(args.Input.CodeHostID) + if err != nil { + return nil, errors.Wrap(err, "invalid code host id") + } + + err = r.db.WithTransact(ctx, func(tx database.DB) (err error) { + codeHost, err := tx.CodeHosts().GetByID(ctx, codeHostID) + if err != nil { + return err + } + codeHost.APIRateLimitQuota = &input.APIQuota + codeHost.APIRateLimitIntervalSeconds = &input.APIReplenishmentIntervalSeconds + codeHost.GitRateLimitQuota = &input.GitQuota + codeHost.GitRateLimitIntervalSeconds = &input.GitReplenishmentIntervalSeconds + + err = tx.CodeHosts().Update(ctx, codeHost) + if err != nil { + return err + } + return nil + }) + + return &EmptyResponse{}, err +} diff --git a/cmd/frontend/graphqlbackend/code_host_test.go b/cmd/frontend/graphqlbackend/code_host_test.go new file mode 100644 index 00000000000..fb356d3a006 --- /dev/null +++ b/cmd/frontend/graphqlbackend/code_host_test.go @@ -0,0 +1,241 @@ +package graphqlbackend + +import ( + "context" + "testing" + + "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/sourcegraph/internal/auth" + "github.com/sourcegraph/sourcegraph/internal/database" + "github.com/sourcegraph/sourcegraph/internal/database/dbmocks" + "github.com/sourcegraph/sourcegraph/internal/types" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/stretchr/testify/require" +) + +func TestSchemaResolver_SetCodeHostRateLimits_NotASiteAdmin(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: false}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{}, + }) + require.NotNil(t, err) + require.Equal(t, auth.ErrMustBeSiteAdmin, err) +} + +func TestSchemaResolver_SetCodeHostRateLimits_InvalidConfigs(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + ctx := context.Background() + wantErr := errors.New("rate limit settings must be positive integers") + + tests := []struct { + name string + args SetCodeHostRateLimitsArgs + wantErr error + }{ + { + name: "Negative APIQuota", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: -1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative APIReplenishmentIntervalSeconds", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIReplenishmentIntervalSeconds: -1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative GitQuota", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + GitQuota: -1, + }, + }, + wantErr: wantErr, + }, + { + name: "Negative GitReplenishmentIntervalSeconds", + args: SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + GitReplenishmentIntervalSeconds: -1, + }, + }, + wantErr: wantErr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + _, err := r.SetCodeHostRateLimits(ctx, test.args) + require.NotNil(t, err) + require.Equal(t, errCodeHostRateLimitsMustBePositiveIntegers, err) + }) + } +} + +func TestSchemaResolver_SetCodeHostRateLimits_InvalidCodeHostID(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + return nil, wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: ""}, + }) + require.NotNil(t, err) + require.Equal(t, "invalid code host id: invalid graphql.ID", err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_GetCodeHostByIDError(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + return nil, wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, + }) + require.NotNil(t, err) + require.Equal(t, wantErr.Error(), err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_UpdateCodeHostError(t *testing.T) { + logger := logtest.Scoped(t) + db := dbmocks.NewMockDB() + r := &schemaResolver{logger: logger, db: db} + wantErr := errors.New("test error") + + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + require.Equal(t, int32(1), id) + return &types.CodeHost{ID: 1}, nil + }) + codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { + return wantErr + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + _, err := r.SetCodeHostRateLimits(context.Background(), SetCodeHostRateLimitsArgs{ + Input: SetCodeHostRateLimitsInput{CodeHostID: "Q29kZUhvc3Q6MQ=="}, + }) + require.NotNil(t, err) + require.Equal(t, wantErr.Error(), err.Error()) +} + +func TestSchemaResolver_SetCodeHostRateLimits_Success(t *testing.T) { + db := dbmocks.NewMockDB() + ctx := context.Background() + setCodeHostRateLimitsInput := SetCodeHostRateLimitsInput{ + CodeHostID: "Q29kZUhvc3Q6MQ==", + APIQuota: 1, + APIReplenishmentIntervalSeconds: 2, + GitQuota: 3, + GitReplenishmentIntervalSeconds: 4, + } + + db.WithTransactFunc.SetDefaultHook(func(ctx context.Context, f func(database.DB) error) error { + return f(db) + }) + + usersStore := dbmocks.NewMockUserStore() + usersStore.GetByCurrentAuthUserFunc.SetDefaultReturn(&types.User{SiteAdmin: true}, nil) + db.UsersFunc.SetDefaultReturn(usersStore) + + codeHostStore := dbmocks.NewMockCodeHostStore() + codeHostStore.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int32) (*types.CodeHost, error) { + require.Equal(t, int32(1), id) + return &types.CodeHost{ID: 1}, nil + }) + codeHostStore.UpdateFunc.SetDefaultHook(func(ctx context.Context, host *types.CodeHost) error { + require.Equal(t, setCodeHostRateLimitsInput.APIQuota, *(host.APIRateLimitQuota)) + require.Equal(t, setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, *(host.APIRateLimitIntervalSeconds)) + require.Equal(t, setCodeHostRateLimitsInput.GitQuota, *(host.GitRateLimitQuota)) + require.Equal(t, setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, *(host.GitRateLimitIntervalSeconds)) + return nil + }) + db.CodeHostsFunc.SetDefaultReturn(codeHostStore) + + variables := map[string]any{ + "input": map[string]any{ + "codeHostID": string(setCodeHostRateLimitsInput.CodeHostID), + "apiQuota": setCodeHostRateLimitsInput.APIQuota, + "apiReplenishmentIntervalSeconds": setCodeHostRateLimitsInput.APIReplenishmentIntervalSeconds, + "gitQuota": setCodeHostRateLimitsInput.GitQuota, + "gitReplenishmentIntervalSeconds": setCodeHostRateLimitsInput.GitReplenishmentIntervalSeconds, + }, + } + RunTest(t, &Test{ + Context: ctx, + Schema: mustParseGraphQLSchema(t, db), + Variables: variables, + Query: `mutation setCodeHostRateLimits($input:SetCodeHostRateLimitsInput!) { + setCodeHostRateLimits(input:$input) { + alwaysNil + } + }`, + ExpectedResult: `{ + "setCodeHostRateLimits": { + "alwaysNil": null + } + }`, + }) +} diff --git a/cmd/frontend/graphqlbackend/schema.graphql b/cmd/frontend/graphqlbackend/schema.graphql index bf7436df7bf..6d08bf08216 100755 --- a/cmd/frontend/graphqlbackend/schema.graphql +++ b/cmd/frontend/graphqlbackend/schema.graphql @@ -10604,6 +10604,43 @@ type PerforceChangelist { commit: GitCommit! } +extend type Mutation { + """ + Updates a code host's rate limit configurations. All rate limit values must be positive integers. + """ + setCodeHostRateLimits(input: SetCodeHostRateLimitsInput!): EmptyResponse +} + +""" +SetCodeHostRateLimitsInput represents the input for configuring rate limits for a code host. +""" +input SetCodeHostRateLimitsInput { + """ + ID of the code host for which rate limits are being set. + """ + codeHostID: ID! + + """ + The maximum number of API requests allowed per time window defined by apiReplenishmentIntervalSeconds. + """ + apiQuota: Int! + + """ + The time interval at which the apiQuota's worth of API requests are replenished. + """ + apiReplenishmentIntervalSeconds: Int! + + """ + The maximum number of Git requests allowed per time window defined by gitReplenishmentIntervalSeconds. + """ + gitQuota: Int! + + """ + The time interval at which the gitQuota's worth of Git requests are replenished. + """ + gitReplenishmentIntervalSeconds: Int! +} + extend type Query { """ List of all configured code hosts on this instance.