Skip to content

Commit

Permalink
reconciler: allow custom comparison function
Browse files Browse the repository at this point in the history
  • Loading branch information
wasaga committed Nov 8, 2023
1 parent 62a9299 commit 469184d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
4 changes: 2 additions & 2 deletions pkg/grpc/databroker/changeset.go
Expand Up @@ -9,13 +9,13 @@ import (

// GetChangeSet returns list of changes between the current and target record sets,
// that may be applied to the databroker to bring it to the target state.
func GetChangeSet(current, target RecordSetBundle) []*Record {
func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Record {
cs := &changeSet{now: timestamppb.Now()}

for _, rec := range current.GetRemoved(target).Flatten() {
cs.Remove(rec.GetType(), rec.GetId())
}
for _, rec := range current.GetModified(target).Flatten() {
for _, rec := range current.GetModified(target, cmpFn).Flatten() {
cs.Upsert(rec)
}
for _, rec := range current.GetAdded(target).Flatten() {
Expand Down
5 changes: 4 additions & 1 deletion pkg/grpc/databroker/reconciler.go
Expand Up @@ -17,6 +17,7 @@ type Reconciler struct {
name string
client DataBrokerServiceClient
currentStateBuilder StateBuilderFn
cmpFn RecordCompareFn
targetStateBuilder StateBuilderFn
setCurrentState func([]*Record)
trigger chan struct{}
Expand Down Expand Up @@ -58,6 +59,7 @@ func NewReconciler(
currentStateBuilder StateBuilderFn,
targetStateBuilder StateBuilderFn,
setCurrentState func([]*Record),
cmpFn RecordCompareFn,
opts ...ReconcilerOption,
) *Reconciler {
return &Reconciler{
Expand All @@ -68,6 +70,7 @@ func NewReconciler(
currentStateBuilder: currentStateBuilder,
targetStateBuilder: targetStateBuilder,
setCurrentState: setCurrentState,
cmpFn: cmpFn,
}
}

Expand Down Expand Up @@ -119,7 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context) error {
return fmt.Errorf("get config record sets: %w", err)
}

updates := GetChangeSet(current, target)
updates := GetChangeSet(current, target, r.cmpFn)

err = r.applyChanges(ctx, updates)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions pkg/grpc/databroker/recordset.go
Expand Up @@ -2,7 +2,6 @@ package databroker

import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

// RecordSetBundle is an index of databroker records by type
Expand All @@ -11,6 +10,9 @@ type RecordSetBundle map[string]RecordSet
// RecordSet is an index of databroker records by their id.
type RecordSet map[string]*Record

// RecordCompareFn is a function that compares two records.
type RecordCompareFn func(record1, record2 *Record) bool

// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle) RecordTypes() []string {
types := make([]string, 0, len(rsb))
Expand Down Expand Up @@ -53,14 +55,14 @@ func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
}

// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
func (rsb RecordSetBundle) GetModified(other RecordSetBundle, cmpFn RecordCompareFn) RecordSetBundle {
modified := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
m := rs.GetModified(otherRS, cmpFn)
if len(m) > 0 {
modified[otherType] = m
}
Expand All @@ -86,15 +88,15 @@ func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {

// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
func (rs RecordSet) GetModified(other RecordSet, cmpFn RecordCompareFn) RecordSet {
modified := make(RecordSet)
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}

if !proto.Equal(record, otherRecord) {
if !cmpFn(record, otherRecord) {
modified[id] = record
}
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/grpc/databroker/recordset_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"

"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
Expand All @@ -19,6 +20,10 @@ func TestRecords(t *testing.T) {
}
}

cmpFn := func(a, b *databroker.Record) bool {
return proto.Equal(a, b)
}

initial := make(databroker.RecordSetBundle)
initial.Add(tr("1", "a", "a-1"))
initial.Add(tr("2", "a", "a-2"))
Expand Down Expand Up @@ -68,7 +73,7 @@ func TestRecords(t *testing.T) {
},
})

modified := initial.GetModified(updated)
modified := initial.GetModified(updated, cmpFn)
equalJSON(modified, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"1": tr("1", "a", "a-1-1"),
Expand Down

0 comments on commit 469184d

Please sign in to comment.