diff --git a/go.mod b/go.mod index 5eb83ad7d5f..2e9961da205 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.22.2 github.com/aws/aws-sdk-go-v2/config v1.18.42 github.com/aws/aws-sdk-go-v2/service/s3 v1.42.1 + github.com/bits-and-blooms/bitset v1.11.0 github.com/caddyserver/certmagic v0.19.2 github.com/cenkalti/backoff/v4 v4.2.1 github.com/cespare/xxhash/v2 v2.2.0 diff --git a/go.sum b/go.sum index 948c3d1ffb5..34a2b925dec 100644 --- a/go.sum +++ b/go.sum @@ -137,6 +137,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.11.0 h1:RMyy2mBBShArUAhfVRZJ2xyBO58KCBCtZFShw3umo6k= +github.com/bits-and-blooms/bitset v1.11.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bufbuild/buf v1.26.1 h1:+GdU4z2paCmDclnjLv7MqnVi3AGviImlIKhG0MHH9FA= github.com/bufbuild/buf v1.26.1/go.mod h1:UMPncXMWgrmIM+0QpwTEwjNr2SA0z2YIVZZsmNflvB4= diff --git a/pkg/counter/counter.go b/pkg/counter/counter.go new file mode 100644 index 00000000000..56b7558529c --- /dev/null +++ b/pkg/counter/counter.go @@ -0,0 +1,64 @@ +// Package counter implements linear counter estimator +package counter + +import ( + "hash/crc32" + "math" + + "github.com/bits-and-blooms/bitset" +) + +const ( + // DefaultCap max capacity for the counter + DefaultCap = 1 << 19 + loadFactor = 4 +) + +// Counter implements a simple probabilistic counter estimator with 1% estimation accuracy +// as described in https://www.waitingforcode.com/big-data-algorithms/cardinality-estimation-linear-probabilistic-counting/read +type Counter struct { + Bits *bitset.BitSet `json:"bits"` +} + +// New creates a counter for the maximum amount unique elements provided +func New(cap uint) *Counter { + return &Counter{ + // from paper: a load factor (number of unique values/hash table size) much larger + // than 1.0 (e.g., 12) can be used for accurate estimation (e.g., 1% of error) + Bits: bitset.New(cap / loadFactor), + } +} + +// FromBinary unmarshals counter state +func FromBinary(data []byte) (*Counter, error) { + pc := &Counter{ + Bits: &bitset.BitSet{}, + } + if err := pc.Bits.UnmarshalBinary(data); err != nil { + return nil, err + } + return pc, nil +} + +// ToBinary marshals counter state +func (c *Counter) ToBinary() ([]byte, error) { + return c.Bits.MarshalBinary() +} + +// Reset the counter +func (c *Counter) Reset() { + c.Bits.ClearAll() +} + +// Mark marks key as present in the set +func (c *Counter) Mark(key string) { + hash := crc32.ChecksumIEEE([]byte(key)) + c.Bits.Set(uint(hash) % c.Bits.Len()) +} + +// Count returns an estimate of distinct elements in the set +func (c *Counter) Count() uint { + size := float64(c.Bits.Len()) + zeros := size - float64(c.Bits.Count()) + return uint(-1 * size * math.Log(zeros/size)) +} diff --git a/pkg/counter/counter_test.go b/pkg/counter/counter_test.go new file mode 100644 index 00000000000..547808ee50e --- /dev/null +++ b/pkg/counter/counter_test.go @@ -0,0 +1,77 @@ +package counter_test + +import ( + "fmt" + "math" + "math/rand" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/counter" +) + +func stableRandomUUIDs(n int) []string { + r := rand.New(rand.NewSource(1234567890)) + out := make([]string, 0, n) + for i := 0; i < n; i++ { + u, _ := uuid.NewRandomFromReader(r) + out = append(out, u.String()) + } + return out +} + +func TestStableRandomUUIDs(t *testing.T) { + t.Parallel() + + assert.Equal(t, stableRandomUUIDs(20), stableRandomUUIDs(20)) +} + +func TestCounter(t *testing.T) { + t.Parallel() + + limit := 1000 + n := (limit * 8) / 10 + for j := 0; j < 20; j++ { + t.Run(fmt.Sprint(j), func(t *testing.T) { + c := counter.New(uint(limit)) + for _, id := range stableRandomUUIDs(n) { + c.Mark(id) + } + est := c.Count() + assert.LessOrEqual(t, math.Abs(float64(n)-float64(est)), math.Ceil(float64(n)*0.01)) + }) + } +} + +func TestSerialize(t *testing.T) { + t.Parallel() + + c := counter.New(counter.DefaultCap) + for _, id := range stableRandomUUIDs(20) { + c.Mark(id) + } + assert.EqualValues(t, 20, c.Count()) + + data, err := c.ToBinary() + require.NoError(t, err) + + c2, err := counter.FromBinary(data) + require.NoError(t, err) + + assert.EqualValues(t, 20, c2.Count()) +} + +func TestReset(t *testing.T) { + t.Parallel() + + c := counter.New(counter.DefaultCap) + for _, id := range stableRandomUUIDs(20) { + c.Mark(id) + } + assert.EqualValues(t, 20, c.Count()) + c.Reset() + assert.EqualValues(t, 0, c.Count()) +}