-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[devops] Upgrade to lightning 2.0 #1514
Changes from 1 commit
26cb5bd
6cac2bb
35133f5
32146f7
4176a75
2fd30b1
329931a
1222282
b059658
5e0428d
293755e
65caa47
7f62f38
7cc25c9
7f71169
2b2a060
2001d8d
6ebb274
a0f921a
cadcc7b
64b4e17
c4773e9
cfbe96e
891ba53
0be90e7
c4037b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,15 +5,15 @@ | |
import types | ||
from collections import OrderedDict | ||
from dataclasses import dataclass, field | ||
from typing import Callable, Iterable, List, Optional | ||
from typing import Callable, List, Optional | ||
from typing import OrderedDict as OrderedDictType | ||
from typing import Type, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
|
||
from neuralprophet import df_utils, np_types, utils, utils_torch | ||
from neuralprophet import df_utils, np_types, utils_torch | ||
from neuralprophet.custom_loss_metrics import PinballLoss | ||
from neuralprophet.hdays_utils import get_holidays_from_country | ||
|
||
|
@@ -306,17 +306,17 @@ | |
self.trend_global_local = "global" | ||
|
||
# If trend_local_reg < 0 | ||
if self.trend_local_reg < 0: | ||
log.error("Invalid negative trend_local_reg '{}'. Set to False".format(self.trend_local_reg)) | ||
self.trend_local_reg = False | ||
|
||
# If trend_local_reg = True | ||
if self.trend_local_reg: | ||
if self.trend_local_reg is True: | ||
log.error("trend_local_reg = True. Default trend_local_reg value set to 1") | ||
self.trend_local_reg = 1 | ||
|
||
# If Trend modelling is global. | ||
if self.trend_global_local == "global" and self.trend_local_reg: | ||
if self.trend_global_local == "global" and self.trend_local_reg is True: | ||
log.error("Trend modeling is '{}'. Setting the trend_local_reg to False".format(self.trend_global_local)) | ||
self.trend_local_reg = False | ||
|
||
|
@@ -356,13 +356,13 @@ | |
log.error("Invalid global_local mode '{}'. Set to 'global'".format(self.global_local)) | ||
self.global_local = "global" | ||
|
||
self.periods = OrderedDict( | ||
{ | ||
"yearly": Season( | ||
resolution=6, | ||
period=365.25, | ||
arg=self.yearly_arg, | ||
global_local=( | ||
Check failure on line 365 in neuralprophet/configure.py
|
||
self.yearly_global_local | ||
if self.yearly_global_local in ["global", "local"] | ||
else self.global_local | ||
|
@@ -373,7 +373,7 @@ | |
resolution=3, | ||
period=7, | ||
arg=self.weekly_arg, | ||
global_local=( | ||
Check failure on line 376 in neuralprophet/configure.py
|
||
self.weekly_global_local | ||
if self.weekly_global_local in ["global", "local"] | ||
else self.global_local | ||
|
@@ -384,7 +384,7 @@ | |
resolution=6, | ||
period=1, | ||
arg=self.daily_arg, | ||
global_local=( | ||
Check failure on line 387 in neuralprophet/configure.py
|
||
self.daily_global_local if self.daily_global_local in ["global", "local"] else self.global_local | ||
), | ||
condition_name=None, | ||
|
@@ -393,17 +393,17 @@ | |
) | ||
|
||
# If seasonality_local_reg < 0 | ||
if self.seasonality_local_reg < 0: | ||
log.error("Invalid negative seasonality_local_reg '{}'. Set to False".format(self.seasonality_local_reg)) | ||
self.seasonality_local_reg = False | ||
|
||
# If seasonality_local_reg = True | ||
if self.seasonality_local_reg: | ||
if self.seasonality_local_reg is True: | ||
log.error("seasonality_local_reg = True. Default seasonality_local_reg value set to 1") | ||
self.seasonality_local_reg = 1 | ||
|
||
# If Season modelling is global. | ||
if self.global_local == "global" and self.seasonality_local_reg: | ||
if self.global_local == "global" and self.seasonality_local_reg is True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, your change was actually correct before: |
||
log.error( | ||
"Seasonality modeling is '{}'. Setting the seasonality_local_reg to False".format(self.global_local) | ||
) | ||
|
@@ -414,7 +414,7 @@ | |
resolution=resolution, | ||
period=period, | ||
arg=arg, | ||
global_local=global_local if global_local in ["global", "local"] else self.global_local, | ||
Check failure on line 417 in neuralprophet/configure.py
|
||
condition_name=condition_name, | ||
) | ||
|
||
|
@@ -490,7 +490,7 @@ | |
regressors: OrderedDict = field(init=False) # contains RegressorConfig objects | ||
|
||
def __post_init__(self): | ||
self.regressors = None | ||
|
||
|
||
@dataclass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, your change was actually correct before:
if self.trend_global_local == "global" and self.trend_local_reg: