In [1]:
import torch 
import numpy
import torch.nn as nn # 学习点1
import copy
import math
import torch.nn.functional as F # 学习点2

In [2]:
torch.cuda.current_device()

0

# Encoder和Decoder中的sub_layer的实现

In [69]:
# 重点理解
#1.计算q,k,v的时候必须要先reshape再permute而不能直接一步reshape到位。以k为例
# 直接经过linear projection得到的raw k中不同的head的信息都在最后一维中，如果
# 直接reshape成 batch_size,num_heads,seq_len.embedding_dim的大小会导致两个不
# 同的head的信息直接被存在了seq_len这一个维度，而没有在head维度被区分开，这显
# 然是不能被接受的。因此正确的做法是先将embedding维度拆分为heads*dim_k，然后再
# 执行permute操作交换维度信息。
class MultiHeadAttention(nn.Module):
    def __init__(self,num_heads,embedding_dim):
        super(MultiHeadAttention,self).__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.wq = nn.Linear(embedding_dim,embedding_dim)
        self.wk = nn.Linear(embedding_dim,embedding_dim)
        self.wv = nn.Linear(embedding_dim,embedding_dim)
        
    def forward(self,query,key,value,mask=None):
        # query,key,value的尺寸:(batch_size,seq_len,embedding_dim)
        batch_size,seq_len_1,embedding_dim = query.shape # 可能是encoder侧，也可能是decoder侧的序列长度
        seq_len_2 = key.shape[1] # 一定是encoder侧的序列长度
        #q = self.wq(query).reshape(batch_size,seq_len_1,self.num_heads,-1).permute(0,2,1,3)
        #k = self.wk(key).reshape(batch_size,seq_len_2,self.num_heads,-1).permute(0,2,1,3)
        #v = self.wv(value).reshape(batch_size,seq_len_2,self.num_heads,-1).permute(0,2,1,3)
        q = self.wq(query)
        k = self.wk(key)
        v = self.wv(value)
        #print(q.shape,k.shape,v.shape)
        # 此时的q,k,v的尺寸:(batch_size,num_heads,seq_len,dim_k)
        dim_k = self.embedding_dim//self.num_heads
        attent_score = torch.matmul(q,k.transpose(1,2))/math.sqrt(dim_k) # 在cuda上
        #print(mask.shape,attent_score.shape)
        if mask != None:
            mask = mask.to('cuda:0')
            attent_score = attent_score.masked_fill(mask,-1e9)
        scaled_attent_score = F.softmax(attent_score,-1) # (batch_size,num_heads,seq_len,seq_len)
        output = torch.matmul(scaled_attent_score,v).to('cuda:0') # (batch_size,num_heads,seq_len,dim_k)
        #print('multi:',output.device)
        #output = output_mat.permute(0,2,1,3).reshape(batch_size,seq_len_1,-1) # 重新将size复原为(batch_size,seq_len,embedding_dim)
        return output
                               
class FeedForward(nn.Module):
    def __init__(self,embedding_dim,dropout = 0.1):
        super(FeedForward,self).__init__()
        self.linear_1 = nn.Linear(embedding_dim,4 * embedding_dim)
        self.linear_2 = nn.Linear(4 * embedding_dim,embedding_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        x = self.relu(self.linear_1(x))
        return self.linear_2(self.dropout(x))                               

# Encoder的实现

In [70]:
class Encoder(nn.Module):
    def __init__(self,layer,vocab_size,embedding_dim, N=6):
        super(Encoder,self).__init__()
        self.WordEmbedding = nn.Embedding(vocab_size,embedding_dim)
        self.PositionEmbedding  = PositionEncoding()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)]) # Encoder层中的多个layer
    def forward(self,x,mask):
        x = self.WordEmbedding(x)
       # print('Encoder:',x.device)
        x = self.PositionEmbedding(x)
        for layer in self.layers:
            x = layer(x,mask)
        return x
    
class PositionEncoding(nn.Module):
    def __init__(self):
        super(PositionEncoding,self).__init__()
        
    def forward(self,x):
        # size of x: (batch_zize,seq_len,embedding_dim)
        # print('PE',x.device)
        batch_size,seq_size,embedding_dim = x.shape
        pos = torch.arange(seq_size).reshape(-1,1)
        i = torch.arange(0,embedding_dim,2).reshape(1,-1)
        PE = torch.zeros(seq_size,embedding_dim)
        PE[:,0::2]= torch.sin(pos/torch.pow(10000,i/embedding_dim))
        PE[:,1::2] = torch.cos(pos/torch.pow(10000,i/embedding_dim))
        PE = PE.repeat(batch_size,1).reshape(batch_size,seq_size,embedding_dim).to('cuda:0')
        #print('PE:',PE.device)
        return x + PE

# 单个Encoder Block的实现
class EncoderLayer(nn.Module):
    def __init__(self,num_heads,embedding_size):
        super(EncoderLayer,self).__init__()
        self.multi_head_attention = MultiHeadAttention(num_heads,embedding_size) 
        self.ffn = FeedForward(embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
    
    def forward(self,x,enc_mask):
        x = self.norm(x + self.multi_head_attention(x,x,x,enc_mask))
        x = self.norm(x + self.ffn(x))
        return x


# Decoder的实现

In [71]:
class Decoder(nn.Module):
    def __init__(self,layer,vocab_size,embedding_dim,N=6):
        super(Decoder,self).__init__()
        self.WordEmbedding = nn.Embedding(vocab_size,embedding_dim)
        self.PositionEmbedding = PositionEncoding()
        self.layers = nn.ModuleList(copy.deepcopy(layer) for _ in range(N))
        self.linear = nn.Linear(embedding_dim,vocab_size)
        
    def forward(self,x,enc_out,dec_mask):
        x = self.WordEmbedding(x) # x是idx，输出得到对应的词向量
        x = self.PositionEmbedding(x)
        for layer in self.layers:
            x = layer(x,enc_out,dec_mask)
        return self.linear(x)
    
class DecoderLayer(nn.Module):
    def __init__(self,num_heads,embedding_dim):
        super(DecoderLayer,self).__init__()
        self.masked_multi_head_attention = MultiHeadAttention(num_heads,embedding_dim)
        self.dec_attention = MultiHeadAttention(num_heads,embedding_dim)
        self.ffn = FeedForward(embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
        
    # 这里的dec_seq_mask为什么需要直接传入
    def forward(self,x,enc_out,dec_hidden_mask):
        dec_seq_len = x.shape[1]
        dec_seq_mask = self.get_dec_seq_mask(dec_seq_len)
        a = self.masked_multi_head_attention(x,x,x,dec_seq_mask)
        #print(a.device)
        x = self.norm(self.masked_multi_head_attention(x,x,x,dec_seq_mask) + x)
        #print(x.shape,enc_out.shape)
        #print(dec_hidden_mask.shape)
        x = self.norm(self.dec_attention(x,enc_out,enc_out,dec_hidden_mask) + x)
        x = self.norm(self.ffn(x)+x)
        return x
    
    @staticmethod
    def get_dec_seq_mask(dec_seq_len):
        ones = torch.ones(1,dec_seq_len,dec_seq_len) # 第一维度是batch维度
        valid_mat = torch.tril(ones)
        mask_mat = (1 - valid_mat).to(torch.bool)
        return mask_mat


# Transformer的实现

In [72]:
class Transformer(nn.Module):
    def __init__(self,encoder,decoder):
        super(Transformer,self).__init__()
        self.Encoder = encoder
        self.Decoder = decoder
    def forward(self,x,y,enc_mask,dec_mask):
        enc_out = self.Encoder(x,enc_mask)
        dec_out = self.Decoder(y,enc_out,dec_mask)
        return dec_out

# Preprocessing

In [73]:
import pandas as pd
import numpy as np
import re

def normalize_text(content):
    content = re.sub(r"([.!?])", r" \1", content) # 将标点符号和字符分隔开，到时候方便分离字符和标点符号
    content = re.sub(r"[^a-zA-Z.!?]+", r" ", content)
    return content

def read_raw_data(filename):
    with open(filename,'r',encoding='utf-8') as f:
        lines = f.readlines()
        pairs = [[normalize_text(pair) for pair in line.split('\t')]for line in lines]
        # 过滤
        #content = re.sub(r"([.!?])", r" \1", content)
        #content = re.sub(r"[^a-zA-Z.!?]+", r" ", content)
    return pairs


def process_raw_data(pairs,reserved_tokens):
    # 将输入的raw data转化为
    src_lang = [pair[0] for pair in pairs]
    tgt_lang = [pair[1] for pair in pairs]
    
    vocab_src = Vocab(pairs = src_lang,reserved_tokens = reserved_tokens)
    vocab_tgt = Vocab(pairs = tgt_lang,reserved_tokens = reserved_tokens)
    return vocab_src,vocab_tgt



In [74]:
from collections import defaultdict

class Vocab:
    def __init__(self,pairs,reserved_tokens=None,min_freq=2):
        self.reserved_tokens = reserved_tokens
        self.idx2token = []
        self.token2idx = {}
        self.UNK_TOKEN = '<UNK>'
        token_freq = defaultdict(int) # 记录token出现的频率
        for pair in pairs:
            for token in pair.split(' '):
                token_freq[token] += 1
        #print(token_freq)
        # 将所有字符装入列表中
        unique_tokens = reserved_tokens if reserved_tokens else []
        unique_tokens += [token for token, freq in token_freq.items() if freq>=min_freq]
        
        # 检查是否将UNK的token装入了列表中
        if self.UNK_TOKEN not in unique_tokens:
            unique_token = [self.UNK_TOKEN] + unique_token
        
        # 构造idx2token和token2idx
        for token in unique_tokens:
            self.idx2token.append(token)
            self.token2idx[token]=len(self.idx2token)-1
        
        self.unk = self.token2idx[self.UNK_TOKEN] # 记录unk在词表中的下标，用于标记语料中未出现的token
        
    def __len__(self):
        # 返回词表长度
        return len(self.idx2token)
    
    def __getitem__(self,token):
        # 输入token，返回idx
        return self.token2idx.get(token,self.unk)
    
    def convert_token_to_idx(self,tokens):
        # 将输入的一串token转化为idx并以列表形式输出
        return [self[token] for token in tokens]
    
    def convert_idx_to_token(self,ids):
        # 将输入的一串ids转化为token并以列表形式输出
        return [self.idx2token[idx] for idx in ids]
        

In [75]:
from torch.utils.data import Dataset,DataLoader
from torch.nn.utils.rnn import pad_sequence

class TranslationDataset(Dataset):
    def __init__(self,vocab_src,vocab_tgt,pairs):
        self.src_vocab,self.tgt_vocab = vocab_src, vocab_tgt
        self.src_lang = [pair[0] for pair in pairs]
        self.tgt_lang = [pair[1] for pair in pairs]
        
    #__len
    def __len__(self):
        return len(self.src_lang)
    # 实现索引数据集中的某一个元素
    
    def __getitem__(self,idx):
        src_tokens = self.src_lang[idx]
        tgt_tokens = self.tgt_lang[idx]
        encoder_input = vocab_src.convert_token_to_idx(src_tokens.split())
        decoder_input = vocab_tgt.convert_token_to_idx(['<BOS>'] + tgt_tokens.split())
        decoder_output = vocab_tgt.convert_token_to_idx(tgt_tokens.split() + ['EOS'])
        return encoder_input,decoder_input,decoder_output
    
def collate_fn(batch):
    encoder_raw_input = [torch.tensor(example[0]) for example in batch]
    decoder_raw_input = [torch.tensor(example[1]) for example in batch]
    decoder_raw_output  = [torch.tensor(example[2]) for example in batch]
    
    encoder_input = pad_sequence(encoder_raw_input,batch_first=True,padding_value=vocab_src['[<PAD>]'])
    decoder_input = pad_sequence(decoder_raw_input,batch_first=True,padding_value=vocab_tgt['[<PAD>]'])
    decoder_output = pad_sequence(decoder_raw_output,batch_first=True,padding_value=vocab_tgt['[<PAD>]'])
    return encoder_input,decoder_input,decoder_output

def mask_process(enc_input,dec_input):

    # 生成encoder中的mask
    enc_valid_pos = torch.where(enc_input>0.0,1.0,0.0)
    enc_valid_pos = torch.unsqueeze(enc_valid_pos,2)
    enc_valid_mat = torch.bmm(enc_valid_pos,enc_valid_pos.transpose(1,2))
    enc_mask = (1 - enc_valid_mat).to(torch.bool)
    
    # 生成encoder_decoder_mask
    dec_valid_pos = torch.where(dec_input > 0.0,1.0,0.0)
    dec_valid_pos = torch.unsqueeze(dec_valid_pos,2)
    enc_dec_mat = torch.bmm(dec_valid_pos,enc_valid_pos.transpose(1,2)) #前后顺序不能调换，否则会导致维度信息错误
    enc_dec_mask = (1 - enc_dec_mat).to(torch.bool)
    
    return enc_mask,enc_dec_mask
    

In [76]:
import matplotlib.pyplot as plt

filename = r'C:\Users\吴昊伦\Desktop\nlp\data\data\eng-fra.txt'
reserved_tokens = ['<UNK>','<BOS>','<EOS>','<PAD>']
embedding_dim = 512
num_heads =  8
epochs = 10
batch_size = 10

pairs = read_raw_data(filename)
vocab_src,vocab_tgt = process_raw_data(pairs,reserved_tokens)

dataset = TranslationDataset(vocab_src,vocab_tgt,pairs)
dataloader = DataLoader(dataset,batch_size=batch_size,collate_fn=collate_fn,shuffle=True)

#print(next(iter(dataloader)))
src_size = len(vocab_src) # 源文本的词表长度
tgt_size = len(vocab_tgt) # 目标文本的词表长度
#print(src_size,tgt_size)

# 实例化encoder decoder和transformer
encoder = Encoder(EncoderLayer(num_heads,embedding_dim),src_size,embedding_dim)
decoder = Decoder(DecoderLayer(num_heads,embedding_dim),tgt_size,embedding_dim)
model = Transformer(encoder,decoder).to('cuda:0')
#print(next(model.parameters()).device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
loss_func = nn.CrossEntropyLoss(ignore_index=vocab_tgt['<PAD>'])

def train(model,loss_func,dataloader,epochs):
    
    # batch normalization用的是每一个batch的均值和方差，启用dropout
    model.train()
    
    avg_loss = []
    for epoch in range(epochs):
        total_loss = 0
        for x,y_in,y_out in dataloader:
            #print(x)
            #print(y_in)
            # 将向量载入gpu中
            x = x.to('cuda:0')
            y_in = y_in.to('cuda:0')
            y_out = y_out.to('cuda:0')
            enc_mask,enc_dec_mask = mask_process(x,y_in)
            #print(enc_mask.device,enc_dec_mask.device)
            y_pred = model(x,y_in,enc_mask,enc_dec_mask)
            # print(y_pred.device)
            # print(y_pred.shape)
            
            # 清空梯度
            optimizer.zero_grad()
            
            # 计算损失
            loss = loss_func(y_pred.permute(0,2,1),y_out) #permute的作用是什么
            total_loss += loss
            
            # 更新梯度
            loss.backward()
            optimizer.step()
        print('epoch: '+ str(epoch)+ 'finish training')
        avg_loss.append(total_loss/len(x))
    return avg_loss

loss = train(model,loss_func,dataloader,epochs)
plt.plot(loss)
plt.show()

epoch: 0finish training
epoch: 1finish training
epoch: 2finish training


RuntimeError: [enforce fail at ..\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 266240 bytes.

# Word Embedding

In [7]:
max_num_src_words = 8 # 源单词序列的最大长度，即总共有多少个单词
max_num_tgt_words = 8

batch_size = 2 # 总共有几个序列

embedding_size = 8 # 原论文使用了512

src_len =  torch.tensor([2,4]).to(torch.int32)
tgt_len = torch.tensor([4,3]).to(torch.int32)

# 单词索引构成的源句子和目标句子，注意0索引表示padding，单词索引从1开始
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len) - L)),0) 
                     for L in src_len]) # 源序列输入时要保证长度一致，需要pad来填充,然后先将序列的tensor扩展成2维的再拼接起来
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len) - L)),0)
                     for L in tgt_len])
print(src_seq,'\n',tgt_seq)

tensor([[7, 5, 0, 0],
        [1, 1, 4, 6]]) 
 tensor([[5, 2, 5, 3],
        [2, 7, 7, 0]])


In [9]:
# 构造embedding,可以用pytorch的api实现,会取出对应索引对应的embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, embedding_size) # 需要+1是因为padding占据了一位
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, embedding_size)

src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

In [10]:
src_embedding_table.weight

Parameter containing:
tensor([[ 2.4201, -1.3914,  0.2174,  0.6065, -0.7812,  0.3648,  0.1808,  1.5536],
        [ 1.3456, -0.4914, -0.7922, -0.0246,  0.7840, -0.4474, -0.7568,  0.6310],
        [-0.7511, -1.8326,  0.5105, -0.6110,  0.9412,  0.8918, -0.5526, -0.2191],
        [-0.8040, -1.2852, -1.3060, -0.2187, -1.4082, -0.9575,  0.7691, -0.0717],
        [-0.4035,  0.2432,  0.9620, -0.3032,  0.5702, -0.8582,  0.5748, -1.5265],
        [-1.1838, -1.8929,  0.0855,  0.4732, -0.2798,  0.6351, -0.6859, -0.1743],
        [ 0.8086,  0.2111,  1.0780,  1.9412, -0.8829, -0.8943, -0.3597, -0.6305],
        [ 0.3129, -1.0148, -1.0447,  0.2352, -0.4093,  0.0517,  0.2710,  1.3452],
        [ 0.8691, -0.6900, -0.5903, -0.8140, -2.2298,  1.3410, -1.4226,  1.0431]],
       requires_grad=True)

# Position Embedding

In [12]:
pos_mat = torch.arange(max(src_len)).reshape(-1,1)
i_mat = torch.pow(10000,torch.arange(0,embedding_size,2).reshape(1,-1)/embedding_size)
PE_table = torch.zeros(max(src_len),embedding_size)
PE_table[:,0::2] = torch.sin(pos_mat/i_mat)
PE_table[:,1::2] = torch.cos(pos_mat/i_mat)

PE = nn.Embedding(max(src_len),embedding_size)
PE.weight = nn.Parameter(PE_table,requires_grad = False)

# 获取位置信息，得到源序列和目标序列中每一个序列对应的位置编码
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len])
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len])
src_pe = PE(src_pos)
tgt_pe = PE(tgt_pos)

# Encoder中的Self-Attention Mask

In [24]:
valid_encoder_pos_mat = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len) - L)),0) for L in src_len]),1)
valid_encoder_mat = torch.bmm(valid_encoder_pos_mat.transpose(1,2),valid_encoder_pos_mat)
mask_encoder_mat = (1 - valid_encoder_mat).to(torch.bool)
mask_encoer_mat

tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]])

#    intra-attention的Mask？？？

# 构造decoder的self-attention的mask

In [36]:
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L,L)),(0,max(tgt_len)-L,0,max(tgt_len)-L)),0) 
                                                      for L in tgt_len]) # pad的方向是左右上下
mask_decoder_tri_matrix = (1 - valid_decoder_tri_matrix).to(torch.bool)

In [38]:
mask_decoder_tri_matrix

tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [ True,  True,  True,  True]]])

# 构造self-attention

In [None]:
def scaled_dot_product_attention(Q,K,V,masked_atten):
    # shape of Q,K,V: (batch_size*num_heads*seq_len*)
    score = torch.bmm(Q,K.transpose(-2,-1))/torch.sqrt(embedding_size)
    masked_score = score.masked_fill(masked_atten,-1e9)
    prob = F.softmax(masked_score,-1)
    context = torch.bmm(prob,V)
    return context

In [39]:
seq_len = 5
batch_size = 2
embedding_dim = 8
#pos = torch.arange(seq_len).reshape(-1,1).repeat(1,batch_size)
pos = torch.arange(seq_len).reshape(-1,1).repeat(batch_size,1)
#i_dim = torch.arange(0,embedding_dim,2).reshape(1,-1).repeat(batch_size,1)
i_dim = torch.arange(0,embedding_dim,2).reshape(1,-1).repeat(1,batch_size)
aa = torch.sin(pos/torch.pow(10000,i_dim/embedding_dim))

In [40]:
pos,i_dim

(tensor([[0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4]]),
 tensor([[0, 2, 4, 6, 0, 2, 4, 6]]))

In [35]:
aa.reshape(batch_size,)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.8415,  0.0998,  0.0100,  0.0010,  0.8415,  0.0998,  0.0100,  0.0010],
        [ 0.9093,  0.1987,  0.0200,  0.0020,  0.9093,  0.1987,  0.0200,  0.0020],
        [ 0.1411,  0.2955,  0.0300,  0.0030,  0.1411,  0.2955,  0.0300,  0.0030],
        [-0.7568,  0.3894,  0.0400,  0.0040, -0.7568,  0.3894,  0.0400,  0.0040],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.8415,  0.0998,  0.0100,  0.0010,  0.8415,  0.0998,  0.0100,  0.0010],
        [ 0.9093,  0.1987,  0.0200,  0.0020,  0.9093,  0.1987,  0.0200,  0.0020],
        [ 0.1411,  0.2955,  0.0300,  0.0030,  0.1411,  0.2955,  0.0300,  0.0030],
        [-0.7568,  0.3894,  0.0400,  0.0040, -0.7568,  0.3894,  0.0400,  0.0040]])

In [48]:
seq_len = 5
batch_size = 2
embedding_dim = 8
#pos = torch.arange(seq_len).reshape(-1,1).repeat(1,batch_size)
pos = torch.arange(seq_len).reshape(-1,1)
#i_dim = torch.arange(0,embedding_dim,2).reshape(1,-1).repeat(batch_size,1)
i_dim = torch.arange(0,embedding_dim,2).reshape(1,-1)
PE = torch.zeros(seq_len,embedding_dim)
PE[:,0::2] = torch.sin(pos/torch.pow(10000,i_dim/embedding_dim))
PE[:,1::2] = torch.cos(pos/torch.pow(10000,i_dim/embedding_dim))


In [49]:
PE

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
          9.9955e-01,  3.0000e-03,  1.0000e+00],
        [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
          9.9920e-01,  4.0000e-03,  9.9999e-01]])

In [50]:
PE.repeat(batch_size,1).reshape(batch_size,seq_len,embedding_dim)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00]

In [104]:
import random
import math
batch_size = 2
seq_len = 6
embedding_dim = 8
num_heads = 2

a = torch.randn(batch_size,seq_len,embedding_dim)
b = torch.randn(batch_size,seq_len,embedding_dim)
c = torch.randn(batch_size,seq_len,embedding_dim)
wq = nn.Linear(embedding_dim,embedding_dim)
wk = nn.Linear(embedding_dim,embedding_dim)
wv = nn.Linear(embedding_dim,embedding_dim)

In [108]:
key = a
query = b
value = c


k = wk(key) #(batch_size,seq_len_embedding_dim)
# print('raw k:',k,k.shape)
k = k.reshape(batch_size,seq_len,num_heads,-1) # batch_size,seq_len,num_heads,dim_k
# print('reshape k:',k,k.shape)
k=k.permute(0,2,1,3) # batch_size,num_heads.seq_len,dim_k
# print('final k:',k,k.shape)

q = wq(query).reshape(batch_size,seq_len,num_heads,-1).permute(0,2,1,3)
v = wv(query).reshape(batch_size,seq_len,num_heads,-1).permute(0,2,1,3) 
dim_k = embedding_dim//num_heads

attent_score = torch.matmul(q,k.transpose(2,3))/math.sqrt(dim_k)
#print(attent_score,attent_score.shape) 
scaled_attent_score = F.softmax(attent_score,-1) # (batch_size,num_heads,seq_len,seq_len)
#print(scaled_attent_score,scaled_attent_score.shape) 
context_mat = torch.matmul(scaled_attent_score,v) # (batch_size,)
print(context_mat,context_mat.shape) #(batch_size,num_heads,seq_len,dim_k)

tensor([[[[ 0.2291, -0.1363,  0.4670, -0.1756],
          [ 0.1886, -0.1864,  0.4660, -0.1887],
          [ 0.2121, -0.1470,  0.4563, -0.1611],
          [ 0.1453, -0.1846,  0.4733, -0.2189],
          [ 0.2455, -0.2170,  0.5185, -0.2419],
          [ 0.3388, -0.1870,  0.5067, -0.1635]],

         [[-0.1176,  0.6423,  0.2355,  0.1742],
          [ 0.0247,  0.4520,  0.3118,  0.0976],
          [ 0.1012,  0.3430,  0.3379,  0.0537],
          [ 0.0823,  0.3480,  0.3429,  0.0338],
          [-0.0160,  0.4432,  0.3425,  0.0157],
          [-0.0215,  0.4969,  0.2840,  0.1019]]],


        [[[-0.0488, -0.2397,  0.1415, -0.6647],
          [-0.0662, -0.2207,  0.1350, -0.7083],
          [-0.0219, -0.2494,  0.1512, -0.6331],
          [-0.0363, -0.2556,  0.1776, -0.6930],
          [-0.0749, -0.2134,  0.1155, -0.6936],
          [-0.1219, -0.1628,  0.0692, -0.7518]],

         [[-0.5035, -0.3416,  0.0799, -0.3281],
          [-0.5935, -0.2866, -0.1025, -0.3248],
          [-0.5911, -0.1463, -0.

In [89]:
k = wq(key) #(batch_size,seq_len_embedding_dim)
k = k.reshape(batch_size,num_heads,seq_len,-1) # batch_size,seq_lne,num_heads,dim_k
k

tensor([[[[ 0.7317, -0.0352,  0.1948,  0.4176],
          [-0.1013, -0.7229,  0.3077, -0.7610],
          [-0.1686, -0.1843,  0.7883,  0.1585],
          [ 0.2514, -0.1433, -0.3918, -0.6952],
          [ 0.7596,  0.7998,  0.5008,  0.0423],
          [-0.0270,  0.3743, -0.1011, -0.2615]],

         [[-0.6680, -0.8831, -0.6066,  0.8854],
          [ 0.2504, -0.0797, -1.4161, -0.8618],
          [ 0.1207, -0.3178,  0.1992,  0.5513],
          [-0.4767, -0.6457, -0.4611, -0.5847],
          [ 0.2248,  0.0268, -0.3078,  0.8197],
          [ 0.2010, -0.3610,  0.4303, -0.2490]]],


        [[[-0.1199, -0.4672,  1.0029,  0.7095],
          [ 0.2995,  0.8645, -0.9686, -0.5742],
          [ 1.1665,  0.2206,  0.1515, -0.1391],
          [ 0.0688,  0.2192,  0.3753, -0.5280],
          [ 0.1752,  0.1438, -0.3738,  1.6293],
          [-0.4363, -0.1756,  0.1196, -0.1044]],

         [[ 0.3226,  0.3679, -0.3191,  0.9891],
          [-0.4345, -0.6546,  0.8613, -0.1316],
          [ 1.5812,  0.5335,  0.

tensor([[0.4592]])