From 4ea831673fc7d9869b0462898044b168701b5c5b Mon Sep 17 00:00:00 2001 From: Michal Niewrzal Date: Tue, 9 Jan 2024 11:13:39 +0100 Subject: [PATCH] satellite/metainfo: stop using project limit cache for uploads/downloads To reduce number of requests to DB we are getting all project limits values with api key lookup. All limits are retrieved only once and cached together. Later we can remove project limit cache completely. Change-Id: Ib2fd1c290e63949885182f24d125229ad4a7537c --- satellite/accounting/projectlimitcache.go | 1 + satellite/accounting/projectusage.go | 149 ++++++++++--------- satellite/accounting/projectusage_test.go | 16 +- satellite/admin.go | 4 +- satellite/api.go | 4 +- satellite/console/apikeys.go | 9 +- satellite/metainfo/endpoint_object.go | 24 +-- satellite/metainfo/endpoint_segment.go | 26 +--- satellite/metainfo/validation.go | 58 ++++++-- satellite/payments/stripe/accounts_test.go | 4 +- satellite/satellitedb/apikeys.go | 14 +- satellite/satellitedb/consoledb.go | 2 +- satellite/satellitedb/dbx/project.dbx | 2 +- satellite/satellitedb/dbx/satellitedb.dbx.go | 41 ++--- 14 files changed, 195 insertions(+), 159 deletions(-) diff --git a/satellite/accounting/projectlimitcache.go b/satellite/accounting/projectlimitcache.go index d5a777da30bc..70e31d0adaa0 100644 --- a/satellite/accounting/projectlimitcache.go +++ b/satellite/accounting/projectlimitcache.go @@ -39,6 +39,7 @@ type ProjectLimitConfig struct { // ProjectLimitCache stores the values for both storage usage limit and bandwidth limit for // each project ID if they differ from the default limits. +// TODO remove this cache as its not used anymore. type ProjectLimitCache struct { projectLimitDB ProjectLimitDB diff --git a/satellite/accounting/projectusage.go b/satellite/accounting/projectusage.go index 09ded6f9f835..6636153f52bd 100644 --- a/satellite/accounting/projectusage.go +++ b/satellite/accounting/projectusage.go @@ -30,23 +30,31 @@ var ErrProjectLimitExceeded = errs.Class("project limit") type Service struct { projectAccountingDB ProjectAccounting liveAccounting Cache - projectLimitCache *ProjectLimitCache metabaseDB metabase.DB bandwidthCacheTTL time.Duration nowFn func() time.Time + + defaultMaxStorage memory.Size + defaultMaxBandwidth memory.Size + defaultMaxSegments int64 asOfSystemInterval time.Duration } // NewService created new instance of project usage service. -func NewService(projectAccountingDB ProjectAccounting, liveAccounting Cache, limitCache *ProjectLimitCache, metabaseDB metabase.DB, bandwidthCacheTTL, asOfSystemInterval time.Duration) *Service { +func NewService(projectAccountingDB ProjectAccounting, liveAccounting Cache, metabaseDB metabase.DB, bandwidthCacheTTL time.Duration, + defaultMaxStorage, defaultMaxBandwidth memory.Size, defaultMaxSegments int64, asOfSystemInterval time.Duration) *Service { return &Service{ projectAccountingDB: projectAccountingDB, liveAccounting: liveAccounting, - projectLimitCache: limitCache, metabaseDB: metabaseDB, bandwidthCacheTTL: bandwidthCacheTTL, - nowFn: time.Now, - asOfSystemInterval: asOfSystemInterval, + + defaultMaxStorage: defaultMaxStorage, + defaultMaxBandwidth: defaultMaxBandwidth, + defaultMaxSegments: defaultMaxSegments, + + asOfSystemInterval: asOfSystemInterval, + nowFn: time.Now, } } @@ -57,48 +65,33 @@ func NewService(projectAccountingDB ProjectAccounting, liveAccounting Cache, lim // Among others,it can return one of the following errors returned by // storj.io/storj/satellite/accounting.Cache except the ErrKeyNotFound, wrapped // by ErrProjectUsage. -func (usage *Service) ExceedsBandwidthUsage(ctx context.Context, projectID uuid.UUID) (_ bool, limit memory.Size, err error) { +func (usage *Service) ExceedsBandwidthUsage(ctx context.Context, projectID uuid.UUID, limits ProjectLimits) (_ bool, limit memory.Size, err error) { defer mon.Task()(&ctx)(&err) - var ( - group errgroup.Group - bandwidthUsage int64 - ) + limit = usage.defaultMaxBandwidth + if limits.Bandwidth != nil { + limit = memory.Size(*limits.Bandwidth) + } - group.Go(func() error { - var err error - limit, err = usage.projectLimitCache.GetBandwidthLimit(ctx, projectID) - return err - }) - group.Go(func() error { - var err error + // Get the current bandwidth usage from cache. + bandwidthUsage, err := usage.liveAccounting.GetProjectBandwidthUsage(ctx, projectID, usage.nowFn()) + if err != nil { + // Verify If the cache key was not found + if ErrKeyNotFound.Has(err) { - // Get the current bandwidth usage from cache. - bandwidthUsage, err = usage.liveAccounting.GetProjectBandwidthUsage(ctx, projectID, usage.nowFn()) - if err != nil { - // Verify If the cache key was not found - if ErrKeyNotFound.Has(err) { - - // Get current bandwidth value from database. - now := usage.nowFn() - bandwidthUsage, err = usage.GetProjectBandwidth(ctx, projectID, now.Year(), now.Month(), now.Day()) - if err != nil { - return err - } - - // Create cache key with database value. - _, err = usage.liveAccounting.InsertProjectBandwidthUsage(ctx, projectID, bandwidthUsage, usage.bandwidthCacheTTL, usage.nowFn()) - if err != nil { - return err - } + // Get current bandwidth value from database. + now := usage.nowFn() + bandwidthUsage, err = usage.GetProjectBandwidth(ctx, projectID, now.Year(), now.Month(), now.Day()) + if err != nil { + return false, 0, ErrProjectUsage.Wrap(err) } - } - return err - }) - err = group.Wait() - if err != nil { - return false, 0, ErrProjectUsage.Wrap(err) + // Create cache key with database value. + _, err = usage.liveAccounting.InsertProjectBandwidthUsage(ctx, projectID, bandwidthUsage, usage.bandwidthCacheTTL, usage.nowFn()) + if err != nil { + return false, 0, ErrProjectUsage.Wrap(err) + } + } } // Verify the bandwidth usage cache. @@ -120,22 +113,21 @@ type UploadLimit struct { // ExceedsUploadLimits returns combined checks for storage and segment limits. // Supply nonzero headroom parameters to check if there is room for a new object. func (usage *Service) ExceedsUploadLimits( - ctx context.Context, projectID uuid.UUID, storageSizeHeadroom int64, segmentCountHeadroom int64) (limit UploadLimit, err error) { + ctx context.Context, projectID uuid.UUID, storageSizeHeadroom int64, segmentCountHeadroom int64, limits ProjectLimits) (limit UploadLimit, err error) { defer mon.Task()(&ctx)(&err) - var group errgroup.Group - var segmentUsage, storageUsage int64 - - group.Go(func() error { - var err error - limits, err := usage.projectLimitCache.GetLimits(ctx, projectID) - if err != nil { - return err - } + limit.SegmentsLimit = usage.defaultMaxSegments + if limits.Segments != nil { limit.SegmentsLimit = *limits.Segments + } + + limit.StorageLimit = usage.defaultMaxStorage + if limits.Usage != nil { limit.StorageLimit = memory.Size(*limits.Usage) - return nil - }) + } + + var group errgroup.Group + var segmentUsage, storageUsage int64 group.Go(func() error { var err error @@ -166,20 +158,25 @@ func (usage *Service) ExceedsUploadLimits( // AddProjectUsageUpToLimit increases segment and storage usage up to the projects limit. // If the limit is exceeded, neither usage is increased and accounting.ErrProjectLimitExceeded is returned. -func (usage *Service) AddProjectUsageUpToLimit(ctx context.Context, projectID uuid.UUID, storage int64, segments int64) (err error) { +func (usage *Service) AddProjectUsageUpToLimit(ctx context.Context, projectID uuid.UUID, storage int64, segments int64, limits ProjectLimits) (err error) { defer mon.Task()(&ctx, projectID)(&err) - limits, err := usage.projectLimitCache.GetLimits(ctx, projectID) - if err != nil { - return err + segmentsLimit := usage.defaultMaxSegments + if limits.Segments != nil { + segmentsLimit = *limits.Segments } - err = usage.liveAccounting.AddProjectStorageUsageUpToLimit(ctx, projectID, storage, *limits.Usage) + storageLimit := usage.defaultMaxStorage + if limits.Usage != nil { + storageLimit = memory.Size(*limits.Usage) + } + + err = usage.liveAccounting.AddProjectStorageUsageUpToLimit(ctx, projectID, storage, storageLimit.Int64()) if err != nil { return err } - err = usage.liveAccounting.AddProjectSegmentUsageUpToLimit(ctx, projectID, segments, *limits.Segments) + err = usage.liveAccounting.AddProjectSegmentUsageUpToLimit(ctx, projectID, segments, segmentsLimit) if ErrProjectLimitExceeded.Has(err) { // roll back storage increase err = usage.liveAccounting.AddProjectStorageUsage(ctx, projectID, -1*storage) @@ -252,32 +249,46 @@ func (usage *Service) GetProjectBandwidth(ctx context.Context, projectID uuid.UU // GetProjectStorageLimit returns current project storage limit. func (usage *Service) GetProjectStorageLimit(ctx context.Context, projectID uuid.UUID) (_ memory.Size, err error) { defer mon.Task()(&ctx, projectID)(&err) - limits, err := usage.projectLimitCache.GetLimits(ctx, projectID) + storageLimit, err := usage.projectAccountingDB.GetProjectStorageLimit(ctx, projectID) if err != nil { return 0, ErrProjectUsage.Wrap(err) } - return memory.Size(*limits.Usage), nil + if storageLimit == nil { + return usage.defaultMaxStorage, nil + } + + return memory.Size(*storageLimit), nil } // GetProjectBandwidthLimit returns current project bandwidth limit. func (usage *Service) GetProjectBandwidthLimit(ctx context.Context, projectID uuid.UUID) (_ memory.Size, err error) { defer mon.Task()(&ctx, projectID)(&err) - return usage.projectLimitCache.GetBandwidthLimit(ctx, projectID) + bandwidthLimit, err := usage.projectAccountingDB.GetProjectBandwidthLimit(ctx, projectID) + if err != nil { + return 0, ErrProjectUsage.Wrap(err) + } + + if bandwidthLimit == nil { + return usage.defaultMaxBandwidth, nil + } + + return memory.Size(*bandwidthLimit), nil } // GetProjectSegmentLimit returns current project segment limit. func (usage *Service) GetProjectSegmentLimit(ctx context.Context, projectID uuid.UUID) (_ memory.Size, err error) { defer mon.Task()(&ctx, projectID)(&err) - return usage.projectLimitCache.GetSegmentLimit(ctx, projectID) -} + segmentLimit, err := usage.projectAccountingDB.GetProjectSegmentLimit(ctx, projectID) + if err != nil { + return 0, ErrProjectUsage.Wrap(err) + } -// UpdateProjectLimits sets new value for project's bandwidth and storage limit. -// TODO remove because it's not used. -func (usage *Service) UpdateProjectLimits(ctx context.Context, projectID uuid.UUID, limit memory.Size) (err error) { - defer mon.Task()(&ctx, projectID)(&err) + if segmentLimit == nil { + return memory.Size(usage.defaultMaxSegments), nil + } - return ErrProjectUsage.Wrap(usage.projectAccountingDB.UpdateProjectUsageLimit(ctx, projectID, limit)) + return memory.Size(*segmentLimit), nil } // GetProjectBandwidthUsage get the current bandwidth usage from cache. diff --git a/satellite/accounting/projectusage_test.go b/satellite/accounting/projectusage_test.go index 572326699fce..b399320078af 100644 --- a/satellite/accounting/projectusage_test.go +++ b/satellite/accounting/projectusage_test.go @@ -157,7 +157,7 @@ func TestProjectUsageBandwidth(t *testing.T) { projectUsage.SetNow(func() time.Time { return now }) - actualExceeded, _, err := projectUsage.ExceedsBandwidthUsage(ctx, bucket.ProjectID) + actualExceeded, _, err := projectUsage.ExceedsBandwidthUsage(ctx, bucket.ProjectID, accounting.ProjectLimits{}) require.NoError(t, err) require.Equal(t, testCase.expectedExceeded, actualExceeded) @@ -486,20 +486,22 @@ func TestProjectUsageCustomLimit(t *testing.T) { project := projects[0] // set custom usage limit for project - expectedLimit := memory.Size(memory.GiB.Int64() * 10) - err = acctDB.UpdateProjectUsageLimit(ctx, project.ID, expectedLimit) + expectedLimit := memory.GiB.Int64() * 10 + err = acctDB.UpdateProjectUsageLimit(ctx, project.ID, memory.Size(expectedLimit)) require.NoError(t, err) projectUsage := planet.Satellites[0].Accounting.ProjectUsage // Setup: add data to live accounting to exceed new limit - err = projectUsage.AddProjectStorageUsage(ctx, project.ID, expectedLimit.Int64()) + err = projectUsage.AddProjectStorageUsage(ctx, project.ID, expectedLimit) require.NoError(t, err) - limit, err := projectUsage.ExceedsUploadLimits(ctx, project.ID, 1, 1) + limit, err := projectUsage.ExceedsUploadLimits(ctx, project.ID, 1, 1, accounting.ProjectLimits{ + Usage: &expectedLimit, + }) require.NoError(t, err) require.True(t, limit.ExceedsStorage) - require.Equal(t, expectedLimit.Int64(), limit.StorageLimit.Int64()) + require.Equal(t, expectedLimit, limit.StorageLimit.Int64()) // Setup: create some bytes for the uplink to upload expectedData := testrand.Bytes(50 * memory.KiB) @@ -890,7 +892,7 @@ func TestProjectUsageBandwidthResetAfter3days(t *testing.T) { return tt.now }) - actualExceeded, _, err := projectUsage.ExceedsBandwidthUsage(ctx, bucket.ProjectID) + actualExceeded, _, err := projectUsage.ExceedsBandwidthUsage(ctx, bucket.ProjectID, accounting.ProjectLimits{}) require.NoError(t, err) require.Equal(t, tt.expectedExceeds, actualExceeded, tt.description) } diff --git a/satellite/admin.go b/satellite/admin.go index 9179a848a9e5..c1937cb843d9 100644 --- a/satellite/admin.go +++ b/satellite/admin.go @@ -240,9 +240,11 @@ func NewAdmin(log *zap.Logger, full *identity.FullIdentity, db DB, metabaseDB *m peer.Accounting.Service = accounting.NewService( peer.DB.ProjectAccounting(), peer.LiveAccounting.Cache, - peer.ProjectLimits.Cache, *metabaseDB, config.LiveAccounting.BandwidthCacheTTL, + config.Console.Config.UsageLimits.Storage.Free, + config.Console.Config.UsageLimits.Bandwidth.Free, + config.Console.Config.UsageLimits.Segment.Free, config.LiveAccounting.AsOfSystemInterval, ) } diff --git a/satellite/api.go b/satellite/api.go index 8d277a7264d2..1431e33c362a 100644 --- a/satellite/api.go +++ b/satellite/api.go @@ -352,9 +352,11 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB, peer.Accounting.ProjectUsage = accounting.NewService( peer.DB.ProjectAccounting(), peer.LiveAccounting.Cache, - peer.ProjectLimits.Cache, *metabaseDB, config.LiveAccounting.BandwidthCacheTTL, + config.Console.Config.UsageLimits.Storage.Free, + config.Console.Config.UsageLimits.Bandwidth.Free, + config.Console.Config.UsageLimits.Segment.Free, config.LiveAccounting.AsOfSystemInterval, ) } diff --git a/satellite/console/apikeys.go b/satellite/console/apikeys.go index fbabdf8a84e2..feb15ed3170f 100644 --- a/satellite/console/apikeys.go +++ b/satellite/console/apikeys.go @@ -62,8 +62,13 @@ type APIKeyInfo struct { Secret []byte `json:"-"` CreatedAt time.Time `json:"createdAt"` - ProjectRateLimit *int `json:"-"` - ProjectBurstLimit *int `json:"-"` + // TODO move this closer to metainfo + ProjectRateLimit *int + ProjectBurstLimit *int + + ProjectStorageLimit *int64 + ProjectSegmentsLimit *int64 + ProjectBandwidthLimit *int64 } // APIKeyCursor holds info for api keys cursor pagination. diff --git a/satellite/metainfo/endpoint_object.go b/satellite/metainfo/endpoint_object.go index ed8b188e1279..1ef9ce8225c2 100644 --- a/satellite/metainfo/endpoint_object.go +++ b/satellite/metainfo/endpoint_object.go @@ -96,7 +96,7 @@ func (endpoint *Endpoint) BeginObject(ctx context.Context, req *pb.ObjectBeginRe return nil, rpcstatus.Errorf(rpcstatus.InvalidArgument, "key length is too big, got %v, maximum allowed is %v", objectKeyLength, endpoint.config.MaxEncryptedObjectKeyLength) } - err = endpoint.checkUploadLimits(ctx, keyInfo.ProjectID) + err = endpoint.checkUploadLimits(ctx, keyInfo) if err != nil { return nil, err } @@ -488,22 +488,8 @@ func (endpoint *Endpoint) DownloadObject(ctx context.Context, req *pb.ObjectDown return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } - if exceeded, limit, err := endpoint.projectUsage.ExceedsBandwidthUsage(ctx, keyInfo.ProjectID); err != nil { - if errs2.IsCanceled(err) { - return nil, rpcstatus.Wrap(rpcstatus.Canceled, err) - } - - endpoint.log.Error( - "Retrieving project bandwidth total failed; bandwidth limit won't be enforced", - zap.Stringer("Project ID", keyInfo.ProjectID), - zap.Error(err), - ) - } else if exceeded { - endpoint.log.Warn("Monthly bandwidth limit exceeded", - zap.Stringer("Limit", limit), - zap.Stringer("Project ID", keyInfo.ProjectID), - ) - return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") + if err := endpoint.checkDownloadLimits(ctx, keyInfo); err != nil { + return nil, err } var object metabase.Object @@ -2043,7 +2029,7 @@ func (endpoint *Endpoint) BeginCopyObject(ctx context.Context, req *pb.ObjectBeg ObjectKey: metabase.ObjectKey(req.EncryptedObjectKey), }, VerifyLimits: func(encryptedObjectSize int64, nSegments int64) error { - return endpoint.checkUploadLimitsForNewObject(ctx, keyInfo.ProjectID, encryptedObjectSize, nSegments) + return endpoint.checkUploadLimitsForNewObject(ctx, keyInfo, encryptedObjectSize, nSegments) }, }) if err != nil { @@ -2166,7 +2152,7 @@ func (endpoint *Endpoint) FinishCopyObject(ctx context.Context, req *pb.ObjectFi NewDisallowDelete: false, VerifyLimits: func(encryptedObjectSize int64, nSegments int64) error { - return endpoint.addStorageUsageUpToLimit(ctx, keyInfo.ProjectID, encryptedObjectSize, nSegments) + return endpoint.addStorageUsageUpToLimit(ctx, keyInfo, encryptedObjectSize, nSegments) }, }) if err != nil { diff --git a/satellite/metainfo/endpoint_segment.go b/satellite/metainfo/endpoint_segment.go index 2109758bed42..9386e45a5480 100644 --- a/satellite/metainfo/endpoint_segment.go +++ b/satellite/metainfo/endpoint_segment.go @@ -61,7 +61,7 @@ func (endpoint *Endpoint) beginSegment(ctx context.Context, req *pb.SegmentBegin return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "segment index must be greater then 0") } - if err := endpoint.checkUploadLimits(ctx, keyInfo.ProjectID); err != nil { + if err := endpoint.checkUploadLimits(ctx, keyInfo); err != nil { return nil, err } @@ -204,7 +204,7 @@ func (endpoint *Endpoint) RetryBeginSegmentPieces(ctx context.Context, req *pb.R pieceNumberSet[pieceNumber] = struct{}{} } - if err := endpoint.checkUploadLimits(ctx, keyInfo.ProjectID); err != nil { + if err := endpoint.checkUploadLimits(ctx, keyInfo); err != nil { return nil, err } @@ -390,7 +390,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } - if err := endpoint.checkUploadLimits(ctx, keyInfo.ProjectID); err != nil { + if err := endpoint.checkUploadLimits(ctx, keyInfo); err != nil { return nil, err } @@ -468,7 +468,7 @@ func (endpoint *Endpoint) MakeInlineSegment(ctx context.Context, req *pb.Segment return nil, rpcstatus.Error(rpcstatus.Internal, "unable to parse stream id") } - if err := endpoint.checkUploadLimits(ctx, keyInfo.ProjectID); err != nil { + if err := endpoint.checkUploadLimits(ctx, keyInfo); err != nil { return nil, err } @@ -635,22 +635,8 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo bucket := metabase.BucketLocation{ProjectID: keyInfo.ProjectID, BucketName: string(streamID.Bucket)} - if exceeded, limit, err := endpoint.projectUsage.ExceedsBandwidthUsage(ctx, keyInfo.ProjectID); err != nil { - if errs2.IsCanceled(err) { - return nil, rpcstatus.Wrap(rpcstatus.Canceled, err) - } - - endpoint.log.Error( - "Retrieving project bandwidth total failed; bandwidth limit won't be enforced", - zap.Stringer("Project ID", keyInfo.ProjectID), - zap.Error(err), - ) - } else if exceeded { - endpoint.log.Warn("Monthly bandwidth limit exceeded", - zap.Stringer("Limit", limit), - zap.Stringer("Project ID", keyInfo.ProjectID), - ) - return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") + if err := endpoint.checkDownloadLimits(ctx, keyInfo); err != nil { + return nil, err } id, err := uuid.FromBytes(streamID.StreamId) diff --git a/satellite/metainfo/validation.go b/satellite/metainfo/validation.go index 12bb9a80006c..a879aea2c97f 100644 --- a/satellite/metainfo/validation.go +++ b/satellite/metainfo/validation.go @@ -387,28 +387,49 @@ func (endpoint *Endpoint) validateRemoteSegment(ctx context.Context, commitReque return nil } -func (endpoint *Endpoint) checkUploadLimits(ctx context.Context, projectID uuid.UUID) error { - return endpoint.checkUploadLimitsForNewObject(ctx, projectID, 1, 1) +func (endpoint *Endpoint) checkDownloadLimits(ctx context.Context, keyInfo *console.APIKeyInfo) error { + if exceeded, limit, err := endpoint.projectUsage.ExceedsBandwidthUsage(ctx, keyInfo.ProjectID, keyInfoToLimits(keyInfo)); err != nil { + if errs2.IsCanceled(err) { + return rpcstatus.Wrap(rpcstatus.Canceled, err) + } + + endpoint.log.Error( + "Retrieving project bandwidth total failed; bandwidth limit won't be enforced", + zap.Stringer("Project ID", keyInfo.ProjectID), + zap.Error(err), + ) + } else if exceeded { + endpoint.log.Warn("Monthly bandwidth limit exceeded", + zap.Stringer("Limit", limit), + zap.Stringer("Project ID", keyInfo.ProjectID), + ) + return rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") + } + return nil +} + +func (endpoint *Endpoint) checkUploadLimits(ctx context.Context, keyInfo *console.APIKeyInfo) error { + return endpoint.checkUploadLimitsForNewObject(ctx, keyInfo, 1, 1) } func (endpoint *Endpoint) checkUploadLimitsForNewObject( - ctx context.Context, projectID uuid.UUID, newObjectSize int64, newObjectSegmentCount int64, + ctx context.Context, keyInfo *console.APIKeyInfo, newObjectSize int64, newObjectSegmentCount int64, ) error { - if limit, err := endpoint.projectUsage.ExceedsUploadLimits(ctx, projectID, newObjectSize, newObjectSegmentCount); err != nil { + if limit, err := endpoint.projectUsage.ExceedsUploadLimits(ctx, keyInfo.ProjectID, newObjectSize, newObjectSegmentCount, keyInfoToLimits(keyInfo)); err != nil { if errs2.IsCanceled(err) { return rpcstatus.Wrap(rpcstatus.Canceled, err) } endpoint.log.Error( "Retrieving project upload limit failed; limit won't be enforced", - zap.Stringer("Project ID", projectID), + zap.Stringer("Project ID", keyInfo.ProjectID), zap.Error(err), ) } else { if limit.ExceedsSegments { endpoint.log.Warn("Segment limit exceeded", zap.String("Limit", strconv.Itoa(int(limit.SegmentsLimit))), - zap.Stringer("Project ID", projectID), + zap.Stringer("Project ID", keyInfo.ProjectID), ) return rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Segments Limit") } @@ -416,7 +437,7 @@ func (endpoint *Endpoint) checkUploadLimitsForNewObject( if limit.ExceedsStorage { endpoint.log.Warn("Storage limit exceeded", zap.String("Limit", strconv.Itoa(limit.StorageLimit.Int())), - zap.Stringer("Project ID", projectID), + zap.Stringer("Project ID", keyInfo.ProjectID), ) return rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Storage Limit") } @@ -463,13 +484,13 @@ func (endpoint *Endpoint) addToUploadLimits(ctx context.Context, projectID uuid. return nil } -func (endpoint *Endpoint) addStorageUsageUpToLimit(ctx context.Context, projectID uuid.UUID, storage int64, segments int64) (err error) { - err = endpoint.projectUsage.AddProjectUsageUpToLimit(ctx, projectID, storage, segments) +func (endpoint *Endpoint) addStorageUsageUpToLimit(ctx context.Context, keyInfo *console.APIKeyInfo, storage int64, segments int64) (err error) { + err = endpoint.projectUsage.AddProjectUsageUpToLimit(ctx, keyInfo.ProjectID, storage, segments, keyInfoToLimits(keyInfo)) if err != nil { if accounting.ErrProjectLimitExceeded.Has(err) { endpoint.log.Warn("Upload limit exceeded", - zap.Stringer("Project ID", projectID), + zap.Stringer("Project ID", keyInfo.ProjectID), zap.Error(err), ) return rpcstatus.Error(rpcstatus.ResourceExhausted, err.Error()) @@ -481,7 +502,7 @@ func (endpoint *Endpoint) addStorageUsageUpToLimit(ctx context.Context, projectI endpoint.log.Error( "Updating project upload limits failed; limits won't be enforced", - zap.Stringer("Project ID", projectID), + zap.Stringer("Project ID", keyInfo.ProjectID), zap.Error(err), ) } @@ -523,3 +544,18 @@ func (endpoint *Endpoint) checkObjectUploadRate(ctx context.Context, projectID u return nil } + +func keyInfoToLimits(keyInfo *console.APIKeyInfo) accounting.ProjectLimits { + if keyInfo == nil { + return accounting.ProjectLimits{} + } + + return accounting.ProjectLimits{ + Bandwidth: keyInfo.ProjectBandwidthLimit, + Usage: keyInfo.ProjectStorageLimit, + Segments: keyInfo.ProjectSegmentsLimit, + + RateLimit: keyInfo.ProjectRateLimit, + BurstLimit: keyInfo.ProjectBurstLimit, + } +} diff --git a/satellite/payments/stripe/accounts_test.go b/satellite/payments/stripe/accounts_test.go index 059991b89a97..dc3ffbae1f63 100644 --- a/satellite/payments/stripe/accounts_test.go +++ b/satellite/payments/stripe/accounts_test.go @@ -39,9 +39,7 @@ func TestSignupCouponCodes(t *testing.T) { cache, err := live.OpenCache(ctx, log.Named("cache"), live.Config{StorageBackend: "redis://" + redis.Addr() + "?db=0"}) require.NoError(t, err) - projectLimitCache := accounting.NewProjectLimitCache(db.ProjectAccounting(), 0, 0, 0, accounting.ProjectLimitConfig{CacheCapacity: 100}) - - projectUsage := accounting.NewService(db.ProjectAccounting(), cache, projectLimitCache, *sat.API.Metainfo.Metabase, 5*time.Minute, -10*time.Second) + projectUsage := accounting.NewService(db.ProjectAccounting(), cache, *sat.API.Metainfo.Metabase, 5*time.Minute, 0, 0, 0, -10*time.Second) pc := paymentsconfig.Config{ UsagePrice: paymentsconfig.ProjectUsagePrice{ diff --git a/satellite/satellitedb/apikeys.go b/satellite/satellitedb/apikeys.go index a7840b3f6db1..cfe6817478e7 100644 --- a/satellite/satellitedb/apikeys.go +++ b/satellite/satellitedb/apikeys.go @@ -21,7 +21,7 @@ var _ console.APIKeys = (*apikeys)(nil) // apikeys is an implementation of satellite.APIKeys. type apikeys struct { methods dbx.Methods - lru *lrucache.ExpiringLRUOf[*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row] + lru *lrucache.ExpiringLRUOf[*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row] db *satelliteDB } @@ -139,13 +139,13 @@ func (keys *apikeys) Get(ctx context.Context, id uuid.UUID) (_ *console.APIKeyIn func (keys *apikeys) GetByHead(ctx context.Context, head []byte) (_ *console.APIKeyInfo, err error) { defer mon.Task()(&ctx)(&err) - dbKey, err := keys.lru.Get(ctx, string(head), func() (*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row, error) { - return keys.methods.Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_By_ApiKey_Head(ctx, dbx.ApiKey_Head(head)) + dbKey, err := keys.lru.Get(ctx, string(head), func() (*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row, error) { + return keys.methods.Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_By_ApiKey_Head(ctx, dbx.ApiKey_Head(head)) }) if err != nil { return nil, err } - return fromDBXApiKeyProjectPublicIdProjectRateLimitProjectBurstLimitRow(ctx, dbKey) + return fromDBXApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row(ctx, dbKey) } // GetByNameAndProjectID implements satellite.APIKeys. @@ -291,7 +291,7 @@ func fromDBXApiKeyProjectPublicIdRow(ctx context.Context, row *dbx.ApiKey_Projec return result, nil } -func fromDBXApiKeyProjectPublicIdProjectRateLimitProjectBurstLimitRow(ctx context.Context, row *dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row) (_ *console.APIKeyInfo, err error) { +func fromDBXApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row(ctx context.Context, row *dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row) (_ *console.APIKeyInfo, err error) { defer mon.Task()(&ctx)(&err) result, err := apiKeyToAPIKeyInfo(ctx, &row.ApiKey) @@ -305,6 +305,10 @@ func fromDBXApiKeyProjectPublicIdProjectRateLimitProjectBurstLimitRow(ctx contex result.ProjectRateLimit = row.Project_RateLimit result.ProjectBurstLimit = row.Project_BurstLimit + result.ProjectBandwidthLimit = row.Project_BandwidthLimit + result.ProjectStorageLimit = row.Project_UsageLimit + result.ProjectSegmentsLimit = row.Project_SegmentLimit + return result, nil } diff --git a/satellite/satellitedb/consoledb.go b/satellite/satellitedb/consoledb.go index cd5d3c46492f..1f8c4a21d28c 100644 --- a/satellite/satellitedb/consoledb.go +++ b/satellite/satellitedb/consoledb.go @@ -58,7 +58,7 @@ func (db *ConsoleDB) APIKeys() console.APIKeys { options.Name = "satellitedb-apikeys" db.apikeys = &apikeys{ methods: db.methods, - lru: lrucache.NewOf[*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row](options), + lru: lrucache.NewOf[*dbx.ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row](options), db: db.db, } }) diff --git a/satellite/satellitedb/dbx/project.dbx b/satellite/satellitedb/dbx/project.dbx index d11a5f03fdff..f2401d17acae 100644 --- a/satellite/satellitedb/dbx/project.dbx +++ b/satellite/satellitedb/dbx/project.dbx @@ -243,7 +243,7 @@ read one ( where api_key.id = ? ) read one ( - select api_key project.public_id project.rate_limit project.burst_limit + select api_key project.public_id project.rate_limit project.burst_limit project.segment_limit project.usage_limit project.bandwidth_limit join project.id = api_key.project_id where api_key.head = ? ) diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index e6914e436be1..c58b32a48579 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -12033,11 +12033,14 @@ func (h *__sqlbundle_Hole) Render() string { // end runtime support for building sql statements // -type ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row struct { - ApiKey ApiKey - Project_PublicId []byte - Project_RateLimit *int - Project_BurstLimit *int +type ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row struct { + ApiKey ApiKey + Project_PublicId []byte + Project_RateLimit *int + Project_BurstLimit *int + Project_SegmentLimit *int64 + Project_UsageLimit *int64 + Project_BandwidthLimit *int64 } type ApiKey_Project_PublicId_Row struct { @@ -16430,12 +16433,12 @@ func (obj *pgxImpl) Get_ApiKey_Project_PublicId_By_ApiKey_Id(ctx context.Context } -func (obj *pgxImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_By_ApiKey_Head(ctx context.Context, +func (obj *pgxImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_By_ApiKey_Head(ctx context.Context, api_key_head ApiKey_Head_Field) ( - row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row, err error) { + row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row, err error) { defer mon.Task()(&ctx)(&err) - var __embed_stmt = __sqlbundle_Literal("SELECT api_keys.id, api_keys.project_id, api_keys.head, api_keys.name, api_keys.secret, api_keys.user_agent, api_keys.created_at, projects.public_id, projects.rate_limit, projects.burst_limit FROM projects JOIN api_keys ON projects.id = api_keys.project_id WHERE api_keys.head = ?") + var __embed_stmt = __sqlbundle_Literal("SELECT api_keys.id, api_keys.project_id, api_keys.head, api_keys.name, api_keys.secret, api_keys.user_agent, api_keys.created_at, projects.public_id, projects.rate_limit, projects.burst_limit, projects.segment_limit, projects.usage_limit, projects.bandwidth_limit FROM projects JOIN api_keys ON projects.id = api_keys.project_id WHERE api_keys.head = ?") var __values []interface{} __values = append(__values, api_key_head.value()) @@ -16443,10 +16446,10 @@ func (obj *pgxImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstL var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) obj.logStmt(__stmt, __values...) - row = &ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row{} - err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.ApiKey.Id, &row.ApiKey.ProjectId, &row.ApiKey.Head, &row.ApiKey.Name, &row.ApiKey.Secret, &row.ApiKey.UserAgent, &row.ApiKey.CreatedAt, &row.Project_PublicId, &row.Project_RateLimit, &row.Project_BurstLimit) + row = &ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row{} + err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.ApiKey.Id, &row.ApiKey.ProjectId, &row.ApiKey.Head, &row.ApiKey.Name, &row.ApiKey.Secret, &row.ApiKey.UserAgent, &row.ApiKey.CreatedAt, &row.Project_PublicId, &row.Project_RateLimit, &row.Project_BurstLimit, &row.Project_SegmentLimit, &row.Project_UsageLimit, &row.Project_BandwidthLimit) if err != nil { - return (*ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row)(nil), obj.makeErr(err) + return (*ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row)(nil), obj.makeErr(err) } return row, nil @@ -24857,12 +24860,12 @@ func (obj *pgxcockroachImpl) Get_ApiKey_Project_PublicId_By_ApiKey_Id(ctx contex } -func (obj *pgxcockroachImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_By_ApiKey_Head(ctx context.Context, +func (obj *pgxcockroachImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_By_ApiKey_Head(ctx context.Context, api_key_head ApiKey_Head_Field) ( - row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row, err error) { + row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row, err error) { defer mon.Task()(&ctx)(&err) - var __embed_stmt = __sqlbundle_Literal("SELECT api_keys.id, api_keys.project_id, api_keys.head, api_keys.name, api_keys.secret, api_keys.user_agent, api_keys.created_at, projects.public_id, projects.rate_limit, projects.burst_limit FROM projects JOIN api_keys ON projects.id = api_keys.project_id WHERE api_keys.head = ?") + var __embed_stmt = __sqlbundle_Literal("SELECT api_keys.id, api_keys.project_id, api_keys.head, api_keys.name, api_keys.secret, api_keys.user_agent, api_keys.created_at, projects.public_id, projects.rate_limit, projects.burst_limit, projects.segment_limit, projects.usage_limit, projects.bandwidth_limit FROM projects JOIN api_keys ON projects.id = api_keys.project_id WHERE api_keys.head = ?") var __values []interface{} __values = append(__values, api_key_head.value()) @@ -24870,10 +24873,10 @@ func (obj *pgxcockroachImpl) Get_ApiKey_Project_PublicId_Project_RateLimit_Proje var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) obj.logStmt(__stmt, __values...) - row = &ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row{} - err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.ApiKey.Id, &row.ApiKey.ProjectId, &row.ApiKey.Head, &row.ApiKey.Name, &row.ApiKey.Secret, &row.ApiKey.UserAgent, &row.ApiKey.CreatedAt, &row.Project_PublicId, &row.Project_RateLimit, &row.Project_BurstLimit) + row = &ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row{} + err = obj.queryRowContext(ctx, __stmt, __values...).Scan(&row.ApiKey.Id, &row.ApiKey.ProjectId, &row.ApiKey.Head, &row.ApiKey.Name, &row.ApiKey.Secret, &row.ApiKey.UserAgent, &row.ApiKey.CreatedAt, &row.Project_PublicId, &row.Project_RateLimit, &row.Project_BurstLimit, &row.Project_SegmentLimit, &row.Project_UsageLimit, &row.Project_BandwidthLimit) if err != nil { - return (*ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row)(nil), obj.makeErr(err) + return (*ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row)(nil), obj.makeErr(err) } return row, nil @@ -29544,9 +29547,9 @@ type Methods interface { api_key_project_id ApiKey_ProjectId_Field) ( row *ApiKey_Project_PublicId_Row, err error) - Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_By_ApiKey_Head(ctx context.Context, + Get_ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_By_ApiKey_Head(ctx context.Context, api_key_head ApiKey_Head_Field) ( - row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Row, err error) + row *ApiKey_Project_PublicId_Project_RateLimit_Project_BurstLimit_Project_SegmentLimit_Project_UsageLimit_Project_BandwidthLimit_Row, err error) Get_BillingBalance_Balance_By_UserId(ctx context.Context, billing_balance_user_id BillingBalance_UserId_Field) (