In [1]:
import torch
import copy_model
from functools import partial

## to_eachlead

self.qvkの出力をqvkそれぞれに分割をして、それぞれをヘッドごとに分割する。

(B, n, 3d) -> (B, n, d) x 3 (qvk) -> (B, h, n, d')

- inputs
    - x (torch.tesor) : (B, n, 3d) output of self.qvk
    - head_num : head数
    - split_num : 分割数、qvkに分割する場合は、split_num=3
- outpus
    - out (list)
        - out = [q, v, ...(split num)]
            - q (torch.tensor) : (B, h, n, d')
            - v (torch.tensor) : (B, h, n, d')
            - k (torch.tensor) : (B, h, n, d')
                - ただしd'はマルチヘッドアテンションを行う時の次元数

In [2]:
B, N, D = 2, 10, 32*3
# data = torch.rand(B, N, D)
data = torch.arange(B*N*D).reshape(B, N, D)
print(data)
q, v, k = copy_model.to_eachhead(data, head_num=4, split_num=3)
# print(q)
# print(v)
# print(k)

tensor([[[   0,    1,    2,  ...,   93,   94,   95],
         [  96,   97,   98,  ...,  189,  190,  191],
         [ 192,  193,  194,  ...,  285,  286,  287],
         ...,
         [ 672,  673,  674,  ...,  765,  766,  767],
         [ 768,  769,  770,  ...,  861,  862,  863],
         [ 864,  865,  866,  ...,  957,  958,  959]],

        [[ 960,  961,  962,  ..., 1053, 1054, 1055],
         [1056, 1057, 1058,  ..., 1149, 1150, 1151],
         [1152, 1153, 1154,  ..., 1245, 1246, 1247],
         ...,
         [1632, 1633, 1634,  ..., 1725, 1726, 1727],
         [1728, 1729, 1730,  ..., 1821, 1822, 1823],
         [1824, 1825, 1826,  ..., 1917, 1918, 1919]]])


## concat_head

ヘッドをもとに戻す

- inputs
    - x (torch.tensor) : (B, h, n, d')
- outputs
    - out (torch.tensor) : (B, n, d) (d = d' x h)

In [3]:
copy_model.concat_head(q)

tensor([[[   0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,
            11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,
            22,   23,   24,   25,   26,   27,   28,   29,   30,   31],
         [  96,   97,   98,   99,  100,  101,  102,  103,  104,  105,  106,
           107,  108,  109,  110,  111,  112,  113,  114,  115,  116,  117,
           118,  119,  120,  121,  122,  123,  124,  125,  126,  127],
         [ 192,  193,  194,  195,  196,  197,  198,  199,  200,  201,  202,
           203,  204,  205,  206,  207,  208,  209,  210,  211,  212,  213,
           214,  215,  216,  217,  218,  219,  220,  221,  222,  223],
         [ 288,  289,  290,  291,  292,  293,  294,  295,  296,  297,  298,
           299,  300,  301,  302,  303,  304,  305,  306,  307,  308,  309,
           310,  311,  312,  313,  314,  315,  316,  317,  318,  319],
         [ 384,  385,  386,  387,  388,  389,  390,  391,  392,  393,  394,
           395,  396,  397,  398

## MultiHeadSelfAttention

multiheadselfattention
head増やす(B, H, N, D) -> selfattention function -> output

- args:
    - dim (int) : 
    - attn_type (str) : linear -> LinearAttention / full -> Vannila
    - head_num (int) : 

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - out (torch.tensor) : (B, N, D)

In [4]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
# data = torch.arange(B*N*D).reshape(B, N, D).astype()
mhsa = copy_model.MultiHeadSelfAttention(dim=32, attn_type="linear", head_num=4)
mhsa(data).shape

torch.Size([2, 10, 32])

## FeedForward

feedforwad module. 2層のaffine層

- args:
    - dim (int)
    - hid_dim (int)

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - out (torch.tensor) : (B, N, D)

In [2]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
# data = torch.arange(B*N*D).reshape(B, N, D).astype()
ff= copy_model.FeedForward(32, 64)
ff(data).shape

torch.Size([2, 10, 32])

## EncoderLayer

コピータスクのエンコーダレイヤー
selfattention -> feedforward
residual passとそれに伴ったLayerNormを実装

- args:
    - dim : 潜在次元数
    - attn_type : attentionのタイプ
    - head_num : ヘッド数
    - ff_hidnum (int) : feedforwardでの隠れ層の次元

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - out (torch.tensor) : (B, N, D)

In [3]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
el = copy_model.EncoderLayer(D, "linear", 4, 256)
el(data).shape

torch.Size([2, 10, 32])

## Encoder

コピータスクのエンコーダ
EncoderLayerを所望の数積み重ねる

- args:
    - depth : 層の数
    - dim : 潜在次元数
    - head_num : ヘッド数
    - attn_type : linear -> LinearAttention / full -> Vannila
    - ff_hidnum : feedforwardにおける潜在次元数

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - x : (torch.tensor) : (B, N, D)

In [4]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
enc = copy_model.Encoder(4, D, 4, "linear", 256)
enc(data).shape

torch.Size([2, 10, 32])

## MultiHeadCausalAttention

Causal attentionをやります。
head増やす(B, H, N, D) -> causalattention function -> output

- args:
    - dim (int) : 
    - attn_type (str) : linear -> LinearAttention / full -> Vannila
    - head_num (int) : 

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - out (torch.tensor) : (B, N, D)

In [3]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
# data = torch.arange(B*N*D).reshape(B, N, D).astype()
mhsa = copy_model.MultiHeadCausalAttention(dim=32, attn_type="linear", head_num=4)
mhsa(data).shape

torch.Size([2, 10, 32])

## MultieadSourceAttention

source attention. this is for attention using output of encoder(memory). 

- args:
    - dim (int) : 特徴次元数
    - attn_type (str) : linear -> LinearAttention / full -> Vannila
    - head_num (int) : ヘッド数

- inputs:
    - x (torch.tensor) : (B, N, D) input tensor
    - memory (torch.tensor) : (B, N, D) output of encoder

- outputs:
    - out (torch.tensor) : (B, N, D)

In [4]:
B, N, D = 2, 10, 32
data = torch.rand(B, N, D)
memory = torch.rand(B, N+5, D)
# data = torch.arange(B*N*D).reshape(B, N, D).astype()
mhsa = copy_model.MultiHeadSourceAttention(dim=32, attn_type="linear", head_num=4)
mhsa(data, memory).shape

torch.Size([2, 10, 32])

## DecoderLayer

コピータスクのデコーダレイヤー
(self)causalattention -> sourceattention -> feedforward
residual passとそれに伴ったLayerNormを実装

- args:
    - dim (int) : 潜在次元数
    - attn_type (str) : attentionのタイプ
    - head_num (int) : ヘッド数
    - ff_hidnum (int) : feedforwardでの隠れ層の次元

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - out (torch.tensor) : (B, N, D)

In [3]:
B, N, D = 2, 10, 32
memory = torch.rand(B, N+5, D)
data = torch.rand(B, N, D)
dl = copy_model.DecoderLayer(D, "linear", 4, 256)
dl(data, memory).shape

torch.Size([2, 10, 32])

## Decoder

コピータスクのデコーダ
DecoderLayerを所望の数積み重ねる

- args:
    - depth : 層の数
    - dim : 潜在次元数
    - head_num : ヘッド数
    - attn_type : linear -> LinearAttention / full -> Vannila
    - ff_hidnum : feedforwardにおける潜在次元数

- inputs:
    - x (torch.tensor) : (B, N, D)

- outputs:
    - x : (torch.tensor) : (B, N, D)

In [3]:
B, N, D = 2, 10, 32
memory = torch.rand(B, N+5, D)
data = torch.rand(B, N, D)
dec = copy_model.Decoder(4, D, 4, "linear", 256)
dec(data, memory).shape

torch.Size([2, 10, 32])

## FinalLayer

出力の直前の層
output of transformer -> linear -> output
nn.CrossEntropyでソフトマックスを行うので、ここでは実装しない

args:
  - dim (int) : 特徴次元
  - vocab_num (int) : 語彙数
  - hif_dim (int) : 中間層のユニット数

In [3]:
B, N, D = 2, 10, 32
vocab_num = 10
data = torch.rand(B, N, D)
final = copy_model.FinalLayer(D, 30, 2048)
final(data).shape

torch.Size([2, 10, 30])

## CopyModel

コピータスク専用のTransformerモデル。(マスクは考えない)
position -> encoder -> decoder -> finallayer(最後にsoftmaxしない)

- args:
    - device (str) : cpu or gpu name
    - ed (int) : 潜在次元数
    - vocab_num (int) : number of vcab
    - N_enc (int) : number of encoderlayer
    - N_dec (int) : number of decoderlayer
    - h_enc (int) : number of multihead in encoder
    - h_dec (int) : number of multihead in decoder

- inputs:
    - x (torch.tensor) : (B, len_x)
    - y (torch.tensor) : (B, len_y)

- outputs:
    - out (torch.tensor) : (B, len_gen)

In [4]:
B, N, D = 2, 10, 32
vocab_num = 36
x = torch.arange(36).reshape(4, 9)
y = torch.arange(36).reshape(4, 9)
cp = copy_model.CopyModel("cpu", D, vocab_num, "full", 4, 4, 8, 8, 2048, 64)
cp(x, y).shape

torch.Size([4, 9, 36])

## full_attention

Scale Dot-Product Attention (論文Fig.2)

inputs:
  - query (torch.tensor) (B, h, n, d)
  - key (torch.tensor) (B, h, n, d)
  - value (torch.tensor) (B, h, n, d)
  - causal (bool) : Trueの時、時間マスク(三角行列)を使用
  - dropout (float) : ドロップアウトの割合(使用するなら)

return:
  - out (torch.tensor) (B, h, n, d)

In [2]:
B, H, N_kv, N_q, D = 3, 4, 16, 32, 256

q = torch.rand(B, H, N_q, D)
k = torch.rand(B, H, N_kv, D)
v = torch.rand(B, H, N_kv, D)
copy_model.full_attention(q, k, v).shape

torch.Size([3, 4, 32, 256])

In [3]:
B, H, N, D = 3, 4, 16, 256

q = torch.rand(B, H, N, D)
k = torch.rand(B, H, N, D)
v = torch.rand(B, H, N, D)
copy_model.full_attention(q, k, v, causal=True).shape

torch.Size([3, 4, 16, 256])

## phi

nonlinear function for linear attention, which is described in the paper.

$$
φ(x) = elu(x) + 1
$$

In [4]:
data = torch.arange(-36, 36).reshape(6,2,6).type(torch.float32)
copy_model.phi(data)

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [0.0000e+00, 5.9605e-08, 1.1921e-07, 2.9802e-07, 8.3447e-07,
          2.2650e-06]],

        [[6.1393e-06, 1.6689e-05, 4.5419e-05, 1.2338e-04, 3.3545e-04,
          9.1189e-04],
         [2.4788e-03, 6.7379e-03, 1.8316e-02, 4.9787e-02, 1.3534e-01,
          3.6788e-01]],

        [[1.0000e+00, 2.0000e+00, 3.0000e+00, 4.0000e+00, 5.0000e+00,
          6.0000e+00],
         [7.0000e+00, 8.0000e+00, 9.0000e+00, 1.0000e+01, 1.1000e+01,
          1.2000e+01]],

        [[1.3000e+01, 1.4000e+01, 1.5000e+01, 1.6000e+01, 1.7000e+01,
          1.8000e+01],
         [1.9000e+01, 2.0000e+01, 2.1000e+01, 2.2000e+01, 2.3000e+01,
          2.4000e+01]],

        [[2.5000e+01, 2.6000e+01, 2.7000e+01, 2.8000e+01, 2.

In [5]:
data[-1]

tensor([[24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35.]])

## linear_attn_elu

In [2]:
B, H, N, D = 2, 4, 10, 32

q = torch.rand(B, H, N, D)
k = torch.rand(B, H, N, D)
v = torch.rand(B, H, N, D)
copy_model.linear_attn_elu(q, k, v).shape

torch.Size([2, 4, 10, 32])

In [3]:
B, H, N_kv, N_q, D = 3, 4, 16, 32, 256

q = torch.rand(B, H, N_q, D)
k = torch.rand(B, H, N_kv, D)
v = torch.rand(B, H, N_kv, D)
copy_model.linear_attn_elu(q, k, v).shape

torch.Size([3, 4, 32, 256])

## caisal_linear_attn_elu

In [2]:
B, H, N, D = 2, 4, 10, 32

torch.manual_seed(0)

q = torch.rand(B, H, N, D)
k = torch.rand(B, H, N, D)
v = torch.rand(B, H, N, D)
# q[:,:,9,:] = 10.
# k[:,:,9,:] = 10.
# v[:,:,9,:] = 10.
o = copy_model.causal_linear_attn_elu(q, k, v)

In [3]:
o[0,0,9,:]

tensor([0.0821, 0.0814, 0.0923, 0.1001, 0.0840, 0.0994, 0.0919, 0.0922, 0.0957,
        0.0820, 0.0912, 0.0923, 0.0944, 0.0991, 0.0950, 0.0850, 0.0853, 0.0866,
        0.0948, 0.0935, 0.0881, 0.0990, 0.0844, 0.0868, 0.0889, 0.0910, 0.0916,
        0.0960, 0.0853, 0.0939, 0.0908, 0.0908])

In [2]:
B, H, N, D = 2, 4, 10, 32

torch.manual_seed(0)

q = torch.rand(B, H, N, D)
k = torch.rand(B, H, N, D)
v = torch.rand(B, H, N, D)
# q[:,:,9,:] = 10.
# k[:,:,9,:] = 10.
# v[:,:,9,:] = 10.
o = copy_model.causal_linear_attn_elu(q, k, v)
o[0,0,9,:]

tensor([0.4663, 0.6051, 0.4732, 0.4689, 0.4473, 0.5631, 0.5896, 0.5327, 0.3713,
        0.5081, 0.5653, 0.5160, 0.5343, 0.4732, 0.5112, 0.6187, 0.4419, 0.3951,
        0.6316, 0.5591, 0.5036, 0.7060, 0.6905, 0.2582, 0.5555, 0.4821, 0.5335,
        0.5393, 0.4463, 0.5635, 0.5828, 0.3450])