Skip to content

Commit

Permalink
check for duplicates in dataframe renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 5, 2021
1 parent 2091863 commit 807209e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
20 changes: 18 additions & 2 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! DataFrame module.
use std::borrow::Cow;
use std::collections::HashSet;
use std::iter::Iterator;
use std::iter::{FromIterator, Iterator};
use std::mem;
use std::sync::Arc;

use ahash::RandomState;
use ahash::{AHashSet, RandomState};
use arrow::record_batch::RecordBatch;
use itertools::Itertools;
use rayon::prelude::*;
Expand Down Expand Up @@ -478,6 +478,14 @@ impl DataFrame {
if names.len() != self.columns.len() {
return Err(PolarsError::ShapeMisMatch("the provided slice with column names has not the same size as the DataFrame's width".into()));
}
let unique_names: AHashSet<&str, ahash::RandomState> =
AHashSet::from_iter(names.iter().map(|name| name.as_ref()));
if unique_names.len() != self.columns.len() {
return Err(PolarsError::SchemaMisMatch(
"duplicate column names found".into(),
));
}

let columns = mem::take(&mut self.columns);
self.columns = columns
.into_iter()
Expand Down Expand Up @@ -1447,6 +1455,14 @@ impl DataFrame {
self.select_mut(column)
.ok_or_else(|| PolarsError::NotFound(name.into()))
.map(|s| s.rename(name))?;

let unique_names: AHashSet<&str, ahash::RandomState> =
AHashSet::from_iter(self.columns.iter().map(|s| s.name()));
if unique_names.len() != self.columns.len() {
return Err(PolarsError::SchemaMisMatch(
"duplicate column names found".into(),
));
}
Ok(self)
}

Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,3 +1316,12 @@ def test_schema() -> None:
)
expected = {"foo": pl.Int64, "bar": pl.Float64, "ham": pl.Utf8}
assert df.schema == expected


def test_df_schema_unique() -> None:
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(Exception):
df.columns = ["a", "a"]

with pytest.raises(Exception):
df.rename({"b": "a"})

0 comments on commit 807209e

Please sign in to comment.