Skip to content

Commit

Permalink
apply suggestions from PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Jun 17, 2024
1 parent 4e193bb commit a1fd48c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 170 deletions.
52 changes: 25 additions & 27 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,16 +726,15 @@ def fit(
By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory")
for seamless forecasting. Changing them should be done with care to avoid unexpected behavior.
sample_weight
Optionally, some sample weights to apply to the target `series` labels.
They are applied per observation, per label (each step in `output_chunk_length`), and per component.
Optionally, some sample weights to apply to the target `series` labels. They are applied per observation,
per label (each step in `output_chunk_length`), and per component.
If a series or sequence of series, then those weights are used. If the weight series only have a single
component / column, then the weights are applied globally to all components in `series`. Otherwise, for
component-specific weights, the number of components must match those of `series`.
If a string, then the weights are generated using built-in weighting functions. The available options are
`"linear_decay"` or `"exponential_decay"`. The weights are only computed the longest series in `series`,
and then applied globally to all `series` to have a common time weighting.
If a `TimeSeries` or `Sequence[TimeSeries]`, then those weights are used. The number of series must
match the number of target `series` and each series must contain at least all time steps from the
corresponding target `series`. If the weight series only have a single component / column, then the weights
are applied globally to all components in `series`. Otherwise, for component-specific weights, the number
of components must match those of `series`.
`"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are
computed globally based on the length of the longest series in `series`. Then for each series, the weights
are extracted from the end of the global weights. This gives a common time weighting across all series.
val_sample_weight
Same as for `sample_weight` but for the evaluation dataset.
Expand Down Expand Up @@ -976,7 +975,7 @@ def _setup_for_train(
"""
self._verify_train_dataset_type(train_dataset)

# Pro-actively catch length exceptions to display nicer messages
# proactively catch length exceptions to display nicer messages
train_length_ok, val_length_ok = True, True
try:
len(train_dataset)
Expand All @@ -1000,16 +999,16 @@ def _setup_for_train(
)

train_sample = train_dataset[0]
# ignore sample weights [-2] for model dimensions
train_sample_no_weight = train_sample[:-2] + train_sample[-1:]
if self.model is None:
# Build model, based on the dimensions of the first series in the train set.
# the last two elements are sample weights and future target
# build model based on the dimensions of the first series in the train set.
self.train_sample = train_sample_no_weight
self.output_dim = train_sample[-1].shape[1]
model = self._init_model(trainer)
else:
model = self.model
# Check existing model has input/output dims matching what's provided in the training set.
# check existing model has input/output dims matching what's provided in the training set.
raise_if_not(
len(train_sample_no_weight) == len(self.train_sample),
"The size of the training set samples (tuples) does not match what the model has been"
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def _setup_for_train(
logger=logger,
)

# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# least one batch no matter the chosen batch size
dataloader_kwargs = dict(
{
Expand All @@ -1089,7 +1088,7 @@ def _setup_for_train(
**dataloader_kwargs,
)

# Prepare validation data
# prepare validation data
dataloader_kwargs["shuffle"] = False
val_loader = (
None
Expand Down Expand Up @@ -1221,16 +1220,15 @@ def lr_find(
val_future_covariates
Optionally, the future covariates corresponding to the validation series (must match ``covariates``)
sample_weight
Optionally, some sample weights to apply to the target `series` labels.
They are applied per observation, per label (each step in `output_chunk_length`), and per component.
Optionally, some sample weights to apply to the target `series` labels. They are applied per observation,
per label (each step in `output_chunk_length`), and per component.
If a series or sequence of series, then those weights are used. If the weight series only have a single
component / column, then the weights are applied globally to all components in `series`. Otherwise, for
component-specific weights, the number of components must match those of `series`.
If a string, then the weights are generated using built-in weighting functions. The available options are
`"linear_decay"` or `"exponential_decay"`. The weights are only computed the longest series in `series`,
and then applied globally to all `series` to have a common time weighting.
If a `TimeSeries` or `Sequence[TimeSeries]`, then those weights are used. The number of series must
match the number of target `series` and each series must contain at least all time steps from the
corresponding target `series`. If the weight series only have a single component / column, then the weights
are applied globally to all components in `series`. Otherwise, for component-specific weights, the number
of components must match those of `series`.
`"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are
computed globally based on the length of the longest series in `series`. Then for each series, the weights
are extracted from the end of the global weights. This gives a common time weighting across all series.
val_sample_weight
Same as for `sample_weight` but for the evaluation dataset.
trainer
Expand Down Expand Up @@ -1546,7 +1544,7 @@ def predict_from_dataset(
Returns one or more forecasts for time series.
"""

# We need to call super's super's method directly, because GlobalForecastingModel expects series:
# we need to call super's super's method directly, because GlobalForecastingModel expects series:
ForecastingModel.predict(self, n, num_samples)

self._verify_inference_dataset_type(input_series_dataset)
Expand Down Expand Up @@ -2016,7 +2014,7 @@ def load_weights_from_checkpoint(
# meaningful error message if parameters are incompatible with the ckpt weights
self._check_ckpt_parameters(tfm_save)

# instanciate the model without having to call `fit_from_dataset`
# instantiate the model without having to call `fit_from_dataset`
self.model = self._init_model()
# cast model precision to correct type
self.model.to_dtype(ckpt["model_dtype"])
Expand Down Expand Up @@ -2336,7 +2334,7 @@ def _check_ckpt_parameters(self, tfm_save):
"The values of the hyper-parameters in the model and loaded checkpoint should be identical."
]

# warning messages formated to facilate copy-pasting
# warning messages formatted to facilitate copy-pasting
if len(missing_params) > 0:
msg += ["missing :"]
msg += [
Expand Down
Binary file not shown.
143 changes: 0 additions & 143 deletions darts/utils/data/training_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,149 +248,6 @@ def _memory_indexer(
sample_weight_end,
)

def _memory_indexer2(
self,
target_idx: int,
target_series: TimeSeries,
shift: int,
input_chunk_length: int,
output_chunk_length: int,
end_of_output_idx: int,
covariate_series: Optional[TimeSeries] = None,
covariate_type: CovariateType = CovariateType.NONE,
sample_weight: Optional[TimeSeries] = None,
) -> SampleIndexType:
"""Returns the (start, end) indices for past target, future target and covariates (sub sets) of the current
sample `i` from `target_idx`.
Works for all TimeSeries index types: pd.DatetimeIndex, pd.RangeIndex (and the deprecated Int64Index)
When `target_idx` is observed for the first time, it stores the position of the sample `0` within the full
target time series and the (start, end) indices of all sub sets.
This allows to calculate the sub set indices for all future samples `i` by simply adjusting for the difference
between the positions of sample `i` and sample `0`.
Parameters
----------
target_idx
index of the current target TimeSeries.
target_series
current target TimeSeries.
shift
The number of time steps by which to shift the output chunks relative to the input chunks.
input_chunk_length
The length of the emitted past series.
output_chunk_length
The length of the emitted future output series.
end_of_output_idx
the index where the output chunk of the current sample ends in `target_series`.
covariate_series
current covariate TimeSeries.
covariate_type
the type of covariate to extract. Instance of `CovariateType`: One of (`CovariateType.PAST`,
`CovariateType.FUTURE`, `CovariateType.NONE`).
sample_weight
current sample weight TimeSeries.
"""

covariate_start, covariate_end = None, None

# the first time target_idx is observed
if target_idx not in self._index_memory:
start_of_output_idx = end_of_output_idx - output_chunk_length
start_of_input_idx = start_of_output_idx - shift

# select forecast point and target period, using the previously computed indexes
future_start, future_end = (
start_of_output_idx,
start_of_output_idx + output_chunk_length,
)

# select input period; look at the `input_chunk_length` points after start of input
past_start, past_end = (
start_of_input_idx,
start_of_input_idx + input_chunk_length,
)

if covariate_type is not CovariateType.NONE:
# not CovariateType.Future -> both CovariateType.PAST and CovariateType.HISTORIC_FUTURE
start = (
future_start
if covariate_type is CovariateType.FUTURE
else past_start
)
end = future_end if covariate_type is CovariateType.FUTURE else past_end

# we need to be careful with getting ranges and indexes:
# to get entire range, full_range = ts[:len(ts)]; to get last index: last_idx = ts[len(ts) - 1]

# extract actual index value (respects datetime- and integer-based indexes; also from non-zero start)
start_time = target_series.time_index[start]
end_time = target_series.time_index[end - 1]

if (
start_time not in covariate_series.time_index
or end_time not in covariate_series.time_index
):
raise_log(
ValueError(
f"Missing covariates; could not find {covariate_type.value} covariates in index "
f"value range: {start_time} - {end_time}."
),
logger=logger,
)

# extract the index position (index) from index value
covariate_start = covariate_series.time_index.get_loc(start_time)
covariate_end = covariate_series.time_index.get_loc(end_time) + 1

# sample weight
sample_weight_start, sample_weight_end = self._index_memory[target_idx][
"sample_weight"
]

# store position of initial sample and all relevant sub set indices
self._index_memory[target_idx] = {
"end_of_output_idx": end_of_output_idx,
"past_target": (past_start, past_end),
"future_target": (future_start, future_end),
"covariate": (covariate_start, covariate_end),
"sample_weight": (sample_weight_start, sample_weight_end),
}
else:
# load position of initial sample and its sub set indices
end_of_output_idx_last = self._index_memory[target_idx]["end_of_output_idx"]
past_start, past_end = self._index_memory[target_idx]["past_target"]
future_start, future_end = self._index_memory[target_idx]["future_target"]
covariate_start, covariate_end = self._index_memory[target_idx]["covariate"]
sample_weight_start, sample_weight_end = self._index_memory[target_idx][
"sample_weight"
]

# evaluate how much the new sample needs to be shifted, and shift all indexes
idx_shift = end_of_output_idx - end_of_output_idx_last
past_start += idx_shift
past_end += idx_shift
future_start += idx_shift
future_end += idx_shift
covariate_start = (
covariate_start + idx_shift if covariate_start is not None else None
)
covariate_end = (
covariate_end + idx_shift if covariate_end is not None else None
)

return (
past_start,
past_end,
future_start,
future_end,
covariate_start,
covariate_end,
sample_weight_start,
sample_weight_end,
)


class PastCovariatesTrainingDataset(TrainingDataset, ABC):
def __init__(self):
Expand Down

0 comments on commit a1fd48c

Please sign in to comment.