<a href="https://colab.research.google.com/github/wangbxj1234/eeg/blob/main/normal_linear_withcausal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn


## Core code for Flow-Attention, Please refer to each folder for corresponding experiments

class Linear_Attention(nn.Module):
    # flow attention in normal version
    def __init__(self, d_input, d_model, d_output, n_heads, drop_out=0.05, eps=1e-6):
        super(Linear_Attention, self).__init__()
        self.n_heads = n_heads
        self.query_projection = nn.Linear(d_input, d_model)
        self.key_projection = nn.Linear(d_input, d_model)
        self.value_projection = nn.Linear(d_input, d_model)
        self.out_projection = nn.Linear(d_model, d_output)
        self.dropout = nn.Dropout(drop_out)
        self.eps = eps

    def kernel_method(self, x):
        return torch.sigmoid(x)


    def forward(self, x):
        ## input: B (L or S) D; output: B L D
        ## Note: queries, keys, values are not projected yet
        ## 1. Linear projection
        B, L, _ = x.shape
        _, S, _ = x.shape
        queries = self.query_projection(x).view(B, L, self.n_heads, -1)
        keys = self.key_projection(x).view(B, S, self.n_heads, -1)
        values = self.value_projection(x).view(B, S, self.n_heads, -1)
        #queries = queries.transpose(1, 2)
        #keys = keys.transpose(1, 2)
        #values = values.transpose(1, 2)
        # 2. Non-negative projection
        queries = self.kernel_method(queries)
        keys = self.kernel_method(keys)
        ## 3. linear-Attention
        KV = torch.einsum("bshd,bshm->bhdm", keys, values)
        # Compute the normalizer
        Z = 1/(torch.einsum("blhd,bhd->blh", queries, keys.sum(dim=1))+self.eps)
        # Finally compute and return the new values
        out = torch.einsum("blhd,bhmd,blh->blhm", queries, KV, Z)
        ## Final projection
        x = out.reshape(B, L, -1)
        x = self.out_projection(x)
        x = self.dropout(x)
        return x


class Linear_Attention_Causal(nn.Module):
    # flow attention in causal version
    def __init__(self, d_input, d_model, d_output, n_heads, drop_out=0.05, eps=1e-6):
        super(Linear_Attention_Causal, self).__init__()
        self.n_heads = n_heads
        self.query_projection = nn.Linear(d_input, d_model)
        self.key_projection = nn.Linear(d_input, d_model)
        self.value_projection = nn.Linear(d_input, d_model)
        self.out_projection = nn.Linear(d_model, d_output)
        self.dropout = nn.Dropout(drop_out)
        self.eps = eps

    def kernel_method(self, x):
        return torch.sigmoid(x)


    def forward(self, x):
        ## input: B (L or S) D; output: B L D
        ## Note: queries, keys, values are not projected yet
        ## 1. Linear projection
        B, L, _ = x.shape
        _, S, _ = x.shape
        queries = self.query_projection(x).view(B, L, self.n_heads, -1)
        keys = self.key_projection(x).view(B, S, self.n_heads, -1)
        values = self.value_projection(x).view(B, S, self.n_heads, -1)
        #queries = queries.transpose(1, 2)
        #keys = keys.transpose(1, 2)
        #values = values.transpose(1, 2)
        # 2. Non-negative projection
        queries = self.kernel_method(queries)
        keys = self.kernel_method(keys)
        ## 3. linear-Attention
        KV = torch.einsum("bshd,bshm->bhsdm", keys, values)
        KV_cum = torch.cumsum(KV, dim=2)
        # Compute the normalizer
        K_cum = torch.cumsum(keys, dim=1)
        ##
        QKV = torch.einsum("bshd,bhsdm->bshm", queries, KV_cum)

        DENOM = (torch.einsum("bshd,bshd->bsh", queries, K_cum)+self.eps)
            # (N * h, L, d) (N * h, L, 1) -> (N * h, L, d)
        out = QKV / DENOM.unsqueeze(-1)
        ## Final projection
        x = out.reshape(B, L, -1)
        x = self.out_projection(x)
        x = self.dropout(x)
        return x


if __name__ == '__main__':
    # just for simple test
    attn_normal = Linear_Attention(10, 16, 16, 4)
    attn_causal = Linear_Attention_Causal(10, 16, 16, 4)
    x = torch.rand([1, 100, 10])
    x_attn_normal = attn_normal(x)
    x_attn_causal = attn_causal(x)
    assert x_attn_normal.shape == (1, 100, 16)
    assert x_attn_causal.shape == (1, 100, 16)

In [13]:
x_attn_normal.size()

torch.Size([1, 100, 16])

In [14]:
x_attn_causal.size()

torch.Size([1, 100, 16])