Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reconciler: allow custom comparison function #4726

Merged
merged 1 commit into from Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/grpc/databroker/changeset.go
Expand Up @@ -9,13 +9,13 @@ import (

// GetChangeSet returns list of changes between the current and target record sets,
// that may be applied to the databroker to bring it to the target state.
func GetChangeSet(current, target RecordSetBundle) []*Record {
func GetChangeSet(current, target RecordSetBundle, cmpFn RecordCompareFn) []*Record {
cs := &changeSet{now: timestamppb.Now()}

for _, rec := range current.GetRemoved(target).Flatten() {
cs.Remove(rec.GetType(), rec.GetId())
}
for _, rec := range current.GetModified(target).Flatten() {
for _, rec := range current.GetModified(target, cmpFn).Flatten() {
cs.Upsert(rec)
}
for _, rec := range current.GetAdded(target).Flatten() {
Expand Down
5 changes: 4 additions & 1 deletion pkg/grpc/databroker/reconciler.go
Expand Up @@ -17,6 +17,7 @@ type Reconciler struct {
name string
client DataBrokerServiceClient
currentStateBuilder StateBuilderFn
cmpFn RecordCompareFn
targetStateBuilder StateBuilderFn
setCurrentState func([]*Record)
trigger chan struct{}
Expand Down Expand Up @@ -58,6 +59,7 @@ func NewReconciler(
currentStateBuilder StateBuilderFn,
targetStateBuilder StateBuilderFn,
setCurrentState func([]*Record),
cmpFn RecordCompareFn,
opts ...ReconcilerOption,
) *Reconciler {
return &Reconciler{
Expand All @@ -68,6 +70,7 @@ func NewReconciler(
currentStateBuilder: currentStateBuilder,
targetStateBuilder: targetStateBuilder,
setCurrentState: setCurrentState,
cmpFn: cmpFn,
}
}

Expand Down Expand Up @@ -119,7 +122,7 @@ func (r *Reconciler) reconcile(ctx context.Context) error {
return fmt.Errorf("get config record sets: %w", err)
}

updates := GetChangeSet(current, target)
updates := GetChangeSet(current, target, r.cmpFn)

err = r.applyChanges(ctx, updates)
if err != nil {
Expand Down
12 changes: 7 additions & 5 deletions pkg/grpc/databroker/recordset.go
Expand Up @@ -2,7 +2,6 @@ package databroker

import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

// RecordSetBundle is an index of databroker records by type
Expand All @@ -11,6 +10,9 @@ type RecordSetBundle map[string]RecordSet
// RecordSet is an index of databroker records by their id.
type RecordSet map[string]*Record

// RecordCompareFn is a function that compares two records.
type RecordCompareFn func(record1, record2 *Record) bool

// RecordTypes returns the types of records in the bundle.
func (rsb RecordSetBundle) RecordTypes() []string {
types := make([]string, 0, len(rsb))
Expand Down Expand Up @@ -53,14 +55,14 @@ func (rsb RecordSetBundle) GetRemoved(other RecordSetBundle) RecordSetBundle {
}

// GetModified returns the records that are in both rs and other but have different data.
func (rsb RecordSetBundle) GetModified(other RecordSetBundle) RecordSetBundle {
func (rsb RecordSetBundle) GetModified(other RecordSetBundle, cmpFn RecordCompareFn) RecordSetBundle {
modified := make(RecordSetBundle)
for otherType, otherRS := range other {
rs, ok := rsb[otherType]
if !ok {
continue
}
m := rs.GetModified(otherRS)
m := rs.GetModified(otherRS, cmpFn)
if len(m) > 0 {
modified[otherType] = m
}
Expand All @@ -86,15 +88,15 @@ func (rs RecordSet) GetRemoved(other RecordSet) RecordSet {

// GetModified returns the records that are in both rs and other but have different data.
// by comparing the protobuf bytes of the payload.
func (rs RecordSet) GetModified(other RecordSet) RecordSet {
func (rs RecordSet) GetModified(other RecordSet, cmpFn RecordCompareFn) RecordSet {
modified := make(RecordSet)
for id, record := range other {
otherRecord, ok := rs[id]
if !ok {
continue
}

if !proto.Equal(record, otherRecord) {
if !cmpFn(record, otherRecord) {
modified[id] = record
}
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/grpc/databroker/recordset_test.go
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"

"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
Expand All @@ -19,6 +20,10 @@ func TestRecords(t *testing.T) {
}
}

cmpFn := func(a, b *databroker.Record) bool {
return proto.Equal(a, b)
}

initial := make(databroker.RecordSetBundle)
initial.Add(tr("1", "a", "a-1"))
initial.Add(tr("2", "a", "a-2"))
Expand Down Expand Up @@ -68,7 +73,7 @@ func TestRecords(t *testing.T) {
},
})

modified := initial.GetModified(updated)
modified := initial.GetModified(updated, cmpFn)
equalJSON(modified, databroker.RecordSetBundle{
"a": databroker.RecordSet{
"1": tr("1", "a", "a-1-1"),
Expand Down