In [212]:
import torch
from torch import nn
from models import InterpretableMultiHeadAttention

In [213]:
ts_embedding_size = 8
n_head = 4
batch_size = 12
number_features = 6

att_layer = InterpretableMultiHeadAttention(
    n_head=n_head, d_model=ts_embedding_size, dropout=0.5
)

In [214]:
x = torch.randn(batch_size, number_features, ts_embedding_size)
print(x.shape)

torch.Size([12, 6, 8])


In [215]:
output, att = att_layer(x, x, x)
print(output.shape)
print(att.shape)

torch.Size([12, 6, 8])
torch.Size([12, 6, 4, 6])


In [216]:
att[0, :, 0, :]

tensor([[0.0836, 0.2742, 0.1204, 0.0922, 0.3247, 0.1049],
        [0.1619, 0.2253, 0.2448, 0.0971, 0.2617, 0.0093],
        [0.0784, 0.2437, 0.1030, 0.0980, 0.2793, 0.1975],
        [0.0947, 0.1163, 0.0855, 0.1280, 0.1134, 0.4620],
        [0.1190, 0.1641, 0.1203, 0.1421, 0.1668, 0.2876],
        [0.2248, 0.1342, 0.2038, 0.1942, 0.1273, 0.1157]],
       grad_fn=<SliceBackward0>)

In [217]:
num_ts_features = 2
ts_embedders = nn.ModuleList(
    [nn.Linear(1, ts_embedding_size) for _ in range(num_ts_features)]
)

In [218]:
timesteps = 100
x_ts = torch.randn(batch_size, timesteps, num_ts_features)
print(x_ts.shape)
output = torch.concat(
    [
        ts_embedders[i](x_ts[Ellipsis, i].unsqueeze(-1))
        for i in range(0, num_ts_features)
    ],
    axis=-1,
)
output.shape

torch.Size([12, 100, 2])


torch.Size([12, 100, 16])

In [219]:
lstm_hidden_size = ts_embedding_size
lstm = nn.LSTM(
    input_size=16,
    hidden_size=lstm_hidden_size,
    num_layers=2,
    batch_first=True,
    dropout=0.5,
)

In [220]:
output2, (hn, cn) = lstm(output)
print(output2.shape)
print(hn.shape)
print(cn.shape)

torch.Size([12, 100, 8])
torch.Size([2, 12, 8])
torch.Size([2, 12, 8])


In [221]:
emb = nn.Embedding(3, ts_embedding_size)

In [222]:
num_output = 3
dims = (batch_size, num_output)
seeder = torch.zeros(dims, dtype=torch.long)
for i in range(0, num_output):
    seeder[:, i] = i
print(seeder.shape)

seeder_out = emb(seeder)
print(seeder_out.shape)

torch.Size([12, 3])
torch.Size([12, 3, 8])


In [223]:
attn_in = torch.concat([output2, seeder_out], axis=-2)
print(attn_in.shape)

torch.Size([12, 103, 8])


In [224]:
output3, attn = att_layer(attn_in, attn_in, attn_in)
print(output3.shape)
print(attn.shape)

torch.Size([12, 103, 8])
torch.Size([12, 103, 4, 103])


In [225]:
output4 = output3[:, -3:, :]
print(output4.shape)
output = output4

torch.Size([12, 3, 8])


In [226]:
mlps_ts = nn.ModuleList(
    [
        nn.Sequential(
            nn.Linear(ts_embedding_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )
        for _ in range(0, num_output)
    ]
)

In [227]:
num_tabular_features = 3
mlp_tab = nn.Sequential(
            nn.Linear(num_tabular_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
        )

In [228]:
regs = nn.ModuleList(
            [
                nn.Sequential(nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 1))
                for _ in range(0, num_output)
            ]
        )

In [229]:
x_tab = torch.rand(batch_size, num_tabular_features)
print(x_tab.shape)

torch.Size([12, 3])


In [230]:
out_tab = mlp_tab(x_tab)
print(out_tab.shape)

torch.Size([12, 32])


In [231]:
out_ts = mlps_ts[0](output[:, 0, :])
print(out_ts.shape)

torch.Size([12, 32])


In [232]:
comb = torch.concat((out_tab, out_ts), dim=1)
print(comb.shape)

torch.Size([12, 64])


In [234]:
out = regs[0](comb)
print(out.shape)

torch.Size([12, 1])


In [None]:
torch.concat([out,out,out], dim=1)

tensor([[-0.0558, -0.0558, -0.0558],
        [-0.0617, -0.0617, -0.0617],
        [-0.0491, -0.0491, -0.0491],
        [-0.0562, -0.0562, -0.0562],
        [-0.0495, -0.0495, -0.0495],
        [-0.0198, -0.0198, -0.0198],
        [-0.0313, -0.0313, -0.0313],
        [-0.0325, -0.0325, -0.0325],
        [-0.0538, -0.0538, -0.0538],
        [-0.0289, -0.0289, -0.0289],
        [-0.0378, -0.0378, -0.0378],
        [-0.0182, -0.0182, -0.0182]], grad_fn=<SqueezeBackward1>)