Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ docs/tutorials/pytorch-tabular-covertype/

# Pycharm
.idea/
test.ipynb
237 changes: 230 additions & 7 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
# For license information, see LICENSE.TXT
"""Tabular Model."""

import html
import inspect
import json
import os
import uuid
import warnings
from collections import defaultdict
from functools import partial
from pathlib import Path
from pprint import pformat
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import joblib
Expand All @@ -22,7 +26,9 @@
from pandas import DataFrame
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
GradientAccumulationScheduler,
)
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.model_summary import summarize
from rich import print as rich_print
Expand All @@ -41,7 +47,11 @@
)
from pytorch_tabular.config.config import InferredConfig
from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel
from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer
from pytorch_tabular.models.common.layers.embeddings import (
Embedding1dLayer,
Embedding2dLayer,
PreEncoded1dLayer,
)
from pytorch_tabular.tabular_datamodule import TabularDatamodule
from pytorch_tabular.utils import (
OOMException,
Expand Down Expand Up @@ -262,7 +272,9 @@ def _setup_experiment_tracking(self):
"""Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig."""
if self.config.log_target == "tensorboard":
self.logger = pl.loggers.TensorBoardLogger(
name=self.run_name, save_dir=self.config.project_name, version=self.uid
name=self.run_name,
save_dir=self.config.project_name,
version=self.uid,
)
elif self.config.log_target == "wandb":
self.logger = pl.loggers.WandbLogger(
Expand Down Expand Up @@ -1647,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None:
"""Prints a summary of the model.

Args:
max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
Defaults to -1, which means will display all the modules.
max_depth (int): The maximum depth to traverse the modules and
displayed in the summary. Defaults to -1, which means will
display all the modules.

"""
if model is not None:
Expand All @@ -1666,8 +1679,215 @@ def summary(self, model=None, max_depth: int = -1) -> None:
"been initialized or passed in as an argument[/bold red]"
)

def ret_summary(self, model=None, max_depth: int = -1) -> str:
"""Returns a summary of the model as a string.

Args:
max_depth (int): The maximum depth to traverse the modules and
displayed in the summary. Defaults to -1, which means will
display all the modules.

Returns:
str: The summary of the model.

"""
if model is not None:
return str(summarize(model, max_depth=max_depth))
elif self.has_model:
return str(summarize(self.model, max_depth=max_depth))
else:
summary_str = f"{self.__class__.__name__}\n"
summary_str += "-" * 100 + "\n"
summary_str += "Config\n"
summary_str += "-" * 100 + "\n"
summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True)
summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument"
return summary_str

def __str__(self) -> str:
return self.summary()
"""Returns a readable summary of the TabularModel object."""
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + "(Not Initialized)"
return f"{self.__class__.__name__}(model={model_name})"

def __repr__(self) -> str:
"""Returns an unambiguous representation of the TabularModel object."""
config_str = json.dumps(OmegaConf.to_container(self.config, resolve=True), indent=4)
ret_str = f"{self.__class__.__name__}(\n"
if self.has_model:
ret_str += f" model={self.model.__class__.__name__},\n"
else:
ret_str += f" model={self.config._model_name} (Not Initialized),\n"
ret_str += f" config={config_str},\n"
return ret_str

def _repr_html_(self):
"""Generate an HTML representation for Jupyter Notebook."""
css = """
<style>
.main-container {
font-family: Arial, sans-serif;
font-size: 14px;
border: 1px dashed #ccc;
padding: 10px;
margin: 10px;
background-color: #f9f9f9;
}
.header {
background-color: #e8f4fc;
padding: 5px;
font-weight: bold;
text-align: center;
border-bottom: 1px solid #ccc;
}
.section {
margin: 10px 0;
padding: 10px;
border: 1px solid #ccc;
background-color: #ffffff;
}
.step {
border: 1px solid #ccc;
background-color: #f0f8ff;
margin: 5px 0;
padding: 5px;
}
.sub-step {
margin-left: 20px;
border: 1px solid #ddd;
background-color: #f9f9f9;
padding: 5px;
}
.toggle-button {
cursor: pointer;
font-size: 12px;
margin-right: 5px;
}
.toggle-button:hover {
color: #0056b3;
}
.hidden {
display: none;
}
table {
width: 100%;
border-collapse: collapse;
}
table, th, td {
border: 1px solid black;
}
th, td {
padding: 5px;
text-align: left;
}
</style>
<script>
function toggleVisibility(id) {
var element = document.getElementById(id);
if (element.classList.contains('hidden')) {
element.classList.remove('hidden');
} else {
element.classList.add('hidden');
}
}
</script>
"""

# Header (Main model name)
uid = str(uuid.uuid4())
model_status = "" if self.has_model else "(Not Initialized)"
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name
header_html = f"<div class='header'>{html.escape(model_name)}{model_status}</div>"

# Config Section
config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True)

# Summary Section
summary_html = (
""
if not self.has_model
else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid)
)

# Combine sections
return f"""
{css}
<div class='main-container'>
{header_html}
{config_html}
{summary_html}
</div>
"""

def _generate_collapsible_section(self, title, content, uid, is_dict=False):
container_id = title.lower().replace(" ", "_") + uid
if is_dict:
content = self._generate_nested_collapsible_sections(
OmegaConf.to_container(content, resolve=True), container_id
)
return f"""
<div>
<span
class="toggle-button"
onclick="toggleVisibility('{container_id}')"
>
&#9654;
</span>
<strong>{html.escape(title)}</strong>
<div id="{container_id}" class="hidden section">
{content}
</div>
</div>
"""

def _generate_nested_collapsible_sections(self, content, parent_id):
html_content = ""
for key, value in content.items():
if isinstance(value, dict):
nested_id = f"{parent_id}_{key}".replace(" ", "_")
nested_id = nested_id + str(uuid.uuid4())
nested_content = self._generate_nested_collapsible_sections(value, nested_id)
html_content += f"""
<div>
<span
class="toggle-button"
onclick="toggleVisibility('{nested_id}')"
>
&#9654;
</span>
<strong>{html.escape(key)}</strong>
<div id="{nested_id}" class="hidden section">
{nested_content}
</div>
</div>
"""
else:
html_content += f"<div><strong>{html.escape(key)}:</strong> {html.escape(str(value))}</div>"
return html_content

def _generate_model_summary_table(self):
model_summary = summarize(self.model, max_depth=1)
table_html = """
<table>
<tr>
<th><b>Layer</b></th>
<th><b>Type</b></th>
<th><b>Params</b></th>
<th><b>In sizes</b></th>
<th><b>Out sizes</b></th>
</tr>
"""
for name, layer in model_summary._layer_summary.items():
table_html += f"""
<tr>
<td>{html.escape(name)}</td>
<td>{html.escape(layer.layer_type)}</td>
<td>{html.escape(str(layer.num_parameters))}</td>
<td>{html.escape(str(layer.in_size))}</td>
<td>{html.escape(str(layer.out_size))}</td>
</tr>
"""
table_html += "</table>"
return table_html

def feature_importance(self) -> DataFrame:
"""Returns the feature importance of the model as a pandas DataFrame."""
Expand Down Expand Up @@ -1998,7 +2218,10 @@ def cross_validate(
# Initialize datamodule and model in the first fold
# uses train data from this fold to fit all transformers
datamodule = self.prepare_dataloader(
train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
train=train.iloc[train_idx],
validation=train.iloc[val_idx],
seed=42,
**prep_dl_kwargs,
)
model = self.prepare_model(datamodule, **prep_model_kwargs)
else:
Expand Down
63 changes: 62 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

MODEL_CONFIG_SAVE_TEST = [
(CategoryEmbeddingModelConfig, {"layers": "10-20"}),
(AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}),
(GANDALFConfig, {}),
(NodeConfig, {"num_trees": 100, "depth": 2}),
(TabNetModelConfig, {"n_a": 2, "n_d": 2}),
]
Expand Down Expand Up @@ -1247,3 +1247,64 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
# # there may be multiple models with the same score
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
# assert best_model.model._get_name() in best_models


@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]])
@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()])
@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"])
def test_str_repr(
regression_data,
model_config_class,
continuous_cols,
categorical_cols,
custom_metrics,
custom_loss,
custom_optimizer,
):
(train, test, target) = regression_data
data_config = DataConfig(
target=target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
)
model_config_class, model_config_params = model_config_class
model_config_params["task"] = "regression"
model_config = model_config_class(**model_config_params)
trainer_config = TrainerConfig(
max_epochs=3,
checkpoints=None,
early_stopping=None,
accelerator="cpu",
fast_dev_run=True,
)
optimizer_config = OptimizerConfig()

tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
assert "Not Initialized" in str(tabular_model)
assert "Not Initialized" in repr(tabular_model)
assert "Model Summary" not in tabular_model._repr_html_()
assert "Model Config" in tabular_model._repr_html_()
assert "config" in tabular_model.__repr__()
assert "config" not in str(tabular_model)
tabular_model.fit(
train=train,
metrics=custom_metrics,
metrics_prob_inputs=None if custom_metrics is None else [False],
loss=custom_loss,
optimizer=custom_optimizer,
optimizer_params={},
)
assert model_config_class._model_name in str(tabular_model)
assert model_config_class._model_name in repr(tabular_model)
assert "Model Summary" in tabular_model._repr_html_()
assert "Model Config" in tabular_model._repr_html_()
assert "config" in tabular_model.__repr__()
assert model_config_class._model_name in tabular_model._repr_html_()
Loading