Skip to content

Commit

Permalink
Allow using new column names in dtype overwrite with pl.read_csv(); c…
Browse files Browse the repository at this point in the history
…loses #1492
  • Loading branch information
ghuls authored and ritchie46 committed Oct 24, 2021
1 parent 834214b commit a911dbe
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
62 changes: 61 additions & 1 deletion py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def read_csv(
rechunk: bool = True,
encoding: str = "utf8",
n_threads: Optional[int] = None,
dtype: Optional[Dict[str, Type["pl.DataType"]]] = None,
dtype: Optional[
Union[Dict[str, Type["pl.DataType"]], List[Type["pl.DataType"]]]
] = None,
new_columns: Optional[List[str]] = None,
use_pyarrow: bool = False,
low_memory: bool = False,
Expand Down Expand Up @@ -305,6 +307,64 @@ def read_csv(
return update_columns(df, new_columns) # type: ignore
return df # type: ignore

if new_columns and dtype and isinstance(dtype, dict):
current_columns = None

# As new column names are not available yet while parsing the CSV file, rename column names in
# dtype to old names (if possible) so they can be used during CSV parsing.
if columns:
if len(columns) < len(new_columns):
raise ValueError(
"More new colum names are specified than there are selected columns."
)

# Get column names of requested columns.
current_columns = columns[0 : len(new_columns)]
elif not has_headers:
# When there are no header, column names are autogenerated (and known).

if projection:
if columns and len(columns) < len(new_columns):
raise ValueError(
"More new colum names are specified than there are projected columns."
)
# Convert column indices from projection to 'column_1', 'column_2', ... column names.
current_columns = [
f"column_{column_idx + 1}" for column_idx in projection
]
else:
# Generate autogenerated 'column_1', 'column_2', ... column names for new column names.
current_columns = [
f"column_{column_idx}"
for column_idx in range(1, len(new_columns) + 1)
]
else:
# When a header is present, column names are not known yet.

if len(dtype) <= len(new_columns):
# If dtype dictionary contains less or same amount of values than new column names
# a list of dtypes can be created if all listed column names in dtype dictionary
# appear in the first consecutive new column names.
dtype_list = [
dtype[new_column_name]
for new_column_name in new_columns[0 : len(dtype)]
if new_column_name in dtype
]

if len(dtype_list) == len(dtype):
dtype = dtype_list

if current_columns and isinstance(dtype, dict):
new_to_current = {
new_column: current_column
for new_column, current_column in zip(new_columns, current_columns)
}
# Change new column names to current column names in dtype.
dtype = {
new_to_current.get(column_name, column_name): column_dtype
for column_name, column_dtype in dtype.items()
}

with _prepare_file_arg(file, **storage_options) as data:
df = pl.DataFrame.read_csv(
file=data,
Expand Down
37 changes: 37 additions & 0 deletions py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,43 @@ def test_partial_column_rename():
assert df.columns == ["foo", "b", "c"]


def test_column_rename_and_dtype_overwrite():
csv = """
a,b,c
1,2,3
1,2,3
"""
f = io.StringIO(csv)
df = pl.read_csv(
f,
new_columns=["A", "B", "C"],
dtype={"A": pl.Utf8, "B": pl.Int64, "C": pl.Float32},
)
assert df.dtypes == [pl.Utf8, pl.Int64, pl.Float32]

f = io.StringIO(csv)
df = pl.read_csv(
f,
columns=["a", "c"],
new_columns=["A", "C"],
dtype={"A": pl.Utf8, "C": pl.Float32},
)
assert df.dtypes == [pl.Utf8, pl.Float32]

csv = """
1,2,3
1,2,3
"""
f = io.StringIO(csv)
df = pl.read_csv(
f,
new_columns=["A", "B", "C"],
dtype={"A": pl.Utf8, "C": pl.Float32},
has_headers=False,
)
assert df.dtypes == [pl.Utf8, pl.Int64, pl.Float32]


def test_compressed_csv():
# gzip compression
csv = """
Expand Down

0 comments on commit a911dbe

Please sign in to comment.