In [1]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
         https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

token_list = list()
for sentence in sentences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)


In [2]:
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [3]:
# sample IsNext and NotNext to be same in small batch size
def make_data():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # MASK LM
        n_pred =  min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word2idx['[MASK]'] # make mask
            elif random() > 0.9:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
                  index = randint(0, vocab_size - 1)
                input_ids[pos] = index # replace

        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch
# Proprecessing Finished

In [4]:
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)

class MyDataSet(Data.Dataset):
  def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
    self.input_ids = input_ids
    self.segment_ids = segment_ids
    self.masked_tokens = masked_tokens
    self.masked_pos = masked_pos
    self.isNext = isNext
  
  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx):
    return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]

loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

In [5]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # [batch_size, 1, seq_len]
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]

def gelu(x):
    """
      Implementation of the gelu activation function.
      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
      Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # [seq_len] -> [batch_size, seq_len]
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]
        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size, seq_len, n_heads, d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual) # output: [batch_size, seq_len, d_model]

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.fc2(gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]
        return enc_outputs

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]
        for layer in self.layers:
            # output: [batch_size, max_len, d_model]
            output = layer(output, enc_self_attn_mask)
        # it will be decided by first token(CLS)
        h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
        print(output)
        print(masked_pos)
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        print(h_masked)
        h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]
        logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]
        return logits_lm, logits_clsf
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)

In [6]:
for epoch in range(50):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
      logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
      loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
      loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

tensor([[[ 0.0150,  0.0803,  0.1335,  ..., -0.3374, -0.3911, -0.1357],
         [ 0.0618,  0.1217,  0.1166,  ..., -0.3112, -0.4714, -0.1042],
         [ 0.0191, -0.1239,  0.0297,  ..., -0.2181, -0.3603, -0.0241],
         ...,
         [ 0.0772,  0.1451,  0.1039,  ..., -0.3639, -0.3891, -0.1062],
         [ 0.1547,  0.0765,  0.3791,  ..., -0.4754, -0.3614, -0.0632],
         [ 0.1023, -0.1089,  0.2651,  ..., -0.2462, -0.2246, -0.1020]],

        [[-0.0158,  0.0852,  0.1496,  ..., -0.3371, -0.3361, -0.2252],
         [ 0.2166,  0.0245,  0.1520,  ..., -0.4077, -0.3218,  0.0730],
         [-0.0993, -0.0466, -0.0851,  ..., -0.2425, -0.3154, -0.1258],
         ...,
         [ 0.0849,  0.2205,  0.1768,  ..., -0.3987, -0.3296, -0.1214],
         [ 0.0749,  0.1199,  0.3998,  ..., -0.4249, -0.3395, -0.1028],
         [ 0.1194, -0.0553,  0.2519,  ..., -0.2607, -0.2059, -0.1631]],

        [[-0.0107,  0.1061,  0.1554,  ..., -0.3116, -0.3071, -0.1047],
         [-0.0178, -0.0493,  0.1607,  ..., -0

tensor([[[-0.0228,  0.0656,  0.1545,  ..., -0.1720,  0.0957,  0.2239],
         [-0.0444,  0.0155, -0.0922,  ..., -0.1839, -0.0811,  0.1703],
         [-0.1373, -0.0934,  0.1707,  ..., -0.1467,  0.0479,  0.1973],
         ...,
         [-0.0060, -0.0056,  0.1709,  ..., -0.1987, -0.1090,  0.1479],
         [-0.0032, -0.0122,  0.3005,  ..., -0.3724, -0.1475,  0.2039],
         [ 0.0313, -0.0902,  0.1320,  ..., -0.2628, -0.0077,  0.1501]],

        [[-0.0385,  0.1181,  0.1162,  ..., -0.1644,  0.1541,  0.2183],
         [ 0.0928, -0.0966, -0.0299,  ..., -0.5027, -0.0795,  0.2208],
         [-0.4651, -0.0932, -0.0381,  ..., -0.2185,  0.0760,  0.0272],
         ...,
         [ 0.0056,  0.0172,  0.1593,  ..., -0.1823, -0.0318,  0.1085],
         [ 0.0349,  0.0462,  0.2897,  ..., -0.3616, -0.0601,  0.1113],
         [ 0.0192, -0.0361,  0.0916,  ..., -0.2205, -0.0067,  0.1678]],

        [[ 0.0299,  0.0761,  0.1414,  ..., -0.1759,  0.0938,  0.2133],
         [-0.0208,  0.0568,  0.0321,  ..., -0

tensor([[[ 1.0727e-01,  2.0586e-01, -9.5499e-03,  ..., -3.9485e-01,
           4.5807e-02, -1.3133e-01],
         [ 2.1407e-01,  6.1268e-02,  9.9928e-03,  ..., -2.3036e-01,
          -7.8879e-02, -1.4146e-04],
         [ 6.6606e-04,  4.3951e-02, -2.2786e-01,  ..., -1.9354e-01,
          -7.2799e-02, -4.2412e-02],
         ...,
         [ 1.2831e-01,  2.0523e-01, -1.9605e-03,  ..., -2.5680e-01,
          -2.0889e-01,  1.6366e-02],
         [ 1.4610e-01,  1.7800e-01,  2.1594e-01,  ..., -5.1211e-01,
          -1.5636e-01,  1.6795e-01],
         [ 5.2598e-02,  9.4322e-02,  1.4128e-01,  ..., -1.6139e-01,
          -3.5008e-02,  5.0725e-02]],

        [[ 9.8840e-02,  2.1264e-01,  1.6759e-03,  ..., -3.6480e-01,
          -4.7370e-03, -1.7283e-01],
         [ 1.4008e-01,  1.5641e-01, -1.6904e-02,  ..., -2.1028e-01,
          -1.6264e-01, -1.6149e-01],
         [ 2.7039e-01,  6.2256e-02, -4.9685e-02,  ..., -3.0533e-01,
          -9.7589e-02, -4.3638e-02],
         ...,
         [ 9.8611e-02,  2

tensor([[[ 0.1249,  0.2139,  0.0741,  ..., -0.3574, -0.2017, -0.0032],
         [ 0.1047,  0.2048, -0.0505,  ..., -0.3441, -0.0971, -0.0725],
         [ 0.1459,  0.0270,  0.0532,  ..., -0.3920, -0.1251,  0.1291],
         ...,
         [ 0.2061,  0.1866,  0.1231,  ..., -0.4352, -0.1675,  0.1245],
         [ 0.1970,  0.1577,  0.3981,  ..., -0.5813, -0.1088,  0.2193],
         [ 0.1104,  0.1028,  0.0941,  ..., -0.2478, -0.0438,  0.1709]],

        [[ 0.0610,  0.2664,  0.1059,  ..., -0.3174, -0.2440, -0.0053],
         [ 0.0476,  0.0616, -0.0374,  ..., -0.3224, -0.1740,  0.1494],
         [-0.1165,  0.2155,  0.0110,  ..., -0.3376, -0.0808,  0.0198],
         ...,
         [ 0.1682,  0.2178,  0.1478,  ..., -0.4723, -0.1797,  0.1267],
         [ 0.1713,  0.1871,  0.4113,  ..., -0.5911, -0.1432,  0.2329],
         [ 0.0715,  0.1227,  0.1313,  ..., -0.2827, -0.0602,  0.1582]],

        [[ 0.1181,  0.2653,  0.1096,  ..., -0.3692, -0.3136,  0.0072],
         [ 0.1208,  0.1664, -0.0194,  ..., -0

tensor([[[ 1.9250e-01,  3.8072e-01,  1.1791e-01,  ..., -1.0606e-01,
           2.9466e-02,  2.6639e-02],
         [ 1.4175e-01,  3.0009e-01,  4.4127e-02,  ..., -1.9462e-01,
          -1.8417e-01, -8.3322e-02],
         [ 1.4853e-01,  1.2062e-01, -9.0015e-02,  ..., -9.3508e-02,
           2.7332e-02, -3.2882e-01],
         ...,
         [ 2.1124e-01,  1.5213e-01,  2.7659e-02,  ..., -2.2728e-01,
          -8.8903e-02, -1.9994e-01],
         [ 1.5941e-01,  2.2887e-01,  2.1698e-01,  ..., -3.4272e-01,
          -1.4041e-01, -2.5159e-02],
         [ 1.0716e-01,  4.3527e-02,  1.6366e-01,  ..., -6.4023e-02,
          -4.0749e-02, -1.6748e-01]],

        [[ 1.9954e-01,  3.7819e-01,  9.0425e-02,  ..., -1.3831e-01,
           1.5226e-02,  3.8371e-02],
         [ 2.6849e-01,  1.8602e-01,  1.5387e-01,  ..., -2.7568e-01,
          -1.2610e-01, -1.8698e-01],
         [-1.9623e-01,  1.9743e-01,  1.2035e-01,  ..., -2.6188e-01,
           4.1125e-02, -2.4891e-01],
         ...,
         [ 2.1129e-01,  1

tensor([[[ 0.0924,  0.1465,  0.0046,  ..., -0.1630,  0.2289, -0.0701],
         [ 0.3897,  0.1629, -0.1574,  ..., -0.1943, -0.0643, -0.1739],
         [ 0.1649,  0.0010, -0.2404,  ..., -0.1368, -0.1025, -0.0121],
         ...,
         [ 0.2802,  0.1738,  0.0413,  ..., -0.2981, -0.1327, -0.0541],
         [ 0.1759,  0.2015,  0.1050,  ..., -0.3727, -0.0356,  0.0181],
         [ 0.2259,  0.1141,  0.0641,  ..., -0.2274,  0.0787, -0.0452]],

        [[ 0.0969,  0.1654,  0.0149,  ..., -0.1420,  0.1496, -0.0643],
         [ 0.3137,  0.1153,  0.0070,  ..., -0.4477, -0.1981, -0.0208],
         [-0.0456,  0.1014, -0.3548,  ..., -0.1586, -0.0691, -0.1591],
         ...,
         [ 0.2639,  0.2248,  0.0751,  ..., -0.2370, -0.1738, -0.0667],
         [ 0.1897,  0.2469,  0.1170,  ..., -0.3409, -0.1022,  0.0068],
         [ 0.2441,  0.1185,  0.0727,  ..., -0.1965,  0.0519, -0.1163]],

        [[ 0.1218,  0.1968,  0.0277,  ..., -0.1046,  0.1443, -0.0897],
         [ 0.3085,  0.0683,  0.0439,  ..., -0

tensor([[[-5.6560e-02,  1.1504e-01,  3.3937e-02,  ..., -2.0709e-01,
          -7.2420e-02,  3.8503e-02],
         [ 8.1761e-02, -1.8925e-02, -2.3909e-01,  ..., -2.2823e-01,
          -2.1250e-01, -3.6650e-02],
         [ 2.0916e-02, -9.3452e-02, -1.4598e-01,  ..., -1.9704e-01,
          -2.6663e-01,  5.4565e-02],
         ...,
         [ 1.4600e-01, -8.8969e-02, -8.7737e-02,  ..., -2.0590e-01,
          -2.0192e-01,  4.5762e-02],
         [ 1.3105e-01, -1.2731e-01,  1.9155e-01,  ..., -4.2451e-01,
          -9.6739e-02,  9.8420e-02],
         [ 1.5328e-02, -2.1023e-01,  4.6805e-02,  ..., -2.0174e-01,
          -7.8800e-02,  1.1155e-01]],

        [[-1.6898e-02,  1.2959e-01,  7.2008e-02,  ..., -1.6285e-01,
          -1.3976e-01,  1.2982e-01],
         [-1.7620e-02, -5.4740e-02, -5.3956e-03,  ..., -3.8346e-01,
          -3.2089e-01,  9.2270e-02],
         [-9.8762e-02, -1.7658e-01, -1.1264e-01,  ...,  3.5493e-02,
          -2.1656e-01, -1.8720e-02],
         ...,
         [ 1.2236e-01, -1

tensor([[[-0.0335,  0.3167,  0.1128,  ..., -0.2106, -0.2212, -0.3009],
         [ 0.0539,  0.1163,  0.0191,  ..., -0.2629, -0.2133,  0.0382],
         [-0.1397, -0.0717, -0.1933,  ..., -0.0431, -0.0340, -0.1124],
         ...,
         [ 0.1058,  0.0232,  0.1326,  ..., -0.0753, -0.2343, -0.0381],
         [-0.0190,  0.0254,  0.3215,  ..., -0.2841, -0.2951, -0.0381],
         [-0.0471,  0.0071,  0.1459,  ..., -0.0838, -0.0762, -0.0265]],

        [[-0.0648,  0.2984,  0.1103,  ..., -0.2882, -0.2270, -0.2374],
         [ 0.1370, -0.0872, -0.0091,  ..., -0.4345, -0.2785,  0.1081],
         [-0.3426, -0.0864, -0.2467,  ..., -0.2518,  0.0216, -0.0968],
         ...,
         [ 0.0338,  0.0239,  0.1286,  ..., -0.0992, -0.1951, -0.0588],
         [-0.0551,  0.0329,  0.3193,  ..., -0.3192, -0.3166, -0.0271],
         [-0.0537, -0.0561,  0.1442,  ..., -0.1688, -0.0796, -0.0321]],

        [[-0.0611,  0.3398,  0.1184,  ..., -0.1536, -0.1318, -0.2746],
         [ 0.0842,  0.1433, -0.0234,  ..., -0

tensor([[[ 0.0356,  0.2647,  0.0827,  ..., -0.2182, -0.1329,  0.0625],
         [ 0.0578, -0.0375, -0.2203,  ..., -0.2124, -0.0326,  0.1343],
         [-0.1775,  0.0164, -0.3019,  ..., -0.2059, -0.0533,  0.0383],
         ...,
         [ 0.1928, -0.0452, -0.0721,  ..., -0.0992, -0.1154,  0.0262],
         [ 0.1198,  0.1028,  0.1525,  ..., -0.2896, -0.2076,  0.1590],
         [ 0.1257, -0.0692,  0.0418,  ...,  0.0350, -0.0078,  0.1881]],

        [[ 0.0396,  0.2708,  0.1157,  ..., -0.1949, -0.0715,  0.0420],
         [ 0.1889,  0.0656, -0.1013,  ..., -0.2842, -0.1179,  0.1611],
         [-0.2043, -0.0091, -0.2531,  ..., -0.2790, -0.0903,  0.0550],
         ...,
         [ 0.2388, -0.0430, -0.0490,  ..., -0.1399, -0.0731, -0.0027],
         [ 0.1184,  0.0817,  0.1827,  ..., -0.2933, -0.1802,  0.1716],
         [ 0.1422, -0.0563,  0.0632,  ...,  0.0034,  0.0312,  0.1411]],

        [[ 0.0644,  0.2629,  0.1004,  ..., -0.2166, -0.0566,  0.0502],
         [ 0.2001,  0.1382, -0.0732,  ..., -0

tensor([[[-1.3239e-02,  2.3954e-01,  2.0703e-02,  ..., -1.6392e-01,
           3.9570e-02, -2.0134e-01],
         [ 6.1418e-02,  1.7012e-01, -1.8457e-01,  ..., -3.9826e-01,
          -1.6973e-01,  4.7658e-02],
         [-9.8653e-02, -6.9630e-02, -1.4390e-01,  ..., -2.1431e-01,
          -1.2889e-01, -9.4440e-02],
         ...,
         [ 3.6003e-02,  1.3515e-01, -1.2152e-01,  ..., -2.9255e-01,
          -1.5555e-01, -1.2264e-01],
         [ 9.4359e-02,  1.5363e-01,  1.0688e-01,  ..., -5.1170e-01,
          -1.1705e-01, -3.7239e-02],
         [ 5.7133e-04,  4.2848e-02, -8.6165e-02,  ..., -2.9303e-01,
          -5.5852e-02, -1.7272e-01]],

        [[-1.0760e-02,  2.1819e-01, -1.6887e-03,  ..., -1.5049e-01,
          -1.7184e-02, -1.9047e-01],
         [-1.2283e-02,  2.7484e-02, -2.0156e-01,  ..., -1.7833e-01,
          -1.2716e-01,  2.8643e-02],
         [-1.3120e-01,  1.5271e-01, -1.7905e-01,  ..., -3.7366e-01,
          -1.5829e-01, -1.1398e-01],
         ...,
         [ 1.4685e-02,  1

tensor([[[-6.5774e-02,  1.8747e-01,  6.7579e-02,  ..., -2.1956e-01,
          -1.6957e-03,  7.0118e-02],
         [ 3.4895e-04,  4.5233e-02, -1.8076e-01,  ..., -3.0029e-01,
          -2.4027e-02, -3.8537e-02],
         [ 7.6840e-02, -1.0445e-02, -2.1888e-01,  ..., -2.8904e-01,
          -2.7111e-02,  3.0516e-02],
         ...,
         [-2.9442e-02,  5.9976e-02, -4.7252e-02,  ..., -2.4942e-01,
          -1.7941e-01, -4.4764e-02],
         [ 4.6661e-03,  6.8007e-02,  1.9246e-01,  ..., -3.2022e-01,
          -2.1711e-01, -2.9978e-02],
         [-8.2957e-02,  4.2859e-02,  1.5115e-02,  ..., -8.2222e-02,
           9.4001e-02,  8.1546e-02]],

        [[-4.9893e-02,  1.7845e-01,  5.5359e-02,  ..., -2.0940e-01,
           7.5676e-02,  6.9486e-02],
         [ 1.8393e-01, -7.7102e-02, -6.5511e-02,  ..., -4.7635e-01,
          -1.1096e-02,  3.6409e-02],
         [-3.0248e-01, -1.1883e-01, -8.5704e-02,  ..., -2.2810e-01,
           4.5840e-02,  9.3770e-02],
         ...,
         [ 3.0672e-02,  1

tensor([[[ 4.0847e-03,  2.5095e-01,  6.9564e-03,  ..., -1.8289e-01,
          -5.5967e-02,  1.2422e-01],
         [-2.6882e-02,  2.3251e-01, -2.1973e-01,  ..., -3.0688e-01,
          -2.5822e-01,  1.4873e-01],
         [-2.3713e-02, -3.7429e-02, -2.5265e-01,  ..., -3.1824e-01,
          -2.2003e-01, -8.9830e-03],
         ...,
         [ 2.4182e-02,  1.1257e-01,  4.8621e-02,  ..., -2.7176e-01,
          -3.4820e-01,  2.2002e-02],
         [ 2.0974e-02,  1.8874e-01,  2.3569e-01,  ..., -3.6727e-01,
          -2.5686e-01,  4.5463e-02],
         [ 2.1578e-02,  1.1228e-01,  8.9333e-02,  ..., -1.6251e-01,
          -1.2728e-01,  5.7889e-02]],

        [[-6.9854e-03,  2.1396e-01, -2.8339e-02,  ..., -1.7581e-01,
           1.6834e-02,  7.5103e-02],
         [-3.2349e-02,  2.0236e-01, -2.5128e-01,  ..., -1.7128e-01,
          -1.5192e-01,  4.8155e-02],
         [ 3.2116e-02,  3.3441e-02, -1.5878e-01,  ..., -1.8849e-01,
          -2.7037e-01,  7.6179e-02],
         ...,
         [-1.3334e-02,  1

tensor([[[-0.0918,  0.2836,  0.4143,  ..., -0.3979,  0.0159, -0.0614],
         [ 0.1473,  0.0050,  0.3139,  ..., -0.4542, -0.1262, -0.1244],
         [-0.2911,  0.1165,  0.1430,  ..., -0.2324, -0.0808, -0.1236],
         ...,
         [ 0.1012,  0.1242,  0.1800,  ..., -0.3883, -0.0624, -0.1340],
         [ 0.0713,  0.1558,  0.4705,  ..., -0.5997, -0.1664,  0.0043],
         [ 0.0218,  0.0310,  0.3572,  ..., -0.3846,  0.0036, -0.0823]],

        [[-0.0804,  0.3203,  0.3237,  ..., -0.3612,  0.0366, -0.1371],
         [ 0.0973,  0.1511,  0.0520,  ..., -0.2653, -0.1802, -0.1645],
         [ 0.0777,  0.0658,  0.0749,  ..., -0.2267, -0.0708, -0.0788],
         ...,
         [ 0.0691,  0.0858,  0.1982,  ..., -0.4151, -0.1351, -0.0951],
         [ 0.0594,  0.1887,  0.4452,  ..., -0.6060, -0.1658,  0.0600],
         [ 0.0292,  0.0433,  0.3358,  ..., -0.3625, -0.0463, -0.0966]],

        [[-0.0897,  0.3233,  0.3681,  ..., -0.3933,  0.0151, -0.1062],
         [ 0.0769,  0.1122,  0.0904,  ..., -0

tensor([[[ 0.2276,  0.2920,  0.1917,  ..., -0.2947,  0.0321,  0.2067],
         [ 0.2835, -0.0077,  0.0561,  ..., -0.4389, -0.0894,  0.0589],
         [-0.0496,  0.0432, -0.0365,  ..., -0.2939,  0.1177,  0.0465],
         ...,
         [ 0.2903,  0.2270,  0.2114,  ..., -0.2990, -0.0731,  0.0572],
         [ 0.2726,  0.3529,  0.4676,  ..., -0.4742, -0.1617,  0.1092],
         [ 0.2985,  0.1160,  0.3604,  ..., -0.1950,  0.1359, -0.0045]],

        [[ 0.1667,  0.3825,  0.2119,  ..., -0.2864,  0.0669,  0.1957],
         [ 0.2589,  0.1230,  0.3201,  ..., -0.4260, -0.0213, -0.0175],
         [-0.1617,  0.1451,  0.1103,  ..., -0.3485,  0.0504,  0.0285],
         ...,
         [ 0.2911,  0.1936,  0.2756,  ..., -0.2277, -0.0411,  0.1088],
         [ 0.1949,  0.3010,  0.4621,  ..., -0.4327, -0.1411,  0.1447],
         [ 0.1848,  0.0827,  0.3090,  ..., -0.1272,  0.1231,  0.0028]],

        [[ 0.1852,  0.2711,  0.1296,  ..., -0.2947,  0.1285,  0.1930],
         [ 0.2007,  0.1876, -0.0550,  ..., -0

tensor([[[ 0.0616,  0.2594,  0.2343,  ..., -0.4014, -0.1078,  0.0074],
         [ 0.1885,  0.1903,  0.0074,  ..., -0.3321, -0.2521, -0.0501],
         [ 0.0812,  0.0423,  0.0465,  ..., -0.2746, -0.2372,  0.0602],
         ...,
         [ 0.0924,  0.1071,  0.2666,  ..., -0.2092, -0.3271, -0.0367],
         [ 0.0699,  0.2315,  0.5334,  ..., -0.5078, -0.3499,  0.0485],
         [ 0.0876, -0.0361,  0.3131,  ..., -0.2910, -0.1581,  0.0022]],

        [[ 0.1193,  0.2035,  0.2119,  ..., -0.3229, -0.0814,  0.0190],
         [ 0.2201,  0.1576, -0.0180,  ..., -0.2626, -0.2307, -0.0555],
         [ 0.0793, -0.0038,  0.0554,  ..., -0.2266, -0.2426,  0.0805],
         ...,
         [ 0.1004,  0.0393,  0.2512,  ..., -0.2054, -0.3057, -0.0643],
         [ 0.0721,  0.1601,  0.4975,  ..., -0.4177, -0.3355,  0.0345],
         [ 0.0569, -0.0770,  0.2673,  ..., -0.2205, -0.1186, -0.0210]],

        [[ 0.1066,  0.2635,  0.2083,  ..., -0.3906, -0.0571,  0.0101],
         [ 0.1599,  0.0482, -0.0326,  ..., -0

tensor([[[ 0.1308,  0.2217,  0.1642,  ..., -0.2458, -0.2053,  0.1431],
         [ 0.0693,  0.1224,  0.0998,  ..., -0.2725, -0.1569, -0.0136],
         [ 0.1975, -0.1111,  0.1397,  ..., -0.2770, -0.1902,  0.1892],
         ...,
         [ 0.2446,  0.0296,  0.2826,  ..., -0.2691, -0.3952,  0.0209],
         [ 0.3264, -0.0157,  0.4570,  ..., -0.4773, -0.2227,  0.1085],
         [ 0.2209, -0.0976,  0.3511,  ..., -0.2678, -0.2532,  0.0297]],

        [[ 0.1354,  0.1847,  0.1641,  ..., -0.2663, -0.2252,  0.1549],
         [ 0.2322, -0.0728,  0.1770,  ..., -0.3896, -0.3169,  0.1804],
         [-0.1241, -0.1968, -0.1111,  ..., -0.4014, -0.0752,  0.0795],
         ...,
         [ 0.2560,  0.0170,  0.3054,  ..., -0.2949, -0.3782, -0.0079],
         [ 0.3551, -0.0584,  0.4570,  ..., -0.4825, -0.2559,  0.1115],
         [ 0.2763, -0.1231,  0.3491,  ..., -0.2837, -0.2832,  0.0649]],

        [[ 0.1169,  0.1896,  0.0456,  ..., -0.2445, -0.2029,  0.0648],
         [ 0.0588,  0.0541,  0.1169,  ..., -0

tensor([[[-0.0282,  0.0848,  0.1087,  ..., -0.2567, -0.0700,  0.1813],
         [-0.0801, -0.1205, -0.0181,  ..., -0.0809, -0.1879,  0.1099],
         [-0.1451, -0.1585, -0.1890,  ..., -0.2228, -0.2888, -0.0663],
         ...,
         [ 0.1524, -0.0332, -0.0726,  ..., -0.2358, -0.3165,  0.0425],
         [ 0.2404, -0.0885,  0.0781,  ..., -0.4075, -0.2118,  0.1055],
         [ 0.0841, -0.1349,  0.0295,  ..., -0.1883, -0.2708,  0.2462]],

        [[-0.1122,  0.1071,  0.1480,  ..., -0.2824,  0.0021,  0.2349],
         [-0.0407, -0.0421,  0.0856,  ..., -0.2565, -0.2271,  0.0813],
         [-0.1192, -0.0912, -0.1978,  ..., -0.0801, -0.1383,  0.0316],
         ...,
         [ 0.1353, -0.0246, -0.0761,  ..., -0.1870, -0.3300,  0.0764],
         [ 0.2327, -0.0812,  0.1383,  ..., -0.3957, -0.2652,  0.1315],
         [ 0.0374, -0.1400,  0.0643,  ..., -0.1463, -0.2045,  0.2592]],

        [[-0.1151,  0.1407,  0.1130,  ..., -0.2347,  0.0134,  0.2282],
         [ 0.0929,  0.0328,  0.0214,  ..., -0

tensor([[[ 0.1314,  0.2830,  0.2558,  ..., -0.2922,  0.0818,  0.0723],
         [ 0.0711,  0.1715, -0.0169,  ..., -0.3298,  0.0283,  0.1197],
         [-0.1553,  0.1784, -0.0450,  ..., -0.3157, -0.1196,  0.1771],
         ...,
         [ 0.3395,  0.2420,  0.1266,  ..., -0.5233, -0.2744, -0.0070],
         [ 0.4477,  0.2481,  0.3712,  ..., -0.5336, -0.1971,  0.1040],
         [ 0.3420,  0.1284,  0.2734,  ..., -0.3497, -0.1345, -0.0217]],

        [[ 0.1485,  0.1666,  0.1766,  ..., -0.2283,  0.0330,  0.1092],
         [ 0.3057,  0.0684,  0.0770,  ..., -0.4334, -0.2494,  0.2174],
         [-0.1187,  0.1169,  0.0225,  ..., -0.1772,  0.0628,  0.0377],
         ...,
         [ 0.3478,  0.1680,  0.0909,  ..., -0.4516, -0.2075,  0.0358],
         [ 0.4253,  0.1986,  0.3555,  ..., -0.3988, -0.1775,  0.0762],
         [ 0.3604,  0.1056,  0.2456,  ..., -0.2354, -0.1135,  0.0221]],

        [[ 0.1397,  0.3167,  0.2399,  ..., -0.2656,  0.0910,  0.0956],
         [ 0.2048,  0.2839,  0.0604,  ..., -0

tensor([[[-9.2868e-02,  3.0554e-01,  1.0226e-01,  ..., -3.8989e-01,
          -6.6917e-02,  1.7263e-01],
         [ 2.3554e-02,  2.7661e-01, -6.6588e-02,  ..., -3.8477e-01,
          -1.7880e-01,  1.9818e-02],
         [ 1.9718e-02, -1.0733e-01, -3.7347e-02,  ..., -2.2064e-01,
          -1.4323e-01,  1.2583e-01],
         ...,
         [ 8.1571e-02,  5.4404e-02,  1.5817e-01,  ..., -2.8512e-01,
          -2.6250e-01, -1.1456e-02],
         [ 1.7503e-01,  9.7616e-02,  2.5931e-01,  ..., -5.1418e-01,
          -2.1858e-01,  2.3319e-01],
         [ 5.9306e-02,  6.7785e-02,  1.4073e-01,  ..., -2.4027e-01,
          -5.4803e-02,  1.3430e-01]],

        [[-1.2600e-01,  3.3263e-01,  2.1401e-02,  ..., -3.6105e-01,
           9.9689e-03,  1.3380e-01],
         [ 3.2000e-02,  3.3657e-01, -1.2928e-01,  ..., -3.6781e-01,
          -1.3689e-01, -2.3799e-02],
         [ 1.0998e-05, -5.3232e-02, -3.9459e-02,  ..., -1.8428e-01,
          -2.0780e-02,  8.0772e-02],
         ...,
         [ 8.1346e-02,  1

tensor([[[-0.1771,  0.2250,  0.2477,  ..., -0.4530,  0.0978,  0.0432],
         [-0.2399,  0.0422,  0.0350,  ..., -0.3629,  0.0928,  0.0498],
         [-0.4625,  0.1354, -0.0805,  ..., -0.1847,  0.0332, -0.0643],
         ...,
         [-0.0537,  0.0910,  0.2463,  ..., -0.2537,  0.0435,  0.1145],
         [-0.1619,  0.1286,  0.2777,  ..., -0.4216, -0.0816,  0.1520],
         [-0.1426, -0.0162,  0.1162,  ..., -0.2789,  0.1346, -0.0377]],

        [[-0.0613,  0.2641,  0.1588,  ..., -0.5609,  0.0575,  0.0300],
         [ 0.0272,  0.0029,  0.0073,  ..., -0.4312,  0.0193,  0.0690],
         [-0.3541, -0.0263,  0.0686,  ..., -0.4120,  0.0404, -0.0054],
         ...,
         [ 0.0327,  0.1171,  0.2050,  ..., -0.3505,  0.0028,  0.0597],
         [-0.0896,  0.1588,  0.1807,  ..., -0.4828, -0.1211,  0.0728],
         [-0.0458, -0.0200,  0.0854,  ..., -0.3619,  0.0853, -0.1141]],

        [[-0.1196,  0.1911,  0.1636,  ..., -0.4188,  0.1335,  0.0529],
         [ 0.0219, -0.0183,  0.0368,  ..., -0

tensor([[[ 0.0405,  0.0739,  0.0044,  ..., -0.2418, -0.0883,  0.0976],
         [-0.1415, -0.1696, -0.1555,  ..., -0.2784, -0.1589,  0.1546],
         [-0.1923, -0.0481, -0.2344,  ..., -0.3073, -0.1674, -0.0383],
         ...,
         [ 0.1440,  0.0948,  0.1317,  ..., -0.4287, -0.1804,  0.0433],
         [ 0.1117,  0.0488,  0.2142,  ..., -0.5064, -0.1722,  0.1460],
         [ 0.1183, -0.0221, -0.0324,  ..., -0.3500, -0.0878,  0.0858]],

        [[ 0.0920,  0.0128,  0.0303,  ..., -0.2839, -0.0787,  0.0749],
         [ 0.2733, -0.1728, -0.0440,  ..., -0.4082, -0.2048,  0.1032],
         [-0.0642, -0.1191, -0.3048,  ..., -0.4358, -0.0864,  0.0452],
         ...,
         [ 0.1961,  0.0102,  0.0739,  ..., -0.4277, -0.1594,  0.0186],
         [ 0.1465,  0.0319,  0.1532,  ..., -0.5266, -0.1529,  0.0950],
         [ 0.1470, -0.0399, -0.0643,  ..., -0.3551, -0.0679,  0.0731]],

        [[ 0.0508,  0.0417,  0.0208,  ..., -0.2645, -0.0943,  0.0463],
         [ 0.0228, -0.0141, -0.0789,  ..., -0

tensor([[[ 0.0402,  0.0182,  0.2277,  ..., -0.2353,  0.0083,  0.2618],
         [-0.0349,  0.1021,  0.0534,  ..., -0.1706, -0.2458,  0.1909],
         [-0.0746, -0.0616, -0.0038,  ..., -0.0943, -0.0384,  0.0305],
         ...,
         [ 0.0373,  0.0168,  0.0781,  ..., -0.1935, -0.3003, -0.0215],
         [ 0.0972, -0.0014,  0.1623,  ..., -0.3243, -0.1518,  0.0440],
         [-0.0504, -0.0559,  0.1591,  ..., -0.0299, -0.1617,  0.0589]],

        [[-0.0353,  0.0719,  0.2034,  ..., -0.3070,  0.0847,  0.3216],
         [ 0.1228,  0.0122,  0.1144,  ..., -0.2581, -0.1654,  0.1880],
         [-0.3077, -0.1568, -0.2126,  ..., -0.1072, -0.1851,  0.0221],
         ...,
         [-0.0395,  0.0235,  0.0725,  ..., -0.1966, -0.2455,  0.0491],
         [ 0.0211, -0.0036,  0.1819,  ..., -0.4086, -0.1366,  0.1181],
         [-0.0879, -0.0571,  0.1536,  ..., -0.1017, -0.0500,  0.1278]],

        [[ 0.0141,  0.0501,  0.2192,  ..., -0.1563,  0.0323,  0.3107],
         [ 0.0709,  0.1695, -0.0835,  ..., -0

tensor([[[-1.4653e-01,  4.1417e-01,  6.4677e-03,  ..., -1.4338e-01,
           1.2599e-01,  2.1922e-01],
         [ 5.5504e-02,  3.1529e-01, -1.1773e-01,  ..., -1.5186e-01,
           3.1506e-02,  9.5915e-03],
         [ 7.9712e-03,  1.4327e-01, -1.5618e-01,  ..., -3.3633e-02,
           2.7093e-01,  1.1917e-01],
         ...,
         [ 1.9878e-02,  2.5943e-01,  6.0418e-02,  ..., -1.4727e-01,
          -9.4030e-02,  1.2171e-01],
         [-3.0112e-02,  2.6323e-01,  3.0410e-01,  ..., -1.6471e-01,
           6.0322e-03,  1.7699e-01],
         [-4.1230e-02,  1.3085e-01,  2.3422e-01,  ..., -1.5087e-01,
           7.1170e-02,  1.0207e-01]],

        [[-1.8309e-01,  3.7764e-01, -3.9486e-02,  ..., -1.2928e-01,
           1.6479e-01,  1.5790e-01],
         [-1.4623e-01,  1.3140e-01, -1.4046e-01,  ..., -6.8126e-02,
           6.6271e-02,  6.3682e-02],
         [-1.2726e-01,  2.4526e-01, -2.2020e-01,  ..., -1.0418e-01,
           2.2515e-01, -2.4950e-02],
         ...,
         [ 4.5271e-02,  2

tensor([[[-1.3971e-02,  1.6734e-01,  1.4760e-01,  ..., -7.9969e-02,
          -7.4909e-02, -1.7170e-01],
         [ 2.6353e-01,  1.0887e-01, -3.0726e-02,  ..., -2.7440e-01,
          -2.7288e-01,  2.9347e-02],
         [-1.9998e-01,  2.0377e-02, -8.3577e-02,  ..., -1.3478e-01,
          -1.2508e-02, -7.6427e-02],
         ...,
         [ 2.1549e-01,  1.3468e-01,  1.7366e-01,  ..., -2.1289e-01,
          -4.3143e-01,  3.1832e-02],
         [ 2.9547e-01,  2.2665e-01,  2.5475e-01,  ..., -2.3428e-01,
          -2.8201e-01,  2.3576e-02],
         [ 2.8110e-01,  5.4864e-02,  1.2143e-01,  ..., -8.0708e-02,
          -2.1251e-01, -1.7581e-01]],

        [[ 7.2166e-03,  1.8540e-01,  8.0365e-02,  ..., -1.2230e-01,
          -7.8814e-02, -1.7042e-01],
         [ 1.8241e-02,  2.3287e-01,  2.6398e-02,  ..., -4.0637e-01,
          -3.4825e-01, -2.5984e-04],
         [-8.0930e-02, -2.5953e-02, -2.0414e-01,  ..., -8.3779e-02,
          -1.4310e-01, -2.0818e-01],
         ...,
         [ 2.5890e-01,  1

tensor([[[-0.0334,  0.2658,  0.1294,  ..., -0.3646, -0.2184, -0.0817],
         [-0.0615,  0.2708,  0.1594,  ..., -0.3895, -0.3595,  0.0938],
         [ 0.0657,  0.0231, -0.0865,  ..., -0.0925, -0.0177,  0.0054],
         ...,
         [ 0.1372,  0.1231,  0.0412,  ..., -0.2142, -0.3604, -0.0227],
         [ 0.1629,  0.2417,  0.2912,  ..., -0.4163, -0.3327,  0.1296],
         [ 0.1305,  0.1231,  0.1652,  ..., -0.1483, -0.3030,  0.0926]],

        [[-0.0621,  0.2669,  0.1046,  ..., -0.3565, -0.2744, -0.0886],
         [ 0.1222, -0.0402,  0.1264,  ..., -0.3270, -0.4386,  0.0772],
         [-0.2476,  0.0599, -0.1901,  ..., -0.2195, -0.1011, -0.1187],
         ...,
         [ 0.1500,  0.0667, -0.0011,  ..., -0.2079, -0.4292, -0.0957],
         [ 0.1160,  0.2441,  0.2561,  ..., -0.4647, -0.4021,  0.0669],
         [ 0.0974,  0.0554,  0.1861,  ..., -0.1453, -0.3533,  0.0670]],

        [[-0.0329,  0.2514,  0.0600,  ..., -0.3669, -0.2726, -0.1273],
         [ 0.1343,  0.1964,  0.0726,  ..., -0

tensor([[[-0.0624,  0.4178,  0.1188,  ..., -0.2217,  0.0522,  0.1910],
         [-0.1654, -0.0031, -0.0533,  ..., -0.1364, -0.0737, -0.0078],
         [-0.3550,  0.1038, -0.0404,  ..., -0.1816, -0.0869,  0.0301],
         ...,
         [ 0.0594,  0.0524,  0.1749,  ..., -0.2240, -0.2220, -0.1007],
         [ 0.1319,  0.0988,  0.3121,  ..., -0.3675, -0.3067,  0.0526],
         [-0.0865, -0.0565,  0.1548,  ..., -0.1663, -0.0310, -0.0647]],

        [[-0.0731,  0.3666,  0.1064,  ..., -0.2150,  0.0760,  0.2232],
         [-0.1503,  0.1640, -0.0235,  ..., -0.2560, -0.2187,  0.1516],
         [-0.2081,  0.0254, -0.0548,  ..., -0.0150, -0.0073,  0.1054],
         ...,
         [ 0.0695,  0.0493,  0.1688,  ..., -0.2627, -0.2151,  0.0200],
         [ 0.1139,  0.0446,  0.2979,  ..., -0.3740, -0.3184,  0.1438],
         [-0.1155, -0.1050,  0.1441,  ..., -0.1836, -0.0154, -0.0152]],

        [[-0.0476,  0.4083,  0.1394,  ..., -0.2065,  0.0818,  0.1897],
         [ 0.0974, -0.0638, -0.0033,  ..., -0

tensor([[[-1.4417e-01,  2.6365e-01,  2.5368e-01,  ..., -2.5475e-01,
          -1.7271e-02,  1.3302e-01],
         [ 4.2612e-02,  1.7545e-01, -9.1141e-02,  ..., -2.4009e-01,
          -1.0801e-01, -6.6152e-02],
         [ 5.1491e-02, -9.8267e-02, -2.8291e-01,  ..., -2.2033e-01,
          -4.8582e-02,  9.2415e-02],
         ...,
         [-8.7097e-02,  1.4182e-01,  7.8921e-02,  ..., -2.3432e-01,
          -2.8608e-01, -2.4909e-02],
         [ 4.6073e-02,  8.5275e-02,  2.2382e-01,  ..., -2.6793e-01,
          -1.6501e-01,  5.8066e-02],
         [-9.0084e-02,  1.4203e-01,  1.4989e-01,  ..., -1.3411e-01,
          -1.0319e-01,  6.1693e-02]],

        [[-1.3082e-01,  2.3478e-01,  2.2200e-01,  ..., -2.4169e-01,
          -2.8670e-02,  5.0762e-02],
         [-2.3651e-02,  2.7595e-02,  7.3399e-02,  ..., -2.6565e-01,
          -3.4617e-01, -4.1058e-03],
         [-2.9819e-01, -1.2732e-01, -1.0045e-01,  ..., -6.3821e-02,
           6.7644e-02,  4.1613e-02],
         ...,
         [-8.4998e-02,  1

tensor([[[-0.0297,  0.1708,  0.3825,  ..., -0.1184,  0.0157, -0.1317],
         [ 0.2721, -0.0909,  0.2505,  ..., -0.2795, -0.1713, -0.0327],
         [-0.2040, -0.1758, -0.0304,  ..., -0.2477,  0.0096, -0.1276],
         ...,
         [ 0.1037,  0.0358,  0.2597,  ..., -0.1078, -0.2125, -0.0411],
         [ 0.1731,  0.0440,  0.4619,  ..., -0.2471, -0.0451,  0.0098],
         [-0.0355, -0.1314,  0.4633,  ...,  0.0693, -0.0266, -0.0248]],

        [[ 0.0190,  0.1686,  0.3185,  ..., -0.1305, -0.0085, -0.0573],
         [ 0.2615, -0.1131,  0.3118,  ..., -0.3475, -0.1727, -0.0536],
         [-0.1379, -0.2416,  0.1280,  ..., -0.2799,  0.0328, -0.0084],
         ...,
         [ 0.2342,  0.0588,  0.1746,  ..., -0.1636, -0.3627, -0.0339],
         [ 0.2835, -0.0010,  0.4205,  ..., -0.2834, -0.0864,  0.0534],
         [ 0.1415, -0.1580,  0.4150,  ..., -0.0342, -0.0767,  0.0498]],

        [[-0.0208,  0.1928,  0.3636,  ..., -0.1637, -0.0315, -0.0530],
         [ 0.0970,  0.0218,  0.2331,  ..., -0

In [7]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]
print(text)
print('================================')
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
Where are you going today?
I am going shopping. What about you?
I am going to visit my grandmother. she is not very well
['[CLS]', 'oh', 'congratulations', 'juliet', '[SEP]', 'hello', 'how', '[MASK]', 'you', 'i', 'am', 'romeo', '[SEP]']
tensor([[[ 0.1184,  0.3888,  0.2697,  ..., -0.3624, -0.3447, -0.0079],
         [ 0.1363,  0.1355,  0.0647,  ..., -0.3423, -0.4019, -0.0362],
         [ 0.0880,  0.0470,  0.0516,  ..., -0.3133, -0.1603,  0.1373],
         ...,
         [ 0.1700,  0.1891,  0.1666,  ..., -0.4446, -0.3746,  0.0226],
         [ 0.2080,  0.2752,  0.3133,  ..., -0.5361, -0.3782,  0.0177],
         [ 0.3132,  0.0837,  0.2141,  ..., -0.2997, -0.3266, -0.0464]]],
       grad_fn=<AddBackward0>)
tensor([[[7, 7, 7,  ..., 7, 7, 7],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [8]:
#应用量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(40, 768)
    (pos_embed): Embedding(30, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0): EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_K): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_V): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): DynamicQuantizedLinear(in_features=768, out_features=3072, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (fc2): DynamicQuantizedLinear(in_features=3072, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
    )
    (1): Enc

In [14]:
#输出对比测试结果
import os
import time

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def print_time_of_model(model):    
    eval_start_time = time.time()
    model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

    
print_size_of_model(model)
print_size_of_model(quantized_model)

print_time_of_model(model)
print_time_of_model(quantized_model)

Size (MB): 160.844839
Size (MB): 40.572431
tensor([[[-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [ 0.1013,  0.3634,  0.0050,  ..., -0.1247, -0.0733, -0.1555],
         [-0.0312,  0.2100, -0.0643,  ..., -0.1972, -0.0326, -0.1339],
         ...,
         [-0.0090,  0.1465,  0.0373,  ..., -0.2756, -0.2431, -0.1753],
         [-0.0873,  0.1474,  0.2928,  ..., -0.3138, -0.1397, -0.0847],
         [ 0.0250,  0.1036,  0.1591,  ..., -0.1132, -0.0359, -0.0897]]],
       grad_fn=<AddBackward0>)
tensor([[[7, 7, 7,  ..., 7, 7, 7],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]])
tensor([[[ 0.0538,  0.1581,  0.0941,  ..., -0.4250,  0.0572,  0.0348],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3