Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add split_by_row feature to CSVDocumentSplitter #9031

Merged
merged 13 commits into from
Mar 19, 2025
34 changes: 31 additions & 3 deletions haystack/components/preprocessors/csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -17,17 +17,21 @@
@component
class CSVDocumentSplitter:
"""
A component for splitting CSV documents into sub-tables based on empty rows and columns.
A component for splitting CSV documents into sub-tables based on split arguments.

The splitter identifies consecutive empty rows or columns that exceed a given threshold
The splitter supports two modes of operation:
- identify consecutive empty rows or columns that exceed a given threshold
and uses them as delimiters to segment the document into smaller tables.
- split each row into a separate sub-table, represented as a Document.

"""

def __init__(
self,
row_split_threshold: Optional[int] = 2,
column_split_threshold: Optional[int] = 2,
read_csv_kwargs: Optional[Dict[str, Any]] = None,
split_by_row: bool = False,
) -> None:
"""
Initializes the CSVDocumentSplitter component.
@@ -40,6 +44,9 @@ def __init__(
- `skip_blank_lines=False` to preserve blank lines
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
:param split_by_row:
If `True`, each row is treated as an individual sub-table.
Overrides `row_split_threshold` and `column_split_threshold`, if enabled.
"""
pandas_import.check()
if row_split_threshold is not None and row_split_threshold < 1:
@@ -51,9 +58,13 @@ def __init__(
if row_split_threshold is None and column_split_threshold is None:
raise ValueError("At least one of row_split_threshold or column_split_threshold must be specified.")

if split_by_row:
logger.warning("split_by_row is set to True. The other split arguments will be ignored.")

self.row_split_threshold = row_split_threshold
self.column_split_threshold = column_split_threshold
self.read_csv_kwargs = read_csv_kwargs or {}
self.split_by_row = split_by_row

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
@@ -97,7 +108,11 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
split_documents.append(document)
continue

if self.row_split_threshold is not None and self.column_split_threshold is None:
if self.split_by_row:
# each row is a separate sub-table
split_dfs = self._split_by_row_mode(df=df)

elif self.row_split_threshold is not None and self.column_split_threshold is None:
# split by rows
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
elif self.column_split_threshold is not None and self.row_split_threshold is None:
@@ -242,3 +257,16 @@ def _recursive_split(
result.append(table)

return result

def _split_by_row_mode(self, df: "pd.DataFrame") -> List["pd.DataFrame"]:
"""Split each CSV row into a separate subtable"""
try:
split_dfs = []
for idx, row in enumerate(df.itertuples(index=False)):
split_df = pd.DataFrame(row).T
split_df.index = [idx] # Set the index of the new DataFrame to idx
split_dfs.append(split_df)
return split_dfs
except Exception as e:
logger.warning("Error while splitting CSV rows to documents: {error}", error=e)
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add a parameter to the `CSVDocumentSplitter` component to convert each row of a CSV file into a separate Haystack Document.
35 changes: 34 additions & 1 deletion test/components/preprocessors/test_csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import logging
from pandas import read_csv
from io import StringIO
from haystack import Document, Pipeline
@@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()


@pytest.fixture
def csv_with_four_rows() -> str:
return """A,B,C
1,2,3
X,Y,Z
7,8,9
"""


@pytest.fixture
def two_tables_sep_by_two_empty_rows() -> str:
return """A,B,C
@@ -255,7 +265,12 @@ def test_to_dict_with_defaults(self) -> None:
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
"init_parameters": {
"row_split_threshold": 2,
"column_split_threshold": 2,
"read_csv_kwargs": {},
"split_by_row": False,
},
}
assert config_serialized == config

@@ -268,6 +283,7 @@ def test_to_dict_non_defaults(self) -> None:
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_by_row": False,
},
}
assert config_serialized == config
@@ -301,3 +317,20 @@ def test_from_dict_non_defaults(self) -> None:
assert splitter.row_split_threshold == 1
assert splitter.column_split_threshold is None
assert splitter.read_csv_kwargs == {"sep": ";"}

def test_split_by_row(self, csv_with_four_rows: str) -> None:
splitter = CSVDocumentSplitter(split_by_row=True)
doc = Document(content=csv_with_four_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 4
assert result[0].content == "A,B,C\n"
assert result[1].content == "1,2,3\n"
assert result[2].content == "X,Y,Z\n"

def test_split_by_row_with_empty_rows(self, caplog) -> None:
splitter = CSVDocumentSplitter(split_by_row=True, row_split_threshold=2)
doc = Document(content="""""")
with caplog.at_level(logging.WARNING):
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == ""
Loading
Oops, something went wrong.