-
Notifications
You must be signed in to change notification settings - Fork 280
/
stream.go
174 lines (148 loc) · 4.03 KB
/
stream.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package storage
import (
"context"
"errors"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// A RecordStream is a stream of records.
type RecordStream interface {
// Close closes the record stream and releases any underlying resources.
Close() error
// Next is called to retrieve the next record. If one is available it will
// be returned immediately. If none is available and block is true, the method
// will block until one is available or an error occurs. The error should be
// checked with a call to `.Err()`.
Next(block bool) bool
// Record returns the current record.
Record() *databroker.Record
// Err returns any error that occurred while streaming.
Err() error
}
// A RecordStreamGenerator generates records for a record stream.
type RecordStreamGenerator = func(ctx context.Context, block bool) (*databroker.Record, error)
type recordStream struct {
generators []RecordStreamGenerator
record *databroker.Record
err error
closeCtx context.Context
close context.CancelFunc
onClose func()
}
// NewRecordStream creates a new RecordStream from a list of generators and an onClose function.
func NewRecordStream(
ctx context.Context,
backendClosed chan struct{},
generators []RecordStreamGenerator,
onClose func(),
) RecordStream {
stream := &recordStream{
generators: generators,
onClose: onClose,
}
stream.closeCtx, stream.close = context.WithCancel(ctx)
if backendClosed != nil {
go func() {
defer stream.close()
select {
case <-backendClosed:
case <-stream.closeCtx.Done():
}
}()
}
return stream
}
func (stream *recordStream) Close() error {
stream.close()
if stream.onClose != nil {
stream.onClose()
}
return nil
}
func (stream *recordStream) Next(block bool) bool {
for {
if len(stream.generators) == 0 || stream.err != nil {
return false
}
stream.record, stream.err = stream.generators[0](stream.closeCtx, block)
if errors.Is(stream.err, ErrStreamDone) {
stream.err = nil
stream.generators = stream.generators[1:]
continue
}
break
}
return stream.err == nil
}
func (stream *recordStream) Record() *databroker.Record {
return stream.record
}
func (stream *recordStream) Err() error {
return stream.err
}
// RecordStreamToList converts a record stream to a list.
func RecordStreamToList(recordStream RecordStream) ([]*databroker.Record, error) {
var all []*databroker.Record
for recordStream.Next(false) {
all = append(all, recordStream.Record())
}
return all, recordStream.Err()
}
// RecordListToStream converts a record list to a stream.
func RecordListToStream(ctx context.Context, records []*databroker.Record) RecordStream {
return NewRecordStream(ctx, nil, []RecordStreamGenerator{
func(_ context.Context, _ bool) (*databroker.Record, error) {
if len(records) == 0 {
return nil, ErrStreamDone
}
record := records[0]
records = records[1:]
return record, nil
},
}, nil)
}
type concatenatedRecordStream struct {
streams []RecordStream
index int
}
// NewConcatenatedRecordStream creates a new record stream that streams all the records from the
// first stream before streaming all the records of the subsequent streams.
func NewConcatenatedRecordStream(streams ...RecordStream) RecordStream {
return &concatenatedRecordStream{
streams: streams,
}
}
func (stream *concatenatedRecordStream) Close() error {
var err error
for _, s := range stream.streams {
if e := s.Close(); e != nil {
err = e
}
}
return err
}
func (stream *concatenatedRecordStream) Next(block bool) bool {
for {
if stream.index >= len(stream.streams) {
return false
}
if stream.streams[stream.index].Next(block) {
return true
}
if stream.streams[stream.index].Err() != nil {
return false
}
stream.index++
}
}
func (stream *concatenatedRecordStream) Record() *databroker.Record {
if stream.index >= len(stream.streams) {
return nil
}
return stream.streams[stream.index].Record()
}
func (stream *concatenatedRecordStream) Err() error {
if stream.index >= len(stream.streams) {
return nil
}
return stream.streams[stream.index].Err()
}