In [18]:
import torch
from torch import nn
from models import InterpretableMultiHeadAttention

In [19]:
ts_embedding_size = 8
n_head = 4
batch_size = 12
number_features = 6

att_layer = InterpretableMultiHeadAttention(
    n_head=n_head, d_model=ts_embedding_size, dropout=0.5
)

In [20]:
x = torch.randn(batch_size, number_features, ts_embedding_size)
print(x.shape)

torch.Size([12, 6, 8])


In [21]:
output, att = att_layer(x, x, x)
print(output.shape)
print(att.shape)

torch.Size([12, 6, 8])
torch.Size([12, 6, 4, 6])


In [22]:
att[0, :, 0, :]

tensor([[0.1004, 0.1144, 0.0622, 0.2222, 0.3399, 0.1609],
        [0.1868, 0.2069, 0.1125, 0.1671, 0.2055, 0.1213],
        [0.2709, 0.3267, 0.0943, 0.1094, 0.1416, 0.0571],
        [0.0945, 0.1131, 0.0469, 0.2225, 0.3834, 0.1397],
        [0.0788, 0.0875, 0.0580, 0.2333, 0.3551, 0.1872],
        [0.1196, 0.0808, 0.6350, 0.0382, 0.0131, 0.1133]],
       grad_fn=<SliceBackward0>)

In [23]:
num_ts_features = 2
ts_embedders = nn.ModuleList(
    [nn.Linear(1, ts_embedding_size) for _ in range(num_ts_features)]
)

In [24]:
timesteps = 100
x_ts = torch.randn(batch_size, timesteps, num_ts_features)
print(x_ts.shape)
output = torch.concat(
    [
        ts_embedders[i](x_ts[Ellipsis, i].unsqueeze(-1))
        for i in range(0, num_ts_features)
    ],
    axis=-1,
)
output.shape

torch.Size([12, 100, 2])


torch.Size([12, 100, 16])

In [25]:
lstm_hidden_size = ts_embedding_size
lstm = nn.LSTM(
    input_size=16,
    hidden_size=lstm_hidden_size,
    num_layers=2,
    batch_first=True,
    dropout=0.5,
)

In [26]:
output2, (hn, cn) = lstm(output)
print(output2.shape)
print(hn.shape)
print(cn.shape)

torch.Size([12, 100, 8])
torch.Size([2, 12, 8])
torch.Size([2, 12, 8])


In [27]:
emb = nn.Embedding(3, ts_embedding_size)

In [28]:
num_output = 3
dims = (batch_size, num_output)
seeder = torch.zeros(dims, dtype=torch.long)
for i in range(0, num_output):
    seeder[:, i] = i
print(seeder.shape)

seeder_out = emb(seeder)
print(seeder_out.shape)

torch.Size([12, 3])
torch.Size([12, 3, 8])


In [29]:
attn_in = torch.concat([output2, seeder_out], axis=-2)
print(attn_in.shape)

torch.Size([12, 103, 8])


In [30]:
output3, attn = att_layer(attn_in, attn_in, attn_in)
print(output3.shape)
print(attn.shape)

torch.Size([12, 103, 8])
torch.Size([12, 103, 4, 103])


In [31]:
output4 = output3[:, -3:, :]
print(output4.shape)
output = output4

torch.Size([12, 3, 8])


In [32]:
mlps_ts = nn.ModuleList(
    [
        nn.Sequential(
            nn.Linear(ts_embedding_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        for _ in range(0, num_output)
    ]
)

In [33]:
num_tabular_features = 3
mlp_tab = nn.Sequential(
            nn.Linear(num_tabular_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )

In [34]:
regs = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 1))
                for _ in range(0, num_output)
            ]
        )

In [35]:
x_tab = torch.rand(batch_size, num_tabular_features)
print(x_tab.shape)

torch.Size([12, 3])


In [36]:
out_tab = mlp_tab(x_tab)
print(out_tab.shape)

torch.Size([12, 32])


In [37]:
out_ts = mlps_ts[0](output[:, 0, :])
print(out_ts.shape)

torch.Size([12, 32])


In [38]:
comb = torch.concat((out_tab, out_ts), dim=1)
print(comb.shape)

torch.Size([12, 64])


In [39]:
out = regs[0](comb)
print(out.shape)

torch.Size([12, 1])


In [40]:
torch.concat([out,out,out], dim=1)

tensor([[-0.0162, -0.0162, -0.0162],
        [-0.0039, -0.0039, -0.0039],
        [ 0.0034,  0.0034,  0.0034],
        [-0.0083, -0.0083, -0.0083],
        [-0.0100, -0.0100, -0.0100],
        [-0.0018, -0.0018, -0.0018],
        [-0.0049, -0.0049, -0.0049],
        [ 0.0074,  0.0074,  0.0074],
        [-0.0004, -0.0004, -0.0004],
        [-0.0076, -0.0076, -0.0076],
        [ 0.0063,  0.0063,  0.0063],
        [ 0.0054,  0.0054,  0.0054]], grad_fn=<CatBackward0>)