-
Notifications
You must be signed in to change notification settings - Fork 543
Closed
Description
I'm wondering how to use the Integrated Gradients with the following model as it has the embedding layers?
class TabularModel(nn.Module):
def __init__(self, embedding_sizes, n_cont):
super().__init__()
self.embeddings = nn.ModuleList([nn.Embedding(categories, size) for categories, size in embedding_sizes])
n_emb = sum(e.embedding_dim for e in self.embeddings)
self.n_emb, self.n_cont = n_emb, n_cont
self.lin1 = nn.Linear(self.n_emb + self.n_cont, 100)
self.lin2 = nn.Linear(100, 50)
self.lin3 = nn.Linear(50, 5)
self.bn1 = nn.BatchNorm1d(self.n_cont)
self.bn2 = nn.BatchNorm1d(100)
self.bn3 = nn.BatchNorm1d(50)
self.emb_drop = nn.Dropout(0.2)
self.drops = nn.Dropout(0.1)
def forward(self, x_cat, x_cont):
x = [e(x_cat[:, i]) for i, e in enumerate(self.embeddings)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
x2 = self.bn1(x_cont)
x = torch.cat([x, x2], 1)
x = F.relu(self.lin1(x))
x = self.drops(x)
x = self.bn2(x)
x = F.relu(self.lin2(x))
x = self.drops(x)
x = self.bn3(x)
x = self.lin3(x)
return x
The main issue is in passing the inputs to
ig = IntegratedGradients(model)
ig.attribute(inputs)
after which I get the error:
AssertionError: Baseline can be provided as a tensor for just one input and broadcasted to the batch or input and baseline must have the same shape or the baseline corresponding to each input tensor must be a scalar. Found baseline: tensor([[8.0000e+00, 1.0138e+03, 8.2027e+01, ..., 1.4000e+01, 0.0000e+00,
0.0000e+00],
[8.0000e+00, 1.0161e+03, 8.7000e+01, ..., 6.6700e+01, 0.0000e+00,
0.0000e+00],
[1.0000e+00, 1.0226e+03, 4.8000e+01, ..., 2.4700e+01, 0.0000e+00,
0.0000e+00],
...,
[0.0000e+00, 1.0208e+03, 8.2000e+01, ..., 1.2400e+01, 0.0000e+00,
0.0000e+00],
[7.0000e+00, 1.0142e+03, 9.8000e+01, ..., 1.1000e+00, 1.0000e+00,
0.0000e+00],
[0.0000e+00, 1.0230e+03, 7.6000e+01, ..., 3.6900e+01, 0.0000e+00,
0.0000e+00]]) and input: tensor([[ 4, 8, 0, 3],
[ 5, 8, 1, 15],
[ 2, 13, 0, 29],
...,
[ 5, 1, 0, 21],
[ 0, 23, 0, 5],
[ 5, 5, 4, 11]])
Metadata
Metadata
Assignees
Labels
No labels