diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d50066a..11b73fd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer exclude: "setup.cfg" @@ -48,7 +48,7 @@ repos: ) - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 + rev: v0.6.4 hooks: - id: ruff args: ["--fix"] @@ -56,7 +56,7 @@ repos: - id: ruff - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + rev: v3.1.0 hooks: - id: prettier files: \.(json|yml|yaml|toml) diff --git a/examples/PyTorch Tabular with Bank Marketing Dataset.ipynb b/examples/PyTorch Tabular with Bank Marketing Dataset.ipynb index b2fa4273..28cf0d46 100644 --- a/examples/PyTorch Tabular with Bank Marketing Dataset.ipynb +++ b/examples/PyTorch Tabular with Bank Marketing Dataset.ipynb @@ -8,10 +8,9 @@ "outputs": [], "source": [ "import numpy as np\n", - "import pandas as pd\n", "from sklearn.datasets import fetch_openml\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import accuracy_score, log_loss" + "from sklearn.metrics import accuracy_score, log_loss\n", + "from sklearn.model_selection import train_test_split" ] }, { @@ -55,12 +54,23 @@ "metadata": {}, "outputs": [], "source": [ - "cat_cols = ['job', 'marital', 'education', 'default', 'housing',\n", - " 'loan', 'contact', 'day', 'month', 'campaign',\n", - " 'previous', 'poutcome']\n", + "cat_cols = [\n", + " \"job\",\n", + " \"marital\",\n", + " \"education\",\n", + " \"default\",\n", + " \"housing\",\n", + " \"loan\",\n", + " \"contact\",\n", + " \"day\",\n", + " \"month\",\n", + " \"campaign\",\n", + " \"previous\",\n", + " \"poutcome\",\n", + "]\n", "\n", - "num_cols = ['age', 'balance', 'duration', 'pdays']\n", - "target=[\"y\"]" + "num_cols = [\"age\", \"balance\", \"duration\", \"pdays\"]\n", + "target = [\"y\"]" ] }, { @@ -96,8 +106,8 @@ "test_enc = test.copy()\n", "for col in cat_cols:\n", " enc = OrdinalEncoder(handle_unknown=\"use_encoded_value\", encoded_missing_value=np.nan, unknown_value=np.nan)\n", - " train_enc[col] = enc.fit_transform(train_enc[col].values.reshape(-1,1))\n", - " test_enc[col] = enc.transform(test_enc[col].values.reshape(-1,1))" + " train_enc[col] = enc.fit_transform(train_enc[col].values.reshape(-1, 1))\n", + " test_enc[col] = enc.transform(test_enc[col].values.reshape(-1, 1))" ] }, { @@ -153,15 +163,15 @@ "outputs": [], "source": [ "from pytorch_tabular import TabularModel\n", + "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n", "from pytorch_tabular.models import (\n", - " CategoryEmbeddingModelConfig, \n", - " FTTransformerConfig, \n", - " TabNetModelConfig, \n", - " GatedAdditiveTreeEnsembleConfig, \n", - " TabTransformerConfig, \n", - " AutoIntConfig\n", + " AutoIntConfig,\n", + " CategoryEmbeddingModelConfig,\n", + " FTTransformerConfig,\n", + " GatedAdditiveTreeEnsembleConfig,\n", + " TabNetModelConfig,\n", + " TabTransformerConfig,\n", ")\n", - "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n", "from pytorch_tabular.models.common.heads import LinearHeadConfig" ] }, @@ -183,29 +193,29 @@ "outputs": [], "source": [ "data_config = DataConfig(\n", - " target=target, #target should always be a list.\n", + " target=target, # target should always be a list.\n", " continuous_cols=num_cols,\n", " categorical_cols=cat_cols,\n", ")\n", "\n", "trainer_config = TrainerConfig(\n", - "# auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n", + " # auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n", " batch_size=256,\n", " max_epochs=500,\n", - " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n", - " early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n", - " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n", - " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n", - " load_best=True, # After training, load the best checkpoint\n", + " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n", + " early_stopping_mode=\"min\", # Set the mode as min because for val_loss, lower is better\n", + " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n", + " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n", + " load_best=True, # After training, load the best checkpoint\n", ")\n", "\n", "optimizer_config = OptimizerConfig()\n", "\n", "head_config = LinearHeadConfig(\n", - " layers=\"\", # No additional layer in head, just a mapping layer to output_dim\n", + " layers=\"\", # No additional layer in head, just a mapping layer to output_dim\n", " dropout=0.1,\n", - " initialization=\"kaiming\"\n", - ").__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)" + " initialization=\"kaiming\",\n", + ").__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)" ] }, { @@ -442,10 +452,10 @@ "model_config = CategoryEmbeddingModelConfig(\n", " task=\"classification\",\n", " layers=\"64-32\", # Number of nodes in each layer\n", - " activation=\"ReLU\", # Activation between each layers\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " activation=\"ReLU\", # Activation between each layers\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -455,7 +465,7 @@ " trainer_config=trainer_config,\n", ")\n", "tabular_model.fit(train=train)\n", - "tabular_model.evaluate(test)\n" + "tabular_model.evaluate(test)" ] }, { @@ -709,9 +719,9 @@ "source": [ "model_config = GatedAdditiveTreeEnsembleConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -983,13 +993,13 @@ "source": [ "model_config = GatedAdditiveTreeEnsembleConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", " gflu_stages=4,\n", " num_trees=30,\n", " tree_depth=5,\n", - " chain_trees=False\n", + " chain_trees=False,\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -1265,9 +1275,9 @@ "source": [ "model_config = FTTransformerConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -1543,9 +1553,9 @@ "source": [ "model_config = TabTransformerConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -1819,9 +1829,9 @@ "source": [ "model_config = AutoIntConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", @@ -2095,9 +2105,9 @@ "source": [ "model_config = TabNetModelConfig(\n", " task=\"classification\",\n", - " learning_rate = 1e-3,\n", - " head = \"LinearHead\", #Linear Head\n", - " head_config = head_config, # Linear Head Config\n", + " learning_rate=1e-3,\n", + " head=\"LinearHead\", # Linear Head\n", + " head_config=head_config, # Linear Head Config\n", ")\n", "\n", "tabular_model = TabularModel(\n", diff --git a/examples/__only_for_dev__/to_test_classification.py b/examples/__only_for_dev__/to_test_classification.py index 7d1fe855..0884582c 100644 --- a/examples/__only_for_dev__/to_test_classification.py +++ b/examples/__only_for_dev__/to_test_classification.py @@ -1,6 +1,7 @@ from pathlib import Path import pandas as pd +from sklearn.model_selection import train_test_split # from torch.utils import data from pytorch_tabular.config import DataConfig, ExperimentConfig, OptimizerConfig, TrainerConfig @@ -9,7 +10,6 @@ # import wget from pytorch_tabular.utils import get_class_weighted_cross_entropy -from sklearn.model_selection import train_test_split # torch.manual_seed(0) # np.random.seed(0) diff --git a/examples/__only_for_dev__/to_test_node.py b/examples/__only_for_dev__/to_test_node.py index 22e0fb94..722c4743 100644 --- a/examples/__only_for_dev__/to_test_node.py +++ b/examples/__only_for_dev__/to_test_node.py @@ -3,10 +3,11 @@ import numpy as np import pandas as pd +from sklearn.datasets import fetch_california_housing, fetch_covtype + from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models.node import NodeConfig from pytorch_tabular.tabular_model import TabularModel -from sklearn.datasets import fetch_california_housing, fetch_covtype def regression_data(): diff --git a/examples/__only_for_dev__/to_test_regression.py b/examples/__only_for_dev__/to_test_regression.py index badf12b9..56c306dc 100644 --- a/examples/__only_for_dev__/to_test_regression.py +++ b/examples/__only_for_dev__/to_test_regression.py @@ -1,9 +1,10 @@ import pandas as pd import torch +from sklearn.datasets import fetch_california_housing + from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models.category_embedding.config import CategoryEmbeddingModelConfig from pytorch_tabular.tabular_model import TabularModel -from sklearn.datasets import fetch_california_housing # from pytorch_tabular.models.mixture_density import ( # CategoryEmbeddingMDNConfig, diff --git a/examples/__only_for_dev__/to_test_regression_custom_models.py b/examples/__only_for_dev__/to_test_regression_custom_models.py index 21e20898..9adb8ed0 100644 --- a/examples/__only_for_dev__/to_test_regression_custom_models.py +++ b/examples/__only_for_dev__/to_test_regression_custom_models.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn from omegaconf import DictConfig +from sklearn.datasets import fetch_california_housing + from pytorch_tabular.config import DataConfig, ModelConfig, OptimizerConfig, TrainerConfig # from pytorch_tabular.models.deep_gmm import ( @@ -14,7 +16,6 @@ # from pytorch_tabular.models.node import utils as utils from pytorch_tabular.tabular_model import TabularModel -from sklearn.datasets import fetch_california_housing @dataclass diff --git a/examples/covertype_classification.py b/examples/covertype_classification.py index c11272b2..2cd2a5b7 100644 --- a/examples/covertype_classification.py +++ b/examples/covertype_classification.py @@ -2,11 +2,12 @@ import pandas as pd import wget +from sklearn.model_selection import train_test_split + from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import CategoryEmbeddingModelConfig from pytorch_tabular.models.common.heads import LinearHeadConfig from pytorch_tabular.tabular_model import TabularModel -from sklearn.model_selection import train_test_split BASE_DIR = Path.home().joinpath("data") datafile = BASE_DIR.joinpath("covtype.data.gz") diff --git a/examples/covertype_classification_using_yaml.py b/examples/covertype_classification_using_yaml.py index c106e2eb..6e4bdd93 100644 --- a/examples/covertype_classification_using_yaml.py +++ b/examples/covertype_classification_using_yaml.py @@ -2,9 +2,10 @@ import pandas as pd import wget -from pytorch_tabular.tabular_model import TabularModel from sklearn.model_selection import train_test_split +from pytorch_tabular.tabular_model import TabularModel + BASE_DIR = Path.home().joinpath("data") datafile = BASE_DIR.joinpath("covtype.data.gz") datafile.parent.mkdir(parents=True, exist_ok=True) diff --git a/setup.py b/setup.py index 20de6fee..60e38980 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """The setup script.""" + import os from setuptools import find_packages, setup diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 6c7e5823..8e8006c6 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -3,6 +3,7 @@ # For license information, see LICENSE.TXT # Modified https://github.com/tcassou/mlencoders/blob/master/mlencoders/base_encoder.py to suit NN encoding """Category Encoders.""" + from pandas import DataFrame, Series, unique try: diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 8b0c6035..7df8180f 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Config.""" + import os import re from dataclasses import MISSING, dataclass, field diff --git a/src/pytorch_tabular/models/autoint/autoint.py b/src/pytorch_tabular/models/autoint/autoint.py index 5d1a3c52..10105bae 100644 --- a/src/pytorch_tabular/models/autoint/autoint.py +++ b/src/pytorch_tabular/models/autoint/autoint.py @@ -3,6 +3,7 @@ # For license information, see LICENSE.TXT # Inspired by https://github.com/rixwew/pytorch-fm/blob/master/torchfm/model/afi.py """AutomaticFeatureInteraction Model.""" + import torch import torch.nn as nn from omegaconf import DictConfig diff --git a/src/pytorch_tabular/models/autoint/config.py b/src/pytorch_tabular/models/autoint/config.py index fbe84ef2..511b44d3 100644 --- a/src/pytorch_tabular/models/autoint/config.py +++ b/src/pytorch_tabular/models/autoint/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """AutomaticFeatureInteraction Config.""" + from dataclasses import dataclass, field from typing import Optional diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 94993c76..1328d34c 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Base Model.""" + import importlib import warnings from abc import ABCMeta, abstractmethod diff --git a/src/pytorch_tabular/models/category_embedding/category_embedding_model.py b/src/pytorch_tabular/models/category_embedding/category_embedding_model.py index c54d866a..4e6562f6 100644 --- a/src/pytorch_tabular/models/category_embedding/category_embedding_model.py +++ b/src/pytorch_tabular/models/category_embedding/category_embedding_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Category Embedding Model.""" + import torch import torch.nn as nn from omegaconf import DictConfig diff --git a/src/pytorch_tabular/models/category_embedding/config.py b/src/pytorch_tabular/models/category_embedding/config.py index 02918135..99b77b29 100644 --- a/src/pytorch_tabular/models/category_embedding/config.py +++ b/src/pytorch_tabular/models/category_embedding/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Category Embedding Model Config.""" + from dataclasses import dataclass, field from pytorch_tabular.config import ModelConfig diff --git a/src/pytorch_tabular/models/common/layers/soft_trees.py b/src/pytorch_tabular/models/common/layers/soft_trees.py index e54b3876..921f5c3e 100644 --- a/src/pytorch_tabular/models/common/layers/soft_trees.py +++ b/src/pytorch_tabular/models/common/layers/soft_trees.py @@ -177,13 +177,12 @@ def initialize(self, input, eps=1e-6): self.log_temperatures.data[...] = torch.log(torch.as_tensor(temperatures) + eps) def __repr__(self): - return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format( - self.__class__.__name__, - self.feature_selection_logits.shape[0], - self.num_trees, - self.depth, - self.tree_dim, - self.flatten_output, + return ( + f"{self.__class__.__name__}(in_features={self.feature_selection_logits.shape[0]}," + f" num_trees={self.num_trees}," + f" depth={self.depth}," + f" tree_dim={self.tree_dim}," + f" flatten_output={self.flatten_output})" ) diff --git a/src/pytorch_tabular/models/danet/config.py b/src/pytorch_tabular/models/danet/config.py index aea4f12f..13978296 100644 --- a/src/pytorch_tabular/models/danet/config.py +++ b/src/pytorch_tabular/models/danet/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """AutomaticFeatureInteraction Config.""" + from dataclasses import dataclass, field from typing import Optional diff --git a/src/pytorch_tabular/models/ft_transformer/config.py b/src/pytorch_tabular/models/ft_transformer/config.py index a30418f8..3697da51 100644 --- a/src/pytorch_tabular/models/ft_transformer/config.py +++ b/src/pytorch_tabular/models/ft_transformer/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """AutomaticFeatureInteraction Config.""" + from dataclasses import dataclass, field from typing import Optional diff --git a/src/pytorch_tabular/models/ft_transformer/ft_transformer.py b/src/pytorch_tabular/models/ft_transformer/ft_transformer.py index a78e671e..1920aa96 100644 --- a/src/pytorch_tabular/models/ft_transformer/ft_transformer.py +++ b/src/pytorch_tabular/models/ft_transformer/ft_transformer.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Feature Tokenizer Transformer Model.""" + from collections import OrderedDict import torch diff --git a/src/pytorch_tabular/models/gandalf/config.py b/src/pytorch_tabular/models/gandalf/config.py index ebfd5359..3b3c4883 100644 --- a/src/pytorch_tabular/models/gandalf/config.py +++ b/src/pytorch_tabular/models/gandalf/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """AutomaticFeatureInteraction Config.""" + from dataclasses import dataclass, field from pytorch_tabular.config import ModelConfig diff --git a/src/pytorch_tabular/models/gate/config.py b/src/pytorch_tabular/models/gate/config.py index bedf13d0..b8ba9729 100644 --- a/src/pytorch_tabular/models/gate/config.py +++ b/src/pytorch_tabular/models/gate/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """GatedAdditiveTreeEnsembleConfig Config.""" + from dataclasses import dataclass, field from pytorch_tabular.config import ModelConfig diff --git a/src/pytorch_tabular/models/mixture_density/config.py b/src/pytorch_tabular/models/mixture_density/config.py index 4fb4734a..428e5871 100644 --- a/src/pytorch_tabular/models/mixture_density/config.py +++ b/src/pytorch_tabular/models/mixture_density/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Mixture Density Head Config.""" + from dataclasses import dataclass, field from typing import Dict diff --git a/src/pytorch_tabular/models/mixture_density/mdn.py b/src/pytorch_tabular/models/mixture_density/mdn.py index b9caab41..6ae02db1 100644 --- a/src/pytorch_tabular/models/mixture_density/mdn.py +++ b/src/pytorch_tabular/models/mixture_density/mdn.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Mixture Density Models.""" + from typing import Dict, Optional, Union import torch diff --git a/src/pytorch_tabular/models/node/architecture_blocks.py b/src/pytorch_tabular/models/node/architecture_blocks.py index 22440bb1..c9d8059b 100644 --- a/src/pytorch_tabular/models/node/architecture_blocks.py +++ b/src/pytorch_tabular/models/node/architecture_blocks.py @@ -3,6 +3,7 @@ # https://github.com/Qwicen/node # For license information, see https://github.com/Qwicen/node/blob/master/LICENSE.md """Dense ODST Block.""" + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/src/pytorch_tabular/models/node/node_model.py b/src/pytorch_tabular/models/node/node_model.py index ab6ed1df..c774bb1a 100644 --- a/src/pytorch_tabular/models/node/node_model.py +++ b/src/pytorch_tabular/models/node/node_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Tabular Model.""" + import warnings import torch diff --git a/src/pytorch_tabular/models/tab_transformer/config.py b/src/pytorch_tabular/models/tab_transformer/config.py index 8b45ead7..d38986a5 100644 --- a/src/pytorch_tabular/models/tab_transformer/config.py +++ b/src/pytorch_tabular/models/tab_transformer/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """AutomaticFeatureInteraction Config.""" + from dataclasses import dataclass, field from typing import Optional diff --git a/src/pytorch_tabular/models/tab_transformer/tab_transformer.py b/src/pytorch_tabular/models/tab_transformer/tab_transformer.py index 24fc91b7..da12d833 100644 --- a/src/pytorch_tabular/models/tab_transformer/tab_transformer.py +++ b/src/pytorch_tabular/models/tab_transformer/tab_transformer.py @@ -12,6 +12,7 @@ # 4. LabML Annotated Deep Learning Papers - The position-wise FF was shamelessly copied from # https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers """TabTransformer Model.""" + from collections import OrderedDict from typing import Dict diff --git a/src/pytorch_tabular/models/tabnet/config.py b/src/pytorch_tabular/models/tabnet/config.py index 83253d78..ade0c6a0 100644 --- a/src/pytorch_tabular/models/tabnet/config.py +++ b/src/pytorch_tabular/models/tabnet/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Tabnet Model Config.""" + from dataclasses import dataclass, field from typing import List, Optional diff --git a/src/pytorch_tabular/models/tabnet/tabnet_model.py b/src/pytorch_tabular/models/tabnet/tabnet_model.py index 672b7c24..a11a117e 100644 --- a/src/pytorch_tabular/models/tabnet/tabnet_model.py +++ b/src/pytorch_tabular/models/tabnet/tabnet_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """TabNet Model.""" + from typing import Dict import torch diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 19720cde..7db2b226 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """SSL Base Model.""" + import warnings from abc import ABCMeta, abstractmethod from typing import Dict, Optional diff --git a/src/pytorch_tabular/ssl_models/common/heads.py b/src/pytorch_tabular/ssl_models/common/heads.py index b7656195..4a0b8351 100644 --- a/src/pytorch_tabular/ssl_models/common/heads.py +++ b/src/pytorch_tabular/ssl_models/common/heads.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """SSL Heads.""" + import torch.nn as nn diff --git a/src/pytorch_tabular/ssl_models/common/noise_generators.py b/src/pytorch_tabular/ssl_models/common/noise_generators.py index bd4cb563..2da372b4 100644 --- a/src/pytorch_tabular/ssl_models/common/noise_generators.py +++ b/src/pytorch_tabular/ssl_models/common/noise_generators.py @@ -3,6 +3,7 @@ # For license information, see LICENSE.TXT # Inspired by implementation https://github.com/ryancheunggit/tabular_dae """DenoisingAutoEncoder Model.""" + import numpy as np import torch import torch.nn as nn diff --git a/src/pytorch_tabular/ssl_models/common/utils.py b/src/pytorch_tabular/ssl_models/common/utils.py index 35524c99..629d83ba 100644 --- a/src/pytorch_tabular/ssl_models/common/utils.py +++ b/src/pytorch_tabular/ssl_models/common/utils.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Utilities.""" + import torch.nn as nn import torch.nn.functional as F diff --git a/src/pytorch_tabular/ssl_models/dae/config.py b/src/pytorch_tabular/ssl_models/dae/config.py index b0508924..b1f74885 100644 --- a/src/pytorch_tabular/ssl_models/dae/config.py +++ b/src/pytorch_tabular/ssl_models/dae/config.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """DenoisingAutoEncoder Config.""" + from dataclasses import dataclass, field from typing import Dict, List, Optional diff --git a/src/pytorch_tabular/ssl_models/dae/dae.py b/src/pytorch_tabular/ssl_models/dae/dae.py index 8f2885e9..172586c4 100644 --- a/src/pytorch_tabular/ssl_models/dae/dae.py +++ b/src/pytorch_tabular/ssl_models/dae/dae.py @@ -3,6 +3,7 @@ # For license information, see LICENSE.TXT # Inspired by implementation https://github.com/ryancheunggit/tabular_dae """DenoisingAutoEncoder Model.""" + from collections import namedtuple from typing import Dict diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 2150d3d9..917bc931 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Tabular Data Module.""" + import re import warnings from enum import Enum diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 92187f29..900aa904 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Tabular Model.""" + import inspect import os import warnings diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index 31e60bcb..9792ab61 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -2,6 +2,7 @@ # Author: Manu Joseph # For license information, see LICENSE.TXT """Tabular Model.""" + import warnings from collections import namedtuple from copy import deepcopy diff --git a/tests/___test_augmentations.py b/tests/___test_augmentations.py index a209ff1e..42e60b75 100644 --- a/tests/___test_augmentations.py +++ b/tests/___test_augmentations.py @@ -1,5 +1,6 @@ import numpy as np import torch + from pytorch_tabular.ssl_models.common.augmentations import _get_random_index, cutmix, mixup diff --git a/tests/test_autoint.py b/tests/test_autoint.py index 166dd8f7..025b0ed6 100644 --- a/tests/test_autoint.py +++ b/tests/test_autoint.py @@ -2,6 +2,7 @@ """Tests for `pytorch_tabular` package.""" import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import AutoIntConfig diff --git a/tests/test_categorical_embedding.py b/tests/test_categorical_embedding.py index efce3742..cce097ba 100644 --- a/tests/test_categorical_embedding.py +++ b/tests/test_categorical_embedding.py @@ -4,11 +4,12 @@ import numpy as np import pytest import torch +from sklearn.preprocessing import PowerTransformer + from pytorch_tabular import TabularModel from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import CategoryEmbeddingModelConfig -from sklearn.preprocessing import PowerTransformer def fake_metric(y_hat, y): diff --git a/tests/test_common.py b/tests/test_common.py index daf77872..5f7c4922 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,11 +1,16 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import copy import os import numpy as np import pytest import torch +from scipy.stats import uniform +from sklearn.metrics import accuracy_score, r2_score +from sklearn.model_selection import KFold + from pytorch_tabular import TabularModel, TabularModelTuner, model_sweep from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.config.config import SSLModelConfig @@ -20,9 +25,6 @@ TabNetModelConfig, ) from pytorch_tabular.ssl_models import DenoisingAutoEncoderConfig -from scipy.stats import uniform -from sklearn.metrics import accuracy_score, r2_score -from sklearn.model_selection import KFold # import os @@ -408,11 +410,12 @@ def test_save_for_inference( ) sv_dir = tmpdir.mkdir("saved_model") + model_name = "model.pt" if save_type == "pytorch" else "model.onnx" tabular_model.save_model_for_inference( - sv_dir / "model.pt" if type == "pytorch" else sv_dir / "model.onnx", + sv_dir / model_name, kind=save_type, ) - assert os.path.exists(sv_dir / "model.pt" if type == "pytorch" else sv_dir / "model.onnx") + assert os.path.exists(sv_dir / model_name) @pytest.mark.parametrize("model_config_class", MODEL_CONFIG_FEATURE_EXT_TEST) diff --git a/tests/test_danet.py b/tests/test_danet.py index dc01ecd7..f65ecd59 100644 --- a/tests/test_danet.py +++ b/tests/test_danet.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import DANetConfig diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index cc859af2..4285be95 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -1,13 +1,15 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import numpy as np import pytest +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import PowerTransformer + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import CategoryEmbeddingModelConfig from pytorch_tabular.tabular_datamodule import TabularDatamodule -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import PowerTransformer @pytest.mark.parametrize("multi_target", [True, False]) diff --git a/tests/test_ft_transformer.py b/tests/test_ft_transformer.py index f83b5b58..6fc626e3 100644 --- a/tests/test_ft_transformer.py +++ b/tests/test_ft_transformer.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig diff --git a/tests/test_gandalf.py b/tests/test_gandalf.py index 912cbd79..7a702d63 100644 --- a/tests/test_gandalf.py +++ b/tests/test_gandalf.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import GANDALFConfig diff --git a/tests/test_gate.py b/tests/test_gate.py index aed057f0..1dc17b8f 100644 --- a/tests/test_gate.py +++ b/tests/test_gate.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig diff --git a/tests/test_mdn.py b/tests/test_mdn.py index bd7fd546..ca78a7bc 100644 --- a/tests/test_mdn.py +++ b/tests/test_mdn.py @@ -2,6 +2,7 @@ """Tests for `pytorch_tabular` package.""" import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import MDNConfig diff --git a/tests/test_node.py b/tests/test_node.py index 31dcb06a..cfac47be 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig diff --git a/tests/test_ssl.py b/tests/test_ssl.py index aa92ac07..1e4b65dd 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,12 +1,14 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest import torch +from sklearn.model_selection import train_test_split + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import CategoryEmbeddingModelConfig from pytorch_tabular.ssl_models.dae import DenoisingAutoEncoderConfig -from sklearn.model_selection import train_test_split def fake_metric(y_hat, y): diff --git a/tests/test_tabnet.py b/tests/test_tabnet.py index 30b6ec99..cb135117 100644 --- a/tests/test_tabnet.py +++ b/tests/test_tabnet.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models import TabNetModelConfig diff --git a/tests/test_tabtransformer.py b/tests/test_tabtransformer.py index 2e0c1478..b0b64a93 100644 --- a/tests/test_tabtransformer.py +++ b/tests/test_tabtransformer.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Tests for `pytorch_tabular` package.""" + import pytest + from pytorch_tabular import TabularModel from pytorch_tabular.categorical_encoders import CategoricalEmbeddingTransformer from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig