diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 938e52c5..4c98c4a0 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -7,6 +7,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import pandas as pd import pytorch_lightning as pl import torch import torch.nn as nn @@ -506,6 +507,18 @@ def reset_weights(self): reset_all_weights(self.head) reset_all_weights(self.embedding_layer) + def feature_importance(self): + if hasattr(self.backbone, "feature_importance_"): + importance_df = pd.DataFrame( + { + "Features": self.hparams.categorical_cols + self.hparams.continuous_cols, + "importance": self.backbone.feature_importance_.detach().cpu().numpy(), + } + ) + return importance_df + else: + raise ValueError("Feature Importance unavailable for this model.") + class _GenericModel(BaseModel): def __init__( diff --git a/src/pytorch_tabular/models/ft_transformer/ft_transformer.py b/src/pytorch_tabular/models/ft_transformer/ft_transformer.py index 1f68a4ed..fba4e579 100644 --- a/src/pytorch_tabular/models/ft_transformer/ft_transformer.py +++ b/src/pytorch_tabular/models/ft_transformer/ft_transformer.py @@ -5,7 +5,6 @@ import math from collections import OrderedDict -import pandas as pd import torch import torch.nn as nn from omegaconf import DictConfig @@ -116,7 +115,7 @@ def _calculate_feature_importance(self): for attn_weights in self.attention_weights_: self.local_feature_importance += attn_weights[:, :, :, -1].sum(dim=1) self.local_feature_importance = (1 / (h * L)) * self.local_feature_importance[:, :-1] - self.feature_importance_ = self.local_feature_importance.mean(dim=0) + self.feature_importance_ = self.local_feature_importance.mean(dim=0).detach().cpu().numpy() # self.feature_importance_count_+=attn_weights.shape[0] @@ -146,12 +145,6 @@ def _build_network(self): def feature_importance(self): if self.hparams.attn_feature_importance: - importance_df = pd.DataFrame( - { - "Features": self.hparams.categorical_cols + self.hparams.continuous_cols, - "importance": self.backbone.feature_importance_.detach().cpu().numpy(), - } - ) - return importance_df + return super().feature_importance() else: raise ValueError("If you want Feature Importance, `attn_feature_weights` should be `True`.") diff --git a/src/pytorch_tabular/models/gate/gate_model.py b/src/pytorch_tabular/models/gate/gate_model.py index faad0a98..00b9a1bf 100644 --- a/src/pytorch_tabular/models/gate/gate_model.py +++ b/src/pytorch_tabular/models/gate/gate_model.py @@ -129,6 +129,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tree_outputs = tree_outputs.permute(1, 2, 0) return tree_outputs + @property + def feature_importance_(self): + return self.gflus.feature_mask_function(self.gflus.feature_masks).sum(dim=0).detach().cpu().numpy() + class CustomHead(nn.Module): """Custom Head for GATE. diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 8b1870d4..1c4a9c92 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -1379,3 +1379,6 @@ def summary(self, max_depth: int = -1) -> None: def __str__(self) -> str: return self.summary() + + def feature_importance(self): + return self.model.feature_importance()