In [12]:
import torch

In [40]:
import abc
import math
import torch
from einops import rearrange, repeat
from torch import nn


class AttentionBias(nn.Module, abc.ABC):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        assert num_heads > 0 and dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = dim // num_heads

    @abc.abstractmethod
    def forward(self, query_id, kv_id): ...


class BinaryAttentionBias(AttentionBias):
    def __init__(self, dim: int, num_heads: int):
        super().__init__(dim, num_heads)
        self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads)

    def forward(self, query_id, kv_id):
        ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
        weight = rearrange(
            self.emb.weight, "two num_heads -> two num_heads 1 1")
        bias = ~ind * weight[:1] + ind * weight[1:]
        return bias
    

class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask
    

class TimerMultivariateMask():
    def __init__(self, B, n_vars, n_tokens, device="cpu"):
            mask_shape = [B, 1, n_tokens, n_tokens]
            with torch.no_grad():
                self._mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(device)
                self._mask2 = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
                self._mask = torch.kron(self._mask1, self._mask2)
    @property
    def mask(self):
        return self._mask


class TimerCovariateMask():
    def __init__(self, B, n_vars, n_tokens, device="cpu"):
        mask_shape = [B, 1, n_tokens, n_tokens]
        with torch.no_grad():
            self._mask1 = torch.eye(n_vars, dtype=torch.bool).to(device)
            self._mask2 = torch.tril(torch.ones(mask_shape, dtype=torch.bool)).to(device)
            self._mask = ~torch.kron(self._mask1, self._mask2)
            self._mask[:, :, -n_tokens:, :-n_tokens] = False
            
    @property
    def mask(self):
        return self._mask


class TimerMultivariateWithCovariateMask():
    def __init__(self, batch_size, n_vars, n_tokens, n_pred_vars=1, device="cpu"):
        mask_shape = [batch_size, 1, n_tokens, n_tokens]
        with torch.no_grad():
            self._mask1 = torch.eye(n_vars, dtype=torch.bool).to(device)
            self._mask1[:n_pred_vars, :n_pred_vars] = True
            self._mask2 = torch.tril(torch.ones(mask_shape, dtype=torch.bool)).to(device)
            self._mask = ~torch.kron(self._mask1, self._mask2)
            self._mask[:, :, -n_tokens:, :-n_tokens] = False
            
    @property
    def mask(self):
        return self._mask

In [43]:
n_vars = 4
n_tokens = 5
n_pred_vars = 2
patch_len = 96
batch_size = 16
n_heads = 8
d_model = 1024

var_id = repeat(torch.arange(n_vars),
                'C -> (C n_tokens)', n_tokens=n_tokens)
var_id = repeat(var_id, 'L -> b h L', b=batch_size, h=1)

bias = BinaryAttentionBias(dim=d_model, num_heads=n_heads)
attn_bias = bias(var_id, var_id)

tri_attn_mask = TriangularCausalMask(batch_size, n_vars*n_tokens)
cov_attn_mask = TimerCovariateMask(batch_size, n_vars, n_tokens)
multi_attn_mask = TimerMultivariateMask(batch_size, n_vars, n_tokens)
multi_with_covariate_attn_mask = TimerMultivariateWithCovariateMask(batch_size, n_vars, n_tokens, n_pred_vars)

queries = torch.randn(batch_size, n_vars*n_tokens, n_heads, d_model)
keys = torch.randn(batch_size, n_vars*n_tokens, n_heads, d_model)
values = torch.randn(batch_size, n_vars*n_tokens, n_heads, d_model)
scores = torch.einsum("blhe,bshe->bhls", queries, keys)

In [31]:
var_id.shape # [batch_size, 1, n_vars*n_tokens]

torch.Size([16, 1, 20])

In [32]:
attn_bias.shape # [batch_size, n_heads, n_vars*n_tokens, n_vars*n_tokens]

torch.Size([16, 8, 20, 20])

In [33]:
tri_attn_mask.mask.shape

torch.Size([16, 1, 20, 20])

In [34]:
torch.set_printoptions(threshold=10_000_000)
# 可选：调整每行显示的字符宽度
torch.set_printoptions(linewidth=200)
tri_attn_mask.mask

tensor([[[[False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False, False, False, False,  True,  True,  True,  True, 

In [35]:
cov_attn_mask.mask.shape

torch.Size([16, 1, 20, 20])

In [45]:
torch.set_printoptions(threshold=10_000_000)
# 可选：调整每行显示的字符宽度
torch.set_printoptions(linewidth=200)
cov_attn_mask.mask[0,0,:,:]

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False,  True,  True,  True,  True,  True,  True, 

In [17]:
torch.set_printoptions(threshold=10_000_000)
# 可选：调整每行显示的字符宽度
torch.set_printoptions(linewidth=200)
multi_with_covariate_attn_mask.mask

tensor([[[[False,  True,  True,  True,  True,  True,  True, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True],
          [False, False,  True,  True,  True,  True,  True, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False,  True,  True,  True,  True, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True],
          [False, False, False, False,  True,  True,  True, False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  Tr

In [31]:
multi_attn_mask.mask.shape

torch.Size([16, 1, 35, 35])

In [32]:
queries.shape

torch.Size([16, 35, 8, 1024])

In [33]:
scores.shape

torch.Size([16, 8, 35, 35])