-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathtest_csv_document_splitter.py
336 lines (296 loc) · 12.7 KB
/
test_csv_document_splitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import logging
from pandas import read_csv
from io import StringIO
from haystack import Document, Pipeline
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.components.preprocessors.csv_document_splitter import CSVDocumentSplitter
@pytest.fixture
def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()
@pytest.fixture
def csv_with_four_rows() -> str:
return """A,B,C
1,2,3
X,Y,Z
7,8,9
"""
@pytest.fixture
def two_tables_sep_by_two_empty_rows() -> str:
return """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
@pytest.fixture
def three_tables_sep_by_empty_rows() -> str:
return """A,B,C
,,
1,2,3
,,
,,
X,Y,Z
7,8,9
"""
@pytest.fixture
def two_tables_sep_by_two_empty_columns() -> str:
return """A,B,,,X,Y
1,2,,,7,8
3,4,,,9,10
"""
class TestFindSplitIndices:
def test_find_split_indices_row_two_tables(
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str
) -> None:
df = read_csv(StringIO(two_tables_sep_by_two_empty_rows), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(2, 3)]
def test_find_split_indices_row_two_tables_with_empty_row(
self, splitter: CSVDocumentSplitter, three_tables_sep_by_empty_rows: str
) -> None:
df = read_csv(StringIO(three_tables_sep_by_empty_rows), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(3, 4)]
def test_find_split_indices_row_three_tables(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
,,
,,
X,Y,Z
7,8,9
,,
,,
P,Q,R
"""
df = read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="row")
assert result == [(2, 3), (6, 7)]
def test_find_split_indices_column_two_tables(
self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str
) -> None:
df = read_csv(StringIO(two_tables_sep_by_two_empty_columns), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=1, axis="column")
assert result == [(2, 3)]
def test_find_split_indices_column_two_tables_with_empty_column(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,,B,,,X,Y
1,,2,,,7,8
3,,4,,,9,10
"""
df = read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
assert result == [(3, 4)]
def test_find_split_indices_column_three_tables(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y,,,P,Q
1,2,,,7,8,,,11,12
3,4,,,9,10,,,13,14
"""
df = read_csv(StringIO(csv_content), header=None, dtype=object) # type: ignore
result = splitter._find_split_indices(df, split_threshold=2, axis="column")
assert result == [(2, 3), (6, 7)]
class TestInit:
def test_row_split_threshold_raises_error(self) -> None:
with pytest.raises(ValueError, match="row_split_threshold must be greater than 0"):
CSVDocumentSplitter(row_split_threshold=-1)
def test_column_split_threshold_raises_error(self) -> None:
with pytest.raises(ValueError, match="column_split_threshold must be greater than 0"):
CSVDocumentSplitter(column_split_threshold=-1)
def test_row_split_threshold_and_row_column_threshold_none(self) -> None:
with pytest.raises(
ValueError, match="At least one of row_split_threshold or column_split_threshold must be specified."
):
CSVDocumentSplitter(row_split_threshold=None, column_split_threshold=None)
class TestCSVDocumentSplitter:
def test_single_table_no_split(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,C
1,2,3
4,5,6
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content
assert result[0].meta == {"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0}
def test_row_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_rows: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_rows, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B,C\n1,2,3\n", "X,Y,Z\n7,8,9\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_column_split(self, splitter: CSVDocumentSplitter, two_tables_sep_by_two_empty_columns: str) -> None:
doc = Document(content=two_tables_sep_by_two_empty_columns, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 2
expected_tables = ["A,B\n1,2\n3,4\n", "X,Y\n7,8\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_recursive_split_one_level(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
,,,,,
,,,,,
P,Q,,,M,N
3,4,,,9,10
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 4
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\n", "P,Q\n3,4\n", "M,N\n9,10\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_recursive_split_two_levels(self, splitter: CSVDocumentSplitter) -> None:
csv_content = """A,B,,,X,Y
1,2,,,7,8
,,,,M,N
,,,,9,10
P,Q,,,,
3,4,,,,
"""
doc = Document(content=csv_content, id="test_id")
result = splitter.run([doc])["documents"]
assert len(result) == 3
expected_tables = ["A,B\n1,2\n", "X,Y\n7,8\nM,N\n9,10\n", "P,Q\n3,4\n"]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
]
for i, table in enumerate(result):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_csv_with_blank_lines(self, splitter: CSVDocumentSplitter) -> None:
csv_data = """ID,LeftVal,,,RightVal,Extra
1,Hello,,,World,Joined
2,StillLeft,,,StillRight,Bridge
A,B,,,C,D
E,F,,,G,H
"""
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=1)
result = splitter.run([Document(content=csv_data, id="test_id")])
docs = result["documents"]
assert len(docs) == 4
expected_tables = [
"ID,LeftVal\n1,Hello\n2,StillLeft\n",
"RightVal,Extra\nWorld,Joined\nStillRight,Bridge\n",
"A,B\nE,F\n",
"C,D\nG,H\n",
]
expected_meta = [
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 0, "split_id": 0},
{"source_id": "test_id", "row_idx_start": 0, "col_idx_start": 4, "split_id": 1},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 0, "split_id": 2},
{"source_id": "test_id", "row_idx_start": 4, "col_idx_start": 4, "split_id": 3},
]
for i, table in enumerate(docs):
assert table.content == expected_tables[i]
assert table.meta == expected_meta[i]
def test_sub_table_with_one_row(self):
splitter = CSVDocumentSplitter(row_split_threshold=1)
doc = Document(content="""A,B,C\n1,2,3\n,,\n4,5,6""")
split_result = splitter.run([doc])
assert len(split_result["documents"]) == 2
def test_threshold_no_effect(self, two_tables_sep_by_two_empty_rows: str) -> None:
splitter = CSVDocumentSplitter(row_split_threshold=3)
doc = Document(content=two_tables_sep_by_two_empty_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 1
def test_empty_input(self, splitter: CSVDocumentSplitter) -> None:
csv_content = ""
doc = Document(content=csv_content)
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == csv_content
def test_empty_documents(self, splitter: CSVDocumentSplitter) -> None:
result = splitter.run([])["documents"]
assert len(result) == 0
def test_to_dict_with_defaults(self) -> None:
splitter = CSVDocumentSplitter()
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {
"row_split_threshold": 2,
"column_split_threshold": 2,
"read_csv_kwargs": {},
"split_by_row": False,
},
}
assert config_serialized == config
def test_to_dict_non_defaults(self) -> None:
splitter = CSVDocumentSplitter(row_split_threshold=1, column_split_threshold=None, read_csv_kwargs={"sep": ";"})
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_by_row": False,
},
}
assert config_serialized == config
def test_from_dict_defaults(self) -> None:
splitter = component_from_dict(
CSVDocumentSplitter,
data={
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {},
},
name="CSVDocumentSplitter",
)
assert splitter.row_split_threshold == 2
assert splitter.column_split_threshold == 2
assert splitter.read_csv_kwargs == {}
def test_from_dict_non_defaults(self) -> None:
splitter = component_from_dict(
CSVDocumentSplitter,
data={
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
},
},
name="CSVDocumentSplitter",
)
assert splitter.row_split_threshold == 1
assert splitter.column_split_threshold is None
assert splitter.read_csv_kwargs == {"sep": ";"}
def test_split_by_row(self, csv_with_four_rows: str) -> None:
splitter = CSVDocumentSplitter(split_by_row=True)
doc = Document(content=csv_with_four_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 4
assert result[0].content == "A,B,C\n"
assert result[1].content == "1,2,3\n"
assert result[2].content == "X,Y,Z\n"
def test_split_by_row_with_empty_rows(self, caplog) -> None:
splitter = CSVDocumentSplitter(split_by_row=True, row_split_threshold=2)
doc = Document(content="""""")
with caplog.at_level(logging.WARNING):
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == ""