In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import jax
import jax.numpy as jnp
from jax import random


In [3]:
class Projector(nn.Module):
    """
    Making projection matrix(Q, K, V) for each attention head
    When you call this class, it returns projection matrix of each attention head
    For example, if you call this class with 8 heads, it returns 8 set of projection matrices (Q, K, V)
    Args:
        num_heads: number of heads in MHA, default 8
        dim_head: dimension of each attention head, default 64
    """
    def __init__(self, num_heads: int = 8, dim_head: int = 64) -> None:
        super(Projector, self).__init__()
        self.dim_model = num_heads * dim_head
        self.num_heads = num_heads
        self.dim_head = dim_head

    def __call__(self):
        fc_q = nn.Linear(self.dim_model, self.dim_head)
        fc_k = nn.Linear(self.dim_model, self.dim_head)
        fc_v = nn.Linear(self.dim_model, self.dim_head)
        return fc_q, fc_k, fc_v


class MultiHeadAttention(nn.Module):
    """
    Class for multi-head attention (MHA) module in vanilla transformer
    We apply linear transformation to input vector by each attention head's projection matrix (8, 512, 64)
    Other approaches are possible, such as using one projection matrix for all attention heads (1, 512, 512)
    and then split into each attention heads (8. 512, 64)
    Args:
        dim_model: dimension of model's latent vector space, default 512 from official paper
        num_heads: number of heads in MHA, default 8 from official paper
        dropout: dropout rate, default 0.1
    Math:
        MHA(Q, K, V) = Concat(Head1, Head2, ... Head8) * W_concat
    Reference:
        https://arxiv.org/abs/1706.03762
    """
    def __init__(self, dim_model: int = 512, num_heads: int = 8, dropout: float = 0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        self.dim = dim_model
        self.num_heads = num_heads
        self.dropout = dropout
        self.dim_head = int(self.dim / self.num_heads)  # dimension of each attention head
        self.dot_scale = torch.sqrt(torch.tensor(self.dim_head))  # scale factor for Q•K^T Result

        # linear combination: projection matrix(Q_1, K_1, V_1, ... Q_n, K_n, V_n) for each attention head
        self.projector = Projector(self.num_heads, self.dim_head)  # init instance
        self.projector_list = [list(self.projector()) for _ in range(self.num_heads)]  # call instance
        self.fc_concat = nn.Linear(self.dim, self.dim)  # for concatenation of each attention head

    def forward(self, x: torch.Tensor, mask: bool = None) -> torch.Tensor:
        """
        1) make Q, K, V matrix for each attention head: [BS, HEAD, SEQ_LEN, DIM_HEAD], ex) [10, 8, 512, 64]
        2) Do self-attention in each attention head
            - Matmul (Q, K^T) with scale factor (sqrt(DIM_HEAD))
            - Mask for padding token (Option for Decoder)
            - Softmax
            - Matmul (Softmax, V)
        3) Concatenate each attention head & linear transformation (512, 512)
        """
        # 1) make Q, K, V matrix for each attention head
        Q, K, V = [], [], []

        for i in range(self.num_heads):
            Q.append(self.projector_list[i][0](x))
            K.append(self.projector_list[i][1](x))
            V.append(self.projector_list[i][2](x))

        Q = torch.stack(Q, dim=1)
        K = torch.stack(K, dim=1)
        V = torch.stack(V, dim=1)
        # 2) Do self-attention in each attention head
        attention_score = torch.matmul(Q, K.transpose(-1, -2)) / self.dot_scale
        if mask is not None:  # for padding token
            attention_score[mask] = float('-inf')
        attention_dist = F.softmax(attention_score, dim=-1)  # [BS, HEAD, SEQ_LEN, SEQ_LEN]
        attention_matrix = torch.matmul(attention_dist, V).transpose(1, 2).reshape(x.shape[0], x.shape[1], self.dim)  # [BS, SEQ_LEN, DIM]

        # 3) Concatenate each attention head & linear transformation (512, 512)
        x = self.fc_concat(attention_matrix)
        return x

In [None]:
""" Debug for MultiHeadAttention """

x = torch.randn(10, 512, 512)
test_head = MultiHeadAttention()
test_result = test_head(x)
test_result, test_result.shape

In [47]:
""" torch.reshape test for making input shape in Vision Transformers """
patch_size, num_patches = 16, 32
x = torch.randn(10, 3, 512, 512)
x = x.reshape(x.shape[0], num_patches**2, patch_size**2 * x.shape[1])
x.shape

torch.Size([10, 1024, 768])

In [48]:
""" Check Input Embedding shape """
input_embedding = nn.Linear(768, 1024)
x = input_embedding(x)
x.shape

torch.Size([10, 1024, 1024])

In [51]:
""" make classification token for Vision Transformers """
cls_token = torch.zeros(x.shape[0], 1, x.shape[2])  # can change init method
cls_token.shape

torch.Size([10, 1, 1024])

In [52]:
torch.cat([cls_token, x], dim=1).shape

torch.Size([10, 1025, 1024])

In [11]:
""" Test for Hybrid Model """
x = torch.randn(10, 3, 512, 512)
dim_model = 512
patch_size = 16
num_patches = 32
conv = nn.Conv2d(
            in_channels=3,
            out_channels=dim_model,
            kernel_size=patch_size,
            stride=16
)
x = conv(x).reshape(x.shape[0], dim_model, num_patches**2).transpose(-1, -2)
x.shape

torch.Size([10, 1024, 512])

In [4]:
32*32

1024

In [49]:
""" Test for DeBERTa Disentangled Self-Attention """
batch, sequence, dim_model, dim_head, k = 10, 512, 2048, 64, 512
position_embedding = nn.Embedding(2*k, dim_model)
x = torch.randn(sequence, dim_model)  # [Batch, Sequence, Dim]
p_x = position_embedding(torch.arange(2*k))


fc_q = nn.Linear(dim_model, dim_head)
fc_k = nn.Linear(dim_model, dim_head)
fc_v = nn.Linear(dim_model, dim_head)
fc_qr = nn.Linear(dim_model, dim_head)  # projector for Relative Position Query matrix
fc_kr = nn.Linear(dim_model, dim_head)  # projector for Relative Position Key matrix

q = fc_q(x)
kr = fc_kr(p_x)

# c2p attention matrix
tmp_c2p= torch.stack(
    [torch.matmul(q[i, :], kr.transpose(-1, -2)) for i in range(x.shape[0])],
    dim=0
)
tmp_c2p, tmp_c2p.shape

(tensor([[ 0.0238, -4.5300, -3.4175,  ..., -0.9470, -3.4201, -0.5022],
         [-1.5505,  1.3956, -1.6272,  ...,  1.0928,  0.7447,  1.1566],
         [-3.2708,  4.4402, -0.8418,  ...,  1.0128,  0.4196, -1.3004],
         ...,
         [ 0.7353,  0.9112, -3.1176,  ...,  0.8015, -1.4260, -2.5051],
         [-1.5237, -2.7802,  0.5513,  ..., -4.7021,  6.9739, -2.3313],
         [ 1.5769, -1.0775,  2.1967,  ..., -1.8082,  3.3897, -1.3888]],
        grad_fn=<StackBackward0>),
 torch.Size([512, 1024]))

i번째 토큰의 latent vector space 1024 차원에서 max relative position 값인 k만 뽑아 내는게 목적
그리고 빼는 것도 max sequence length 만큼 상대 위치 임베딩 토큰을 구하는거구나.

In [52]:
""" tmp_c2p matrix calculation """
tmp_c2p = torch.matmul(q, kr.transpose(-1, -2))
tmp_c2p, tmp_c2p.shape

(tensor([[ 0.0238, -4.5300, -3.4175,  ..., -0.9470, -3.4201, -0.5022],
         [-1.5505,  1.3956, -1.6272,  ...,  1.0928,  0.7447,  1.1566],
         [-3.2708,  4.4402, -0.8418,  ...,  1.0128,  0.4196, -1.3004],
         ...,
         [ 0.7353,  0.9112, -3.1176,  ...,  0.8015, -1.4260, -2.5051],
         [-1.5237, -2.7802,  0.5513,  ..., -4.7021,  6.9739, -2.3313],
         [ 1.5769, -1.0775,  2.1967,  ..., -1.8082,  3.3897, -1.3888]],
        grad_fn=<MmBackward0>),
 torch.Size([512, 1024]))

In [None]:
""" c2p matrix calculation """
torch.where()

In [55]:
a = torch.randn(4)
a

tensor([ 1.3991,  1.2138, -0.2081, -1.0195])

In [56]:
min = torch.linspace(-1, 1, steps=4)  # torch.linspace 사용 하면 element-wise 하게 torch.clamp 가능
min

tensor([-1.0000, -0.3333,  0.3333,  1.0000])

In [57]:
torch.clamp(a, min=min)

tensor([1.3991, 1.2138, 0.3333, 1.0000])

In [None]:
torch.gather(tmp_c2p, -1, )

In [45]:
def relative_position_bucket(c: int, p: int, k: int) -> torch.Tensor:
    """
    Making relative position bucket for each token

    Args:
        c: content token index from content matrix
        p: position token index from position matrix
    Math:
        bucket(c, p) = 0 (c - p <= -k)
                       2k - 1 (c - p >= k)
                       c - p + k (other)
    Reference:
        https://arxiv.org/abs/2006.03654
        https://arxiv.org/abs/2111.09543
    """
    token_index = c - p
    torch.where()
    return p_index

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [59]:
torch.tensor([0,1,2,3,4,5]) - torch.tensor([1,2,3,4,5,6])

tensor([-1, -1, -1, -1, -1, -1])

In [None]:
for i in range(x.shape[0]):
    for j in range(2 * sequence):
     torch.expand(tmp_c2p[i, relative_position_bucket(i, j)])

In [18]:
position = position_embedding(torch.arange(1024))
position

tensor([[-0.9897,  0.8158,  0.6102,  ...,  0.6439, -0.8470,  1.1054],
        [ 0.7618, -0.4408, -0.0770,  ..., -0.8948,  1.3229,  0.2300],
        [ 0.4758, -0.1387, -0.8447,  ..., -2.4888, -0.5423, -1.0494],
        ...,
        [-0.1688, -0.1389,  1.5126,  ..., -0.7173, -2.0270,  1.1630],
        [ 1.8704,  0.8813,  1.1547,  ...,  0.0560, -1.6586,  1.0536],
        [-0.1594, -1.0213, -1.0092,  ..., -0.4617,  0.3844,  0.1689]],
       grad_fn=<EmbeddingBackward0>)

In [42]:
c2p = [[i for i in range(6)] for _ in range(5)]
c2p

[[0, 1, 2, 3, 4, 5],
 [0, 1, 2, 3, 4, 5],
 [0, 1, 2, 3, 4, 5],
 [0, 1, 2, 3, 4, 5],
 [0, 1, 2, 3, 4, 5]]

In [43]:
c2p = torch.gather(position, 1, torch.tensor(c2p))
c2p

tensor([[-0.9897,  0.8158,  0.6102,  0.7443,  0.6303,  0.7030],
        [ 0.7618, -0.4408, -0.0770,  0.4425,  0.5437,  0.8620],
        [ 0.4758, -0.1387, -0.8447,  0.2820,  0.1380,  0.3166],
        [-1.3630, -0.7809,  0.7347,  1.5590, -0.4329,  1.8260],
        [ 0.0065, -0.6625,  0.3151, -0.9509, -0.4702, -0.7222]],
       grad_fn=<GatherBackward0>)