Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions pytorch_forecasting/models/timexer/_timexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,13 @@ def __init__(
if enc_in is None:
self.enc_in = len(self.reals)

self.n_quantiles = None
# NOTE: assume point prediction as default here,
# with single median quantile being the point prediction.
# hence self.n_quantiles = 1 for point predictions.
self.n_quantiles = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify why are we doing n_quantiles=1 and not n_quantiles = None here?
As from what I have seen so far in the package, we always use n_quantiles = None when not using QuantileLoss, and making n_quantiles=1 may change that contract? Although it wont affect the working ig, but for the user this can be misleading?
That QuantileLoss is being used, even when it is not?

Copy link
Contributor Author

@PranavBhatP PranavBhatP Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, self.n_quantiles is being set to 1 because in the case of a point prediction, the median quantile is what is being predicted. So instead of setting self.n_quantiles = None and having step-outs for checking the value of self.n_quantiles inside the layer architecture, we simply set it to 1 and let the loss handle the rest.

For some context, setting self.n_quantiles = 1 simplifies how we handle the distinction between a loss like MAE and QuantileLoss. Refer this diff - https://github.com/sktime/pytorch-forecasting/pull/1936/files#diff-8c4bc78319ca93d3c0e93a38a0ee551ac3c4193f717955c0ce3477e4147a9153

TLDR; It removes unecessary "if" statements and step-outs in the logic for TimeXer's output.

Maybe a comment would clarify this in the code, I will add it, thanks for highlighting this.

Copy link
Member

@phoeenniixx phoeenniixx Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, I see, Thanks!
If this is working, maybe then we should refactor other places as well where this if/else block is being used because of n_quantiles = None?
I jam just thinking it would be good if everything consistent with rest of the API....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just, I am a bit concerned about the meaning here- if a person is using this estimator, they would see n_quantiles=1 even for point predicitons (I am talking about the user who is just using the estimator and not actually reading the code) and this may confuse the user that "why this param is not None like it is for other estimators in ptf?"

Also, idts this if/else block would be much of a overhead, so, if using this block keeps the user from getting confused, I think we should prefer that.

But that is my thought, Please let me know what is your reasoning here...

Copy link
Contributor Author

@PranavBhatP PranavBhatP Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It removes unecessary "if" statements and step-outs in the logic for TimeXer's output.

The reasoning is as simple as this :). Since we are standardising the output format, it makes more sense if we renamed this attribute to self.output_dim or something like that, instead of self.n_quantiles. That way we have a more generic meaning and assign different values to it only when there is a step-out to a different case (QuantileLoss in our case)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you kindly explain what API or parameter change exactly we are talking about?

Copy link
Member

@phoeenniixx phoeenniixx Oct 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, in this PR @PranavBhatP is changing the Convention of using n_quantiles = None (that happens currently in the package) when the user is not using QuantileLoss (@PranavBhatP please correct me if I understood it wrongly), rather n_quantiles=1 is being introduced in absence of QuantileLoss . (see comments #1936 (comment) and #1936 (comment)) and I was concerned about the meaning it could convey (see comment #1936 (comment)).

He suggested, we could rename the n_quantiles to output_dim and it would solve the issue of the meaning. But again this would lead to changes to the code across the whole package. And it would be breaking as a param name is being changed, so as you said, we could deprecate it, but I think we should introduce (if we decide to) to v2 and keep v1 intact, as v1 is already mature enough, and it may not even complete the deprecation cycle, as before that we may completely move to v2. I am assuming the deprecation cycle will complete after 2 realeases, hopefully we might release a full fledged v2 by then. So, I think if we are going to make this change, we should do this in v2.

I am still not very sure if this would be a good idea, that's why I wanted if you and @agobbifbk could comment your thoughts on this change.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me output_dim is a good idea, it is more general. I don't think it is necessary to deprecate it if we rename because it is not a parameter that the user can modify, isn't? It just depend on the chosen loss function(s) right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but I am assuming n_quantiles are right now used in documentation to define the shape of output when we use QuantileLoss, so, ig this arg is visible to the user in the documentation? If yes, we might need to deprecate it? I am not sure, what's the process here😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a design document is the right process. Write up the vignettes now - and analyze whether n_quantiles even appears in them.

Then propose alternative vignettes for future use, and from that develop a change and/or deprecation trajectory.


# set n_quantiles to the length of the quantiles list passed
# into the "quantiles" parameter when QuantileLoss is used.
if isinstance(loss, QuantileLoss):
self.n_quantiles = len(loss.quantiles)

Expand Down Expand Up @@ -353,10 +358,7 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)
dec_out = dec_out.permute(0, 2, 1, 3)

return dec_out

Expand Down Expand Up @@ -395,10 +397,7 @@ def _forecast_multi(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]
enc_out = enc_out.permute(0, 1, 3, 2)

dec_out = self.head(enc_out)
if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)
dec_out = dec_out.permute(0, 2, 1, 3)

return dec_out

Expand Down Expand Up @@ -470,25 +469,15 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if prediction.size(2) != len(target_positions):
prediction = prediction[:, :, : len(target_positions)]

# In the case of a single target, the result will be a torch.Tensor
# with shape (batch_size, prediction_length)
# In the case of multiple targets, the result will be a list of "n_targets"
# tensors with shape (batch_size, prediction_length)
# If quantile predictions are used, the result will have an additional
# dimension for quantiles, resulting in a shape of
# (batch_size, prediction_length, n_quantiles)
if self.n_quantiles is not None:
# quantile predictions.
if len(target_indices) == 1:
prediction = prediction[..., 0, :]
else:
prediction = [prediction[..., i, :] for i in target_indices]
# output format is (batch_size, prediction_length, n_quantiles)
# in case of quantile loss, the output n_quantiles = self.n_quantiles
# which is the length of a list of float. In case of MAE, MSE, etc.
# n_quantiles = 1 and it mimics the behavior of a point prediction.
# for multi-target forecasting, the output is a list of tensors.
if len(target_positions) == 1:
prediction = prediction[..., 0, :]
else:
# point predictions.
if len(target_indices) == 1:
prediction = prediction[..., 0]
else:
prediction = [prediction[..., i] for i in target_indices]
prediction = [prediction[..., i, :] for i in target_indices]
prediction = self.transform_output(
prediction=prediction, target_scale=x["target_scale"]
)
Expand Down
1 change: 0 additions & 1 deletion pytorch_forecasting/models/timexer/_timexer_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class TimeXer_pkg(_BasePtForecaster):
"capability:pred_int": True,
"capability:flexible_history_length": True,
"capability:cold_start": False,
"tests:skip_by_name": "test_integration",
}

@classmethod
Expand Down
15 changes: 5 additions & 10 deletions pytorch_forecasting/models/timexer/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,24 @@ class FlattenHead(nn.Module):
nf (int): Number of features in the last layer.
target_window (int): Target window size.
head_dropout (float): Dropout rate for the head. Defaults to 0.
n_quantiles (int, optional): Number of quantiles. Defaults to None."""
n_quantiles (int, optional): Number of quantiles. Defaults to 1."""

def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None):
def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=1):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.n_quantiles = n_quantiles

if self.n_quantiles is not None:
self.linear = nn.Linear(nf, target_window * n_quantiles)
else:
self.linear = nn.Linear(nf, target_window)
self.linear = nn.Linear(nf, target_window * n_quantiles)
self.dropout = nn.Dropout(head_dropout)

def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)

if self.n_quantiles is not None:
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
return x


Expand Down
Loading