diff --git a/.gitignore b/.gitignore index 79707d0f..ecb1e054 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ docs/tutorials/pytorch-tabular-covertype/ # Pycharm .idea/ +test.ipynb diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 11234934..217e7b30 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -3,12 +3,16 @@ # For license information, see LICENSE.TXT """Tabular Model.""" +import html import inspect +import json import os +import uuid import warnings from collections import defaultdict from functools import partial from pathlib import Path +from pprint import pformat from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import joblib @@ -22,7 +26,9 @@ from pandas import DataFrame from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import RichProgressBar -from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( + GradientAccumulationScheduler, +) from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities.model_summary import summarize from rich import print as rich_print @@ -41,7 +47,11 @@ ) from pytorch_tabular.config.config import InferredConfig from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel -from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer +from pytorch_tabular.models.common.layers.embeddings import ( + Embedding1dLayer, + Embedding2dLayer, + PreEncoded1dLayer, +) from pytorch_tabular.tabular_datamodule import TabularDatamodule from pytorch_tabular.utils import ( OOMException, @@ -262,7 +272,9 @@ def _setup_experiment_tracking(self): """Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig.""" if self.config.log_target == "tensorboard": self.logger = pl.loggers.TensorBoardLogger( - name=self.run_name, save_dir=self.config.project_name, version=self.uid + name=self.run_name, + save_dir=self.config.project_name, + version=self.uid, ) elif self.config.log_target == "wandb": self.logger = pl.loggers.WandbLogger( @@ -1647,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None: """Prints a summary of the model. Args: - max_depth (int): The maximum depth to traverse the modules and displayed in the summary. - Defaults to -1, which means will display all the modules. + max_depth (int): The maximum depth to traverse the modules and + displayed in the summary. Defaults to -1, which means will + display all the modules. """ if model is not None: @@ -1666,8 +1679,215 @@ def summary(self, model=None, max_depth: int = -1) -> None: "been initialized or passed in as an argument[/bold red]" ) + def ret_summary(self, model=None, max_depth: int = -1) -> str: + """Returns a summary of the model as a string. + + Args: + max_depth (int): The maximum depth to traverse the modules and + displayed in the summary. Defaults to -1, which means will + display all the modules. + + Returns: + str: The summary of the model. + + """ + if model is not None: + return str(summarize(model, max_depth=max_depth)) + elif self.has_model: + return str(summarize(self.model, max_depth=max_depth)) + else: + summary_str = f"{self.__class__.__name__}\n" + summary_str += "-" * 100 + "\n" + summary_str += "Config\n" + summary_str += "-" * 100 + "\n" + summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True) + summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument" + return summary_str + def __str__(self) -> str: - return self.summary() + """Returns a readable summary of the TabularModel object.""" + model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + "(Not Initialized)" + return f"{self.__class__.__name__}(model={model_name})" + + def __repr__(self) -> str: + """Returns an unambiguous representation of the TabularModel object.""" + config_str = json.dumps(OmegaConf.to_container(self.config, resolve=True), indent=4) + ret_str = f"{self.__class__.__name__}(\n" + if self.has_model: + ret_str += f" model={self.model.__class__.__name__},\n" + else: + ret_str += f" model={self.config._model_name} (Not Initialized),\n" + ret_str += f" config={config_str},\n" + return ret_str + + def _repr_html_(self): + """Generate an HTML representation for Jupyter Notebook.""" + css = """ + + + """ + + # Header (Main model name) + uid = str(uuid.uuid4()) + model_status = "" if self.has_model else "(Not Initialized)" + model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + header_html = f"
| Layer | +Type | +Params | +In sizes | +Out sizes | +
|---|---|---|---|---|
| {html.escape(name)} | +{html.escape(layer.layer_type)} | +{html.escape(str(layer.num_parameters))} | +{html.escape(str(layer.in_size))} | +{html.escape(str(layer.out_size))} | +