From a2d5d8062bada3209ae2c0cd978f91b1d5a03d10 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 9 Jun 2022 12:18:20 -0600 Subject: [PATCH] postgres: use CTE and GENERATED version number instead of serialized transaction (#3408) * postgres: use CTE and GENERATED version number instead of serialized transaction * update server version * fix indexing CIDRs --- pkg/grpc/databroker/syncer.go | 7 ---- pkg/grpc/databroker/syncer_test.go | 10 +---- pkg/storage/postgres/backend.go | 61 ++++++++++-------------------- pkg/storage/postgres/migrate.go | 35 +++++++++++++++++ pkg/storage/postgres/postgres.go | 50 ++++++++++++------------ 5 files changed, 81 insertions(+), 82 deletions(-) diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index 4abd16bd504..844d4f18609 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -2,7 +2,6 @@ package databroker import ( "context" - "fmt" "time" backoff "github.com/cenkalti/backoff/v4" @@ -174,12 +173,6 @@ func (syncer *Syncer) sync(ctx context.Context) error { rec := res.GetRecord() log.Debug(logCtxRec(ctx, rec)).Msg("syncer got record") - if syncer.recordVersion != res.GetRecord().GetVersion()-1 { - log.Error(logCtxRec(ctx, rec)).Err(err). - Msg("aborted sync due to missing record") - syncer.serverVersion = 0 - return fmt.Errorf("missing record version") - } syncer.recordVersion = res.GetRecord().GetVersion() if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() { ctx := logCtxRec(ctx, rec) diff --git a/pkg/grpc/databroker/syncer_test.go b/pkg/grpc/databroker/syncer_test.go index 7e51561e58f..22d7f40a9a4 100644 --- a/pkg/grpc/databroker/syncer_test.go +++ b/pkg/grpc/databroker/syncer_test.go @@ -205,15 +205,9 @@ func TestSyncer(t *testing.T) { select { case <-ctx.Done(): - t.Fatal("6. expected call to clear records due to skipped version") - case <-clearCh: - } - - select { - case <-ctx.Done(): - t.Fatal("7. expected call to update records") + t.Fatal("6. expected call to update records") case records := <-updateCh: - testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}, {"id": "r5", "version": "1004"}]`, records) + testutil.AssertProtoJSONEqual(t, `[{"id": "r5", "version": "1004"}]`, records) } assert.NoError(t, syncer.Close()) diff --git a/pkg/storage/postgres/backend.go b/pkg/storage/postgres/backend.go index e8013efc67e..6eba9186ae8 100644 --- a/pkg/storage/postgres/backend.go +++ b/pkg/storage/postgres/backend.go @@ -158,53 +158,32 @@ func (backend *Backend) Put( return 0, err } - err = pool.BeginTxFunc(ctx, pgx.TxOptions{ - IsoLevel: pgx.Serializable, - AccessMode: pgx.ReadWrite, - }, func(tx pgx.Tx) error { - now := timestamppb.Now() + now := timestamppb.Now() - recordVersion, err := getLatestRecordVersion(ctx, tx) + // add all the records + recordTypes := map[string]struct{}{} + for i, record := range records { + recordTypes[record.GetType()] = struct{}{} + + record = dup(record) + record.ModifiedAt = now + err := putRecordAndChange(ctx, pool, record) if err != nil { - return fmt.Errorf("storage/postgres: error getting latest record version: %w", err) + return serverVersion, fmt.Errorf("storage/postgres: error saving record: %w", err) } + records[i] = record + } - // add all the records - recordTypes := map[string]struct{}{} - for i, record := range records { - recordTypes[record.GetType()] = struct{}{} - - record = dup(record) - record.ModifiedAt = now - record.Version = recordVersion + uint64(i) + 1 - err := putRecordChange(ctx, tx, record) - if err != nil { - return fmt.Errorf("storage/postgres: error saving record change: %w", err) - } - - err = putRecord(ctx, tx, record) - if err != nil { - return fmt.Errorf("storage/postgres: error saving record: %w", err) - } - records[i] = record + // enforce options for each record type + for recordType := range recordTypes { + options, err := getOptions(ctx, pool, recordType) + if err != nil { + return serverVersion, fmt.Errorf("storage/postgres: error getting options: %w", err) } - - // enforce options for each record type - for recordType := range recordTypes { - options, err := getOptions(ctx, tx, recordType) - if err != nil { - return fmt.Errorf("storage/postgres: error getting options: %w", err) - } - err = enforceOptions(ctx, tx, recordType, options) - if err != nil { - return fmt.Errorf("storage/postgres: error enforcing options: %w", err) - } + err = enforceOptions(ctx, pool, recordType, options) + if err != nil { + return serverVersion, fmt.Errorf("storage/postgres: error enforcing options: %w", err) } - - return nil - }) - if err != nil { - return serverVersion, err } err = signalRecordChange(ctx, pool) diff --git a/pkg/storage/postgres/migrate.go b/pkg/storage/postgres/migrate.go index fb22fc76855..d65e2809b2e 100644 --- a/pkg/storage/postgres/migrate.go +++ b/pkg/storage/postgres/migrate.go @@ -77,6 +77,41 @@ var migrations = []func(context.Context, pgx.Tx) error{ return err } + return nil + }, + 2: func(ctx context.Context, tx pgx.Tx) error { + serverVersion := uint64(cryptutil.NewRandomUInt32()) + _, err := tx.Exec(ctx, ` + UPDATE `+schemaName+`.`+migrationInfoTableName+` + SET server_version = $1 + `, serverVersion) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + DELETE FROM `+schemaName+`.`+recordChangesTableName+` + `) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + DELETE FROM `+schemaName+`.`+recordsTableName+` + `) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + ALTER TABLE `+schemaName+`.`+recordChangesTableName+` + ALTER COLUMN version + ADD GENERATED BY DEFAULT AS IDENTITY + `) + if err != nil { + return err + } + return nil }, } diff --git a/pkg/storage/postgres/postgres.go b/pkg/storage/postgres/postgres.go index da8b55ea024..2809f166e89 100644 --- a/pkg/storage/postgres/postgres.go +++ b/pkg/storage/postgres/postgres.go @@ -225,7 +225,7 @@ func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string return leaseHolderID, err } -func putRecordChange(ctx context.Context, q querier, record *databroker.Record) error { +func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error { data, err := jsonbFromAny(record.GetData()) if err != nil { return err @@ -233,40 +233,38 @@ func putRecordChange(ctx context.Context, q querier, record *databroker.Record) modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt()) deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt()) - _, err = q.Exec(ctx, ` - INSERT INTO `+schemaName+`.`+recordChangesTableName+` (type, id, version, data, modified_at, deleted_at) - VALUES ($1, $2, $3, $4, $5, $6) - `, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt, deletedAt) - if err != nil { - return err + indexCIDR := &pgtype.Text{Status: pgtype.Null} + if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil { + _ = indexCIDR.Set(cidr.String()) } - return nil -} - -func putRecord(ctx context.Context, q querier, record *databroker.Record) error { - data, err := jsonbFromAny(record.GetData()) - if err != nil { - return err - } - - modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt()) - if record.GetDeletedAt() == nil { - _, err = q.Exec(ctx, ` - INSERT INTO `+schemaName+`.`+recordsTableName+` (type, id, version, data, modified_at) + query := ` + WITH t1 AS ( + INSERT INTO ` + schemaName + `.` + recordChangesTableName + ` (type, id, data, modified_at, deleted_at) VALUES ($1, $2, $3, $4, $5) + RETURNING * + ) + ` + if record.GetDeletedAt() == nil { + query += ` + INSERT INTO ` + schemaName + `.` + recordsTableName + ` (type, id, version, data, modified_at, index_cidr) + VALUES ($1, $2, (SELECT version FROM t1), $3, $4, $6) ON CONFLICT (type, id) DO UPDATE - SET version=$3, data=$4, modified_at=$5 - `, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt) + SET version=(SELECT version FROM t1), data=$3, modified_at=$4, index_cidr=$6 + RETURNING ` + schemaName + `.` + recordsTableName + `.version + ` } else { - _, err = q.Exec(ctx, ` - DELETE FROM `+schemaName+`.`+recordsTableName+` - WHERE type=$1 AND id=$2 AND version<$3 - `, record.GetType(), record.GetId(), record.GetVersion()) + query += ` + DELETE FROM ` + schemaName + `.` + recordsTableName + ` + WHERE type=$1 AND id=$2 + RETURNING ` + schemaName + `.` + recordsTableName + `.version + ` } + err = q.QueryRow(ctx, query, record.GetType(), record.GetId(), data, modifiedAt, deletedAt, indexCIDR).Scan(&record.Version) if err != nil { return err } + return nil }