-
Notifications
You must be signed in to change notification settings - Fork 0
/
multihead_attention.py
30 lines (27 loc) · 1.32 KB
/
multihead_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
from self_attention import SelfAttention
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads: int = 6, d_model: int = 512, masked: bool = False):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.d_model = d_model
self.attention_mechanism = SelfAttention(masked=masked)
self.w_q = []
self.w_k = []
self.w_v = []
for i in range(self.n_heads):
self.w_q.append(nn.Linear(self.d_model, self.d_model))
self.w_k.append(nn.Linear(self.d_model, self.d_model))
self.w_v.append(nn.Linear(self.d_model, self.d_model))
self.w_o = nn.Linear(self.d_model * self.n_heads, self.d_model)
def forward(self, q: torch.tensor, k: torch.tensor, v: torch.tensor) -> torch.tensor:
assert (q.size(2) == self.d_model), f"Error: the embedding space is not {self.d_model}-dim"
concatenated_attentions = torch.empty(0)
for i in range(self.n_heads):
proj_q = self.w_q[i](q)
proj_k = self.w_k[i](k)
proj_v = self.w_v[i](v)
attention = self.attention_mechanism(proj_q, proj_k, proj_v)
concatenated_attentions = torch.cat([concatenated_attentions, attention], dim=-1)
return self.w_o(concatenated_attentions)