diff --git a/satellite/admin/project.go b/satellite/admin/project.go index f13c0b1d5e2e..d47e2ce1746d 100644 --- a/satellite/admin/project.go +++ b/satellite/admin/project.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "time" "github.com/gorilla/mux" "github.com/gorilla/schema" @@ -17,6 +18,7 @@ import ( "storj.io/common/storj" "storj.io/common/uuid" "storj.io/storj/satellite/console" + "storj.io/storj/satellite/payments/stripecoinpayments" ) func (server *Server) getProjectLimit(w http.ResponseWriter, r *http.Request) { @@ -259,6 +261,50 @@ func (server *Server) deleteProject(w http.ResponseWriter, r *http.Request) { return } + // do not delete projects that have usage for the current month. + year, month, _ := time.Now().UTC().Date() + firstOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, time.UTC) + + currentUsage, err := server.db.ProjectAccounting().GetProjectTotal(ctx, projectUUID, firstOfMonth, time.Now()) + if err != nil { + http.Error(w, fmt.Sprintf("unable to list project usage: %v", err), http.StatusInternalServerError) + return + } + if currentUsage.Storage > 0 || currentUsage.Egress > 0 || currentUsage.ObjectCount > 0 { + http.Error(w, "usage for current month exists", http.StatusConflict) + return + } + + // if usage of last month exist, make sure to look for billing records + lastMonthUsage, err := server.db.ProjectAccounting().GetProjectTotal(ctx, projectUUID, firstOfMonth.AddDate(0, -1, 0), firstOfMonth.AddDate(0, 0, -1)) + if err != nil { + http.Error(w, "error getting project totals", http.StatusInternalServerError) + return + } + + if lastMonthUsage.Storage > 0 || lastMonthUsage.Egress > 0 || lastMonthUsage.ObjectCount > 0 { + err := server.db.StripeCoinPayments().ProjectRecords().Check(ctx, projectUUID, firstOfMonth.AddDate(0, -1, 0), firstOfMonth.Add(-time.Hour)) + switch err { + case stripecoinpayments.ErrProjectRecordExists: + record, err := server.db.StripeCoinPayments().ProjectRecords().Get(ctx, projectUUID, firstOfMonth.AddDate(0, -1, 0), firstOfMonth.Add(-time.Hour)) + if err != nil { + http.Error(w, fmt.Sprintf("unable to get project records: %v", err), http.StatusInternalServerError) + return + } + // state = 0 means unapplied and not invoiced yet. + if record.State == 0 { + http.Error(w, "unapplied project invoice record exist", http.StatusConflict) + return + } + case nil: + http.Error(w, "usage for last month exist, but is not billed yet", http.StatusConflict) + return + default: + http.Error(w, fmt.Sprintf("unable to get project records: %v", err), http.StatusInternalServerError) + return + } + } + err = server.db.Console().Projects().Delete(ctx, projectUUID) if err != nil { http.Error(w, fmt.Sprintf("unable to delete project: %v", err), http.StatusInternalServerError) diff --git a/satellite/admin/project_test.go b/satellite/admin/project_test.go index 76c9ef4dd77c..81e7db62a362 100644 --- a/satellite/admin/project_test.go +++ b/satellite/admin/project_test.go @@ -11,6 +11,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -21,6 +22,7 @@ import ( "storj.io/common/uuid" "storj.io/storj/private/testplanet" "storj.io/storj/satellite" + "storj.io/storj/satellite/accounting" "storj.io/storj/satellite/console" ) @@ -193,6 +195,171 @@ func TestDeleteProject(t *testing.T) { }) } +func TestDeleteProjectWithUsageCurrentMonth(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, + StorageNodeCount: 0, + UplinkCount: 1, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Admin.Address = "127.0.0.1:0" + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + address := planet.Satellites[0].Admin.Admin.Listener.Addr() + projectID := planet.Uplinks[0].Projects[0].ID + + apiKeys, err := planet.Satellites[0].DB.Console().APIKeys().GetPagedByProjectID(ctx, projectID, console.APIKeyCursor{ + Page: 1, + Limit: 2, + Search: "", + }) + require.NoError(t, err) + require.Len(t, apiKeys.APIKeys, 1) + + err = planet.Satellites[0].DB.Console().APIKeys().Delete(ctx, apiKeys.APIKeys[0].ID) + require.NoError(t, err) + + accTime := time.Now().UTC().AddDate(0,0,-1) + tally := accounting.BucketStorageTally{ + BucketName: "test", + ProjectID: projectID, + IntervalStart: accTime, + ObjectCount: 1, + InlineSegmentCount: 1, + RemoteSegmentCount: 1, + InlineBytes: 10, + RemoteBytes: 640000, + MetadataSize: 2, + } + err = planet.Satellites[0].DB.ProjectAccounting().CreateStorageTally(ctx, tally) + require.NoError(t, err) + tally = accounting.BucketStorageTally{ + BucketName: "test", + ProjectID: projectID, + IntervalStart: accTime.AddDate(0,0,1), + ObjectCount: 1, + InlineSegmentCount: 1, + RemoteSegmentCount: 1, + InlineBytes: 10, + RemoteBytes: 640000, + MetadataSize: 2, + } + err = planet.Satellites[0].DB.ProjectAccounting().CreateStorageTally(ctx, tally) + require.NoError(t, err) + + inline, remote, err := planet.Satellites[0].DB.ProjectAccounting().GetStorageTotals(ctx, projectID) + require.NoError(t, err) + require.Equal(t, int64(10), inline) + require.Equal(t, int64(640000), remote) + + bw, err := planet.Satellites[0].DB.ProjectAccounting().GetAllocatedBandwidthTotal(ctx, projectID, accTime.AddDate(0,0,-1)) + require.NoError(t, err) + require.EqualValues(t, 0, bw) + + usage, err := planet.Satellites[0].DB.ProjectAccounting().GetProjectTotal(ctx, projectID, accTime.AddDate(0,0,-1), accTime.AddDate(0,0,2)) + require.NoError(t, err) + require.NotEqual(t, 0, usage.Egress) + require.NotEqual(t, float64(0), usage.Storage) + + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("http://"+address.String()+"/api/project/%s", projectID), nil) + require.NoError(t, err) + req.Header.Set("Authorization", "very-secret-token") + + response, err := http.DefaultClient.Do(req) + require.NoError(t, err) + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + require.Equal(t, "usage for current month exists\n", string(responseBody)) + require.NoError(t, response.Body.Close()) + require.Equal(t, http.StatusConflict, response.StatusCode) + }) +} + +func TestDeleteProjectWithUsagePreviousMonth(t *testing.T) { + testplanet.Run(t, testplanet.Config{ + SatelliteCount: 1, + StorageNodeCount: 0, + UplinkCount: 1, + Reconfigure: testplanet.Reconfigure{ + Satellite: func(log *zap.Logger, index int, config *satellite.Config) { + config.Admin.Address = "127.0.0.1:0" + }, + }, + }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { + address := planet.Satellites[0].Admin.Admin.Listener.Addr() + projectID := planet.Uplinks[0].Projects[0].ID + + apiKeys, err := planet.Satellites[0].DB.Console().APIKeys().GetPagedByProjectID(ctx, projectID, console.APIKeyCursor{ + Page: 1, + Limit: 2, + Search: "", + }) + require.NoError(t, err) + require.Len(t, apiKeys.APIKeys, 1) + + err = planet.Satellites[0].DB.Console().APIKeys().Delete(ctx, apiKeys.APIKeys[0].ID) + require.NoError(t, err) + + //ToDo: Improve updating of DB entries + accTime := time.Now().UTC().AddDate(0,-1,0) + tally := accounting.BucketStorageTally{ + BucketName: "test", + ProjectID: projectID, + IntervalStart: accTime, + ObjectCount: 1, + InlineSegmentCount: 1, + RemoteSegmentCount: 1, + InlineBytes: 10, + RemoteBytes: 640000, + MetadataSize: 2, + } + err = planet.Satellites[0].DB.ProjectAccounting().CreateStorageTally(ctx, tally) + require.NoError(t, err) + tally = accounting.BucketStorageTally{ + BucketName: "test", + ProjectID: projectID, + IntervalStart: accTime.AddDate(0,0,1), + ObjectCount: 1, + InlineSegmentCount: 1, + RemoteSegmentCount: 1, + InlineBytes: 10, + RemoteBytes: 640000, + MetadataSize: 2, + } + err = planet.Satellites[0].DB.ProjectAccounting().CreateStorageTally(ctx, tally) + require.NoError(t, err) + + inline, remote, err := planet.Satellites[0].DB.ProjectAccounting().GetStorageTotals(ctx, projectID) + require.NoError(t, err) + require.Equal(t, int64(10), inline) + require.Equal(t, int64(640000), remote) + + bw, err := planet.Satellites[0].DB.ProjectAccounting().GetAllocatedBandwidthTotal(ctx, projectID, accTime.AddDate(0,0,-1)) + require.NoError(t, err) + require.EqualValues(t, 0, bw) + + usage, err := planet.Satellites[0].DB.ProjectAccounting().GetProjectTotal(ctx, projectID, accTime.AddDate(0,0,-1), accTime.AddDate(0,0,2)) + require.NoError(t, err) + require.NotEqual(t, 0, usage.Egress) + require.NotEqual(t, float64(0), usage.Storage) + + + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("http://"+address.String()+"/api/project/%s", projectID), nil) + require.NoError(t, err) + req.Header.Set("Authorization", "very-secret-token") + + response, err := http.DefaultClient.Do(req) + require.NoError(t, err) + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + require.Equal(t, "usage for last month exist, but is not billed yet\n", string(responseBody)) + require.NoError(t, response.Body.Close()) + require.Equal(t, http.StatusConflict, response.StatusCode) + }) +} + func assertGet(t *testing.T, link string, expected string) { t.Helper() diff --git a/satellite/payments/stripecoinpayments/projectrecords.go b/satellite/payments/stripecoinpayments/projectrecords.go index c9716590e24e..c82d67c6fba8 100644 --- a/satellite/payments/stripecoinpayments/projectrecords.go +++ b/satellite/payments/stripecoinpayments/projectrecords.go @@ -47,6 +47,7 @@ type ProjectRecord struct { Objects float64 PeriodStart time.Time PeriodEnd time.Time + State int } // ProjectRecordsPage holds project records and diff --git a/satellite/satellitedb/invoiceprojectrecords.go b/satellite/satellitedb/invoiceprojectrecords.go index 9a8aada1e0fc..888a5370eb1c 100644 --- a/satellite/satellitedb/invoiceprojectrecords.go +++ b/satellite/satellitedb/invoiceprojectrecords.go @@ -208,5 +208,6 @@ func fromDBXInvoiceProjectRecord(dbxRecord *dbx.StripecoinpaymentsInvoiceProject Objects: float64(dbxRecord.Objects), PeriodStart: dbxRecord.PeriodStart, PeriodEnd: dbxRecord.PeriodEnd, + State: dbxRecord.State, }, nil } diff --git a/satellite/satellitedb/projectaccounting.go b/satellite/satellitedb/projectaccounting.go index fe188c4003cb..b9f571490eb4 100644 --- a/satellite/satellitedb/projectaccounting.go +++ b/satellite/satellitedb/projectaccounting.go @@ -23,12 +23,12 @@ import ( // ensure that ProjectAccounting implements accounting.ProjectAccounting. var _ accounting.ProjectAccounting = (*ProjectAccounting)(nil) -// ProjectAccounting implements the accounting/db ProjectAccounting interface +// ProjectAccounting implements the accounting/db ProjectAccounting interface. type ProjectAccounting struct { db *satelliteDB } -// SaveTallies saves the latest bucket info +// SaveTallies saves the latest bucket info. func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time.Time, bucketTallies map[string]*accounting.BucketTally) (err error) { defer mon.Task()(&ctx)(&err) if len(bucketTallies) == 0 { @@ -69,7 +69,7 @@ func (db *ProjectAccounting) SaveTallies(ctx context.Context, intervalStart time return Error.Wrap(err) } -// GetTallies saves the latest bucket info +// GetTallies saves the latest bucket info. func (db *ProjectAccounting) GetTallies(ctx context.Context) (tallies []accounting.BucketTally, err error) { defer mon.Task()(&ctx)(&err) @@ -99,7 +99,7 @@ func (db *ProjectAccounting) GetTallies(ctx context.Context) (tallies []accounti return tallies, nil } -// CreateStorageTally creates a record in the bucket_storage_tallies accounting table +// CreateStorageTally creates a record in the bucket_storage_tallies accounting table. func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accounting.BucketStorageTally) (err error) { defer mon.Task()(&ctx)(&err) @@ -117,7 +117,7 @@ func (db *ProjectAccounting) CreateStorageTally(ctx context.Context, tally accou )) } -// GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame +// GetAllocatedBandwidthTotal returns the sum of GET bandwidth usage allocated for a projectID for a time frame. func (db *ProjectAccounting) GetAllocatedBandwidthTotal(ctx context.Context, projectID uuid.UUID, from time.Time) (_ int64, err error) { defer mon.Task()(&ctx)(&err) var sum *int64 @@ -146,7 +146,7 @@ func (db *ProjectAccounting) GetProjectAllocatedBandwidth(ctx context.Context, p return *egress, err } -// GetStorageTotals returns the current inline and remote storage usage for a projectID +// GetStorageTotals returns the current inline and remote storage usage for a projectID. func (db *ProjectAccounting) GetStorageTotals(ctx context.Context, projectID uuid.UUID) (inline int64, remote int64, err error) { defer mon.Task()(&ctx)(&err) var inlineSum, remoteSum sql.NullInt64 @@ -228,8 +228,7 @@ func (db *ProjectAccounting) GetProjectBandwidthLimit(ctx context.Context, proje func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid.UUID, since, before time.Time) (usage *accounting.ProjectUsage, err error) { defer mon.Task()(&ctx)(&err) since = timeTruncateDown(since) - - bucketNames, err := db.getBuckets(ctx, projectID, since, before) + bucketNames, err := db.getBuckets(ctx, projectID) if err != nil { return nil, err } @@ -244,7 +243,7 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid bucket_storage_tallies WHERE bucket_storage_tallies.project_id = ? AND - bucket_storage_tallies.bucket_name = ? AND + bucket_storage_tallies.bucket_name = ? AND bucket_storage_tallies.interval_start >= ? AND bucket_storage_tallies.interval_start <= ? ORDER BY bucket_storage_tallies.interval_start DESC @@ -259,7 +258,6 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid if err != nil { return nil, err } - // generating tallies for each bucket name. for storageTalliesRows.Next() { tally := accounting.BucketStorageTally{} @@ -287,14 +285,11 @@ func (db *ProjectAccounting) GetProjectTotal(ctx context.Context, projectID uuid usage = new(accounting.ProjectUsage) usage.Egress = memory.Size(totalEgress).Int64() - // sum up storage and objects for _, tallies := range bucketsTallies { for i := len(tallies) - 1; i > 0; i-- { current := (tallies)[i] - hours := (tallies)[i-1].IntervalStart.Sub(current.IntervalStart).Hours() - usage.Storage += memory.Size(current.InlineBytes).Float64() * hours usage.Storage += memory.Size(current.RemoteBytes).Float64() * hours usage.ObjectCount += float64(current.ObjectCount) * hours @@ -329,13 +324,13 @@ func (db *ProjectAccounting) getTotalEgress(ctx context.Context, projectID uuid. return totalEgress, err } -// GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period +// GetBucketUsageRollups retrieves summed usage rollups for every bucket of particular project for a given period. func (db *ProjectAccounting) GetBucketUsageRollups(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ []accounting.BucketUsageRollup, err error) { defer mon.Task()(&ctx)(&err) since = timeTruncateDown(since.UTC()) before = before.UTC() - buckets, err := db.getBuckets(ctx, projectID, since, before) + buckets, err := db.getBuckets(ctx, projectID) if err != nil { return nil, err } @@ -478,7 +473,7 @@ func (db *ProjectAccounting) prefixMatch(expr string, prefix []byte) (string, [] } -// GetBucketTotals retrieves bucket usage totals for period of time +// GetBucketTotals retrieves bucket usage totals for period of time. func (db *ProjectAccounting) GetBucketTotals(ctx context.Context, projectID uuid.UUID, cursor accounting.BucketUsageCursor, since, before time.Time) (_ *accounting.BucketUsagePage, err error) { defer mon.Task()(&ctx)(&err) since = timeTruncateDown(since) @@ -617,14 +612,14 @@ func (db *ProjectAccounting) GetBucketTotals(ctx context.Context, projectID uuid return page, nil } -// getBuckets list all bucket of certain project for given period -func (db *ProjectAccounting) getBuckets(ctx context.Context, projectID uuid.UUID, since, before time.Time) (_ []string, err error) { +// getBuckets list all bucket of certain project. +func (db *ProjectAccounting) getBuckets(ctx context.Context, projectID uuid.UUID) (_ []string, err error) { defer mon.Task()(&ctx)(&err) bucketsQuery := db.db.Rebind(`SELECT DISTINCT bucket_name - FROM bucket_bandwidth_rollups - WHERE project_id = ? AND interval_start >= ? AND interval_start <= ?`) + FROM bucket_storage_tallies + WHERE project_id = ?`) - bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:], since, before) + bucketRows, err := db.db.QueryContext(ctx, bucketsQuery, projectID[:]) if err != nil { return nil, err } @@ -644,7 +639,7 @@ func (db *ProjectAccounting) getBuckets(ctx context.Context, projectID uuid.UUID return buckets, bucketRows.Err() } -// timeTruncateDown truncates down to the hour before to be in sync with orders endpoint +// timeTruncateDown truncates down to the hour before to be in sync with orders endpoint. func timeTruncateDown(t time.Time) time.Time { return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 0, t.Location()) }