Skip to content

Commit

Permalink
Add flush method to writer
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacd9 committed Mar 7, 2024
1 parent b2b17ac commit c3edd76
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 6 deletions.
32 changes: 31 additions & 1 deletion writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,33 @@ func (w *Writer) spawn(f func()) {
}()
}

// Flush writes all currently buffered messages to the kafka cluster. This will
// block until all messages in the batch has been written to kafka, or until the
// context is canceled.
func (w *Writer) Flush(ctx context.Context) error {
w.mutex.Lock()

// flush all writers
for _, writer := range w.writers {
writer.flush()
}

w.mutex.Unlock()
done := make(chan struct{})

go func() {
w.group.Wait()
close(done)
}()

select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

// Close flushes pending writes, and waits for all writes to complete before
// returning. Calling Close also prevents new writes from being submitted to
// the writer, further calls to WriteMessages and the like will fail with
Expand Down Expand Up @@ -1184,7 +1211,7 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) {
batch.complete(err)
}

func (ptw *partitionWriter) close() {
func (ptw *partitionWriter) flush() {
ptw.mutex.Lock()
defer ptw.mutex.Unlock()

Expand All @@ -1194,7 +1221,10 @@ func (ptw *partitionWriter) close() {
ptw.currBatch = nil
batch.trigger()
}
}

func (ptw *partitionWriter) close() {
ptw.flush()
ptw.queue.Close()
}

Expand Down
73 changes: 68 additions & 5 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func TestWriter(t *testing.T) {
scenario: "closing a writer right after creating it returns promptly with no error",
function: testWriterClose,
},

{
scenario: "writing 1 message through a writer using round-robin balancing produces 1 message to the first partition",
function: testWriterRoundRobin1,
Expand All @@ -130,6 +129,10 @@ func TestWriter(t *testing.T) {
scenario: "writing a batch of messages",
function: testWriterBatchSize,
},
{
scenario: "writing and flushing a batch of messages",
function: testWriterBatchSize,
},

{
scenario: "writing messages with a small batch byte size",
Expand Down Expand Up @@ -450,7 +453,7 @@ func readPartition(topic string, partition int, offset int64) (msgs []Message, e
}
}

func testWriterBatchBytes(t *testing.T) {
func tetsWriterFlush(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)
Expand All @@ -461,9 +464,11 @@ func testWriterBatchBytes(t *testing.T) {
}

w := newTestWriter(WriterConfig{
Topic: topic,
BatchBytes: 50,
BatchTimeout: math.MaxInt32 * time.Second,
Topic: topic,
// Set the batch timeout to a large value to avoid the timeout
BatchSize: 1000,
BatchBytes: 1000000,
BatchTimeout: 1000 * time.Second,
Balancer: &RoundRobin{},
})
defer w.Close()
Expand All @@ -480,6 +485,11 @@ func testWriterBatchBytes(t *testing.T) {
return
}

if err := w.Flush(ctx); err != nil {
t.Errorf("flush error %v", err)
return
}

if w.Stats().Writes != 2 {
t.Error("didn't create expected batches")
return
Expand All @@ -503,6 +513,59 @@ func testWriterBatchBytes(t *testing.T) {
}
}

func testWriterBatchBytes(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)

offset, err := readOffset(topic, 0)
if err != nil {
t.Fatal(err)
}

w := newTestWriter(WriterConfig{
Topic: topic,
BatchBytes: 50,
BatchTimeout: math.MaxInt32 * time.Second,
Balancer: &RoundRobin{},
})
defer w.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := w.WriteMessages(ctx, []Message{
{Value: []byte("M0")},
{Value: []byte("M1")},
{Value: []byte("M2")},
{Value: []byte("M3")},
}...); err != nil {
t.Error(err)
return
}

if w.Stats().Writes != 1 {
t.Error("didn't create expected batches")
return
}
msgs, err := readPartition(topic, 0, offset)
if err != nil {
t.Error("error reading partition", err)
return
}

if len(msgs) != 4 {
t.Error("bad messages in partition", msgs)
return
}

for i, m := range msgs {
if string(m.Value) == "M"+strconv.Itoa(i) {
continue
}
t.Error("bad messages in partition", string(m.Value))
}
}

func testWriterBatchSize(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
Expand Down

0 comments on commit c3edd76

Please sign in to comment.