Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
util/parallel: add package for parallel execution
- Loading branch information
Showing
2 changed files
with
196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (C) 2017 ScyllaDB | ||
|
||
package parallel | ||
|
||
import ( | ||
"go.uber.org/atomic" | ||
"go.uber.org/multierr" | ||
) | ||
|
||
// NoLimit means full parallelism mode. | ||
const NoLimit = 0 | ||
|
||
// ErrAbort is a special kind of error that aborts all further execution. | ||
// Function calls that are in progress will continue to execute but no new | ||
// functions will be called. | ||
type ErrAbort struct { | ||
error | ||
} | ||
|
||
// Abort is special kind of error that aborts all further execution. | ||
func Abort(err error) ErrAbort { | ||
return ErrAbort{error: err} | ||
} | ||
|
||
func isErrAbort(err error) (bool, error) { | ||
a, ok := err.(ErrAbort) | ||
if !ok { | ||
return false, nil | ||
} | ||
return true, a.error | ||
} | ||
|
||
// Run executes function f with arguments ranging from 0 to n-1 executing at | ||
// most limit in parallel. | ||
// If limit is 0 it runs f(0),f(1),...,f(n-1) in parallel. | ||
func Run(n, limit int, f func(i int) error) error { | ||
if limit <= 0 || limit > n { | ||
limit = n | ||
} | ||
|
||
var ( | ||
idx = atomic.NewInt32(0) | ||
out = make(chan error) | ||
abrt = atomic.NewBool(false) | ||
) | ||
for j := 0; j < limit; j++ { | ||
go func() { | ||
for { | ||
// Exit when there is nothing to do | ||
i := int(idx.Inc()) - 1 | ||
if i >= n { | ||
return | ||
} | ||
|
||
// Exit if aborted | ||
if abrt.Load() { | ||
out <- nil | ||
continue | ||
} | ||
|
||
// Execute | ||
err := f(i) | ||
if ok, inner := isErrAbort(err); ok { | ||
abrt.Store(true) | ||
err = inner | ||
} | ||
out <- err | ||
} | ||
}() | ||
} | ||
|
||
var errs error | ||
for i := 0; i < n; i++ { | ||
errs = multierr.Append(errs, <-out) | ||
} | ||
return errs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Copyright (C) 2017 ScyllaDB | ||
|
||
package parallel | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
"github.com/scylladb/scylla-operator/pkg/util/timeutc" | ||
"go.uber.org/atomic" | ||
) | ||
|
||
func TestRun(t *testing.T) { | ||
t.Parallel() | ||
|
||
const ( | ||
n = 50 | ||
wait = 5 * time.Millisecond | ||
) | ||
|
||
table := []struct { | ||
Name string | ||
Limit int | ||
Duration time.Duration | ||
}{ | ||
// This test is flaky under race | ||
//{ | ||
// Name: "No limit", | ||
// Duration: wait, | ||
//}, | ||
{ | ||
Name: "One by one", | ||
Limit: 1, | ||
Duration: n * wait, | ||
}, | ||
{ | ||
Name: "Five by five", | ||
Limit: 5, | ||
Duration: n / 5 * wait, | ||
}, | ||
} | ||
|
||
for i := range table { | ||
test := table[i] | ||
|
||
t.Run(test.Name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
active := atomic.NewInt32(0) | ||
f := func(i int) error { | ||
v := active.Inc() | ||
if test.Limit != NoLimit { | ||
if v > int32(test.Limit) { | ||
t.Errorf("limit exeded, got %d", v) | ||
} | ||
} | ||
time.Sleep(wait) | ||
active.Dec() | ||
return nil | ||
} | ||
|
||
start := timeutc.Now() | ||
if err := Run(n, test.Limit, f); err != nil { | ||
t.Error("Run() error", err) | ||
} | ||
d := timeutc.Since(start) | ||
if a, b := epsilonRange(test.Duration); d < a || d > b { | ||
t.Errorf("Run() not within expected time margin %v got %v", test.Duration, d) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestIsErrAbort(t *testing.T) { | ||
t.Parallel() | ||
|
||
t.Run("nil", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
if ok, err := isErrAbort(Abort(nil)); !ok || err != nil { | ||
t.Errorf("isErrAbort() = (%v, %v), expected (%v, %v))", ok, err, true, nil) | ||
} | ||
}) | ||
|
||
t.Run("not nil", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
err := errors.New("too") | ||
|
||
if ok, inner := isErrAbort(Abort(err)); !ok || inner != err { | ||
t.Errorf("isErrAbort() = (%v, %v), expected (%v, %v))", ok, inner, true, err) | ||
} | ||
}) | ||
} | ||
|
||
func TestAbort(t *testing.T) { | ||
t.Parallel() | ||
|
||
called := atomic.NewInt32(0) | ||
f := func(i int) error { | ||
called.Inc() | ||
return Abort(errors.New("boo")) | ||
} | ||
|
||
if err := Run(10, 1, f); err == nil { | ||
t.Error("Run() expected error") | ||
} | ||
|
||
if c := called.Load(); c != 1 { | ||
t.Errorf("Called %d times expected 1", c) | ||
} | ||
} | ||
|
||
// EpsilonRange returns start and end of range 5% close to provided value. | ||
func epsilonRange(d time.Duration) (a, b time.Duration) { | ||
e := time.Duration(float64(d) * 1.05) | ||
return d - e, d + e | ||
} |