Skip to content

Commit

Permalink
Improve the speed of from_dataframe with a MultiIndex
Browse files Browse the repository at this point in the history
Fixes pydataGH-2459

Before:

    pandas.MultiIndexSeries.time_to_xarray
    ======= ========= ==========
    --             subset
    ------- --------------------
    dtype     True     False
    ======= ========= ==========
      int    505±0ms   37.1±0ms
     float   485±0ms   38.3±0ms
    ======= ========= ==========

After:

    pandas.MultiIndexSeries.time_to_xarray
    ======= ========= ==========
    --             subset
    ------- --------------------
    dtype     True     False
    ======= ========= ==========
      int    11.5±0ms   39.2±0ms
     float   12.5±0ms   26.6±0ms
    ======= ========= ==========

There are still some cases where we have to fall back to the existing
slow implementation, but hopefully they should now be relatively rare.
  • Loading branch information
shoyer committed Jun 26, 2020
1 parent c340961 commit aa5a7dd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
46 changes: 39 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4582,17 +4582,49 @@ def _set_sparse_data_from_dataframe(
def _set_numpy_data_from_dataframe(
self, dataframe: pd.DataFrame, dims: tuple
) -> None:

idx = dataframe.index
if isinstance(idx, pd.MultiIndex):
# expand the DataFrame to include the product of all levels

if not isinstance(idx, pd.MultiIndex):
for name, series in dataframe.items():
self[name] = (dims, np.asarray(series))
return

shape = tuple(lev.size for lev in idx.levels)

# all elements in the result index a unique combination of MultiIndex
# levels, so if there are more of them than elements in the source,
# then we *must* be inserting missing values
definitely_inserting_na = np.prod(shape) > dataframe.shape[0]

if definitely_inserting_na or all(
issubclass(dtype.type, dtypes.TYPES_WITH_NA) for dtype in dataframe.dtypes
):
full_indexer = tuple(idx.codes)
for name, series in dataframe.items():
data = np.asarray(series)
dtype, fill_value = dtypes.maybe_promote(data.dtype)
# much faster than reindex:
# https://stackoverflow.com/a/35049899/809705
new_data = np.full(shape, fill_value, dtype)
new_data[full_indexer] = data
self[name] = (dims, new_data)
else:
# It can be very expensive to use get_indexer/reindex to check
# whether all values are found in a MultiIndex:
# https://github.com/pydata/xarray/issues/2459

# Unfortunately, we sometimes need to do this in order to return
# the correct dtype for columns that don't support NA: if there are
# no missing values, then the dtype should be preserved and we
# cannot insert a fill value.

full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names)
dataframe = dataframe.reindex(full_idx)
shape = tuple(lev.size for lev in idx.levels)
else:
shape = (idx.size,)
for name, series in dataframe.items():
data = np.asarray(series).reshape(shape)
self[name] = (dims, data)
for name, series in dataframe.items():
data = np.asarray(series).reshape(shape)
self[name] = (dims, data)

@classmethod
def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Dataset":
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __eq__(self, other):
]


TYPES_WITH_NA = (np.inexact, np.timedelta64, np.datetime64)


def maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote
Expand Down

0 comments on commit aa5a7dd

Please sign in to comment.