Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 81 additions & 8 deletions pipelines/pipelines.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,6 @@ func Map[S, T any](ctx context.Context, in <-chan S, f func(S) T, opts ...Option
}, configure(opts)))
}

func return1[T any](chans []chan T) <-chan T {
return chans[0]
}

func return2[T any](chans []chan T) (<-chan T, <-chan T) {
return chans[0], chans[1]
}

func doMap[S, T any](ctx context.Context, in <-chan S, f func(S) T, result chan<- T) {
for {
select {
Expand Down Expand Up @@ -510,3 +502,84 @@ func Reduce[S, T string](ctx context.Context, in <-chan S, f func(T, S) T) (T, e
}
}
}

// ErrorSink provides an error-handling solution for pipelines created by this package. It manages a
// pipeline stage which can receive fatal and non-fatal errors that may occur during the course of a pipeline.
type ErrorSink struct {
errors chan errWrapper
cancel context.CancelFunc
errs []error
needLock bool
lock *sync.Mutex
}

// NewErrorSink returns a new ErrorSink, along with a context which is cancelled when a fatal error is sent to the
// ErrorSink. Starts a new, configurable pipeline stage which catches any errors reported.
func NewErrorSink(ctx context.Context, opts ...Option) (context.Context, *ErrorSink) {
ctx, cancel := context.WithCancel(ctx)
result := &ErrorSink{cancel: cancel}
config := configure(opts)
if config.workers > 1 {
result.needLock = true
result.lock = new(sync.Mutex)
}
doWithConf(ctx, func(ctx context.Context, in ...chan errWrapper) {
result.doErrSink(ctx, in[0])
}, config)
return ctx, result
}

func (s *ErrorSink) doErrSink(ctx context.Context, errors chan errWrapper) {
s.errors = errors
for {
select {
case <-ctx.Done():
return
case werr := <-errors:
s.appendErr(werr.err)
if werr.isFatal {
s.cancel()
}
}
}
}

func (s *ErrorSink) appendErr(err error) {
if s.needLock {
s.lock.Lock()
defer s.lock.Unlock()
}
s.errs = append(s.errs, err)
}

// Fatal sends a fatal error to this ErrorSink, cancelling the child context which was created by NewErrorSink,
// as well as reporting this error.
func (s *ErrorSink) Fatal(err error) {
s.errors <- errWrapper{isFatal: true, err: err}
}

// Error sends a non-fatal error to this ErrorSink, which is reported and included along with All()
func (s *ErrorSink) Error(err error) {
s.errors <- errWrapper{isFatal: false, err: err}
}

// All returns all errors which have been received by this ErrorSink so far. Subsequent calls to All can return strictly
// more errors, but will never return fewer errors. The only way to be certain that all errors from a pipeline have been
// reported is to pass WithWaitGroup to every pipeline stage which sends an error to this ErrorSink and wait for all
// stages to terminate before calling All().
func (s *ErrorSink) All() []error {
return s.errs
}

type errWrapper struct {
isFatal bool
err error
}

func return1[T any](chans []chan T) <-chan T {
return chans[0]
}

func return2[T any](chans []chan T) (<-chan T, <-chan T) {
return chans[0], chans[1]
}
50 changes: 50 additions & 0 deletions pipelines/pipelines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,56 @@ func ExampleForkMapCtx() {
}
}

func ExampleErrorSink() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ctx, errs := pipelines.NewErrorSink(ctx)

urls := pipelines.Chan([]string{
"https://httpstat.us/200",
"https://httpstat.us/410",
"wrong://not.a.url/test", // malformed URL; triggers a fatal error
"https://httpstat.us/502",
})

// fetch a bunch of URLs, reporting errors along the way.
responses := pipelines.OptionMapCtx(ctx, urls, func(ctx context.Context, url string) *http.Response {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
errs.Fatal(fmt.Errorf("error forming request context: %w", err))
return nil
}
resp, err := http.DefaultClient.Do(req)

if err != nil {
errs.Error(fmt.Errorf("error fetching %s: %w", url, err))
return nil
}
if resp.StatusCode >= 400 {
errs.Error(fmt.Errorf("unsuccessful %s: %d", url, resp.StatusCode))
return nil
}
return resp
})

// retrieve all responses; there should be only one
for response := range responses {
fmt.Printf("success: %s: %d\n", response.Request.URL, response.StatusCode)
}

// retrieve all errors; the 502 error should be skipped, since the malformed URL triggers
// a fatal error.
for _, err := range errs.All() {
fmt.Printf("error: %v\n", err.Error())
}

// Output:
// success: https://httpstat.us/200: 200
// error: unsuccessful https://httpstat.us/410: 410
// error: error fetching wrong://not.a.url/test: Get "wrong://not.a.url/test": unsupported protocol scheme "wrong"
}

func ExampleWithWaitGroup() {
ctx := context.Background()

Expand Down