Skip to content

Commit

Permalink
Add flush method to writer
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacd9 committed Apr 22, 2024
1 parent b2b17ac commit dafa65d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 6 deletions.
53 changes: 47 additions & 6 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,41 @@ func (w *Writer) spawn(f func()) {
}()
}

// Flush writes all currently buffered messages to the kafka cluster. This will
// block until all messages in the batch has been written to kafka, or until the
// context is canceled.
func (w *Writer) Flush(ctx context.Context) error {
w.mutex.Lock()

var wg sync.WaitGroup

// flush all writers
for _, writer := range w.writers {
w := writer
wg.Add(1)
go func() {
b := w.flush()
<-b.done
wg.Done()
}()
}

w.mutex.Unlock()
done := make(chan struct{})

go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

// Close flushes pending writes, and waits for all writes to complete before
// returning. Calling Close also prevents new writes from being submitted to
// the writer, further calls to WriteMessages and the like will fail with
Expand Down Expand Up @@ -1184,17 +1219,23 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) {
batch.complete(err)
}

func (ptw *partitionWriter) close() {
func (ptw *partitionWriter) flush() *writeBatch {
ptw.mutex.Lock()
defer ptw.mutex.Unlock()

if ptw.currBatch != nil {
batch := ptw.currBatch
ptw.queue.Put(batch)
ptw.currBatch = nil
batch.trigger()
if ptw.currBatch == nil {
return nil
}

batch := ptw.currBatch
ptw.queue.Put(batch)
ptw.currBatch = nil
batch.trigger()
return batch
}

func (ptw *partitionWriter) close() {
ptw.flush()
ptw.queue.Close()
}

Expand Down
66 changes: 66 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ func TestWriter(t *testing.T) {
function: testWriterMaxBytes,
},

{
scenario: "writing a batch of message and flush",
function: testWriterFlush,
},

{
scenario: "writing a batch of message based on batch byte size",
function: testWriterBatchBytes,
Expand Down Expand Up @@ -503,6 +508,67 @@ func testWriterBatchBytes(t *testing.T) {
}
}

func testWriterFlush(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)

offset, err := readOffset(topic, 0)
if err != nil {
t.Fatal(err)
}

w := newTestWriter(WriterConfig{
Topic: topic,
// Set the batch timeout to a large value to avoid the timeout
BatchSize: 1000,
BatchBytes: 1000000,
BatchTimeout: 1000 * time.Second,
Balancer: &RoundRobin{},
Async: true,
})
defer w.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := w.WriteMessages(ctx, []Message{
{Value: []byte("M0")}, // 25 Bytes
{Value: []byte("M1")}, // 25 Bytes
{Value: []byte("M2")}, // 25 Bytes
{Value: []byte("M3")}, // 25 Bytes
}...); err != nil {
t.Error(err)
return
}

if err := w.Flush(ctx); err != nil {
t.Errorf("flush error %v", err)
return
}

if w.Stats().Writes != 1 {
t.Error("didn't create expected batches")
return
}
msgs, err := readPartition(topic, 0, offset)
if err != nil {
t.Error("error reading partition", err)
return
}

if len(msgs) != 4 {
t.Error("bad messages in partition", msgs)
return
}

for i, m := range msgs {
if string(m.Value) == "M"+strconv.Itoa(i) {
continue
}
t.Error("bad messages in partition", string(m.Value))
}
}

func testWriterBatchSize(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
Expand Down

0 comments on commit dafa65d

Please sign in to comment.