Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add split_by_row feature to CSVDocumentSplitter #9031

Merged
merged 13 commits into from
Mar 19, 2025
40 changes: 36 additions & 4 deletions haystack/components/preprocessors/csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from io import StringIO
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple, get_args

from haystack import Document, component, logging
from haystack.lazy_imports import LazyImport
@@ -13,18 +13,24 @@

logger = logging.getLogger(__name__)

SplitMode = Literal["threshold", "row-wise"]


@component
class CSVDocumentSplitter:
"""
A component for splitting CSV documents into sub-tables based on empty rows and columns.
A component for splitting CSV documents into sub-tables based on split arguments.

The splitter identifies consecutive empty rows or columns that exceed a given threshold
The splitter supports two modes of operation:
- identify consecutive empty rows or columns that exceed a given threshold
and uses them as delimiters to segment the document into smaller tables.
- split each row into a separate sub-table, represented as a Document.

"""

def __init__(
self,
split_mode: SplitMode = "threshold",
row_split_threshold: Optional[int] = 2,
column_split_threshold: Optional[int] = 2,
read_csv_kwargs: Optional[Dict[str, Any]] = None,
@@ -40,8 +46,16 @@ def __init__(
- `skip_blank_lines=False` to preserve blank lines
- `dtype=object` to prevent type inference (e.g., converting numbers to floats).
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for more information.
:param split_mode:
If `threshold`, the component will split the document based on the number of
consecutive empty rows or columns that exceed the `row_split_threshold` or `column_split_threshold`.
If `row-wise`, the component will split each row into a separate sub-table.
"""
pandas_import.check()
if split_mode not in get_args(SplitMode):
raise ValueError(
f"Split mode '{split_mode}' not recognized. Choose one among: {', '.join(get_args(SplitMode))}."
)
if row_split_threshold is not None and row_split_threshold < 1:
raise ValueError("row_split_threshold must be greater than 0")

@@ -54,6 +68,7 @@ def __init__(
self.row_split_threshold = row_split_threshold
self.column_split_threshold = column_split_threshold
self.read_csv_kwargs = read_csv_kwargs or {}
self.split_mode = split_mode

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
@@ -97,7 +112,11 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
split_documents.append(document)
continue

if self.row_split_threshold is not None and self.column_split_threshold is None:
if self.split_mode == "row-wise":
# each row is a separate sub-table
split_dfs = self._split_by_row(df=df)

elif self.row_split_threshold is not None and self.column_split_threshold is None:
# split by rows
split_dfs = self._split_dataframe(df=df, split_threshold=self.row_split_threshold, axis="row")
elif self.column_split_threshold is not None and self.row_split_threshold is None:
@@ -242,3 +261,16 @@ def _recursive_split(
result.append(table)

return result

def _split_by_row(self, df: "pd.DataFrame") -> List["pd.DataFrame"]:
"""Split each CSV row into a separate subtable"""
try:
split_dfs = []
for idx, row in enumerate(df.itertuples(index=False)):
split_df = pd.DataFrame(row).T
split_df.index = [idx] # Set the index of the new DataFrame to idx
split_dfs.append(split_df)
return split_dfs
except Exception as e:
logger.warning("Error while splitting CSV rows to documents: {error}", error=e)
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added a new parameter `split_mode` to the `CSVDocumentSplitter` component to control the splitting mode.
The new parameter can be set to `row-wise` to split the CSV file by rows.
The default value is `threshold`, which is the previous behavior.
40 changes: 39 additions & 1 deletion test/components/preprocessors/test_csv_document_splitter.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import logging
from pandas import read_csv
from io import StringIO
from haystack import Document, Pipeline
@@ -15,6 +16,15 @@ def splitter() -> CSVDocumentSplitter:
return CSVDocumentSplitter()


@pytest.fixture
def csv_with_four_rows() -> str:
return """A,B,C
1,2,3
X,Y,Z
7,8,9
"""


@pytest.fixture
def two_tables_sep_by_two_empty_rows() -> str:
return """A,B,C
@@ -255,7 +265,12 @@ def test_to_dict_with_defaults(self) -> None:
config_serialized = component_to_dict(splitter, name="CSVDocumentSplitter")
config = {
"type": "haystack.components.preprocessors.csv_document_splitter.CSVDocumentSplitter",
"init_parameters": {"row_split_threshold": 2, "column_split_threshold": 2, "read_csv_kwargs": {}},
"init_parameters": {
"row_split_threshold": 2,
"column_split_threshold": 2,
"read_csv_kwargs": {},
"split_mode": "threshold",
},
}
assert config_serialized == config

@@ -268,6 +283,7 @@ def test_to_dict_non_defaults(self) -> None:
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_mode": "threshold",
},
}
assert config_serialized == config
@@ -294,10 +310,32 @@ def test_from_dict_non_defaults(self) -> None:
"row_split_threshold": 1,
"column_split_threshold": None,
"read_csv_kwargs": {"sep": ";"},
"split_mode": "threshold",
},
},
name="CSVDocumentSplitter",
)
assert splitter.row_split_threshold == 1
assert splitter.column_split_threshold is None
assert splitter.read_csv_kwargs == {"sep": ";"}

def test_split_by_row(self, csv_with_four_rows: str) -> None:
splitter = CSVDocumentSplitter(split_mode="row-wise")
doc = Document(content=csv_with_four_rows)
result = splitter.run([doc])["documents"]
assert len(result) == 4
assert result[0].content == "A,B,C\n"
assert result[1].content == "1,2,3\n"
assert result[2].content == "X,Y,Z\n"

def test_split_by_row_with_empty_rows(self, caplog) -> None:
splitter = CSVDocumentSplitter(split_mode="row-wise")
doc = Document(content="")
with caplog.at_level(logging.ERROR):
result = splitter.run([doc])["documents"]
assert len(result) == 1
assert result[0].content == ""

def test_incorrect_split_mode(self) -> None:
with pytest.raises(ValueError, match="not recognized"):
CSVDocumentSplitter(split_mode="incorrect_mode")
Loading
Oops, something went wrong.