Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,13 +2975,12 @@ def __bool__(self) -> NoReturn:
def _get_reconciled_name_object(self, other):
"""
If the result of a set operation will be self,
return self, unless the name changes, in which
case make a shallow copy of self.
return a shallow copy of self.
"""
name = get_op_result_name(self, other)
if self.name is not name:
return self.rename(name)
return self
return self.copy()

@final
def _validate_sort_keyword(self, sort) -> None:
Expand Down
5 changes: 2 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4082,13 +4082,12 @@ def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
def _get_reconciled_name_object(self, other) -> MultiIndex:
"""
If the result of a set operation will be self,
return self, unless the names change, in which
case make a shallow copy of self.
return a shallow copy of self.
"""
names = self._maybe_match_names(other)
if self.names != names:
return self.rename(names)
return self
return self.copy()

def _maybe_match_names(self, other):
"""
Expand Down
87 changes: 82 additions & 5 deletions pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def test_intersection(self, index, sort):

# Corner cases
inter = first.intersection(first, sort=sort)
assert inter is first
assert inter is not first

@pytest.mark.parametrize(
"index2_name,keeps_name",
Expand Down Expand Up @@ -812,16 +812,16 @@ def test_union_identity(self, index, sort):
first = index[5:20]

union = first.union(first, sort=sort)
# i.e. identity is not preserved when sort is True
assert (union is first) is (not sort)
# GH#63169 - identity is not preserved to prevent shared mutable state
assert union is not first

# This should no longer be the same object, since [] is not consistent,
# both objects will be recast to dtype('O')
union = first.union(Index([], dtype=first.dtype), sort=sort)
assert (union is first) is (not sort)
assert union is not first

union = Index([], dtype=first.dtype).union(first, sort=sort)
assert (union is first) is (not sort)
assert union is not first

@pytest.mark.parametrize("index", ["string"], indirect=True)
@pytest.mark.parametrize("second_name,expected", [(None, None), ("name", "name")])
Expand Down Expand Up @@ -984,3 +984,80 @@ def test_union_pyarrow_timestamp(self):
res = left.union(right)
expected = Index(["2020-01-01", "2020-01-02"], dtype=left.dtype)
tm.assert_index_equal(res, expected)


class TestSetOpsMutation:
def test_intersection_mutation_safety(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.intersection(index2)

assert result is not index1
assert result is not index2

tm.assert_index_equal(result, index1)
assert result.name == "original"

index1.name = "changed"

assert result.name == "original"
assert index1.name == "changed"

def test_union_mutation_safety(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.union(index2)

assert result is not index1
assert result is not index2

tm.assert_index_equal(result, index1)
assert result.name == "original"

index1.name = "changed"

assert result.name == "original"
assert index1.name == "changed"

def test_union_mutation_safety_other(self):
# GH#63169
index1 = Index([0, 1], name="original")
index2 = Index([0, 1], name="original")

result = index1.union(index2)

assert result is not index2

tm.assert_index_equal(result, index2)
assert result.name == "original"

index2.name = "changed"

assert result.name == "original"
assert index2.name == "changed"

def test_multiindex_intersection_mutation_safety(self):
# GH#63169
mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])
mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])

result = mi1.intersection(mi2)
assert result is not mi1

mi1.names = ["changed1", "changed2"]
assert result.names == ["x", "y"]

def test_multiindex_union_mutation_safety(self):
# GH#63169
mi1 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])
mi2 = MultiIndex.from_tuples([("a", 1), ("b", 2)], names=["x", "y"])

result = mi1.union(mi2)
assert result is not mi1

mi1.names = ["changed1", "changed2"]
assert result.names == ["x", "y"]
4 changes: 2 additions & 2 deletions pandas/tests/indexes/timedeltas/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_intersection_bug_1708(self):

def test_intersection_equal(self, sort):
# GH 24471 Test intersection outcome given the sort keyword
# for equal indices intersection should return the original index
# GH#63169 intersection returns a copy to prevent shared mutable state
first = timedelta_range("1 day", periods=4, freq="h")
second = timedelta_range("1 day", periods=4, freq="h")
intersect = first.intersection(second, sort=sort)
Expand All @@ -124,7 +124,7 @@ def test_intersection_equal(self, sort):

# Corner cases
inter = first.intersection(first, sort=sort)
assert inter is first
assert inter is not first

@pytest.mark.parametrize("period_1, period_2", [(0, 4), (4, 0)])
def test_intersection_zero_length(self, period_1, period_2, sort):
Expand Down
Loading