Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove dataframe.columns name to avoid error with xarray #1938

Merged
merged 9 commits into from
Aug 9, 2023
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

[Full Changelog](https://github.com/unit8co/darts/compare/0.25.0...master)

### For users of the library:

**Fixed**
- Fixed a bug in `TimeSeries.from_dataframe()` when using a pandas.DataFrame with `df.columns.name != None`. [#1938](https://github.com/unit8co/darts/pull/1938) by [Antoine Madrona](https://github.com/madtoinou).


## [0.25.0](https://github.com/unit8co/darts/tree/0.25.0) (2023-08-04)
### For users of the library:

Expand Down
28 changes: 27 additions & 1 deletion darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from darts import TimeSeries, concatenate
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.timeseries_generation import (
constant_timeseries,
generate_index,
linear_timeseries,
)


class TimeSeriesTestCase(DartsBaseTestClass):
Expand Down Expand Up @@ -2202,6 +2206,28 @@ def test_time_col_convert_garbage(self):
with self.assertRaises(AttributeError):
TimeSeries.from_dataframe(df=df, time_col="Time")

def test_df_named_columns_index(self):
time_index = generate_index(
start=pd.Timestamp("2000-01-01"), length=4, freq="D", name="index"
)
df = pd.DataFrame(
data=np.arange(4),
index=time_index,
columns=["y"],
)
df.columns.name = "id"
ts = TimeSeries.from_dataframe(df)

exp_ts = TimeSeries.from_times_and_values(
times=time_index,
values=np.arange(4),
columns=["y"],
)
# check that series are exactly identical
self.assertEqual(ts, exp_ts)
# check that the original df was not changed
self.assertEqual(df.columns.name, "id")


class SimpleStatisticsTestCase(DartsBaseTestClass):

Expand Down
3 changes: 3 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,9 @@ def from_dataframe(
if not time_index.name:
time_index.name = time_col if time_col else DIMS[0]

if series_df.columns.name:
series_df.columns.name = None
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

xa = xr.DataArray(
series_df.values[:, :, np.newaxis],
dims=(time_index.name,) + DIMS[-2:],
Expand Down