-
Notifications
You must be signed in to change notification settings - Fork 0
/
BertEmbeddings.py
33 lines (31 loc) · 1.57 KB
/
BertEmbeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
class BertEmbeddings(nn.Module):
def __init__(self, vocab_size, max_len, hidden_size, device, dropout_prob=0.1):
super(BertEmbeddings, self).__init__()
self.device = device
self.max_len = max_len
self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
self.type_embeddings = nn.Embedding(2, hidden_size)
self.position_embeddings = nn.Embedding(self.max_len, hidden_size)
self.emb_normalization = nn.LayerNorm(hidden_size)
self.emb_dropout = nn.Dropout(p=dropout_prob)
def forward(self, input_token, segment_ids):
token_embeddings = self.token_embeddings(input_token)
type_embeddings = self.type_embeddings(segment_ids)
# 生成固定位置信息
position_ids = []
input_count = list(input_token.size())[0]
max_len = list(input_token.size())[1]
for i in range(input_count):
tmp = [x for x in range(max_len)]
position_ids.append(tmp)
position_ids = torch.tensor(position_ids).to(self.device)
postion_embeddings = self.position_embeddings(position_ids)
# print("token_embeddings.shape:", token_embeddings.shape)
# print("type_embeddings.shape:", type_embeddings.shape)
# print("postion_embeddings.shape:", postion_embeddings.shape)
embedding_x = token_embeddings + type_embeddings + postion_embeddings
embedding_x = self.emb_normalization(embedding_x)
embedding_x = self.emb_dropout(embedding_x)
return embedding_x