Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
Expand Down Expand Up @@ -506,6 +507,18 @@ def reset_weights(self):
reset_all_weights(self.head)
reset_all_weights(self.embedding_layer)

def feature_importance(self):
if hasattr(self.backbone, "feature_importance_"):
importance_df = pd.DataFrame(
{
"Features": self.hparams.categorical_cols + self.hparams.continuous_cols,
"importance": self.backbone.feature_importance_.detach().cpu().numpy(),
}
)
return importance_df
else:
raise ValueError("Feature Importance unavailable for this model.")


class _GenericModel(BaseModel):
def __init__(
Expand Down
11 changes: 2 additions & 9 deletions src/pytorch_tabular/models/ft_transformer/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import math
from collections import OrderedDict

import pandas as pd
import torch
import torch.nn as nn
from omegaconf import DictConfig
Expand Down Expand Up @@ -116,7 +115,7 @@ def _calculate_feature_importance(self):
for attn_weights in self.attention_weights_:
self.local_feature_importance += attn_weights[:, :, :, -1].sum(dim=1)
self.local_feature_importance = (1 / (h * L)) * self.local_feature_importance[:, :-1]
self.feature_importance_ = self.local_feature_importance.mean(dim=0)
self.feature_importance_ = self.local_feature_importance.mean(dim=0).detach().cpu().numpy()
# self.feature_importance_count_+=attn_weights.shape[0]


Expand Down Expand Up @@ -146,12 +145,6 @@ def _build_network(self):

def feature_importance(self):
if self.hparams.attn_feature_importance:
importance_df = pd.DataFrame(
{
"Features": self.hparams.categorical_cols + self.hparams.continuous_cols,
"importance": self.backbone.feature_importance_.detach().cpu().numpy(),
}
)
return importance_df
return super().feature_importance()
else:
raise ValueError("If you want Feature Importance, `attn_feature_weights` should be `True`.")
4 changes: 4 additions & 0 deletions src/pytorch_tabular/models/gate/gate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
tree_outputs = tree_outputs.permute(1, 2, 0)
return tree_outputs

@property
def feature_importance_(self):
return self.gflus.feature_mask_function(self.gflus.feature_masks).sum(dim=0).detach().cpu().numpy()


class CustomHead(nn.Module):
"""Custom Head for GATE.
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,3 +1379,6 @@ def summary(self, max_depth: int = -1) -> None:

def __str__(self) -> str:
return self.summary()

def feature_importance(self):
return self.model.feature_importance()