Skip to content

Commit

Permalink
Preserve column order
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Feb 22, 2024
1 parent 88f8761 commit a5e73c6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
26 changes: 12 additions & 14 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,10 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
}
}

/// Determine `values` columns.
/// Determine `values` columns, which is optional in `pivot` calls.
///
/// When the optional `values` parameter is `None`, we use all remaining columns in the `DataFrame`
/// after `index` and `columns` have been excluded. When `values` is `Some`, we return a vector of
/// strings.
/// If not specified (i.e. is `None`, we use all remaining columns in the `DataFrame`)after `index`
/// and `columns` have been excluded.
fn _get_values_columns<I, S>(
df: &DataFrame,
index: &[String],
Expand All @@ -98,20 +97,19 @@ where
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>(),
None => {
let column_names = df.get_column_names_owned();
let mut column_set = PlHashSet::<String>::with_capacity(column_names.len());
let mut column_set = PlHashSet::<String>::with_capacity(index.len() + columns.len());

// Column names are always unique.
column_names.into_iter().for_each(|s| {
column_set.insert_unique_unchecked(s.to_string());
});

// Remove `index` and `column` columns.
// Hash columns we don't want to include
index.iter().chain(columns.iter()).for_each(|s| {
column_set.remove(s);
column_set.insert_unique_unchecked(s.to_owned());
});

column_set.drain().collect()
// filter out
df.get_column_names_owned()
.into_iter()
.map(|s| s.to_string())
.filter(|s| !column_set.contains(s))
.collect()
},
}
}
Expand Down
4 changes: 1 addition & 3 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def test_pivot_no_values() -> None:
}
)

# the order of the output columns is volatile
assert set(result.columns) == set(expected.columns)
assert_frame_equal(result, expected.select(result.columns))
assert_frame_equal(result, expected)


def test_pivot_list() -> None:
Expand Down

0 comments on commit a5e73c6

Please sign in to comment.