From f2c2780d415ea4c8a1af18e22ae8751a3944ca9d Mon Sep 17 00:00:00 2001 From: Manu Joseph Date: Sun, 24 Nov 2024 06:43:59 +0530 Subject: [PATCH 1/4] added informative repr, str and repr_html --- .gitignore | 1 + src/pytorch_tabular/tabular_model.py | 585 ++++++++++++++++++++++----- 2 files changed, 477 insertions(+), 109 deletions(-) diff --git a/.gitignore b/.gitignore index 79707d0f..ecb1e054 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ docs/tutorials/pytorch-tabular-covertype/ # Pycharm .idea/ +test.ipynb diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 11234934..472ffda6 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -3,7 +3,10 @@ # For license information, see LICENSE.TXT """Tabular Model.""" +import html import inspect +import json +import uuid import os import warnings from collections import defaultdict @@ -22,11 +25,14 @@ from pandas import DataFrame from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import RichProgressBar -from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( + GradientAccumulationScheduler, +) from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities.model_summary import summarize from rich import print as rich_print from rich.pretty import pprint +from pprint import pformat from sklearn.base import TransformerMixin from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold from torch import nn @@ -41,7 +47,11 @@ ) from pytorch_tabular.config.config import InferredConfig from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel -from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer +from pytorch_tabular.models.common.layers.embeddings import ( + Embedding1dLayer, + Embedding2dLayer, + PreEncoded1dLayer, +) from pytorch_tabular.tabular_datamodule import TabularDatamodule from pytorch_tabular.utils import ( OOMException, @@ -119,17 +129,23 @@ def __init__( self.verbose = verbose self.exp_manager = ExperimentRunManager() if config is None: - assert any(c is not None for c in (data_config, model_config, optimizer_config, trainer_config)), ( + assert any( + c is not None + for c in (data_config, model_config, optimizer_config, trainer_config) + ), ( "If `config` is None, `data_config`, `model_config`," " `trainer_config`, and `optimizer_config` cannot be None" ) data_config = self._read_parse_config(data_config, DataConfig) model_config = self._read_parse_config(model_config, ModelConfig) trainer_config = self._read_parse_config(trainer_config, TrainerConfig) - optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig) + optimizer_config = self._read_parse_config( + optimizer_config, OptimizerConfig + ) if model_config.task != "ssl": assert data_config.target is not None, ( - "`target` in data_config should not be None for" f" {model_config.task} task" + "`target` in data_config should not be None for" + f" {model_config.task} task" ) if experiment_config is None: if self.verbose: @@ -142,7 +158,9 @@ def __init__( OmegaConf.to_container(optimizer_config), ) else: - experiment_config = self._read_parse_config(experiment_config, ExperimentConfig) + experiment_config = self._read_parse_config( + experiment_config, ExperimentConfig + ) self.track_experiment = True self.config = OmegaConf.merge( OmegaConf.to_container(data_config), @@ -169,7 +187,9 @@ def __init__( self.exp_manager = ExperimentRunManager() if model_callable is None: - self.model_callable = getattr_nested(self.config._module_src, self.config._model_name) + self.model_callable = getattr_nested( + self.config._module_src, self.config._model_name + ) self.custom_model = False else: self.model_callable = model_callable @@ -216,7 +236,9 @@ def _run_validation(self): if ( (len(self.config.target_range) != len(self.config.target)) or any(len(range_) != 2 for range_ in self.config.target_range) - or any(range_[0] > range_[1] for range_ in self.config.target_range) + or any( + range_[0] > range_[1] for range_ in self.config.target_range + ) ): raise ValueError( "Targe Range, if defined, should be list tuples of length" @@ -234,7 +256,8 @@ def _read_parse_config(self, config, cls): **{ k: v for k, v in _config.items() - if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init) + if (k in cls.__dataclass_fields__.keys()) + and (cls.__dataclass_fields__[k].init) } ) else: @@ -251,7 +274,10 @@ def _get_run_name_uid(self) -> Tuple[str, int]: """ if hasattr(self.config, "run_name") and self.config.run_name is not None: name = self.config.run_name - elif hasattr(self.config, "checkpoints_name") and self.config.checkpoints_name is not None: + elif ( + hasattr(self.config, "checkpoints_name") + and self.config.checkpoints_name is not None + ): name = self.config.checkpoints_name else: name = self.config.task @@ -262,7 +288,9 @@ def _setup_experiment_tracking(self): """Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig.""" if self.config.log_target == "tensorboard": self.logger = pl.loggers.TensorBoardLogger( - name=self.run_name, save_dir=self.config.project_name, version=self.uid + name=self.run_name, + save_dir=self.config.project_name, + version=self.uid, ) elif self.config.log_target == "wandb": self.logger = pl.loggers.WandbLogger( @@ -272,7 +300,8 @@ def _setup_experiment_tracking(self): ) else: raise NotImplementedError( - f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]" + f"{self.config.log_target} is not implemented. Try one of [wandb," + " tensorboard]" ) def _prepare_callbacks(self, callbacks=None) -> List: @@ -308,13 +337,17 @@ def _prepare_callbacks(self, callbacks=None) -> List: self.config.enable_checkpointing = True else: self.config.enable_checkpointing = False - if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True): + if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get( + "enable_progress_bar", True + ): callbacks.append(RichProgressBar()) if self.verbose: logger.debug(f"Callbacks used: {callbacks}") return callbacks - def _prepare_trainer(self, callbacks: List, max_epochs: int = None, min_epochs: int = None) -> pl.Trainer: + def _prepare_trainer( + self, callbacks: List, max_epochs: int = None, min_epochs: int = None + ) -> pl.Trainer: """Prepares the Trainer object. Args: @@ -335,11 +368,15 @@ def _prepare_trainer(self, callbacks: List, max_epochs: int = None, min_epochs: # Getting Trainer Arguments from the init signature trainer_sig = inspect.signature(pl.Trainer.__init__) trainer_args = [p for p in trainer_sig.parameters.keys() if p != "self"] - trainer_args_config = {k: v for k, v in self.config.items() if k in trainer_args} + trainer_args_config = { + k: v for k, v in self.config.items() if k in trainer_args + } # For some weird reason, checkpoint_callback is not appearing in the Trainer vars trainer_args_config["enable_checkpointing"] = self.config.enable_checkpointing # turn off progress bar if progress_bar=='none' - trainer_args_config["enable_progress_bar"] = self.config.progress_bar != "none" + trainer_args_config["enable_progress_bar"] = ( + self.config.progress_bar != "none" + ) # Adding trainer_kwargs from config to trainer_args trainer_args_config.update(self.config.trainer_kwargs) if trainer_args_config["devices"] == -1: @@ -362,14 +399,20 @@ def _check_and_set_target_transform(self, target_transform): pass else: raise ValueError( - "`target_transform` should wither be an sklearn Transformer or a" " tuple of callables." + "`target_transform` should wither be an sklearn Transformer or a" + " tuple of callables." ) if self.config.task == "classification" and target_transform is not None: - logger.warning("For classification task, target transform is not used. Ignoring the" " parameter") + logger.warning( + "For classification task, target transform is not used. Ignoring the" + " parameter" + ) target_transform = None return target_transform - def _prepare_for_training(self, model, datamodule, callbacks=None, max_epochs=None, min_epochs=None): + def _prepare_for_training( + self, model, datamodule, callbacks=None, max_epochs=None, min_epochs=None + ): self.callbacks = self._prepare_callbacks(callbacks) self.trainer = self._prepare_trainer(self.callbacks, max_epochs, min_epochs) self.model = model @@ -419,11 +462,17 @@ def load_model(cls, dir: str, map_location=None, strict=True): callbacks = joblib.load(os.path.join(dir, "callbacks.sav")) # Excluding Gradient Accumulation Scheduler Callback as we are creating # a new one in trainer - callbacks = [c for c in callbacks if not isinstance(c, GradientAccumulationScheduler)] + callbacks = [ + c + for c in callbacks + if not isinstance(c, GradientAccumulationScheduler) + ] else: callbacks = [] if os.path.exists(os.path.join(dir, "custom_model_callable.sav")): - model_callable = joblib.load(os.path.join(dir, "custom_model_callable.sav")) + model_callable = joblib.load( + os.path.join(dir, "custom_model_callable.sav") + ) custom_model = True else: model_callable = getattr_nested(config._module_src, config._model_name) @@ -441,7 +490,9 @@ def load_model(cls, dir: str, map_location=None, strict=True): if custom_params.get("custom_loss") is not None: model_args["loss"] = "MSELoss" # For compatibility. Not Used if custom_params.get("custom_metrics") is not None: - model_args["metrics"] = ["mean_squared_error"] # For compatibility. Not Used + model_args["metrics"] = [ + "mean_squared_error" + ] # For compatibility. Not Used model_args["metrics_params"] = [{}] # For compatibility. Not Used model_args["metrics_prob_inputs"] = [False] # For compatibility. Not Used if custom_params.get("custom_optimizer") is not None: @@ -481,9 +532,13 @@ def load_model(cls, dir: str, map_location=None, strict=True): model.loss = custom_params["custom_loss"] if custom_params.get("custom_metrics") is not None: model.custom_metrics = custom_params.get("custom_metrics") - model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")] + model.hparams.metrics = [ + m.__name__ for m in custom_params.get("custom_metrics") + ] model.hparams.metrics_params = [{}] - model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs") + model.hparams.metrics_prob_input = custom_params.get( + "custom_metrics_prob_inputs" + ) model._setup_loss() model._setup_metrics() tabular_model = cls(config=config, model_callable=model_callable) @@ -586,7 +641,9 @@ def prepare_model( if self.verbose: logger.info(f"Preparing the Model: {self.config._model_name}") # Fetching the config as some data specific configs have been added in the datamodule - self.inferred_config = self._read_parse_config(datamodule.update_config(self.config), InferredConfig) + self.inferred_config = self._read_parse_config( + datamodule.update_config(self.config), InferredConfig + ) model = self.model_callable( self.config, custom_loss=loss, # Unused in SSL tasks @@ -601,7 +658,9 @@ def prepare_model( if self.model_state_dict_path is not None: self._load_weights(model, self.model_state_dict_path) if self.track_experiment and self.config.log_target == "wandb": - self.logger.watch(model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq) + self.logger.watch( + model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq + ) return model def train( @@ -633,7 +692,9 @@ def train( pl.Trainer: The PyTorch Lightning Trainer instance """ - self._prepare_for_training(model, datamodule, callbacks, max_epochs, min_epochs) + self._prepare_for_training( + model, datamodule, callbacks, max_epochs, min_epochs + ) train_loader, val_loader = ( self.datamodule.train_dataloader(), self.datamodule.val_dataloader(), @@ -651,7 +712,10 @@ def train( if oom_handler.oom_triggered: raise OOMException( "OOM detected during LR Find. Try reducing your batch_size or the" - " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg + " model parameters." + + "/n" + + "Original Error: " + + oom_handler.oom_msg ) if self.verbose: logger.info( @@ -757,7 +821,8 @@ def fit( """ assert self.config.task != "ssl", ( - "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning" + "`fit` is not valid for SSL task. Please use `pretrain` for" + " semi-supervised learning" ) if metrics is not None: assert len(metrics) == len( @@ -791,7 +856,9 @@ def fit( optimizer_params or {}, ) - return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom) + return self.train( + model, datamodule, callbacks, max_epochs, min_epochs, handle_oom + ) def pretrain( self, @@ -840,7 +907,8 @@ def pretrain( """ assert self.config.task == "ssl", ( - f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead." + f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" + " instead." ) seed = seed or self.config.seed if seed: @@ -986,7 +1054,10 @@ def create_finetune_model( if self.track_experiment: # Renaming the experiment run so that a different log is created for finetuning if self.verbose: - logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}") + logger.info( + "Renaming the experiment run for finetuning as" + f" {config['run_name'] + '_finetuned'}" + ) config["run_name"] = config["run_name"] + "_finetuned" datamodule = self.datamodule.copy( @@ -1007,9 +1078,16 @@ def create_finetune_model( if not hasattr(config, "metrics_prob_input"): config.metrics_prob_input = metrics_prob_input or [False] if metrics is not None: - assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same" - assert len(metrics) == len(metrics_prob_input), "Number of metrics and metrics_prob_input should be same" - metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics] + assert len(metrics) == len( + metrics_params + ), "Number of metrics and metrics_params should be same" + assert len(metrics) == len( + metrics_prob_input + ), "Number of metrics and metrics_prob_input should be same" + metrics = [ + getattr(torchmetrics.functional, m) if isinstance(m, str) else m + for m in metrics + ] if task == "regression": loss = loss or torch.nn.MSELoss() if metrics is None: @@ -1031,7 +1109,9 @@ def create_finetune_model( for i, mp in enumerate(metrics_params): # For classification task, output_dim == number of classses metrics_params[i]["task"] = mp.get("task", "multiclass") - metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim) + metrics_params[i]["num_classes"] = mp.get( + "num_classes", inferred_config.output_dim + ) metrics_params[i]["top_k"] = mp.get("top_k", 1) else: raise ValueError(f"Task {task} not supported") @@ -1089,7 +1169,8 @@ def finetune( """ assert self._is_finetune_model, ( - "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`" + "finetune() can only be called on a finetune model created using" + " `TabularModel.create_finetune_model()`" ) seed_everything(self.config.seed) if freeze_backbone: @@ -1146,7 +1227,9 @@ def find_learning_rate( The suggested learning rate and the learning rate finder results """ - self._prepare_for_training(model, datamodule, callbacks, max_epochs=None, min_epochs=None) + self._prepare_for_training( + model, datamodule, callbacks, max_epochs=None, min_epochs=None + ) train_loader, _ = datamodule.train_dataloader(), datamodule.val_dataloader() lr_finder = Tuner(self.trainer).lr_find( model=self.model, @@ -1227,11 +1310,15 @@ def _generate_predictions( continue # Skipping empty list batch[k] = v.to(model.device) if is_probabilistic: - samples, ret_value = model.sample(batch, n_samples, ret_model_output=True) + samples, ret_value = model.sample( + batch, n_samples, ret_model_output=True + ) y_hat = torch.mean(samples, dim=-1) quantile_preds = [] for q in quantiles: - quantile_preds.append(torch.quantile(samples, q=q, dim=-1).unsqueeze(1)) + quantile_preds.append( + torch.quantile(samples, q=q, dim=-1).unsqueeze(1) + ) else: y_hat, ret_value = model.predict(batch, ret_model_output=True) if ret_logits: @@ -1239,12 +1326,16 @@ def _generate_predictions( logits_predictions[k].append(v.detach().cpu()) point_predictions.append(y_hat.detach().cpu()) if is_probabilistic: - quantile_predictions.append(torch.cat(quantile_preds, dim=-1).detach().cpu()) + quantile_predictions.append( + torch.cat(quantile_preds, dim=-1).detach().cpu() + ) point_predictions = torch.cat(point_predictions, dim=0) if point_predictions.ndim == 1: point_predictions = point_predictions.unsqueeze(-1) if is_probabilistic: - quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze(-1) + quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze( + -1 + ) if quantile_predictions.ndim == 2: quantile_predictions = quantile_predictions.unsqueeze(-1) return point_predictions, quantile_predictions, logits_predictions @@ -1260,7 +1351,9 @@ def _format_predicitons( include_input_features, is_probabilistic, ): - pred_df = test.copy() if include_input_features else DataFrame(index=test.index) + pred_df = ( + test.copy() if include_input_features else DataFrame(index=test.index) + ) if self.config.task == "regression": point_predictions = point_predictions.numpy() # Probabilistic Models are only implemented for Regression @@ -1269,35 +1362,50 @@ def _format_predicitons( for i, target_col in enumerate(self.config.target): if self.datamodule.do_target_transform: if self.config.target[i] in pred_df.columns: - pred_df[self.config.target[i]] = self.datamodule.target_transforms[i].inverse_transform( + pred_df[ + self.config.target[i] + ] = self.datamodule.target_transforms[i].inverse_transform( pred_df[self.config.target[i]].values.reshape(-1, 1) ) - pred_df[f"{target_col}_prediction"] = self.datamodule.target_transforms[i].inverse_transform( - point_predictions[:, i].reshape(-1, 1) + pred_df[f"{target_col}_prediction"] = ( + self.datamodule.target_transforms[i].inverse_transform( + point_predictions[:, i].reshape(-1, 1) + ) ) if is_probabilistic: for j, q in enumerate(quantiles): col_ = f"{target_col}_q{int(q*100)}" - pred_df[col_] = self.datamodule.target_transforms[i].inverse_transform( + pred_df[col_] = self.datamodule.target_transforms[ + i + ].inverse_transform( quantile_predictions[:, j, i].reshape(-1, 1) ) else: pred_df[f"{target_col}_prediction"] = point_predictions[:, i] if is_probabilistic: for j, q in enumerate(quantiles): - pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1) + pred_df[f"{target_col}_q{int(q*100)}"] = ( + quantile_predictions[:, j, i].reshape(-1, 1) + ) elif self.config.task == "classification": start_index = 0 for i, target_col in enumerate(self.config.target): - end_index = start_index + self.datamodule._inferred_config.output_cardinality[i] - prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy() + end_index = ( + start_index + + self.datamodule._inferred_config.output_cardinality[i] + ) + prob_prediction = nn.Softmax(dim=-1)( + point_predictions[:, start_index:end_index] + ).numpy() start_index = end_index for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_): - pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j] - pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform( - np.argmax(prob_prediction, axis=1) - ) + pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[ + :, j + ] + pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[ + i + ].inverse_transform(np.argmax(prob_prediction, axis=1)) warnings.warn( "Classification prediction column will be renamed to" " `{target_col}_prediction` in the next release to maintain" @@ -1347,7 +1455,9 @@ def _predict( If classification, it returns probabilities and final prediction """ - assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1" + assert all( + q <= 1 and q >= 0 for q in quantiles + ), "Quantiles should be a decimal between 0 and 1" model = self.model # default if device is not None: if isinstance(device, str): @@ -1356,7 +1466,9 @@ def _predict( model = self.model.to(device) model.eval() inference_dataloader = self.datamodule.prepare_inference_dataloader(test) - is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic + is_probabilistic = ( + hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic + ) if progress_bar == "rich": from rich.progress import track @@ -1368,14 +1480,16 @@ def _predict( progress_bar = partial(tqdm, description="Generating Predictions...") else: progress_bar = lambda it: it # E731 - point_predictions, quantile_predictions, logits_predictions = self._generate_predictions( - model, - inference_dataloader, - quantiles, - n_samples, - ret_logits, - progress_bar, - is_probabilistic, + point_predictions, quantile_predictions, logits_predictions = ( + self._generate_predictions( + model, + inference_dataloader, + quantiles, + n_samples, + ret_logits, + progress_bar, + is_probabilistic, + ) ) pred_df = self._format_predicitons( test, @@ -1458,7 +1572,9 @@ def predict( if test_time_augmentation: assert num_tta > 0, "num_tta should be greater than 0" assert alpha_tta > 0, "alpha_tta should be greater than 0" - assert include_input_features is False, "include_input_features cannot be True for TTA." + assert ( + include_input_features is False + ), "include_input_features cannot be True for TTA." if not callable(aggregate_tta): assert aggregate_tta in [ "mean", @@ -1466,14 +1582,21 @@ def predict( "min", "max", "hard_voting", - ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" + ], ( + "aggregate should be one of 'mean', 'median', 'min', 'max', or" + " 'hard_voting'" + ) if self.config.task == "regression": - assert aggregate_tta != "hard_voting", "hard_voting is only available for classification" + assert ( + aggregate_tta != "hard_voting" + ), "hard_voting is only available for classification" torch.manual_seed(tta_seed) def add_noise(module, input, output): - return output + alpha_tta * torch.randn_like(output, memory_format=torch.contiguous_format) + return output + alpha_tta * torch.randn_like( + output, memory_format=torch.contiguous_format + ) # Register the hook to the embedding_layer handle = self.model.embedding_layer.register_forward_hook(add_noise) @@ -1493,7 +1616,9 @@ def add_noise(module, input, output): pred_prob_l.append(pred_df.values[:, : -len(self.config.target)]) elif self.config.task == "regression": pred_prob_l.append(pred_df.values) - pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None) + pred_df = self._combine_predictions( + pred_prob_l, pred_idx, aggregate_tta, None + ) # Remove the hook handle.remove() else: @@ -1520,10 +1645,14 @@ def load_best_model(self) -> None: ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) self.model.load_state_dict(ckpt["state_dict"]) else: - logger.warning("No best model available to load. Did you run it more than 1" " epoch?...") + logger.warning( + "No best model available to load. Did you run it more than 1" + " epoch?..." + ) else: logger.warning( - "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work" + "No best model available to load. Checkpoint Callback needs to be" + " enabled for this to work" ) def save_datamodule(self, dir: str, inference_only: bool = False) -> None: @@ -1572,12 +1701,20 @@ def save_model(self, dir: str, inference_only: bool = False) -> None: custom_params = {} custom_params["custom_loss"] = getattr(self.model, "custom_loss", None) custom_params["custom_metrics"] = getattr(self.model, "custom_metrics", None) - custom_params["custom_metrics_prob_inputs"] = getattr(self.model, "custom_metrics_prob_inputs", None) - custom_params["custom_optimizer"] = getattr(self.model, "custom_optimizer", None) - custom_params["custom_optimizer_params"] = getattr(self.model, "custom_optimizer_params", None) + custom_params["custom_metrics_prob_inputs"] = getattr( + self.model, "custom_metrics_prob_inputs", None + ) + custom_params["custom_optimizer"] = getattr( + self.model, "custom_optimizer", None + ) + custom_params["custom_optimizer_params"] = getattr( + self.model, "custom_optimizer_params", None + ) joblib.dump(custom_params, os.path.join(dir, "custom_params.sav")) if self.custom_model: - joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav")) + joblib.dump( + self.model_callable, os.path.join(dir, "custom_model_callable.sav") + ) def save_weights(self, path: Union[str, Path]) -> None: """Saves the model weights in the specified directory. @@ -1622,7 +1759,9 @@ def save_model_for_inference( elif kind == "onnx": # Export the model onnx_export_params["input_names"] = ["categorical", "continuous"] - onnx_export_params["output_names"] = onnx_export_params.get("output_names", ["output"]) + onnx_export_params["output_names"] = onnx_export_params.get( + "output_names", ["output"] + ) onnx_export_params["dynamic_axes"] = { onnx_export_params["input_names"][0]: {0: "batch_size"}, onnx_export_params["output_names"][0]: {0: "batch_size"}, @@ -1666,8 +1805,169 @@ def summary(self, model=None, max_depth: int = -1) -> None: "been initialized or passed in as an argument[/bold red]" ) + def ret_summary(self, model=None, max_depth: int = -1) -> str: + """Returns a summary of the model as a string. + + Args: + max_depth (int): The maximum depth to traverse the modules and displayed in the summary. + Defaults to -1, which means will display all the modules. + + Returns: + str: The summary of the model. + """ + if model is not None: + return str(summarize(model, max_depth=max_depth)) + elif self.has_model: + return str(summarize(self.model, max_depth=max_depth)) + else: + summary_str = f"{self.__class__.__name__}\n" + summary_str += "-" * 100 + "\n" + summary_str += "Config\n" + summary_str += "-" * 100 + "\n" + summary_str += pformat( + self.config.__dict__["_content"], indent=4, width=80, compact=True + ) + summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument" + return summary_str + def __str__(self) -> str: - return self.summary() + """Returns a readable summary of the TabularModel object.""" + return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else 'None'})" + + def _repr_html_(self): + """Generate an HTML representation for Jupyter Notebook.""" + css = """ + + + """ + + # Header (Main model name) + uid = str(uuid.uuid4()) + model_status = "" if self.has_model else "(Not Initialized)" + header_html = f"
{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}
" + + # Config Section + config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True) + + # Summary Section + summary_html = "" if not self.has_model else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid) + + # Combine sections + return f""" + {css} +
+ {header_html} + {config_html} + {summary_html} +
+ """ + + def _generate_collapsible_section(self, title, content, uid, is_dict=False): + container_id = title.lower().replace(" ", "_") + uid + if is_dict: + content = self._generate_nested_collapsible_sections(OmegaConf.to_container(content, resolve=True), container_id) + return f""" +
+ + {html.escape(title)} + +
+ """ + + def _generate_nested_collapsible_sections(self, content, parent_id): + html_content = "" + for key, value in content.items(): + if isinstance(value, dict): + nested_id = f"{parent_id}_{key}".replace(" ", "_") + nested_id = nested_id + str(uuid.uuid4()) + nested_content = self._generate_nested_collapsible_sections(value, nested_id) + html_content += f""" +
+ + {html.escape(key)} + +
+ """ + else: + html_content += f"
{html.escape(key)}: {html.escape(str(value))}
" + return html_content + + def _generate_model_summary_table(self): + model_summary = summarize(self.model, max_depth=1) + table_html = "" + for name, layer in model_summary._layer_summary.items(): + table_html += f"" + table_html += "
LayerTypeParamsIn sizesOut sizes
{html.escape(name)}{html.escape(layer.layer_type)}{html.escape(str(layer.num_parameters))}{html.escape(str(layer.in_size))}{html.escape(str(layer.out_size))}
" + return table_html def feature_importance(self) -> DataFrame: """Returns the feature importance of the model as a pandas DataFrame.""" @@ -1711,7 +2011,8 @@ def _prepare_baselines_captum( baselines = baselines.mean(dim=0, keepdim=True) else: raise ValueError( - "Invalid value for `baselines`. Please refer to the documentation" " for more details." + "Invalid value for `baselines`. Please refer to the documentation" + " for more details." ) return baselines @@ -1728,7 +2029,11 @@ def _handle_categorical_embeddings_attributions( cat_attributions = [] index_counter = self.model.hparams.continuous_dim for _, embed_dim in self.model.hparams.embedding_dims: - cat_attributions.append(attributions[:, index_counter : index_counter + embed_dim].sum(dim=1)) + cat_attributions.append( + attributions[ + :, index_counter : index_counter + embed_dim + ].sum(dim=1) + ) index_counter += embed_dim cat_attributions = torch.stack(cat_attributions, dim=1) attributions = torch.cat( @@ -1779,7 +2084,9 @@ def explain( DataFrame: The dataframe with the feature importance """ - assert CAPTUM_INSTALLED, "Captum not installed. Please install using `pip install captum` or " + assert ( + CAPTUM_INSTALLED + ), "Captum not installed. Please install using `pip install captum` or " "install PyTorch Tabular using `pip install pytorch-tabular[extra]`" ALLOWED_METHODS = [ "GradientShap", @@ -1801,7 +2108,9 @@ def explain( " IntegratedGradients etc." ) if method in ["FeaturePermutation", "FeatureAblation"]: - assert data.shape[0] > 1, f"{method} only works when the number of samples is greater than 1" + assert ( + data.shape[0] > 1 + ), f"{method} only works when the number of samples is greater than 1" if len(data) <= 100: warnings.warn( f"{method} gives better results when the number of samples is" @@ -1820,44 +2129,60 @@ def explain( "FeaturePermutation", "LRP", ] - if is_full_baselines and (baselines is None or isinstance(baselines, (float, int))): + if is_full_baselines and ( + baselines is None or isinstance(baselines, (float, int)) + ): raise ValueError( f"baselines cannot be a scalar or None for {method}. Please " "provide a tensor or a string like `b|`" ) if is_not_supported: - raise NotImplementedError(f"Attributions are not implemented for {self.model._get_name()}") + raise NotImplementedError( + f"Attributions are not implemented for {self.model._get_name()}" + ) - is_embedding1d = isinstance(self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer)) + is_embedding1d = isinstance( + self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer) + ) is_embedding2d = isinstance(self.model.embedding_layer, Embedding2dLayer) # Models like NODE may have no embedding dims (doing leaveOneOut encoding) even if categorical_dim > 0 is_embbeding_dims = ( - hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None + hasattr(self.model.hparams, "embedding_dims") + and self.model.hparams.embedding_dims is not None ) if (not is_embedding1d) and (not is_embedding2d): raise NotImplementedError( - "Attributions are not implemented for models with this type of" " embedding layer" + "Attributions are not implemented for models with this type of" + " embedding layer" ) test_dl = self.datamodule.prepare_inference_dataloader(data) self.model.eval() # prepare import for Captum tensor_inp, tensor_tgt = self._prepare_input_for_captum(test_dl) - baselines = self._prepare_baselines_captum(baselines, test_dl, do_baselines, is_full_baselines) + baselines = self._prepare_baselines_captum( + baselines, test_dl, do_baselines, is_full_baselines + ) # prepare model for Captum try: interp_model = _CaptumModel(self.model) - captum_interp_cls = getattr(captum.attr, method)(interp_model, **method_args) + captum_interp_cls = getattr(captum.attr, method)( + interp_model, **method_args + ) if do_baselines: attributions = captum_interp_cls.attribute( tensor_inp, baselines=baselines, - target=(tensor_tgt if self.config.task == "classification" else None), + target=( + tensor_tgt if self.config.task == "classification" else None + ), **kwargs, ) else: attributions = captum_interp_cls.attribute( tensor_inp, - target=(tensor_tgt if self.config.task == "classification" else None), + target=( + tensor_tgt if self.config.task == "classification" else None + ), **kwargs, ) attributions = self._handle_categorical_embeddings_attributions( @@ -1865,7 +2190,10 @@ def explain( ) finally: self.model.train() - assert attributions.shape[1] == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim, ( + assert ( + attributions.shape[1] + == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim + ), ( "Something went wrong. The number of features in the attributions" f" ({attributions.shape[1]}) does not match the number of features in" " the model" @@ -1979,7 +2307,9 @@ def cross_validate( is_callable_metric = True if isinstance(cv, BaseCrossValidator): - it = enumerate(cv.split(train, y=train[self.config.target], groups=groups)) + it = enumerate( + cv.split(train, y=train[self.config.target], groups=groups) + ) else: # when iterable is directly passed it = enumerate(cv) @@ -1998,27 +2328,40 @@ def cross_validate( # Initialize datamodule and model in the first fold # uses train data from this fold to fit all transformers datamodule = self.prepare_dataloader( - train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs + train=train.iloc[train_idx], + validation=train.iloc[val_idx], + seed=42, + **prep_dl_kwargs, ) model = self.prepare_model(datamodule, **prep_model_kwargs) else: # Preprocess the current fold data using the fitted transformers and save in datamodule - datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference") - datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference") + datamodule.train, _ = datamodule.preprocess_data( + train.iloc[train_idx], stage="inference" + ) + datamodule.validation, _ = datamodule.preprocess_data( + train.iloc[val_idx], stage="inference" + ) # Train the model handle_oom = train_kwargs.pop("handle_oom", handle_oom) self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs) if return_oof or is_callable_metric: - preds = self.predict(train.iloc[val_idx], include_input_features=False) + preds = self.predict( + train.iloc[val_idx], include_input_features=False + ) oof_preds.append(preds) if is_callable_metric: - cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds)) + cv_metrics.append( + metric(train.iloc[val_idx][self.config.target], preds) + ) else: result = self.evaluate(train.iloc[val_idx], verbose=False) cv_metrics.append(result[0][metric]) if verbose: - logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}") + logger.info( + f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}" + ) self.model.reset_weights() return cv_metrics, oof_preds @@ -2052,7 +2395,11 @@ def _combine_predictions( if aggregate == "hard_voting": pred_df = pd.DataFrame( np.concatenate(pred_prob_l, axis=1), - columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes], + columns=[ + f"{c}_probability_fold_{i}" + for i in range(len(pred_prob_l)) + for c in classes + ], index=pred_idx, ) pred_df["prediction"] = classes[final_pred] @@ -2061,14 +2408,21 @@ def _combine_predictions( pred_df = pd.DataFrame( bagged_pred, # FIXME - columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_], + columns=[ + f"{c}_probability" + for c in self.datamodule.label_encoder[0].classes_ + ], index=pred_idx, ) pred_df["prediction"] = final_pred elif self.config.task == "regression": - pred_df = pd.DataFrame(bagged_pred, columns=self.config.target, index=pred_idx) + pred_df = pd.DataFrame( + bagged_pred, columns=self.config.target, index=pred_idx + ) else: - raise NotImplementedError(f"Task {self.config.task} not supported for bagging") + raise NotImplementedError( + f"Task {self.config.task} not supported for bagging" + ) return pred_df def bagging_predict( @@ -2140,23 +2494,30 @@ def bagging_predict( """ if weights is not None: - assert len(weights) == cv.n_splits, "Number of weights should be equal to the number of folds" + assert ( + len(weights) == cv.n_splits + ), "Number of weights should be equal to the number of folds" assert self.config.task in [ "classification", "regression", ], "Bagging is only available for classification and regression" if not callable(aggregate): assert aggregate in ["mean", "median", "min", "max", "hard_voting"], ( - "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" + "aggregate should be one of 'mean', 'median', 'min', 'max', or" + " 'hard_voting'" ) if self.config.task == "regression": - assert aggregate != "hard_voting", "hard_voting is only available for classification" + assert ( + aggregate != "hard_voting" + ), "hard_voting is only available for classification" cv = self._check_cv(cv) prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs) pred_prob_l = [] datamodule = None model = None - for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)): + for fold, (train_idx, val_idx) in enumerate( + cv.split(train, y=train[self.config.target], groups=groups) + ): if verbose: logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}") train_fold = train.iloc[train_idx] @@ -2166,12 +2527,18 @@ def bagging_predict( if datamodule is None: # Initialize datamodule and model in the first fold # uses train data from this fold to fit all transformers - datamodule = self.prepare_dataloader(train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs) + datamodule = self.prepare_dataloader( + train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs + ) model = self.prepare_model(datamodule, **prep_model_kwargs) else: # Preprocess the current fold data using the fitted transformers and save in datamodule - datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference") - datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference") + datamodule.train, _ = datamodule.preprocess_data( + train_fold, stage="inference" + ) + datamodule.validation, _ = datamodule.preprocess_data( + val_fold, stage="inference" + ) # Train the model handle_oom = train_kwargs.pop("handle_oom", handle_oom) From be9c5634356083983c01e79494cfdad642f95b06 Mon Sep 17 00:00:00 2001 From: Manu Joseph Date: Sun, 24 Nov 2024 07:23:30 +0530 Subject: [PATCH 2/4] fixed some issues and added test cases --- src/pytorch_tabular/tabular_model.py | 35 +++++++++++++--- tests/test_common.py | 62 +++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 472ffda6..a6c01993 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1832,7 +1832,20 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: def __str__(self) -> str: """Returns a readable summary of the TabularModel object.""" - return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else 'None'})" + return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else self.config._model_name+'(Not Initialized)'})" + + def __repr__(self) -> str: + """Returns an unambiguous representation of the TabularModel object.""" + config_str = json.dumps( + OmegaConf.to_container(self.config, resolve=True), indent=4 + ) + ret_str = f"{self.__class__.__name__}(\n" + if self.has_model: + ret_str += f" model={self.model.__class__.__name__},\n" + else: + ret_str += f" model={self.config._model_name} (Not Initialized),\n" + ret_str += f" config={config_str},\n" + return ret_str def _repr_html_(self): """Generate an HTML representation for Jupyter Notebook.""" @@ -1912,10 +1925,18 @@ def _repr_html_(self): header_html = f"
{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}
" # Config Section - config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True) + config_html = self._generate_collapsible_section( + "Model Config", self.config, uid=uid, is_dict=True + ) # Summary Section - summary_html = "" if not self.has_model else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid) + summary_html = ( + "" + if not self.has_model + else self._generate_collapsible_section( + "Model Summary", self._generate_model_summary_table(), uid=uid + ) + ) # Combine sections return f""" @@ -1930,7 +1951,9 @@ def _repr_html_(self): def _generate_collapsible_section(self, title, content, uid, is_dict=False): container_id = title.lower().replace(" ", "_") + uid if is_dict: - content = self._generate_nested_collapsible_sections(OmegaConf.to_container(content, resolve=True), container_id) + content = self._generate_nested_collapsible_sections( + OmegaConf.to_container(content, resolve=True), container_id + ) return f"""
@@ -1947,7 +1970,9 @@ def _generate_nested_collapsible_sections(self, content, parent_id): if isinstance(value, dict): nested_id = f"{parent_id}_{key}".replace(" ", "_") nested_id = nested_id + str(uuid.uuid4()) - nested_content = self._generate_nested_collapsible_sections(value, nested_id) + nested_content = self._generate_nested_collapsible_sections( + value, nested_id + ) html_content += f"""
diff --git a/tests/test_common.py b/tests/test_common.py index bd5d428e..b67bc529 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,7 +31,7 @@ MODEL_CONFIG_SAVE_TEST = [ (CategoryEmbeddingModelConfig, {"layers": "10-20"}), - (AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}), + (GANDALFConfig, {}), (NodeConfig, {"num_trees": 100, "depth": 2}), (TabNetModelConfig, {"n_a": 2, "n_d": 2}), ] @@ -1247,3 +1247,63 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols, # # there may be multiple models with the same score # best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist() # assert best_model.model._get_name() in best_models + +@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST) +@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)]) +@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) +@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]]) +@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()]) +@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"]) +def test_str_repr( + regression_data, + model_config_class, + continuous_cols, + categorical_cols, + custom_metrics, + custom_loss, + custom_optimizer, +): + (train, test, target) = regression_data + data_config = DataConfig( + target=target, + continuous_cols=continuous_cols, + categorical_cols=categorical_cols, + ) + model_config_class, model_config_params = model_config_class + model_config_params["task"] = "regression" + model_config = model_config_class(**model_config_params) + trainer_config = TrainerConfig( + max_epochs=3, + checkpoints=None, + early_stopping=None, + accelerator="cpu", + fast_dev_run=True, + ) + optimizer_config = OptimizerConfig() + + tabular_model = TabularModel( + data_config=data_config, + model_config=model_config, + optimizer_config=optimizer_config, + trainer_config=trainer_config, + ) + assert "Not Initialized" in str(tabular_model) + assert "Not Initialized" in repr(tabular_model) + assert "Model Summary" not in tabular_model._repr_html_() + assert "Model Config" in tabular_model._repr_html_() + assert "config" in tabular_model.__repr__() + assert "config" not in str(tabular_model) + tabular_model.fit( + train=train, + metrics=custom_metrics, + metrics_prob_inputs=None if custom_metrics is None else [False], + loss=custom_loss, + optimizer=custom_optimizer, + optimizer_params={} + ) + assert model_config_class._model_name in str(tabular_model) + assert model_config_class._model_name in repr(tabular_model) + assert "Model Summary" in tabular_model._repr_html_() + assert "Model Config" in tabular_model._repr_html_() + assert "config" in tabular_model.__repr__() + assert model_config_class._model_name in tabular_model._repr_html_() From a9a949e88fefcd8de40a2227562e508988925e84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 00:28:01 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/tabular_model.py | 424 +++++++-------------------- tests/test_common.py | 3 +- 2 files changed, 114 insertions(+), 313 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index a6c01993..d0bbea0f 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -6,12 +6,13 @@ import html import inspect import json -import uuid import os +import uuid import warnings from collections import defaultdict from functools import partial from pathlib import Path +from pprint import pformat from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import joblib @@ -32,7 +33,6 @@ from pytorch_lightning.utilities.model_summary import summarize from rich import print as rich_print from rich.pretty import pprint -from pprint import pformat from sklearn.base import TransformerMixin from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold from torch import nn @@ -129,23 +129,17 @@ def __init__( self.verbose = verbose self.exp_manager = ExperimentRunManager() if config is None: - assert any( - c is not None - for c in (data_config, model_config, optimizer_config, trainer_config) - ), ( + assert any(c is not None for c in (data_config, model_config, optimizer_config, trainer_config)), ( "If `config` is None, `data_config`, `model_config`," " `trainer_config`, and `optimizer_config` cannot be None" ) data_config = self._read_parse_config(data_config, DataConfig) model_config = self._read_parse_config(model_config, ModelConfig) trainer_config = self._read_parse_config(trainer_config, TrainerConfig) - optimizer_config = self._read_parse_config( - optimizer_config, OptimizerConfig - ) + optimizer_config = self._read_parse_config(optimizer_config, OptimizerConfig) if model_config.task != "ssl": assert data_config.target is not None, ( - "`target` in data_config should not be None for" - f" {model_config.task} task" + "`target` in data_config should not be None for" f" {model_config.task} task" ) if experiment_config is None: if self.verbose: @@ -158,9 +152,7 @@ def __init__( OmegaConf.to_container(optimizer_config), ) else: - experiment_config = self._read_parse_config( - experiment_config, ExperimentConfig - ) + experiment_config = self._read_parse_config(experiment_config, ExperimentConfig) self.track_experiment = True self.config = OmegaConf.merge( OmegaConf.to_container(data_config), @@ -187,9 +179,7 @@ def __init__( self.exp_manager = ExperimentRunManager() if model_callable is None: - self.model_callable = getattr_nested( - self.config._module_src, self.config._model_name - ) + self.model_callable = getattr_nested(self.config._module_src, self.config._model_name) self.custom_model = False else: self.model_callable = model_callable @@ -236,9 +226,7 @@ def _run_validation(self): if ( (len(self.config.target_range) != len(self.config.target)) or any(len(range_) != 2 for range_ in self.config.target_range) - or any( - range_[0] > range_[1] for range_ in self.config.target_range - ) + or any(range_[0] > range_[1] for range_ in self.config.target_range) ): raise ValueError( "Targe Range, if defined, should be list tuples of length" @@ -256,8 +244,7 @@ def _read_parse_config(self, config, cls): **{ k: v for k, v in _config.items() - if (k in cls.__dataclass_fields__.keys()) - and (cls.__dataclass_fields__[k].init) + if (k in cls.__dataclass_fields__.keys()) and (cls.__dataclass_fields__[k].init) } ) else: @@ -274,10 +261,7 @@ def _get_run_name_uid(self) -> Tuple[str, int]: """ if hasattr(self.config, "run_name") and self.config.run_name is not None: name = self.config.run_name - elif ( - hasattr(self.config, "checkpoints_name") - and self.config.checkpoints_name is not None - ): + elif hasattr(self.config, "checkpoints_name") and self.config.checkpoints_name is not None: name = self.config.checkpoints_name else: name = self.config.task @@ -300,8 +284,7 @@ def _setup_experiment_tracking(self): ) else: raise NotImplementedError( - f"{self.config.log_target} is not implemented. Try one of [wandb," - " tensorboard]" + f"{self.config.log_target} is not implemented. Try one of [wandb," " tensorboard]" ) def _prepare_callbacks(self, callbacks=None) -> List: @@ -337,17 +320,13 @@ def _prepare_callbacks(self, callbacks=None) -> List: self.config.enable_checkpointing = True else: self.config.enable_checkpointing = False - if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get( - "enable_progress_bar", True - ): + if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True): callbacks.append(RichProgressBar()) if self.verbose: logger.debug(f"Callbacks used: {callbacks}") return callbacks - def _prepare_trainer( - self, callbacks: List, max_epochs: int = None, min_epochs: int = None - ) -> pl.Trainer: + def _prepare_trainer(self, callbacks: List, max_epochs: int = None, min_epochs: int = None) -> pl.Trainer: """Prepares the Trainer object. Args: @@ -368,15 +347,11 @@ def _prepare_trainer( # Getting Trainer Arguments from the init signature trainer_sig = inspect.signature(pl.Trainer.__init__) trainer_args = [p for p in trainer_sig.parameters.keys() if p != "self"] - trainer_args_config = { - k: v for k, v in self.config.items() if k in trainer_args - } + trainer_args_config = {k: v for k, v in self.config.items() if k in trainer_args} # For some weird reason, checkpoint_callback is not appearing in the Trainer vars trainer_args_config["enable_checkpointing"] = self.config.enable_checkpointing # turn off progress bar if progress_bar=='none' - trainer_args_config["enable_progress_bar"] = ( - self.config.progress_bar != "none" - ) + trainer_args_config["enable_progress_bar"] = self.config.progress_bar != "none" # Adding trainer_kwargs from config to trainer_args trainer_args_config.update(self.config.trainer_kwargs) if trainer_args_config["devices"] == -1: @@ -399,20 +374,14 @@ def _check_and_set_target_transform(self, target_transform): pass else: raise ValueError( - "`target_transform` should wither be an sklearn Transformer or a" - " tuple of callables." + "`target_transform` should wither be an sklearn Transformer or a" " tuple of callables." ) if self.config.task == "classification" and target_transform is not None: - logger.warning( - "For classification task, target transform is not used. Ignoring the" - " parameter" - ) + logger.warning("For classification task, target transform is not used. Ignoring the" " parameter") target_transform = None return target_transform - def _prepare_for_training( - self, model, datamodule, callbacks=None, max_epochs=None, min_epochs=None - ): + def _prepare_for_training(self, model, datamodule, callbacks=None, max_epochs=None, min_epochs=None): self.callbacks = self._prepare_callbacks(callbacks) self.trainer = self._prepare_trainer(self.callbacks, max_epochs, min_epochs) self.model = model @@ -462,17 +431,11 @@ def load_model(cls, dir: str, map_location=None, strict=True): callbacks = joblib.load(os.path.join(dir, "callbacks.sav")) # Excluding Gradient Accumulation Scheduler Callback as we are creating # a new one in trainer - callbacks = [ - c - for c in callbacks - if not isinstance(c, GradientAccumulationScheduler) - ] + callbacks = [c for c in callbacks if not isinstance(c, GradientAccumulationScheduler)] else: callbacks = [] if os.path.exists(os.path.join(dir, "custom_model_callable.sav")): - model_callable = joblib.load( - os.path.join(dir, "custom_model_callable.sav") - ) + model_callable = joblib.load(os.path.join(dir, "custom_model_callable.sav")) custom_model = True else: model_callable = getattr_nested(config._module_src, config._model_name) @@ -490,9 +453,7 @@ def load_model(cls, dir: str, map_location=None, strict=True): if custom_params.get("custom_loss") is not None: model_args["loss"] = "MSELoss" # For compatibility. Not Used if custom_params.get("custom_metrics") is not None: - model_args["metrics"] = [ - "mean_squared_error" - ] # For compatibility. Not Used + model_args["metrics"] = ["mean_squared_error"] # For compatibility. Not Used model_args["metrics_params"] = [{}] # For compatibility. Not Used model_args["metrics_prob_inputs"] = [False] # For compatibility. Not Used if custom_params.get("custom_optimizer") is not None: @@ -532,13 +493,9 @@ def load_model(cls, dir: str, map_location=None, strict=True): model.loss = custom_params["custom_loss"] if custom_params.get("custom_metrics") is not None: model.custom_metrics = custom_params.get("custom_metrics") - model.hparams.metrics = [ - m.__name__ for m in custom_params.get("custom_metrics") - ] + model.hparams.metrics = [m.__name__ for m in custom_params.get("custom_metrics")] model.hparams.metrics_params = [{}] - model.hparams.metrics_prob_input = custom_params.get( - "custom_metrics_prob_inputs" - ) + model.hparams.metrics_prob_input = custom_params.get("custom_metrics_prob_inputs") model._setup_loss() model._setup_metrics() tabular_model = cls(config=config, model_callable=model_callable) @@ -641,9 +598,7 @@ def prepare_model( if self.verbose: logger.info(f"Preparing the Model: {self.config._model_name}") # Fetching the config as some data specific configs have been added in the datamodule - self.inferred_config = self._read_parse_config( - datamodule.update_config(self.config), InferredConfig - ) + self.inferred_config = self._read_parse_config(datamodule.update_config(self.config), InferredConfig) model = self.model_callable( self.config, custom_loss=loss, # Unused in SSL tasks @@ -658,9 +613,7 @@ def prepare_model( if self.model_state_dict_path is not None: self._load_weights(model, self.model_state_dict_path) if self.track_experiment and self.config.log_target == "wandb": - self.logger.watch( - model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq - ) + self.logger.watch(model, log=self.config.exp_watch, log_freq=self.config.exp_log_freq) return model def train( @@ -692,9 +645,7 @@ def train( pl.Trainer: The PyTorch Lightning Trainer instance """ - self._prepare_for_training( - model, datamodule, callbacks, max_epochs, min_epochs - ) + self._prepare_for_training(model, datamodule, callbacks, max_epochs, min_epochs) train_loader, val_loader = ( self.datamodule.train_dataloader(), self.datamodule.val_dataloader(), @@ -712,10 +663,7 @@ def train( if oom_handler.oom_triggered: raise OOMException( "OOM detected during LR Find. Try reducing your batch_size or the" - " model parameters." - + "/n" - + "Original Error: " - + oom_handler.oom_msg + " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg ) if self.verbose: logger.info( @@ -821,8 +769,7 @@ def fit( """ assert self.config.task != "ssl", ( - "`fit` is not valid for SSL task. Please use `pretrain` for" - " semi-supervised learning" + "`fit` is not valid for SSL task. Please use `pretrain` for" " semi-supervised learning" ) if metrics is not None: assert len(metrics) == len( @@ -856,9 +803,7 @@ def fit( optimizer_params or {}, ) - return self.train( - model, datamodule, callbacks, max_epochs, min_epochs, handle_oom - ) + return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom) def pretrain( self, @@ -907,8 +852,7 @@ def pretrain( """ assert self.config.task == "ssl", ( - f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" - " instead." + f"`pretrain` is not valid for {self.config.task} task. Please use `fit`" " instead." ) seed = seed or self.config.seed if seed: @@ -1054,10 +998,7 @@ def create_finetune_model( if self.track_experiment: # Renaming the experiment run so that a different log is created for finetuning if self.verbose: - logger.info( - "Renaming the experiment run for finetuning as" - f" {config['run_name'] + '_finetuned'}" - ) + logger.info("Renaming the experiment run for finetuning as" f" {config['run_name'] + '_finetuned'}") config["run_name"] = config["run_name"] + "_finetuned" datamodule = self.datamodule.copy( @@ -1078,16 +1019,9 @@ def create_finetune_model( if not hasattr(config, "metrics_prob_input"): config.metrics_prob_input = metrics_prob_input or [False] if metrics is not None: - assert len(metrics) == len( - metrics_params - ), "Number of metrics and metrics_params should be same" - assert len(metrics) == len( - metrics_prob_input - ), "Number of metrics and metrics_prob_input should be same" - metrics = [ - getattr(torchmetrics.functional, m) if isinstance(m, str) else m - for m in metrics - ] + assert len(metrics) == len(metrics_params), "Number of metrics and metrics_params should be same" + assert len(metrics) == len(metrics_prob_input), "Number of metrics and metrics_prob_input should be same" + metrics = [getattr(torchmetrics.functional, m) if isinstance(m, str) else m for m in metrics] if task == "regression": loss = loss or torch.nn.MSELoss() if metrics is None: @@ -1109,9 +1043,7 @@ def create_finetune_model( for i, mp in enumerate(metrics_params): # For classification task, output_dim == number of classses metrics_params[i]["task"] = mp.get("task", "multiclass") - metrics_params[i]["num_classes"] = mp.get( - "num_classes", inferred_config.output_dim - ) + metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim) metrics_params[i]["top_k"] = mp.get("top_k", 1) else: raise ValueError(f"Task {task} not supported") @@ -1169,8 +1101,7 @@ def finetune( """ assert self._is_finetune_model, ( - "finetune() can only be called on a finetune model created using" - " `TabularModel.create_finetune_model()`" + "finetune() can only be called on a finetune model created using" " `TabularModel.create_finetune_model()`" ) seed_everything(self.config.seed) if freeze_backbone: @@ -1227,9 +1158,7 @@ def find_learning_rate( The suggested learning rate and the learning rate finder results """ - self._prepare_for_training( - model, datamodule, callbacks, max_epochs=None, min_epochs=None - ) + self._prepare_for_training(model, datamodule, callbacks, max_epochs=None, min_epochs=None) train_loader, _ = datamodule.train_dataloader(), datamodule.val_dataloader() lr_finder = Tuner(self.trainer).lr_find( model=self.model, @@ -1310,15 +1239,11 @@ def _generate_predictions( continue # Skipping empty list batch[k] = v.to(model.device) if is_probabilistic: - samples, ret_value = model.sample( - batch, n_samples, ret_model_output=True - ) + samples, ret_value = model.sample(batch, n_samples, ret_model_output=True) y_hat = torch.mean(samples, dim=-1) quantile_preds = [] for q in quantiles: - quantile_preds.append( - torch.quantile(samples, q=q, dim=-1).unsqueeze(1) - ) + quantile_preds.append(torch.quantile(samples, q=q, dim=-1).unsqueeze(1)) else: y_hat, ret_value = model.predict(batch, ret_model_output=True) if ret_logits: @@ -1326,16 +1251,12 @@ def _generate_predictions( logits_predictions[k].append(v.detach().cpu()) point_predictions.append(y_hat.detach().cpu()) if is_probabilistic: - quantile_predictions.append( - torch.cat(quantile_preds, dim=-1).detach().cpu() - ) + quantile_predictions.append(torch.cat(quantile_preds, dim=-1).detach().cpu()) point_predictions = torch.cat(point_predictions, dim=0) if point_predictions.ndim == 1: point_predictions = point_predictions.unsqueeze(-1) if is_probabilistic: - quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze( - -1 - ) + quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze(-1) if quantile_predictions.ndim == 2: quantile_predictions = quantile_predictions.unsqueeze(-1) return point_predictions, quantile_predictions, logits_predictions @@ -1351,9 +1272,7 @@ def _format_predicitons( include_input_features, is_probabilistic, ): - pred_df = ( - test.copy() if include_input_features else DataFrame(index=test.index) - ) + pred_df = test.copy() if include_input_features else DataFrame(index=test.index) if self.config.task == "regression": point_predictions = point_predictions.numpy() # Probabilistic Models are only implemented for Regression @@ -1362,50 +1281,35 @@ def _format_predicitons( for i, target_col in enumerate(self.config.target): if self.datamodule.do_target_transform: if self.config.target[i] in pred_df.columns: - pred_df[ - self.config.target[i] - ] = self.datamodule.target_transforms[i].inverse_transform( + pred_df[self.config.target[i]] = self.datamodule.target_transforms[i].inverse_transform( pred_df[self.config.target[i]].values.reshape(-1, 1) ) - pred_df[f"{target_col}_prediction"] = ( - self.datamodule.target_transforms[i].inverse_transform( - point_predictions[:, i].reshape(-1, 1) - ) + pred_df[f"{target_col}_prediction"] = self.datamodule.target_transforms[i].inverse_transform( + point_predictions[:, i].reshape(-1, 1) ) if is_probabilistic: for j, q in enumerate(quantiles): col_ = f"{target_col}_q{int(q*100)}" - pred_df[col_] = self.datamodule.target_transforms[ - i - ].inverse_transform( + pred_df[col_] = self.datamodule.target_transforms[i].inverse_transform( quantile_predictions[:, j, i].reshape(-1, 1) ) else: pred_df[f"{target_col}_prediction"] = point_predictions[:, i] if is_probabilistic: for j, q in enumerate(quantiles): - pred_df[f"{target_col}_q{int(q*100)}"] = ( - quantile_predictions[:, j, i].reshape(-1, 1) - ) + pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1) elif self.config.task == "classification": start_index = 0 for i, target_col in enumerate(self.config.target): - end_index = ( - start_index - + self.datamodule._inferred_config.output_cardinality[i] - ) - prob_prediction = nn.Softmax(dim=-1)( - point_predictions[:, start_index:end_index] - ).numpy() + end_index = start_index + self.datamodule._inferred_config.output_cardinality[i] + prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy() start_index = end_index for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_): - pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[ - :, j - ] - pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[ - i - ].inverse_transform(np.argmax(prob_prediction, axis=1)) + pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j] + pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform( + np.argmax(prob_prediction, axis=1) + ) warnings.warn( "Classification prediction column will be renamed to" " `{target_col}_prediction` in the next release to maintain" @@ -1455,9 +1359,7 @@ def _predict( If classification, it returns probabilities and final prediction """ - assert all( - q <= 1 and q >= 0 for q in quantiles - ), "Quantiles should be a decimal between 0 and 1" + assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1" model = self.model # default if device is not None: if isinstance(device, str): @@ -1466,9 +1368,7 @@ def _predict( model = self.model.to(device) model.eval() inference_dataloader = self.datamodule.prepare_inference_dataloader(test) - is_probabilistic = ( - hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic - ) + is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic if progress_bar == "rich": from rich.progress import track @@ -1480,16 +1380,14 @@ def _predict( progress_bar = partial(tqdm, description="Generating Predictions...") else: progress_bar = lambda it: it # E731 - point_predictions, quantile_predictions, logits_predictions = ( - self._generate_predictions( - model, - inference_dataloader, - quantiles, - n_samples, - ret_logits, - progress_bar, - is_probabilistic, - ) + point_predictions, quantile_predictions, logits_predictions = self._generate_predictions( + model, + inference_dataloader, + quantiles, + n_samples, + ret_logits, + progress_bar, + is_probabilistic, ) pred_df = self._format_predicitons( test, @@ -1572,9 +1470,7 @@ def predict( if test_time_augmentation: assert num_tta > 0, "num_tta should be greater than 0" assert alpha_tta > 0, "alpha_tta should be greater than 0" - assert ( - include_input_features is False - ), "include_input_features cannot be True for TTA." + assert include_input_features is False, "include_input_features cannot be True for TTA." if not callable(aggregate_tta): assert aggregate_tta in [ "mean", @@ -1582,21 +1478,14 @@ def predict( "min", "max", "hard_voting", - ], ( - "aggregate should be one of 'mean', 'median', 'min', 'max', or" - " 'hard_voting'" - ) + ], "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" if self.config.task == "regression": - assert ( - aggregate_tta != "hard_voting" - ), "hard_voting is only available for classification" + assert aggregate_tta != "hard_voting", "hard_voting is only available for classification" torch.manual_seed(tta_seed) def add_noise(module, input, output): - return output + alpha_tta * torch.randn_like( - output, memory_format=torch.contiguous_format - ) + return output + alpha_tta * torch.randn_like(output, memory_format=torch.contiguous_format) # Register the hook to the embedding_layer handle = self.model.embedding_layer.register_forward_hook(add_noise) @@ -1616,9 +1505,7 @@ def add_noise(module, input, output): pred_prob_l.append(pred_df.values[:, : -len(self.config.target)]) elif self.config.task == "regression": pred_prob_l.append(pred_df.values) - pred_df = self._combine_predictions( - pred_prob_l, pred_idx, aggregate_tta, None - ) + pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None) # Remove the hook handle.remove() else: @@ -1645,14 +1532,10 @@ def load_best_model(self) -> None: ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) self.model.load_state_dict(ckpt["state_dict"]) else: - logger.warning( - "No best model available to load. Did you run it more than 1" - " epoch?..." - ) + logger.warning("No best model available to load. Did you run it more than 1" " epoch?...") else: logger.warning( - "No best model available to load. Checkpoint Callback needs to be" - " enabled for this to work" + "No best model available to load. Checkpoint Callback needs to be" " enabled for this to work" ) def save_datamodule(self, dir: str, inference_only: bool = False) -> None: @@ -1701,20 +1584,12 @@ def save_model(self, dir: str, inference_only: bool = False) -> None: custom_params = {} custom_params["custom_loss"] = getattr(self.model, "custom_loss", None) custom_params["custom_metrics"] = getattr(self.model, "custom_metrics", None) - custom_params["custom_metrics_prob_inputs"] = getattr( - self.model, "custom_metrics_prob_inputs", None - ) - custom_params["custom_optimizer"] = getattr( - self.model, "custom_optimizer", None - ) - custom_params["custom_optimizer_params"] = getattr( - self.model, "custom_optimizer_params", None - ) + custom_params["custom_metrics_prob_inputs"] = getattr(self.model, "custom_metrics_prob_inputs", None) + custom_params["custom_optimizer"] = getattr(self.model, "custom_optimizer", None) + custom_params["custom_optimizer_params"] = getattr(self.model, "custom_optimizer_params", None) joblib.dump(custom_params, os.path.join(dir, "custom_params.sav")) if self.custom_model: - joblib.dump( - self.model_callable, os.path.join(dir, "custom_model_callable.sav") - ) + joblib.dump(self.model_callable, os.path.join(dir, "custom_model_callable.sav")) def save_weights(self, path: Union[str, Path]) -> None: """Saves the model weights in the specified directory. @@ -1759,9 +1634,7 @@ def save_model_for_inference( elif kind == "onnx": # Export the model onnx_export_params["input_names"] = ["categorical", "continuous"] - onnx_export_params["output_names"] = onnx_export_params.get( - "output_names", ["output"] - ) + onnx_export_params["output_names"] = onnx_export_params.get("output_names", ["output"]) onnx_export_params["dynamic_axes"] = { onnx_export_params["input_names"][0]: {0: "batch_size"}, onnx_export_params["output_names"][0]: {0: "batch_size"}, @@ -1814,6 +1687,7 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: Returns: str: The summary of the model. + """ if model is not None: return str(summarize(model, max_depth=max_depth)) @@ -1824,9 +1698,7 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: summary_str += "-" * 100 + "\n" summary_str += "Config\n" summary_str += "-" * 100 + "\n" - summary_str += pformat( - self.config.__dict__["_content"], indent=4, width=80, compact=True - ) + summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True) summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument" return summary_str @@ -1836,9 +1708,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """Returns an unambiguous representation of the TabularModel object.""" - config_str = json.dumps( - OmegaConf.to_container(self.config, resolve=True), indent=4 - ) + config_str = json.dumps(OmegaConf.to_container(self.config, resolve=True), indent=4) ret_str = f"{self.__class__.__name__}(\n" if self.has_model: ret_str += f" model={self.model.__class__.__name__},\n" @@ -1925,17 +1795,13 @@ def _repr_html_(self): header_html = f"
{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}
" # Config Section - config_html = self._generate_collapsible_section( - "Model Config", self.config, uid=uid, is_dict=True - ) + config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True) # Summary Section summary_html = ( "" if not self.has_model - else self._generate_collapsible_section( - "Model Summary", self._generate_model_summary_table(), uid=uid - ) + else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid) ) # Combine sections @@ -1970,9 +1836,7 @@ def _generate_nested_collapsible_sections(self, content, parent_id): if isinstance(value, dict): nested_id = f"{parent_id}_{key}".replace(" ", "_") nested_id = nested_id + str(uuid.uuid4()) - nested_content = self._generate_nested_collapsible_sections( - value, nested_id - ) + nested_content = self._generate_nested_collapsible_sections(value, nested_id) html_content += f"""
@@ -2036,8 +1900,7 @@ def _prepare_baselines_captum( baselines = baselines.mean(dim=0, keepdim=True) else: raise ValueError( - "Invalid value for `baselines`. Please refer to the documentation" - " for more details." + "Invalid value for `baselines`. Please refer to the documentation" " for more details." ) return baselines @@ -2054,11 +1917,7 @@ def _handle_categorical_embeddings_attributions( cat_attributions = [] index_counter = self.model.hparams.continuous_dim for _, embed_dim in self.model.hparams.embedding_dims: - cat_attributions.append( - attributions[ - :, index_counter : index_counter + embed_dim - ].sum(dim=1) - ) + cat_attributions.append(attributions[:, index_counter : index_counter + embed_dim].sum(dim=1)) index_counter += embed_dim cat_attributions = torch.stack(cat_attributions, dim=1) attributions = torch.cat( @@ -2109,9 +1968,7 @@ def explain( DataFrame: The dataframe with the feature importance """ - assert ( - CAPTUM_INSTALLED - ), "Captum not installed. Please install using `pip install captum` or " + assert CAPTUM_INSTALLED, "Captum not installed. Please install using `pip install captum` or " "install PyTorch Tabular using `pip install pytorch-tabular[extra]`" ALLOWED_METHODS = [ "GradientShap", @@ -2133,9 +1990,7 @@ def explain( " IntegratedGradients etc." ) if method in ["FeaturePermutation", "FeatureAblation"]: - assert ( - data.shape[0] > 1 - ), f"{method} only works when the number of samples is greater than 1" + assert data.shape[0] > 1, f"{method} only works when the number of samples is greater than 1" if len(data) <= 100: warnings.warn( f"{method} gives better results when the number of samples is" @@ -2154,60 +2009,44 @@ def explain( "FeaturePermutation", "LRP", ] - if is_full_baselines and ( - baselines is None or isinstance(baselines, (float, int)) - ): + if is_full_baselines and (baselines is None or isinstance(baselines, (float, int))): raise ValueError( f"baselines cannot be a scalar or None for {method}. Please " "provide a tensor or a string like `b|`" ) if is_not_supported: - raise NotImplementedError( - f"Attributions are not implemented for {self.model._get_name()}" - ) + raise NotImplementedError(f"Attributions are not implemented for {self.model._get_name()}") - is_embedding1d = isinstance( - self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer) - ) + is_embedding1d = isinstance(self.model.embedding_layer, (Embedding1dLayer, PreEncoded1dLayer)) is_embedding2d = isinstance(self.model.embedding_layer, Embedding2dLayer) # Models like NODE may have no embedding dims (doing leaveOneOut encoding) even if categorical_dim > 0 is_embbeding_dims = ( - hasattr(self.model.hparams, "embedding_dims") - and self.model.hparams.embedding_dims is not None + hasattr(self.model.hparams, "embedding_dims") and self.model.hparams.embedding_dims is not None ) if (not is_embedding1d) and (not is_embedding2d): raise NotImplementedError( - "Attributions are not implemented for models with this type of" - " embedding layer" + "Attributions are not implemented for models with this type of" " embedding layer" ) test_dl = self.datamodule.prepare_inference_dataloader(data) self.model.eval() # prepare import for Captum tensor_inp, tensor_tgt = self._prepare_input_for_captum(test_dl) - baselines = self._prepare_baselines_captum( - baselines, test_dl, do_baselines, is_full_baselines - ) + baselines = self._prepare_baselines_captum(baselines, test_dl, do_baselines, is_full_baselines) # prepare model for Captum try: interp_model = _CaptumModel(self.model) - captum_interp_cls = getattr(captum.attr, method)( - interp_model, **method_args - ) + captum_interp_cls = getattr(captum.attr, method)(interp_model, **method_args) if do_baselines: attributions = captum_interp_cls.attribute( tensor_inp, baselines=baselines, - target=( - tensor_tgt if self.config.task == "classification" else None - ), + target=(tensor_tgt if self.config.task == "classification" else None), **kwargs, ) else: attributions = captum_interp_cls.attribute( tensor_inp, - target=( - tensor_tgt if self.config.task == "classification" else None - ), + target=(tensor_tgt if self.config.task == "classification" else None), **kwargs, ) attributions = self._handle_categorical_embeddings_attributions( @@ -2215,10 +2054,7 @@ def explain( ) finally: self.model.train() - assert ( - attributions.shape[1] - == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim - ), ( + assert attributions.shape[1] == self.model.hparams.continuous_dim + self.model.hparams.categorical_dim, ( "Something went wrong. The number of features in the attributions" f" ({attributions.shape[1]}) does not match the number of features in" " the model" @@ -2332,9 +2168,7 @@ def cross_validate( is_callable_metric = True if isinstance(cv, BaseCrossValidator): - it = enumerate( - cv.split(train, y=train[self.config.target], groups=groups) - ) + it = enumerate(cv.split(train, y=train[self.config.target], groups=groups)) else: # when iterable is directly passed it = enumerate(cv) @@ -2361,32 +2195,22 @@ def cross_validate( model = self.prepare_model(datamodule, **prep_model_kwargs) else: # Preprocess the current fold data using the fitted transformers and save in datamodule - datamodule.train, _ = datamodule.preprocess_data( - train.iloc[train_idx], stage="inference" - ) - datamodule.validation, _ = datamodule.preprocess_data( - train.iloc[val_idx], stage="inference" - ) + datamodule.train, _ = datamodule.preprocess_data(train.iloc[train_idx], stage="inference") + datamodule.validation, _ = datamodule.preprocess_data(train.iloc[val_idx], stage="inference") # Train the model handle_oom = train_kwargs.pop("handle_oom", handle_oom) self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs) if return_oof or is_callable_metric: - preds = self.predict( - train.iloc[val_idx], include_input_features=False - ) + preds = self.predict(train.iloc[val_idx], include_input_features=False) oof_preds.append(preds) if is_callable_metric: - cv_metrics.append( - metric(train.iloc[val_idx][self.config.target], preds) - ) + cv_metrics.append(metric(train.iloc[val_idx][self.config.target], preds)) else: result = self.evaluate(train.iloc[val_idx], verbose=False) cv_metrics.append(result[0][metric]) if verbose: - logger.info( - f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}" - ) + logger.info(f"Fold {fold+1}/{cv.get_n_splits()} score: {cv_metrics[-1]}") self.model.reset_weights() return cv_metrics, oof_preds @@ -2420,11 +2244,7 @@ def _combine_predictions( if aggregate == "hard_voting": pred_df = pd.DataFrame( np.concatenate(pred_prob_l, axis=1), - columns=[ - f"{c}_probability_fold_{i}" - for i in range(len(pred_prob_l)) - for c in classes - ], + columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes], index=pred_idx, ) pred_df["prediction"] = classes[final_pred] @@ -2433,21 +2253,14 @@ def _combine_predictions( pred_df = pd.DataFrame( bagged_pred, # FIXME - columns=[ - f"{c}_probability" - for c in self.datamodule.label_encoder[0].classes_ - ], + columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_], index=pred_idx, ) pred_df["prediction"] = final_pred elif self.config.task == "regression": - pred_df = pd.DataFrame( - bagged_pred, columns=self.config.target, index=pred_idx - ) + pred_df = pd.DataFrame(bagged_pred, columns=self.config.target, index=pred_idx) else: - raise NotImplementedError( - f"Task {self.config.task} not supported for bagging" - ) + raise NotImplementedError(f"Task {self.config.task} not supported for bagging") return pred_df def bagging_predict( @@ -2519,30 +2332,23 @@ def bagging_predict( """ if weights is not None: - assert ( - len(weights) == cv.n_splits - ), "Number of weights should be equal to the number of folds" + assert len(weights) == cv.n_splits, "Number of weights should be equal to the number of folds" assert self.config.task in [ "classification", "regression", ], "Bagging is only available for classification and regression" if not callable(aggregate): assert aggregate in ["mean", "median", "min", "max", "hard_voting"], ( - "aggregate should be one of 'mean', 'median', 'min', 'max', or" - " 'hard_voting'" + "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'" ) if self.config.task == "regression": - assert ( - aggregate != "hard_voting" - ), "hard_voting is only available for classification" + assert aggregate != "hard_voting", "hard_voting is only available for classification" cv = self._check_cv(cv) prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs) pred_prob_l = [] datamodule = None model = None - for fold, (train_idx, val_idx) in enumerate( - cv.split(train, y=train[self.config.target], groups=groups) - ): + for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)): if verbose: logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}") train_fold = train.iloc[train_idx] @@ -2552,18 +2358,12 @@ def bagging_predict( if datamodule is None: # Initialize datamodule and model in the first fold # uses train data from this fold to fit all transformers - datamodule = self.prepare_dataloader( - train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs - ) + datamodule = self.prepare_dataloader(train=train_fold, validation=val_fold, seed=42, **prep_dl_kwargs) model = self.prepare_model(datamodule, **prep_model_kwargs) else: # Preprocess the current fold data using the fitted transformers and save in datamodule - datamodule.train, _ = datamodule.preprocess_data( - train_fold, stage="inference" - ) - datamodule.validation, _ = datamodule.preprocess_data( - val_fold, stage="inference" - ) + datamodule.train, _ = datamodule.preprocess_data(train_fold, stage="inference") + datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference") # Train the model handle_oom = train_kwargs.pop("handle_oom", handle_oom) diff --git a/tests/test_common.py b/tests/test_common.py index b67bc529..d0c3e26c 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1248,6 +1248,7 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols, # best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist() # assert best_model.model._get_name() in best_models + @pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST) @pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)]) @pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) @@ -1299,7 +1300,7 @@ def test_str_repr( metrics_prob_inputs=None if custom_metrics is None else [False], loss=custom_loss, optimizer=custom_optimizer, - optimizer_params={} + optimizer_params={}, ) assert model_config_class._model_name in str(tabular_model) assert model_config_class._model_name in repr(tabular_model) From ed602ba4c8461e48ca8bb621a39085031b8cc2cc Mon Sep 17 00:00:00 2001 From: Manu Joseph Date: Mon, 25 Nov 2024 09:54:24 +0530 Subject: [PATCH 4/4] fixed some precommit errors --- src/pytorch_tabular/tabular_model.py | 53 ++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index d0bbea0f..217e7b30 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1659,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None: """Prints a summary of the model. Args: - max_depth (int): The maximum depth to traverse the modules and displayed in the summary. - Defaults to -1, which means will display all the modules. + max_depth (int): The maximum depth to traverse the modules and + displayed in the summary. Defaults to -1, which means will + display all the modules. """ if model is not None: @@ -1682,8 +1683,9 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: """Returns a summary of the model as a string. Args: - max_depth (int): The maximum depth to traverse the modules and displayed in the summary. - Defaults to -1, which means will display all the modules. + max_depth (int): The maximum depth to traverse the modules and + displayed in the summary. Defaults to -1, which means will + display all the modules. Returns: str: The summary of the model. @@ -1699,12 +1701,13 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str: summary_str += "Config\n" summary_str += "-" * 100 + "\n" summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True) - summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument" + summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument" return summary_str def __str__(self) -> str: """Returns a readable summary of the TabularModel object.""" - return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else self.config._model_name+'(Not Initialized)'})" + model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + "(Not Initialized)" + return f"{self.__class__.__name__}(model={model_name})" def __repr__(self) -> str: """Returns an unambiguous representation of the TabularModel object.""" @@ -1792,7 +1795,8 @@ def _repr_html_(self): # Header (Main model name) uid = str(uuid.uuid4()) model_status = "" if self.has_model else "(Not Initialized)" - header_html = f"
{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}
" + model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + header_html = f"
{html.escape(model_name)}{model_status}
" # Config Section config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True) @@ -1822,7 +1826,12 @@ def _generate_collapsible_section(self, title, content, uid, is_dict=False): ) return f"""
- + + ▶ + {html.escape(title)}