In [0]:
import torch
import torch.nn as nn
import numpy as np

In [0]:
TIME_LEN = 1000

In [0]:
FEAT_DIM_GROUP1 = 14  # グループ１の特徴量数
FEAT_DIM_GROUP2 = 7  # グループ2の特徴量数
FEAT_DIM_GROUP3 = 10  # グループ3の特徴量数
FEAT_DIM_GROUP4 = 6  # グループ4の特徴量数

FEAT_DIMS = (
    FEAT_DIM_GROUP1,
    FEAT_DIM_GROUP2,
    FEAT_DIM_GROUP3,
    FEAT_DIM_GROUP4
)

In [0]:
X_feat_group1 = np.random.randn(TIME_LEN, FEAT_DIM_GROUP1)
X_feat_group2 = np.random.randn(TIME_LEN, FEAT_DIM_GROUP2)
X_feat_group3 = np.random.randn(TIME_LEN, FEAT_DIM_GROUP3)
X_feat_group4 = np.random.randn(TIME_LEN, FEAT_DIM_GROUP4)

y = np.random.poisson(10, TIME_LEN)

In [270]:
X_feat_group1.shape

(1000, 14)

In [0]:
SEQ_LEN = 50

In [0]:
X_feat_group1_tensor = []
X_feat_group2_tensor = []
X_feat_group3_tensor = []
X_feat_group4_tensor = []
y_tensor = []

for t in range(TIME_LEN-SEQ_LEN-1):
    X_feat_group1_tensor.append(torch.from_numpy(X_feat_group1[t:t+SEQ_LEN, :]))
    X_feat_group2_tensor.append(torch.from_numpy(X_feat_group2[t:t+SEQ_LEN, :]))
    X_feat_group3_tensor.append(torch.from_numpy(X_feat_group3[t:t+SEQ_LEN, :]))
    X_feat_group4_tensor.append(torch.from_numpy(X_feat_group4[t:t+SEQ_LEN, :]))
    y_tensor.append(y[t+1:t+SEQ_LEN+1])

X_feat_group1_tensor = torch.stack(X_feat_group1_tensor, dim=0).float()
X_feat_group2_tensor = torch.stack(X_feat_group2_tensor, dim=0).float()
X_feat_group3_tensor = torch.stack(X_feat_group3_tensor, dim=0).float()
X_feat_group4_tensor = torch.stack(X_feat_group4_tensor, dim=0).float()
y_tensor = torch.from_numpy(np.array(y_tensor)).float()

In [273]:
X_feat_group1_tensor.shape

torch.Size([949, 50, 14])

In [274]:
y_tensor.shape

torch.Size([949, 50])

In [0]:
class ARMDN(nn.Module):

    def __init__(self, feat_dims, embedding_dim=4, gauss_component_n=10):

        super(ARMDN, self).__init__()

        self._feat_layers = []
        self._embedding_dim = embedding_dim
        self._gauss_component_n = gauss_component_n
        self._z_mu = nn.Parameter(torch.randn(gauss_component_n).unsqueeze(0).unsqueeze(0))
        self._z_sigma = nn.Parameter(torch.randn(gauss_component_n).unsqueeze(0).unsqueeze(0))
        self._z_p = nn.Parameter(torch.randn(gauss_component_n).unsqueeze(0).unsqueeze(0))

        for feat_dim in feat_dims:
            self._feat_layers.append(
                nn.Linear(feat_dim, self._embedding_dim)
            )
        
        self._lstm = nn.LSTM(
            input_size=embedding_dim*len(feat_dims),
            hidden_size=self._gauss_component_n,
            batch_first=True
        )

    def forward(self, xs):
        assert len(xs) == len(self._feat_layers)

        seq_len = xs[0].shape[1]

        embs = []
        for x, feat_layer in zip(xs, self._feat_layers):
            emb_timeseries = []
            for t in range(seq_len):
                emb_timeseries.append(feat_layer(x[:, t, :]))
            # print(torch.stack(emb_timeseries, dim=1).shape) => (BATCH_SIZE, SEQ_LEN, FEAT_DIM)
            embs.append(torch.stack(emb_timeseries, dim=1))

        embs = torch.cat(embs, dim=2)  # (BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM*len(FEAT_DIMS))

        lstm_output, (h_n, c_n) = self._lstm(embs)  # lstm_output.shape = (BATCH_SIZE, SEQ_LEN, gauss_component_n)

        y_pred = self.pred(lstm_output)
        return y_pred
        
    def gaussian_mixture(self, lstm_output):
        means = lstm_output * self._z_mu  # (BATCH_SIZE, SEQ_LEN, gauss_component_n)
        sigmas = torch.exp(lstm_output * self._z_sigma)  # (BATCH_SIZE, SEQ_LEN, gauss_component_n)
        ps = torch.exp(lstm_output * self._z_p) / (lstm_output * self._z_p).sum(axis=2, keepdim=True)  # (BATCH_SIZE, SEQ_LEN, gauss_component_n)

        return means, sigmas, ps

    def pred(self, lstm_output):
        means, sigmas, ps = self.gaussian_mixture(lstm_output)
        y_pred = (ps*torch.exp(-0.5*(lstm_output-means)/sigmas)).sum(axis=2)

        return y_pred

    def loss(self, y, y_pred):
        return ((y - y_pred).abs() / y).mean()

In [0]:
EMBEDDING_DIM = 4

In [0]:
armdn = ARMDN(
    feat_dims=FEAT_DIMS,
    embedding_dim=EMBEDDING_DIM,
    gauss_component_n=10,
)

In [0]:
BATCH_SIZE = 32

In [0]:
X = (
    X_feat_group1_tensor[:BATCH_SIZE],
    X_feat_group2_tensor[:BATCH_SIZE],
    X_feat_group3_tensor[:BATCH_SIZE],
    X_feat_group4_tensor[:BATCH_SIZE]
)

In [0]:
out = armdn(X)

In [281]:
out.shape  # (BATCH_SIZE, SEQ_LEN)

torch.Size([32, 50])

In [282]:
y_tensor[:BATCH_SIZE].shape

torch.Size([32, 50])

In [283]:
armdn.loss(y_tensor[:BATCH_SIZE], out)

tensor(4.7437, grad_fn=<MeanBackward0>)