In [1]:
import torch
from torch import nn
from models import VariableSelectionNetwork, MyModel

In [2]:
batch_size = 5
num_tabular_features = 16
embedding_size = 4
embedder_array = nn.ModuleList(
    [nn.Linear(1, embedding_size) for i in range(0, num_tabular_features)]
)

In [3]:
x = torch.randn(batch_size, num_tabular_features)
x.shape

torch.Size([5, 16])

In [4]:
embedd_input = [
    embedder_array[i](x[Ellipsis, i].unsqueeze(-1))
    for i in range(0, num_tabular_features)
]

In [5]:
res = torch.stack(embedd_input, -2)
res.shape

torch.Size([5, 16, 4])

In [6]:
vsn = VariableSelectionNetwork(
    num_inputs=num_tabular_features,
    input_dim=embedding_size,
    hidden_dim=num_tabular_features # This is the output size
)

In [7]:
res2, weights = vsn(res)
print(res2.shape)
print(weights.shape)

torch.Size([5, 16])
torch.Size([5, 16])


In [8]:
model = MyModel(num_tabular_features=num_tabular_features, output_size=3, embedding_size=embedding_size)

In [9]:
model(x)

(tensor([[-0.0671,  0.0328, -0.1152],
         [-0.0656,  0.0320, -0.1121],
         [-0.0651,  0.0276, -0.1138],
         [-0.0732,  0.0315, -0.1128],
         [-0.0696,  0.0363, -0.1079]], grad_fn=<SqueezeBackward1>),
 tensor([[0.0134, 0.1721, 0.3118, 0.0079, 0.0205, 0.0109, 0.0770, 0.0174, 0.0567,
          0.0557, 0.0639, 0.0266, 0.0601, 0.0122, 0.0679, 0.0260],
         [0.0264, 0.1182, 0.2391, 0.0093, 0.0304, 0.0054, 0.1685, 0.0430, 0.0577,
          0.0520, 0.0209, 0.0395, 0.0170, 0.0155, 0.0988, 0.0583],
         [0.0201, 0.0516, 0.1364, 0.0052, 0.0214, 0.0060, 0.1461, 0.0353, 0.1128,
          0.1131, 0.0376, 0.0317, 0.0773, 0.0271, 0.1267, 0.0515],
         [0.0165, 0.1964, 0.2246, 0.0208, 0.0301, 0.0131, 0.0643, 0.0084, 0.0409,
          0.0159, 0.1040, 0.1175, 0.0818, 0.0175, 0.0347, 0.0135],
         [0.0242, 0.1945, 0.0246, 0.0343, 0.0134, 0.0059, 0.1625, 0.0204, 0.0248,
          0.0262, 0.2130, 0.1376, 0.0490, 0.0186, 0.0229, 0.0282]],
        grad_fn=<SoftmaxBackward0>