From c3c521a75fc21814248d6b44cce2e68118487812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alfonso=20Subiotto=20Marqu=C3=A9s?= Date: Thu, 19 Jan 2023 17:54:08 +0100 Subject: [PATCH] Fix SeekToRow on mergedRowGroupRows (#462) When a Read is performed after SeekToRow on mergedRowGroups, the rowIndex is checked against the seek index and advanced until the rowIndex == seek index. Previously, the rowIndex was not advanced in the normal read path, resulting in mistakenly dropping unread rows when advancing the rowIndex. --- merge.go | 11 +++++++--- merge_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/merge.go b/merge.go index 451ae1d3..734bb444 100644 --- a/merge.go +++ b/merge.go @@ -106,6 +106,12 @@ type mergedRowGroupRows struct { schema *Schema } +func (r *mergedRowGroupRows) readInternal(rows []Row) (int, error) { + n, err := r.merge.ReadRows(rows) + r.rowIndex += int64(n) + return n, err +} + func (r *mergedRowGroupRows) Close() (lastErr error) { r.merge.close() r.rowIndex = 0 @@ -126,14 +132,13 @@ func (r *mergedRowGroupRows) ReadRows(rows []Row) (int, error) { if n > len(rows) { n = len(rows) } - n, err := r.merge.ReadRows(rows[:n]) + n, err := r.readInternal(rows[:n]) if err != nil { return 0, err } - r.rowIndex += int64(n) } - return r.merge.ReadRows(rows) + return r.readInternal(rows) } func (r *mergedRowGroupRows) SeekToRow(rowIndex int64) error { diff --git a/merge_test.go b/merge_test.go index 16c21ea3..10e4e377 100644 --- a/merge_test.go +++ b/merge_test.go @@ -423,6 +423,67 @@ func TestMergeRowGroupsCursorsAreClosed(t *testing.T) { } } +func TestMergeRowGroupsSeekToRow(t *testing.T) { + type model struct { + A int + } + + schema := parquet.SchemaOf(model{}) + options := []parquet.RowGroupOption{ + parquet.SortingRowGroupConfig( + parquet.SortingColumns( + parquet.Ascending(schema.Columns()[0]...), + ), + ), + } + + rowGroups := make([]parquet.RowGroup, numRowGroups) + + counter := 0 + for i := range rowGroups { + rows := make([]interface{}, 0, rowsPerGroup) + for j := 0; j < rowsPerGroup; j++ { + rows = append(rows, model{A: counter}) + counter++ + } + rowGroups[i] = sortedRowGroup(options, rows...) + } + + m, err := parquet.MergeRowGroups(rowGroups, options...) + if err != nil { + t.Fatal(err) + } + + func() { + mergedRows := m.Rows() + defer mergedRows.Close() + + rbuf := make([]parquet.Row, 1) + cursor := int64(0) + for { + if err := mergedRows.SeekToRow(cursor); err != nil { + t.Fatal(err) + } + + if _, err := mergedRows.ReadRows(rbuf); err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Fatal(err) + } + v := model{} + if err := schema.Reconstruct(&v, rbuf[0]); err != nil { + t.Fatal(err) + } + if v.A != int(cursor) { + t.Fatalf("expected value %d, got %d", cursor, v.A) + } + + cursor++ + } + }() +} + func BenchmarkMergeRowGroups(b *testing.B) { for _, test := range readerTests { b.Run(test.scenario, func(b *testing.B) {