Skip to content

Commit

Permalink
Enhance pytorch nn (microsoft#917)
Browse files Browse the repository at this point in the history
* enhance pytorch_nn

* fix dim bug

* Black format

* Fix pylint error
  • Loading branch information
you-n-g committed Feb 15, 2022
1 parent 25b0cc4 commit 74530be
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 143 deletions.
5 changes: 5 additions & 0 deletions .pylintrc
@@ -0,0 +1,5 @@
[TYPECHECK]
# https://stackoverflow.com/a/53572939
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
12 changes: 11 additions & 1 deletion docs/developer/code_standard.rst
Expand Up @@ -14,9 +14,19 @@ Continuous Integration (CI) tools help you stick to the quality standards by run

When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.

A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
You can fix the bug by inputing the following code in the command line.

.. code-block:: python
pip install black
python -m black . -l 120
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this

.. code-block:: python
return -ICLoss()(pred, target, index) # pylint: disable=E1130
4 changes: 2 additions & 2 deletions examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml
Expand Up @@ -63,8 +63,6 @@ task:
module_path: qlib.contrib.model.pytorch_nn
kwargs:
loss: mse
input_dim: 157
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
Expand All @@ -73,6 +71,8 @@ task:
batch_size: 8192
GPU: 0
weight_decay: 0.0002
pt_model_kwargs:
input_dim: 157
dataset:
class: DatasetH
module_path: qlib.data.dataset
Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml
Expand Up @@ -51,15 +51,15 @@ task:
module_path: qlib.contrib.model.pytorch_nn
kwargs:
loss: mse
input_dim: 360
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096
GPU: 0
pt_model_kwargs:
input_dim: 360
dataset:
class: DatasetH
module_path: qlib.data.dataset
Expand Down
3 changes: 3 additions & 0 deletions qlib/contrib/meta/data_selection/utils.py
Expand Up @@ -9,6 +9,9 @@
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
"""forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
- It may be caused by the precision problem of model;
:param pred:
:param y:
Expand Down
14 changes: 10 additions & 4 deletions qlib/contrib/model/gbdt.py
Expand Up @@ -10,6 +10,7 @@
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt
from ...data.dataset.weight import Reweighter
from qlib.workflow import R


class LGBModel(ModelFT, LightGBMFInt):
Expand Down Expand Up @@ -59,10 +60,12 @@ def fit(
num_boost_round=None,
early_stopping_rounds=None,
verbose_eval=20,
evals_result=dict(),
evals_result=None,
reweighter=None,
**kwargs
**kwargs,
):
if evals_result is None:
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l))
self.model = lgb.train(
Expand All @@ -76,10 +79,13 @@ def fit(
),
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
**kwargs,
)
for k in names:
evals_result[k] = list(evals_result[k].values())[0]
for key, val in evals_result[k].items():
name = f"{key}.{k}"
for epoch, m in enumerate(val):
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)

def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_gats.py
Expand Up @@ -263,8 +263,8 @@ def fit(

model_dict = self.GAT_model.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
} # pylint: disable=E1135
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_gats_ts.py
Expand Up @@ -278,8 +278,8 @@ def fit(

model_dict = self.GAT_model.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
} # pylint: disable=E1135
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
}
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")
Expand Down

0 comments on commit 74530be

Please sign in to comment.