Skip to content

Commit

Permalink
add example that uses peft config
Browse files Browse the repository at this point in the history
  • Loading branch information
geetu040 committed Jun 4, 2024
1 parent 2c913a5 commit 628b717
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions sktime/forecasting/hf_transformers_forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@ class HFTransformersForecaster(BaseForecaster):
>>> forecaster.fit(y) # doctest: +SKIP
>>> fh = [1, 2, 3]
>>> y_pred = forecaster.predict(fh) # doctest: +SKIP
>>> from sktime.forecasting.hf_transformers_forecaster import (
... HFTransformersForecaster,
... )
>>> from sktime.datasets import load_airline
>>> y = load_airline()
>>> forecaster = HFTransformersForecaster(
... model_path="huggingface/autoformer-tourism-monthly",
... fit_strategy="lora",
... training_args={
... "num_train_epochs": 20,
... "output_dir": "test_output",
... "per_device_train_batch_size": 32,
... },
... config={
... "lags_sequence": [1, 2, 3],
... "context_length": 2,
... "prediction_length": 4,
... "use_cpu": True,
... "label_length": 2,
... },
... peft_config_dict={
... "r": 8,
... "lora_alpha": 32,
... "target_modules": ["q_proj", "v_proj"],
... "lora_dropout": 0.01,
... },
... ) # doctest: +SKIP
>>> forecaster.fit(y) # doctest: +SKIP
>>> fh = [1, 2, 3]
>>> y_pred = forecaster.predict(fh) # doctest: +SKIP
"""

_tags = {
Expand Down

0 comments on commit 628b717

Please sign in to comment.