From 395d7c74149f9ef192abe495455971289c910e49 Mon Sep 17 00:00:00 2001 From: Isaac Diamond Date: Thu, 7 Mar 2024 10:58:03 -0800 Subject: [PATCH] Add flush method to writer --- writer.go | 46 ++++++++++++++++++++++++++++++----- writer_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 6 deletions(-) diff --git a/writer.go b/writer.go index 3c7af907..656e15dc 100644 --- a/writer.go +++ b/writer.go @@ -548,6 +548,34 @@ func (w *Writer) spawn(f func()) { }() } +// Flush writes all currently buffered messages to the kafka cluster. This will +// block until all messages in the batch has been written to kafka, or until the +// context is canceled. +func (w *Writer) Flush(ctx context.Context) error { + w.mutex.Lock() + + var ch = make(chan chan struct{}, len(w.writers)) + + // flush all writers + for _, writer := range w.writers { + b := writer.flush() + ch <- b.done + } + + close(ch) + w.mutex.Unlock() + + for done := range ch { + select { + case <-done: + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil +} + // Close flushes pending writes, and waits for all writes to complete before // returning. Calling Close also prevents new writes from being submitted to // the writer, further calls to WriteMessages and the like will fail with @@ -1184,17 +1212,23 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) { batch.complete(err) } -func (ptw *partitionWriter) close() { +func (ptw *partitionWriter) flush() *writeBatch { ptw.mutex.Lock() defer ptw.mutex.Unlock() - if ptw.currBatch != nil { - batch := ptw.currBatch - ptw.queue.Put(batch) - ptw.currBatch = nil - batch.trigger() + if ptw.currBatch == nil { + return nil } + batch := ptw.currBatch + ptw.queue.Put(batch) + ptw.currBatch = nil + batch.trigger() + return batch +} + +func (ptw *partitionWriter) close() { + ptw.flush() ptw.queue.Close() } diff --git a/writer_test.go b/writer_test.go index 6f894ecd..35d4f564 100644 --- a/writer_test.go +++ b/writer_test.go @@ -121,6 +121,11 @@ func TestWriter(t *testing.T) { function: testWriterMaxBytes, }, + { + scenario: "writing a batch of message and flush", + function: testWriterFlush, + }, + { scenario: "writing a batch of message based on batch byte size", function: testWriterBatchBytes, @@ -503,6 +508,67 @@ func testWriterBatchBytes(t *testing.T) { } } +func testWriterFlush(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + offset, err := readOffset(topic, 0) + if err != nil { + t.Fatal(err) + } + + w := newTestWriter(WriterConfig{ + Topic: topic, + // Set the batch timeout to a large value to avoid the timeout + BatchSize: 1000, + BatchBytes: 1000000, + BatchTimeout: 1000 * time.Second, + Balancer: &RoundRobin{}, + Async: true, + }) + defer w.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := w.WriteMessages(ctx, []Message{ + {Value: []byte("M0")}, // 25 Bytes + {Value: []byte("M1")}, // 25 Bytes + {Value: []byte("M2")}, // 25 Bytes + {Value: []byte("M3")}, // 25 Bytes + }...); err != nil { + t.Error(err) + return + } + + if err := w.Flush(ctx); err != nil { + t.Errorf("flush error %v", err) + return + } + + if w.Stats().Writes != 1 { + t.Error("didn't create expected batches") + return + } + msgs, err := readPartition(topic, 0, offset) + if err != nil { + t.Error("error reading partition", err) + return + } + + if len(msgs) != 4 { + t.Error("bad messages in partition", msgs) + return + } + + for i, m := range msgs { + if string(m.Value) == "M"+strconv.Itoa(i) { + continue + } + t.Error("bad messages in partition", string(m.Value)) + } +} + func testWriterBatchSize(t *testing.T) { topic := makeTopic() createTopic(t, topic, 1)