# Torch Implementation of Set function for Time Series


Base Model: Deep-Set Architecture

$$ f(\mathcal{S})=g\left(\frac{1}{|\mathcal{S}|} \sum_{s_{j} \in \mathcal{S}} h\left(s_{j}\right)\right) $$

Modification: Scaled-Dot-Product
Paper:

$$ K_{j, i}=\left[f^{\prime}(\mathcal{S}), s_{j}\right]^{T} W_{i}$$

where $f'$ is another deep-set model. 

$$ e_{j, i}=\frac{K_{j, i} \cdot Q_{i}}{\sqrt{d}} \quad \text { and } \quad a_{j, i}=\frac{\exp \left(e_{j, i}\right)}{\sum_{j} \exp \left(e_{j, i}\right)}$$

> For each head, we multiply
the set element embeddings computed via the function h
with the attentions derived for the individual instances, i.e.

$$ r_{i}=\sum_{j} a_{j, i} h\left(s_{j}\right)$$

The final prediction is made by

$$  \hat{y} = g\Big(\sum_{s∈S} a(S, s) h(s)\Big) $$ 

#### Notes

- $g$ and $h$ are usually just MLPs, $f'$ is a DeepSet
- $m$ is the number of heads
- $W$, $Q$ are learnable. $Q$ is initialized with zeros
- $W_i$ has shape $(\dim(f')+\dim(s), d)$
- $Q$ has shape $(m, d)$
- $K$ has shape $(|S|, d)$
- $E$ has shape $(|S|, m)$
- $e_i$ is a vector of size $|S|$
- $a_i$ is a vector of size $|S|$
- $a(S, s)$ is $(|S|, m)$
- $h(s)$ is $(d,)$
- $r= [r_1, …, r_m] = \sum_{s∈S} a(S, s) h(s)$ is of shape $(m,d)$
- The authors do not seem to include latent dimension?

## Simplified Equations

Rename: $h = ϕ$, $f' = ρ∘∑∘ψ$

$$ a_{j,i} = \operatorname{softmax}(e_i) = \sigma(e_i)$$

$$ e_{j,i} = \frac{1}{\sqrt{d}}K_{j, i}\cdot Q_{i} = \frac{1}{\sqrt{d}}\left[ψ(\mathcal{S}), s_{j}\right]^{T} W_{i}\cdot Q_{i} $$


In [None]:
from math import isqrt, sqrt

import numpy as np
import torch
from torch import Tensor, jit, nn
from torch.nn import functional as F
from torch.nn.utils.rnn import *
from torchinfo import summary

### Some Tensors for demo purposes

In [None]:
B, Lmax, D = 32, 50, 7
m, d = 6, 5

batch = [torch.randn(np.random.randint(1, Lmax), D) for _ in range(B)]
s = pad_sequence(batch, padding_value=float("nan"), batch_first=True)
mask = torch.isnan(s[..., 0])
L = int(s.shape[1])
s.shape, mask.shape

### MLP Component

In [None]:
class MLP(nn.Sequential):
    def __init__(self, input_size: int, output_size: int, num_layers: int = 2):
        layers = []
        for k in range(num_layers):
            layer = nn.Linear(input_size, input_size)
            nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
            nn.init.kaiming_normal_(layer.bias[None], nonlinearity="relu")
            layers.append(layer)
            layers.append(nn.ReLU())
        else:
            layer = nn.Linear(input_size, output_size)
            nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
            nn.init.kaiming_normal_(layer.bias[None], nonlinearity="relu")
            layers.append(layer)
        super().__init__(*layers)


summary(MLP(3, 4, 2))

### DeepSet Component

In [None]:
class DeepSet(nn.Module):
    """Signature: `[... V, K] -> [... D]`"""

    def __init__(
        self,
        input_size: int,
        output_size: int,
        latent_size: Optional[int] = None,
        encoder_layers: int = 2,
        decoder_layers: int = 2,
        # aggregation: Literal["min", "max", "sum", "mean", "prod"] = "sum",
    ):
        super().__init__()
        latent_size = input_size if latent_size is None else latent_size
        self.encoder = MLP(input_size, latent_size, encoder_layers)
        self.decoder = MLP(latent_size, output_size, decoder_layers)

    def forward(self, x: Tensor) -> Tensor:
        """Signature: [..., <Var>, D] -> [..., F]

        Components:

          - Encoder: [..., D] -> [..., E]
          - Aggregation: [..., V, E] -> [..., E]
          - Decoder: [..., E] -> [..., F]
        """
        x = self.encoder(x)
        x = torch.nanmean(x, dim=-2)
        x = self.decoder(x)
        return x


summary(DeepSet(3, 4, 5))

In [None]:
f = jit.script(DeepSet(7, 4))
f(s)

In [None]:
p = np.random.permutation(s.shape[1])
assert torch.allclose(f(s[..., p, :]), f(s), atol=1e-06), torch.linalg.norm(
    f(s[..., p, :]) - f(s)
)

### Scaled Dot Product Attention

Keys: $K_{ji} = [f(S), s_j]^T W_i$
- $K: |S|×d$. If we want to include batch-size, we need to pad things or operate on lists. 
    - let's do lists and hope torchscript takes care of it.
        - ⟹ But then we need to apply components in "listified" manner
        - Maybe we can write a decorator that automatically takes care of list inputs?
            - Would that work well with torchscript?
    - so use padding, but make sure to


In [None]:
Q = torch.randn(m, d)
K = torch.randn(D, m, d)
V = torch.randn(D, m, 17)
V = torch.einsum("...D, DMF -> ...MF ", s, V)
print(f"{Q.shape=}")
K = torch.einsum("...f, fmd -> ...md", s, K)
print(f"{K.shape=}")
QK = torch.einsum("...md, md -> ...m", K, Q) / np.sqrt(d)
QK[mask] = float("-inf")
print(f"{QK.shape=}")
σ = nn.functional.softmax(QK, dim=1)
print(f"{σ.shape=}")
print(f"{V.shape=}")
r = torch.nanmean(σ[..., None] * V, dim=1)
print(f"{r.shape=}")

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(
        self,
        dim_k: int,
        dim_v: int,
        output_size: int,
        num_heads: int = 5,
        dim_k_latent: Optional[int] = None,
        dim_v_latent: Optional[int] = None,
    ) -> None:
        super().__init__()
        dim_q = dim_k

        dim_k_latent = max(1, isqrt(dim_k)) if dim_k_latent is None else dim_k_latent
        dim_v_latent = dim_v if dim_v_latent is None else dim_v_latent

        Wq = torch.zeros((num_heads, dim_k_latent))
        Wk = torch.randn((dim_k, num_heads, dim_k_latent)) / sqrt(dim_k)
        Wv = torch.randn((dim_v, num_heads, dim_v_latent)) / sqrt(dim_v)
        Wo = torch.randn((num_heads, dim_v_latent, output_size)) / sqrt(
            num_heads * dim_v_latent
        )

        self.Wq = nn.Parameter(Wq)
        self.Wk = nn.Parameter(Wk)
        self.Wv = nn.Parameter(Wv)
        self.Wo = nn.Parameter(Wo)
        # self.softmax = nn.Softmax(dim=-2)
        self.register_buffer("scale", torch.tensor(1 / sqrt(dim_q)))
        self.register_buffer("attention_weights", torch.tensor([]))

    def forward(self, K: Tensor, V: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """
        Q : (h, q)
        K : (..., L, d)
        V : (..., L, d)
        """

        if mask is None:
            mask = torch.isnan(K[..., 0])

        K = torch.einsum("...d, dhk -> ...hk", K, self.Wk)
        V = torch.einsum("...d, dhv -> ...hv", V, self.Wv)
        QK = torch.einsum("hd, ...hd -> ...h", self.Wq, K)
        QK[mask] = float("-inf")
        w = F.softmax(self.scale * QK, dim=-2)
        # w = self.softmax(self.scale * QK)
        self.attention_weights = w
        QKV = torch.nanmean(w[..., None] * V, dim=-3)  #  ...h,...Lhv -> ...hv
        return torch.einsum("...hv, hvr -> ...r", QKV, self.Wo)

In [None]:
model = jit.script(ScaledDotProductAttention(7, 7, 2))
model(s, s)

In [None]:
summary(model)

In [None]:
s.shape, f(s).shape

In [None]:
L

In [None]:
f(s).repeat(1, L, 1, 1).shape

In [None]:
fs = torch.tile(f(s).unsqueeze(-2), (L, 1))
fs.shape

In [None]:
torch.cat([fs, s], dim=-1)

In [None]:
from tsdm.encoders.torch import PositionalEncoder

In [None]:
class SetFuncTS(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        latent_size: Optional[int] = None,
        dim_keys: Optional[int] = None,
        dim_vals: Optional[int] = None,
        dim_time: Optional[int] = None,
    ) -> None:
        super().__init__()

        dim_keys = input_size if dim_keys is None else dim_keys
        dim_vals = input_size if dim_vals is None else dim_vals
        dim_time = 10 if dim_time is None else dim_time
        latent_size = input_size if latent_size is None else latent_size
        # time_encoder
        # feature_encoder -> CNN?
        self.time_encoder = PositionalEncoder(dim_time, scale=1.0)
        self.key_encoder = DeepSet(input_size + dim_time - 1, dim_keys)
        self.value_encoder = MLP(input_size + dim_time - 1, dim_vals)
        self.attn = ScaledDotProductAttention(
            dim_keys + input_size + dim_time - 1, dim_vals, latent_size
        )
        self.head = MLP(latent_size, output_size)

    def forward(self, s: Tensor) -> Tensor:
        """
        s must be a tensor of the shape L×(2+C), sᵢ = [tᵢ, zᵢ, mᵢ], where
        - tᵢ is timestamp
        - zᵢ is observed value
        - mᵢ is indentifier

        C is the number of classes (one-hot encoded identifier)
        """

        t = s[..., 0]
        v = s[..., 1:2]
        m = s[..., 2:]
        time_features = self.time_encoder(t)
        s = torch.cat([time_features, v, m], dim=-1)
        fs = self.key_encoder(s)
        fs = torch.tile(fs.unsqueeze(-2), (s.shape[-2], 1))
        K = torch.cat([fs, s], dim=-1)
        V = self.value_encoder(s)
        mask = torch.isnan(s[..., 0])
        z = self.attn(K, V, mask=mask)
        y = self.head(z)
        return y

    @jit.export
    def batch_forward(self, s: list[Tensor]) -> Tensor:
        return torch.stack([self.forward(x) for x in s])

In [None]:
g = PositionalEncoder(10, 0.9)
g.scales
g(s[:, :, 0]).shape

In [None]:
model = jit.script(SetFuncTS(7, 8))
summary(model)

In [None]:
b = [torch.randn(16, 7), torch.randn(3, 7), torch.randn(7, 7)]
model.batch_forward(b)

## A second heading

and some more text

In [None]:
class SetFuncTS(nn.Module):
    def __init__(self, num_dim: int):
        super().__init__()

        self.encoder
        self.decoder
        self.aggregator

    def forward(self, x: Tensor) -> Tensor:
        """Signature: `[..., <var>, ]`.

        Takes list of triplet-encoded data and applies.
        """
        t = torch.stack(x, dim=-1)
        return torch.sum(t, dim=-1)

In [None]:
jit.script(SetFuncTS())

```
>>>>>> input_shapes:                    [(16, 8), (16, 15009, 1), (16, 15009, 1), (16, 15009), (16,)]
>>>>>> lengths:                         (16,)
>>>>>> max length |S|:                  15009
>>>>>> sum lengths ∑|S|:                238416
>>>>>> transformed_times:               (16, 15009, 4)
>>>>>> transformed_measurements:        (16, 15009, 24)
>>>>>> combined_values:                 (16, 15009, 29)
>>>>>> demo_encoded:                    (16, 29)
>>>>>> combined_with_demo:              (16, 15010, 29)
>>>>>> mask:                            (16, 15010)
>>>>>> collected_values S:              (238432, 29)
>>>>>> encoded ϕ = h(s):                (238432, 256)
>>>>>> encoded ψ = f'(S):               (238432, 128)
>>>>>> agg ψ:                           (16, 128)
>>>>>> agg ρ:                           (16, 128)
>>>>>> combined [f(S),s]:               (238432, 157)
>>>>>> keys [f(S),s]ᵀW:                 (238432, 4, 1, 128)
>>>>>> preattn eᵢⱼ= KQ/√d:              (238432, 4, 1, 128)
>>>>>> attentions a(S):                 (4, 238432, 1)
>>>>>> weighted_values:                 (4, 238432, 256)
>>>>>> weighted_values a(S,s)h(s):      (238432, 1024)
>>>>>> aggregated_values ∑a(S,s)h(s):   (16, 1024)
>>>>>> output_values g(∑a(S,s)h(s)):    (16, 1)
```