Skip to content

Commit

Permalink
Change make_samples of RNNNet and MLPNet to speed up the conver…
Browse files Browse the repository at this point in the history
…sion of data into `torch.Dataset` (#1281)
  • Loading branch information
ostreech1997 committed Jun 6, 2023
1 parent 8e8c0f6 commit 106b989
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 32 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Inherit `SaveMixin` from `AbstractSaveable` for mypy checker ([#1261](https://github.com/tinkoff-ai/etna/pull/1261))
- Update requirements for `holidays` and `scipy`, change saving library from `pickle` to `dill` in `SaveMixin` ([#1268](https://github.com/tinkoff-ai/etna/pull/1268))
- Update requirement for `ruptures`, add requirement for `sqlalchemy` ([#1276](https://github.com/tinkoff-ai/etna/pull/1276))
-
- Optimize `make_samples` of `RNNNet` and `MLPNet` ([#1281](https://github.com/tinkoff-ai/etna/pull/1281))
-
### Fixed
- Fix `plot_backtest` and `plot_backtest_interactive` on one-step forecast ([1260](https://github.com/tinkoff-ai/etna/pull/1260))
- Fix `BaseReconciliator` to work on `pandas==1.1.5` ([#1229](https://github.com/tinkoff-ai/etna/pull/1229))
Expand Down
34 changes: 22 additions & 12 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,30 +109,34 @@ def step(self, batch: MLPBatch, *args, **kwargs): # type: ignore

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterable[dict]:
"""Make samples from segment DataFrame."""
values_real = (
df.select_dtypes(include=[np.number]).pipe(lambda x: x[[i for i in x.columns if i != "target"]]).values
)
values_target = df["target"].values
segment = df["segment"].values[0]

def _make(
values_target: np.ndarray, values_real: np.ndarray, segment: str, start_idx: int, decoder_length: int
) -> Optional[dict]:

def _make(df: pd.DataFrame, start_idx: int, decoder_length: int) -> Optional[dict]:
sample: Dict[str, Any] = {"decoder_real": list(), "decoder_target": list(), "segment": None}
total_length = len(df["target"])
total_sample_length = decoder_length

if total_sample_length + start_idx > total_length:
return None

sample["decoder_real"] = (
df.select_dtypes(include=[np.number])
.pipe(lambda x: x[[i for i in x.columns if i != "target"]])
.values[start_idx : start_idx + decoder_length]
)

target = df["target"].values[start_idx : start_idx + decoder_length].reshape(-1, 1)
sample["decoder_target"] = target
sample["segment"] = df["segment"].values[0]
sample["decoder_real"] = values_real[start_idx : start_idx + decoder_length]
sample["decoder_target"] = values_target[start_idx : start_idx + decoder_length].reshape(-1, 1)
sample["segment"] = segment
return sample

start_idx = 0
while True:
batch = _make(
df=df,
values_target=values_target,
values_real=values_real,
segment=segment,
start_idx=start_idx,
decoder_length=decoder_length,
)
Expand All @@ -142,7 +146,13 @@ def _make(df: pd.DataFrame, start_idx: int, decoder_length: int) -> Optional[dic
start_idx += decoder_length
if start_idx < len(df):
resid_length = len(df) - decoder_length
batch = _make(df=df, start_idx=resid_length, decoder_length=decoder_length)
batch = _make(
values_target=values_target,
values_real=values_real,
segment=segment,
start_idx=resid_length,
decoder_length=decoder_length,
)
if batch is not None:
yield batch

Expand Down
45 changes: 26 additions & 19 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,52 +131,59 @@ def step(self, batch: RNNBatch, *args, **kwargs): # type: ignore

def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterator[dict]:
"""Make samples from segment DataFrame."""
values_real = (
df.select_dtypes(include=[np.number])
.assign(target_shifted=df["target"].shift(1))
.drop(["target"], axis=1)
.pipe(lambda x: x[["target_shifted"] + [i for i in x.columns if i != "target_shifted"]])
.values
)
values_target = df["target"].values
segment = df["segment"].values[0]

def _make(
values_real: np.ndarray,
values_target: np.ndarray,
segment: str,
start_idx: int,
encoder_length: int,
decoder_length: int,
) -> Optional[dict]:

def _make(df: pd.DataFrame, start_idx: int, encoder_length: int, decoder_length: int) -> Optional[dict]:
sample: Dict[str, Any] = {
"encoder_real": list(),
"decoder_real": list(),
"encoder_target": list(),
"decoder_target": list(),
"segment": None,
}
total_length = len(df["target"])
total_length = len(values_target)
total_sample_length = encoder_length + decoder_length

if total_sample_length + start_idx > total_length:
return None

# Get shifted target and concatenate it with real values features
sample["decoder_real"] = (
df.select_dtypes(include=[np.number])
.pipe(lambda x: x[["target"] + [i for i in x.columns if i != "target"]])
.values[start_idx + encoder_length : start_idx + encoder_length + decoder_length]
)
sample["decoder_real"][:, 0] = (
df["target"].shift(1).values[start_idx + encoder_length : start_idx + encoder_length + decoder_length]
)
sample["decoder_real"] = values_real[start_idx + encoder_length : start_idx + total_sample_length]

# Get shifted target and concatenate it with real values features
sample["encoder_real"] = (
df.select_dtypes(include=[np.number])
.pipe(lambda x: x[["target"] + [i for i in x.columns if i != "target"]])
.values[start_idx : start_idx + encoder_length]
)
sample["encoder_real"][:, 0] = df["target"].shift(1).values[start_idx : start_idx + encoder_length]
sample["encoder_real"] = values_real[start_idx : start_idx + encoder_length]
sample["encoder_real"] = sample["encoder_real"][1:]

target = df["target"].values[start_idx : start_idx + encoder_length + decoder_length].reshape(-1, 1)
target = values_target[start_idx : start_idx + encoder_length + decoder_length].reshape(-1, 1)
sample["encoder_target"] = target[1:encoder_length]
sample["decoder_target"] = target[encoder_length:]

sample["segment"] = df["segment"].values[0]
sample["segment"] = segment

return sample

start_idx = 0
while True:
batch = _make(
df=df,
values_target=values_target,
values_real=values_real,
segment=segment,
start_idx=start_idx,
encoder_length=encoder_length,
decoder_length=decoder_length,
Expand Down

1 comment on commit 106b989

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.