diff --git a/reciprocalspaceship/dataset.py b/reciprocalspaceship/dataset.py index 4b0afcf7..8aeb539e 100644 --- a/reciprocalspaceship/dataset.py +++ b/reciprocalspaceship/dataset.py @@ -110,28 +110,33 @@ def set_index(self, keys, **kwargs): return super().set_index(keys, **kwargs) def reset_index(self, **kwargs): + + # GH#6: Handle level argument to reset_index + columns = kwargs.get("level") + if columns is None: + columns = list(self._cache_index_dtypes.keys()) + drop = kwargs.get("drop", False) + + def _handle_cached_dtypes(dataset, columns, drop): + """Use _cache_index_dtypes to restore dtypes""" + if drop: + for key in columns: + dataset._cache_index_dtypes.pop(key) + else: + for key in columns: + dtype = dataset._cache_index_dtypes.pop(key) + dataset[key] = dataset[key].astype(dtype) + return dataset if kwargs.get("inplace", False): super().reset_index(**kwargs) - - # Cast all values to cached dtypes - if not kwargs.get("drop", False): - for key in self._cache_index_dtypes.keys(): - dtype = self._cache_index_dtypes[key] - self[key] = self[key].astype(dtype) - self._cache_index_dtypes = {} + _handle_cached_dtypes(self, columns, drop) return - else: - newdf = super().reset_index(**kwargs) - - # Cast all values to cached dtypes - if not kwargs.get("drop", False): - for key in newdf._cache_index_dtypes.keys(): - dtype = newdf._cache_index_dtypes[key] - newdf[key] = newdf[key].astype(dtype) - newdf._cache_index_dtypes = {} - return newdf + dataset = super().reset_index(**kwargs) + dataset._cache_index_dtypes = dataset._cache_index_dtypes.copy() + dataset = _handle_cached_dtypes(dataset, columns, drop) + return dataset @classmethod def from_gemmi(cls, gemmiMtz):