Skip to content

Commit

Permalink
fix: concurrency-safe graph utils
Browse files Browse the repository at this point in the history
Co-authored-by: Henning Perl <hperl@users.noreply.github.com>
  • Loading branch information
zepatrik and hperl committed Aug 19, 2022
1 parent 6f61af8 commit ea9dda9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
51 changes: 33 additions & 18 deletions internal/x/graph/graph_utils.go
Expand Up @@ -2,34 +2,49 @@ package graph

import (
"context"
"fmt"
"sync"

"github.com/gofrs/uuid"
"github.com/ory/keto/internal/relationtuple"
)

type contextKey string

const visitedMapKey = contextKey("visitedMap")

func CheckAndAddVisited(ctx context.Context, current uuid.UUID) (context.Context, bool) {
visitedMap, ok := ctx.Value(visitedMapKey).(map[uuid.UUID]struct{})
if !ok {
// for the first time initialize the map
visitedMap = make(map[uuid.UUID]struct{})
visitedMap[current] = struct{}{}
return context.WithValue(ctx, visitedMapKey, visitedMap), false
type stringSet struct {
m map[string]struct{}
l sync.Mutex
}

func newStringSet() *stringSet {
return &stringSet{m: make(map[string]struct{})}
}

func (s *stringSet) addNoDuplicate(el fmt.Stringer) bool {
s.l.Lock()
defer s.l.Unlock()

if _, found := s.m[el.String()]; found {
return true
}
s.m[el.String()] = struct{}{}
return false
}

// check if current node was already visited
if _, ok := visitedMap[current]; ok {
return ctx, true
func InitVisited(ctx context.Context) context.Context {
if _, ok := ctx.Value(visitedMapKey).(*stringSet); !ok {
ctx = context.WithValue(ctx, visitedMapKey, newStringSet())
}
return ctx
}

// set current entry to visited
visitedMap[current] = struct{}{}
func CheckAndAddVisited(ctx context.Context, current relationtuple.Subject) (context.Context, bool) {
set, ok := ctx.Value(visitedMapKey).(*stringSet)
if !ok {
set = newStringSet()
ctx = context.WithValue(ctx, visitedMapKey, set)
}

return context.WithValue(
ctx,
visitedMapKey,
visitedMap,
), false
return ctx, set.addNoDuplicate(current.UniqueID())
}
43 changes: 41 additions & 2 deletions internal/x/graph/graph_utils_test.go
Expand Up @@ -2,6 +2,7 @@ package graph

import (
"context"
"sync"
"testing"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -39,7 +40,7 @@ func TestEngineUtilsProvider_CheckVisited(t *testing.T) {
ctx := context.Background()
var isThereACycle bool
for i := range linkedList {
ctx, isThereACycle = CheckAndAddVisited(ctx, linkedList[i].UniqueID())
ctx, isThereACycle = CheckAndAddVisited(ctx, &linkedList[i])
if isThereACycle {
break
}
Expand Down Expand Up @@ -74,12 +75,50 @@ func TestEngineUtilsProvider_CheckVisited(t *testing.T) {
ctx := context.Background()
var isThereACycle bool
for i := range list {
ctx, isThereACycle = CheckAndAddVisited(ctx, list[i].UniqueID())
ctx, isThereACycle = CheckAndAddVisited(ctx, &list[i])
if isThereACycle {
break
}
}

assert.Equal(t, isThereACycle, false)
})

t.Run("case=no race condition during adding", func(t *testing.T) {
racyObj := uuid.Must(uuid.NewV4())
otherObj := uuid.Must(uuid.NewV4())
// we repeat this test a few times to ensure we don't have a race condition
// the race detector alone was not able to catch it
for i := 0; i < 500; i++ {
subject := &relationtuple.SubjectSet{
Namespace: "default",
Object: racyObj,
Relation: "connected",
}

ctx, _ := CheckAndAddVisited(
context.Background(),
&relationtuple.SubjectSet{Object: otherObj},
)
var wg sync.WaitGroup
var aCycle, bCycle bool
var aCtx, bCtx context.Context

wg.Add(2)
go func() {
aCtx, aCycle = CheckAndAddVisited(ctx, subject)
wg.Done()
}()
go func() {
bCtx, bCycle = CheckAndAddVisited(ctx, subject)
wg.Done()
}()

wg.Wait()
// one should be true, and one false
assert.False(t, aCycle && bCycle)
assert.True(t, aCycle || bCycle)
assert.Equal(t, aCtx, bCtx)
}
})
}

0 comments on commit ea9dda9

Please sign in to comment.