Skip to content

Commit

Permalink
Backport PR #51803 on branch 2.0.x (CoW: Add reference tracking to in…
Browse files Browse the repository at this point in the history
…dex when created from series) (#52000)
  • Loading branch information
phofl committed Mar 16, 2023
1 parent 6b6c336 commit fbc660b
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 15 deletions.
10 changes: 10 additions & 0 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,16 @@ cdef class BlockValuesRefs:
"""
self.referenced_blocks.append(weakref.ref(blk))

def add_index_reference(self, index: object) -> None:
"""Adds a new reference to our reference collection when creating an index.

Parameters
----------
index: object
The index that the new reference should point to.
"""
self.referenced_blocks.append(weakref.ref(index))

def has_reference(self) -> bool:
"""Checks if block has foreign references.

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5879,7 +5879,7 @@ def set_index(
names.append(None)
# from here, col can only be a column label
else:
arrays.append(frame[col]._values)
arrays.append(frame[col])
names.append(col)
if drop:
to_remove.append(col)
Expand Down
46 changes: 36 additions & 10 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@
rewrite_exception,
)

from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.astype import (
astype_array,
astype_is_view,
)
from pandas.core.dtypes.cast import (
LossySetitemError,
can_hold_element,
Expand Down Expand Up @@ -457,6 +460,8 @@ def _engine_type(

str = CachedAccessor("str", StringMethods)

_references = None

# --------------------------------------------------------------------
# Constructors

Expand All @@ -477,6 +482,10 @@ def __new__(

data_dtype = getattr(data, "dtype", None)

refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references

# range
if isinstance(data, (range, RangeIndex)):
result = RangeIndex(start=data, copy=copy, name=name)
Expand Down Expand Up @@ -550,7 +559,7 @@ def __new__(
klass = cls._dtype_to_subclass(arr.dtype)

arr = klass._ensure_array(arr, arr.dtype, copy=False)
return klass._simple_new(arr, name)
return klass._simple_new(arr, name, refs=refs)

@classmethod
def _ensure_array(cls, data, dtype, copy: bool):
Expand Down Expand Up @@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):

@classmethod
def _simple_new(
cls: type[_IndexT], values: ArrayLike, name: Hashable = None
cls: type[_IndexT], values: ArrayLike, name: Hashable = None, refs=None
) -> _IndexT:
"""
We require that we have a dtype compat for the values. If we are passed
Expand All @@ -644,6 +653,9 @@ def _simple_new(
result._name = name
result._cache = {}
result._reset_identity()
result._references = refs
if refs is not None:
refs.add_index_reference(result)

return result

Expand Down Expand Up @@ -740,13 +752,13 @@ def _shallow_copy(self: _IndexT, values, name: Hashable = no_default) -> _IndexT
"""
name = self._name if name is no_default else name

return self._simple_new(values, name=name)
return self._simple_new(values, name=name, refs=self._references)

def _view(self: _IndexT) -> _IndexT:
"""
fastpath to make a shallow copy, i.e. new object with same data.
"""
result = self._simple_new(self._values, name=self._name)
result = self._simple_new(self._values, name=self._name, refs=self._references)

result._cache = self._cache
return result
Expand Down Expand Up @@ -956,7 +968,7 @@ def view(self, cls=None):
# of types.
arr_cls = idx_cls._data_cls
arr = arr_cls(self._data.view("i8"), dtype=dtype)
return idx_cls._simple_new(arr, name=self.name)
return idx_cls._simple_new(arr, name=self.name, refs=self._references)

result = self._data.view(cls)
else:
Expand Down Expand Up @@ -1012,7 +1024,15 @@ def astype(self, dtype, copy: bool = True):
new_values = astype_array(values, dtype=dtype, copy=copy)

# pass copy=False because any copying will be done in the astype above
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
if (
not copy
and self._references is not None
and astype_is_view(self.dtype, dtype)
):
result._references = self._references
result._references.add_index_reference(result)
return result

_index_shared_docs[
"take"
Expand Down Expand Up @@ -5155,7 +5175,9 @@ def __getitem__(self, key):
# pessimization com.is_bool_indexer and ndim checks.
result = getitem(key)
# Going through simple_new for performance.
return type(self)._simple_new(result, name=self._name)
return type(self)._simple_new(
result, name=self._name, refs=self._references
)

if com.is_bool_indexer(key):
# if we have list[bools, length=1e5] then doing this check+convert
Expand All @@ -5181,7 +5203,7 @@ def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT:
Fastpath for __getitem__ when we know we have a slice.
"""
res = self._data[slobj]
return type(self)._simple_new(res, name=self._name)
return type(self)._simple_new(res, name=self._name, refs=self._references)

@final
def _can_hold_identifiers_and_holds_name(self, name) -> bool:
Expand Down Expand Up @@ -6700,7 +6722,11 @@ def infer_objects(self, copy: bool = True) -> Index:
)
if copy and res_values is values:
return self.copy()
return Index(res_values, name=self.name)
result = Index(res_values, name=self.name)
if not copy and res_values is values and self._references is not None:
result._references = self._references
result._references.add_index_reference(result)
return result

# --------------------------------------------------------------------
# Generated Arithmetic, Comparison, and Unary Methods
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_datetime64tz_dtype,
is_scalar,
)
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.datetimes import (
Expand Down Expand Up @@ -266,7 +267,7 @@ def strftime(self, date_format) -> Index:
@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> DatetimeIndex:
arr = self._data.tz_convert(tz)
return type(self)._simple_new(arr, name=self.name)
return type(self)._simple_new(arr, name=self.name, refs=self._references)

@doc(DatetimeArray.tz_localize)
def tz_localize(
Expand Down Expand Up @@ -345,8 +346,11 @@ def __new__(
yearfirst=yearfirst,
ambiguous=ambiguous,
)
refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

subarr = cls._simple_new(dtarr, name=name)
subarr = cls._simple_new(dtarr, name=name, refs=refs)
return subarr

# --------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __new__(
result._codes = new_codes

result._reset_identity()
result._references = None

return result

Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pandas.core.dtypes.common import is_integer
from pandas.core.dtypes.dtypes import PeriodDtype
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.missing import is_valid_na_for_dtype

from pandas.core.arrays.period import (
Expand Down Expand Up @@ -217,6 +218,10 @@ def __new__(
"second",
}

refs = None
if not copy and isinstance(data, (Index, ABCSeries)):
refs = data._references

if not set(fields).issubset(valid_field_set):
argument = list(set(fields) - valid_field_set)[0]
raise TypeError(f"__new__() got an unexpected keyword argument {argument}")
Expand Down Expand Up @@ -257,7 +262,7 @@ def __new__(
if copy:
data = data.copy()

return cls._simple_new(data, name=name)
return cls._simple_new(data, name=name, refs=refs)

# ------------------------------------------------------------------------
# Data
Expand Down
1 change: 1 addition & 0 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def _simple_new( # type: ignore[override]
result._name = name
result._cache = {}
result._reset_identity()
result._references = None
return result

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_scalar,
is_timedelta64_dtype,
)
from pandas.core.dtypes.generic import ABCSeries

from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays.timedeltas import TimedeltaArray
Expand Down Expand Up @@ -168,7 +169,11 @@ def __new__(
tdarr = TimedeltaArray._from_sequence_not_strict(
data, freq=freq, unit=unit, dtype=dtype, copy=copy
)
return cls._simple_new(tdarr, name=name)
refs = None
if not copy and isinstance(data, (ABCSeries, Index)):
refs = data._references

return cls._simple_new(tdarr, name=name, refs=refs)

# -------------------------------------------------------------------

Expand Down
Empty file.
56 changes: 56 additions & 0 deletions pandas/tests/copy_view/index/test_datetimeindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from pandas import (
DatetimeIndex,
Series,
Timestamp,
date_range,
)
import pandas._testing as tm


@pytest.mark.parametrize(
"cons",
[
lambda x: DatetimeIndex(x),
lambda x: DatetimeIndex(DatetimeIndex(x)),
],
)
def test_datetimeindex(using_copy_on_write, cons):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = cons(ser)
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_convert(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_convert("US/Eastern")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_tz_localize(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
idx = DatetimeIndex(ser).tz_localize("Europe/Berlin")
expected = idx.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(idx, expected)


def test_datetimeindex_isocalendar(using_copy_on_write):
dt = date_range("2019-12-31", periods=3, freq="D")
ser = Series(dt)
df = DatetimeIndex(ser).isocalendar()
expected = df.index.copy(deep=True)
ser.iloc[0] = Timestamp("2020-12-31")
if using_copy_on_write:
tm.assert_index_equal(df.index, expected)

0 comments on commit fbc660b

Please sign in to comment.