Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] pytorch forecasting adapter with Global Forecasting API #6228

Merged
merged 92 commits into from
Jun 21, 2024

Conversation

Xinyu-Wu-0000
Copy link
Contributor

@Xinyu-Wu-0000 Xinyu-Wu-0000 commented Mar 28, 2024

Reference Issues/PRs

Related: #4651, closes #4923

Main Topic

A pytorch forecasting adapter with Global Forecasting API and several algorithms for design validation.

Details

I'm developing a pytorch forecasting adapter with the Global Forecasting API. To ensure a well-designed implementation, I'd like to discuss some design aspects.

New Base Class for Minimal Impact

A new base class, GlobalBaseForecaster, has been added to minimize the impact on existing forecasters and simplify testing. As discussed in #4651, the plan is to manage the Global Forecasting API via tags only. However, a phased approach might be beneficial. If a tag-based approach is confirmed, we can merge GlobalBaseForecaster back into BaseForecaster after design validation.

Data Type Conversion Challenges

Data type conversion presents a challenge because PyTorch forecasting expects TimeSeriesDataSet as input. While TimeSeriesDataSet can be created from a pandas.DataFrame, it requires numerous parameters. Determining where to pass these parameters is a key question.

Placing them in fit would introduce inconsistency with the existing API. If we put them in __init__, it would be very counterintuitive to define how the data conversion works while initializing the algorithm.

A similar issue arises during trainer initialization. Currently, trainer_params: Dict[str, Any] is used within __init__ to obtain trainer initialization parameters. However, the API for passing these parameters to trainer.fit is yet to be designed.

To convert pytorch_forecasting.models.base_model.Prediction back to a pandas.DataFrame, a custom conversion method is required. Refer to the following issues for more information: jdb78/pytorch-forecasting#734, jdb78/pytorch-forecasting#177.

Train/Validation Strategy

Training a model in PyTorch forecasting necessitates passing both the training and validation datasets together to the training algorithm. This allows for monitoring training progress, adjusting the learning rate, saving the model, or even stopping training prematurely. This differs from the typical sktime approach where only the training data is passed to fit and the test data is used for validation after training. Any suggestions on how to best address this discrepancy?

@benHeid @fkiraly Thank you very much for the feedback on my GSoC proposal! Any suggestions on implementation details or the overall design would be greatly appreciated.

@Xinyu-Wu-0000 Xinyu-Wu-0000 changed the title [ENH] Global pytorch-forecasting [ENH] pytorch forecasting adapter with Global Forecasting API Mar 28, 2024
@benHeid
Copy link
Contributor

benHeid commented Mar 30, 2024

Just a general comment. I would propose to split this into multiple PRs. This would make it easier to review. I would propose a PR for the pytorch-forecasting adapter (first PR) and a second PR that introduces the global forecasting.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just commenting - very interesting!

I suppoes a separate class will give us opportunity to develop this capability. Ultimately, we may decide to merge it into BaseForecaster, or not.

@Xinyu-Wu-0000
Copy link
Contributor Author

Just a general comment. I would propose to split this into multiple PRs. This would make it easier to review. I would propose a PR for the pytorch-forecasting adapter (first PR) and a second PR that introduces the global forecasting.

Yeah, splitting this into 2 PRs will make the workflow more clear. But I am taking pytorch-forecasting as a experiment to try the global forecasting api design, ultimately DL models will need gloabl forecastig api anyway, would it be more convenient to have a single PR?

@fkiraly
Copy link
Collaborator

fkiraly commented Mar 31, 2024

Good question, @Xinyu-Wu-0000 - as long as this is experimental, it's up tp you.

It's always good to have one or two examples working for new API development, so it makes sense ot have examples in the PR.

Although there might be substantial challenges in isolation coming from pytorch forecasting which do not have to do with the global forecasting extension (you already list many of the serious ones, e.g., loader, prediction object), so I wonder if there is a simpler example to develop around.

Either way, that's not a strict requirement, as long as you are working in an exploratory sense.

@fkiraly fkiraly added module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting enhancement Adding new functionality labels Mar 31, 2024
@Xinyu-Wu-0000
Copy link
Contributor Author

so I wonder if there is a simpler example to develop around.

Maybe NeuralForecast could be a simpler example as it's already been interfaced and all models from NeuralForecast are capable of global forecasting, but several PRs are currently working on NeuralForecast. I choose pytorch-forecasting to minimize the impact on existing code base as extending global forecasting API will be a quite big change.

It's always good to have one or two examples working for new API development, so it makes sense ot have examples in the PR.

I just made it work for an example from pytorch-forecasting. It is the first tutorial in the document of pytorch-forecasting.
The test script for the example I use:
test_script.txt

By the way, are we going to have a release with partial global forecasting API support? Something like version 0.30, only NeuralForecast models and pytorch-forecasting models with global forecasting API.

@fkiraly
Copy link
Collaborator

fkiraly commented Apr 1, 2024

By the way, are we going to have a release with partial global forecasting API support?

Yes, I think that's a valid upgrade plan, e.g., release first only some forecasters, and then later merge base classes if everything is robust.

It could be 0.29.0 even in theory, because we're not impacting existing classes with your plan.

@Xinyu-Wu-0000
Copy link
Contributor Author

Tests on macos are having time-out problem again after restarting. pytorch-forecasting requires python<=3.10, but Install and test / test-full (3.11, macos-13) failed too. Checking the log file of Install and test / test-full (3.10, macos-13) (pull_request), pytorch-forecasting==1.0.0, lightning==2.3.0 and torch==2.2.2 are installed correctly. Any ideas about the time-out? @benHeid @fkiraly

@benHeid
Copy link
Contributor

benHeid commented Jun 19, 2024

I take a more detailed look at it this evening.

@fkiraly I think that for some reason more tests are executed in the other jobs than normally. Might it be the case that the GlobalForecasting Tests are executed there?

@benHeid
Copy link
Contributor

benHeid commented Jun 19, 2024

It seems that the PyTorch forecasting logging is quiet aggressive:

image

Part of the full_test run... Should we try to disable them? Or at least set verbosity flags?

@Xinyu-Wu-0000
Copy link
Contributor Author

Part of the full_test run... Should we try to disable them? Or at least set verbosity flags?

I have tried to reduce the logging from pytorch-forecasting before, but unsuccessful. There are also an issue in their repository: jdb78/pytorch-forecasting#1576.

I tried Configure Console Logging from pytorch-lightning, but it still printed a lot to the console:
image

Even after redirecting the stdout, it still prints a lot:

class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


with HiddenPrints():
    model.fit(y=y, fh=fh)

image

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 19, 2024

@fkiraly I think that for some reason more tests are executed in the other jobs than normally. Might it be the case that the GlobalForecasting Tests are executed there?

That is because of the incremental test logic.

If the base modules are changed, then every forecaster is being tested.

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 19, 2024

@Xinyu-Wu-0000, @benHeid, one way to deal with aggressive messages is to redirect the stdout.

I've written a context manager for this in skbase for another use case, and will be moving it to a public module so we can use it:
sktime/skbase#338

What would be great if you could test if it works.

@benHeid
Copy link
Contributor

benHeid commented Jun 19, 2024

There are two failing CI Pipelines.

  • FAILED sktime/forecasting/model_selection/tests/test_tune.py::test_gscv_backends[backend_set7] - KeyError: "None of [PeriodIndex(['1962'], dtype='period[Y-DEC]', name='Period')] are in the [index]"
  • FAILED sktime/tests/test_all_estimators.py::TestAllEstimators::test_non_state_changing_method_contract[ProximityTree-0-ClassifierFitPredict-predict] - _pickle.PicklingError: Could not pickle object as excessively deep recursion required.

@fkiraly Both seems to be unrelated to this PR...

@Xinyu-Wu-0000
Copy link
Contributor Author

Xinyu-Wu-0000 commented Jun 20, 2024

one way to deal with aggressive messages is to redirect the stdout.

I tried to copy the code to test it in a script:

test code
class StdoutMute:
    """A context manager to suppress stdout.
    This class is used to suppress stdout when importing modules.
        except catch and suppress ModuleNotFoundError.
    """

    def __init__(self, active=True):
        self.active = active

    def __enter__(self):
        """Context manager entry point."""
        # capture stdout if active
        # store the original stdout so it can be restored in __exit__
        if self.active:
            self._stdout = sys.stdout
            sys.stdout = io.StringIO()

    def __exit__(self, type, value, traceback):  # noqa: A002
        """Context manager exit point."""
        # restore stdout if active
        # if not active, nothing needs to be done, since stdout was not replaced
        if self.active:
            sys.stdout = self._stdout

        if type is not None:
            # if a ModuleNotFoundError is raised,
            # we suppress to a warning if "soft dependency" is in the error message
            # otherwise, raise
            if type is ModuleNotFoundError:
                if "soft dependency" not in str(value):
                    return False
                warnings.warn(str(value), ImportWarning, stacklevel=2)
                return True

            # all other exceptions are raised
            return False
        # if no exception was raised, return True to indicate successful exit
        # return statement not needed as type was None, but included for clarity
        return True


print("__________________________________________")
with StdoutMute():
    model.fit(y=y, fh=fh)

However it still printed a lot:

print

image

Even after blocking both stdout and stderr, it still printed a few lines:
image

@benHeid
Copy link
Contributor

benHeid commented Jun 20, 2024

Mhm but are the print statements between the tests removed?

I suppose that the tensorflow warnings are okay to be shown.

@Xinyu-Wu-0000
Copy link
Contributor Author

are the print statements between the tests removed?

Yes, if both stdout and stderr are blocked. If only stdout is blocked, statements between the tests will still be printed like this:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/xinyu/Documents/sktime/lightning_logs/7056098100484793063/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 0     
3  | prescalers                         | ModuleDict                      | 32    
4  | static_variable_selection          | VariableSelectionNetwork        | 0     
5  | encoder_variable_selection         | VariableSelectionNetwork        | 1.2 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 528   
7  | static_context_variable_selection  | GatedResidualNetwork            | 1.1 K 
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 1.1 K 
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 1.1 K 
10 | static_context_enrichment          | GatedResidualNetwork            | 1.1 K 
11 | lstm_encoder                       | LSTM                            | 2.2 K 
12 | lstm_decoder                       | LSTM                            | 2.2 K 
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 544   
14 | post_lstm_add_norm_encoder         | AddNorm                         | 32    
15 | static_enrichment                  | GatedResidualNetwork            | 1.4 K 
16 | multihead_attn                     | InterpretableMultiHeadAttention | 676   
17 | post_attn_gate_norm                | GateAddNorm                     | 576   
18 | pos_wise_ff                        | GatedResidualNetwork            | 1.1 K 
19 | pre_output_gate_norm               | GateAddNorm                     | 576   
20 | output_layer                       | Linear                          | 119   
----------------------------------------------------------------------------------------
15.5 K    Trainable params
0         Non-trainable params
15.5 K    Total params
0.062     Total estimated model params size (MB)
/home/xinyu/Documents/sktime/env/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/xinyu/Documents/sktime/env/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/xinyu/Documents/sktime/env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
`Trainer.fit` stopped: `max_epochs=1` reached.
/home/xinyu/Documents/sktime/./sktime/forecasting/base/adapters/_pytorchforecasting.py:477: FutureWarning: A value is trying to be set on a copy of a DataFrame or Series through chained assignment using an inplace method.
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  data["_target_column"].fillna(0, inplace=True)
/home/xinyu/Documents/sktime/env/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/home/xinyu/Documents/sktime/env/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.

@benHeid
Copy link
Contributor

benHeid commented Jun 20, 2024

Can you check if there is a verbosity parameter in PyTorch forecasting? Exposing this would enable users to control the output.

@Xinyu-Wu-0000
Copy link
Contributor Author

Xinyu-Wu-0000 commented Jun 20, 2024

Can you check if there is a verbosity parameter in PyTorch forecasting? Exposing this would enable users to control the output.

I have checked it #6228 (comment), I couldn't find verbosity parameter in pytorch forecasting or pytorch lightning.

Should we expose a parameter to redirect stdout and stderr?

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 20, 2024

not directly related to the discussion, but a planning comment: could we look towards closing out remaining issues in review and see if we can merge? @benHeid, could you kindly state what in your opinion are blocking change requests that we need to focus on getting addressed?

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 20, 2024

@Xinyu-Wu-0000. can you kindly go through the change requests above and do the following:

  • in every conversation, please comment how you addressed it or ask for clarification. That is, please write a reply, this can be, you took an action, or you decided not to take an action, or you have further questions on the requirement. The requester can then follow up or close.
  • For every point in a "change request" post, kindly do the same.

@Xinyu-Wu-0000
Copy link
Contributor Author

All comments are addressed except this one about whether to move some code to base class:

#6228 (comment)

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I want to congratulate to this monumental contribution! This is a framework cornerstone in going forward with foundation models, reduction forecasters, etc. I also would like to congratulate on your understanding of the various parts of the code base coming together here, @Xinyu-Wu-0000.

I'm sure we will end up refining this in the future as we implement more examples, and this may not be entirely final, we are a good way there though. Let's refine in separate pull requests.

@fkiraly
Copy link
Collaborator

fkiraly commented Jun 20, 2024

@benHeid, are your comments addressed? Anything still outstanding?

@benHeid
Copy link
Contributor

benHeid commented Jun 21, 2024

Can you check if there is a verbosity parameter in PyTorch forecasting? Exposing this would enable users to control the output.

I have checked it #6228 (comment), I couldn't find verbosity parameter in pytorch forecasting or pytorch lightning.

Should we expose a parameter to redirect stdout and stderr?

I think this would be cool. But not in this PR. I would suggest to open a new one for this with a corresponding issue.

Copy link
Contributor

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xinyu-Wu-0000 congratulation for this very nice contribution. I think this is a quite important one for sktime.

@fkiraly fkiraly merged commit 2b4300d into sktime:main Jun 21, 2024
155 of 163 checks passed
Spinachboul pushed a commit to Spinachboul/sktime that referenced this pull request Jun 29, 2024
…#6228)

## Reference Issues/PRs
Related: sktime#4651, closes
sktime#4923

## Main Topic
A pytorch forecasting adapter with Global Forecasting API and several
algorithms for design validation.

## Details
I'm developing a pytorch forecasting adapter with the Global Forecasting
API. To ensure a well-designed implementation, I'd like to discuss some
design aspects.

### New Base Class for Minimal Impact
A new base class, `GlobalBaseForecaster`, has been added to minimize the
impact on existing forecasters and simplify testing. As discussed in
sktime#4651, the plan is to manage the
Global Forecasting API via tags only. However, a phased approach might
be beneficial. If a tag-based approach is confirmed, we can merge
`GlobalBaseForecaster` back into `BaseForecaster` after design
validation.

### Data Type Conversion Challenges
Data type conversion presents a challenge because PyTorch forecasting
expects
[TimeSeriesDataSet](https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.html#pytorch_forecasting.data.timeseries.TimeSeriesDataSet)
as input. While `TimeSeriesDataSet` can be created from a
`pandas.DataFrame`, it requires numerous parameters. Determining where
to pass these parameters is a key question.

Placing them in `fit` would introduce inconsistency with the existing
API. If we put them in `__init__`, it would be very counterintuitive to
define how the data conversion works while initializing the algorithm.

A similar issue arises during trainer initialization. Currently,
`trainer_params: Dict[str, Any]` is used within `__init__` to obtain
trainer initialization parameters. However, the API for passing these
parameters to `trainer.fit` is yet to be designed.

To convert `pytorch_forecasting.models.base_model.Prediction` back to a
`pandas.DataFrame`, a custom conversion method is required. Refer to the
following issues for more information:
jdb78/pytorch-forecasting#734,
jdb78/pytorch-forecasting#177.

### Train/Validation Strategy
Training a model in PyTorch forecasting necessitates passing both the
training and validation datasets together to the training algorithm.
This allows for monitoring training progress, adjusting the learning
rate, saving the model, or even stopping training prematurely. This
differs from the typical sktime approach where only the training data is
passed to fit and the test data is used for validation after training.
Any suggestions on how to best address this discrepancy?
Spinachboul pushed a commit to Spinachboul/sktime that referenced this pull request Jun 29, 2024
…#6228)

## Reference Issues/PRs
Related: sktime#4651, closes
sktime#4923

## Main Topic
A pytorch forecasting adapter with Global Forecasting API and several
algorithms for design validation.

## Details
I'm developing a pytorch forecasting adapter with the Global Forecasting
API. To ensure a well-designed implementation, I'd like to discuss some
design aspects.

### New Base Class for Minimal Impact
A new base class, `GlobalBaseForecaster`, has been added to minimize the
impact on existing forecasters and simplify testing. As discussed in
sktime#4651, the plan is to manage the
Global Forecasting API via tags only. However, a phased approach might
be beneficial. If a tag-based approach is confirmed, we can merge
`GlobalBaseForecaster` back into `BaseForecaster` after design
validation.

### Data Type Conversion Challenges
Data type conversion presents a challenge because PyTorch forecasting
expects
[TimeSeriesDataSet](https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.html#pytorch_forecasting.data.timeseries.TimeSeriesDataSet)
as input. While `TimeSeriesDataSet` can be created from a
`pandas.DataFrame`, it requires numerous parameters. Determining where
to pass these parameters is a key question.

Placing them in `fit` would introduce inconsistency with the existing
API. If we put them in `__init__`, it would be very counterintuitive to
define how the data conversion works while initializing the algorithm.

A similar issue arises during trainer initialization. Currently,
`trainer_params: Dict[str, Any]` is used within `__init__` to obtain
trainer initialization parameters. However, the API for passing these
parameters to `trainer.fit` is yet to be designed.

To convert `pytorch_forecasting.models.base_model.Prediction` back to a
`pandas.DataFrame`, a custom conversion method is required. Refer to the
following issues for more information:
jdb78/pytorch-forecasting#734,
jdb78/pytorch-forecasting#177.

### Train/Validation Strategy
Training a model in PyTorch forecasting necessitates passing both the
training and validation datasets together to the training algorithm.
This allows for monitoring training progress, adjusting the learning
rate, saving the model, or even stopping training prematurely. This
differs from the typical sktime approach where only the training data is
passed to fit and the test data is used for validation after training.
Any suggestions on how to best address this discrepancy?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Adding new functionality module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[ENH] Adapter to pytorch-forecasting
5 participants