In [22]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
         https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

token_list = list()
for sentence in sentences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)


In [23]:
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [24]:
# sample IsNext and NotNext to be same in small batch size
def make_data():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # MASK LM
        n_pred =  min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word2idx['[MASK]'] # make mask
            elif random() > 0.9:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
                  index = randint(0, vocab_size - 1)
                input_ids[pos] = index # replace

        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch
# Proprecessing Finished

In [25]:
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)

class MyDataSet(Data.Dataset):
  def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
    self.input_ids = input_ids
    self.segment_ids = segment_ids
    self.masked_tokens = masked_tokens
    self.masked_pos = masked_pos
    self.isNext = isNext
  
  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx):
    return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]

loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

In [26]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # [batch_size, 1, seq_len]
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]

def gelu(x):
    """
      Implementation of the gelu activation function.
      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
      Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # [seq_len] -> [batch_size, seq_len]
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]
        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size, seq_len, n_heads, d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual) # output: [batch_size, seq_len, d_model]

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.fc2(gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]
        return enc_outputs

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]
        for layer in self.layers:
            # output: [batch_size, max_len, d_model]
            output = layer(output, enc_self_attn_mask)
        # it will be decided by first token(CLS)
        h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
        print(output)
        print(masked_pos)
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        print(h_masked)
        h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]
        logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]
        return logits_lm, logits_clsf
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)

In [27]:
for epoch in range(50):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
      logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
      loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
      loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

tensor([[[-0.0461, -0.2901, -0.1020,  ...,  0.1810, -0.0747,  0.0984],
         [-0.2309, -0.0934,  0.0122,  ..., -0.1251, -0.0738,  0.3697],
         [-0.1348, -0.1270,  0.2239,  ..., -0.0139,  0.1371,  0.1689],
         ...,
         [-0.2473, -0.2075,  0.0041,  ...,  0.1354,  0.2070,  0.0684],
         [-0.1707, -0.2790,  0.0476,  ..., -0.0213,  0.2834,  0.1874],
         [-0.3255, -0.1662,  0.1121,  ..., -0.1670,  0.1692,  0.3016]],

        [[-0.0392, -0.2737, -0.1117,  ...,  0.1702, -0.1179,  0.1165],
         [-0.1885,  0.0943, -0.0481,  ...,  0.0300,  0.0258,  0.1541],
         [-0.1121, -0.1045,  0.2073,  ..., -0.0204,  0.1076,  0.1771],
         ...,
         [-0.2069, -0.2072,  0.0065,  ...,  0.1378,  0.1939,  0.0895],
         [-0.1518, -0.2692,  0.0306,  ..., -0.0203,  0.2461,  0.2113],
         [-0.2781, -0.1743,  0.1078,  ..., -0.1679,  0.1411,  0.3257]],

        [[ 0.0120, -0.2599, -0.1031,  ...,  0.1863,  0.0013,  0.1489],
         [-0.0213, -0.0856, -0.0210,  ..., -0

tensor([[[ 0.2810, -0.1347, -0.0684,  ...,  0.3249, -0.2215, -0.1195],
         [ 0.0332, -0.0402, -0.0259,  ...,  0.1830,  0.0384, -0.1760],
         [-0.0414,  0.1846,  0.1344,  ...,  0.3789, -0.0663, -0.1828],
         ...,
         [ 0.0844, -0.2958,  0.0397,  ...,  0.2771, -0.0155,  0.0939],
         [ 0.1056, -0.2982, -0.0060,  ...,  0.2607, -0.0415,  0.0853],
         [ 0.0727, -0.2041,  0.1072,  ...,  0.1801, -0.0342,  0.0268]],

        [[ 0.2299, -0.1320, -0.1082,  ...,  0.3474, -0.2341, -0.1486],
         [ 0.0183,  0.2185, -0.0655,  ...,  0.3073, -0.1772,  0.0225],
         [ 0.0299,  0.1163,  0.0696,  ...,  0.1375, -0.0675, -0.0842],
         ...,
         [ 0.0428, -0.2446,  0.0136,  ...,  0.2773, -0.0250,  0.0978],
         [ 0.0817, -0.2666, -0.0102,  ...,  0.2882, -0.0539,  0.0732],
         [ 0.0204, -0.1347,  0.1165,  ...,  0.1858, -0.0477,  0.0407]],

        [[ 0.2681, -0.1451, -0.0621,  ...,  0.3635, -0.1909, -0.1400],
         [ 0.0072,  0.0801, -0.0250,  ...,  0

tensor([[[ 1.6471e-01, -1.7645e-01, -9.5574e-02,  ...,  3.7096e-01,
          -1.8346e-02, -1.0064e-01],
         [ 5.1482e-02,  1.4884e-01, -6.0363e-02,  ...,  2.1205e-01,
          -1.5063e-01, -2.7773e-04],
         [ 1.8296e-01,  2.8990e-02, -2.0156e-01,  ...,  2.7623e-01,
          -2.5822e-01,  1.0056e-01],
         ...,
         [-4.7566e-02, -1.5232e-01, -1.3458e-01,  ...,  4.0214e-01,
           1.3528e-02, -7.9288e-02],
         [-1.1112e-03, -2.1361e-01, -2.9515e-02,  ...,  2.5617e-01,
           1.1004e-01, -2.9835e-02],
         [-1.1273e-01, -7.2034e-03,  3.6254e-02,  ...,  1.1367e-01,
          -6.0628e-02, -6.7748e-03]],

        [[ 1.6090e-01, -1.7999e-01, -1.4777e-01,  ...,  3.3907e-01,
          -3.0355e-02, -7.3631e-02],
         [-3.0381e-02,  1.8582e-01, -2.5466e-01,  ...,  3.0319e-01,
          -1.4159e-01, -1.0051e-01],
         [ 3.4373e-02, -4.7270e-02,  3.0942e-02,  ...,  1.4420e-01,
          -1.0437e-01,  4.0592e-03],
         ...,
         [-1.8778e-02, -1

tensor([[[-6.1724e-02, -2.8791e-01, -6.4726e-02,  ...,  3.6025e-01,
          -2.0493e-01,  4.1731e-02],
         [ 1.3378e-02, -9.7123e-02,  2.0705e-02,  ...,  1.6234e-01,
           4.8093e-02, -1.8148e-01],
         [-1.3234e-01, -1.0775e-01,  9.3771e-02,  ...,  3.9179e-01,
          -2.2668e-02, -5.5853e-02],
         ...,
         [ 9.3192e-02, -1.8345e-01,  1.1060e-01,  ...,  5.5736e-01,
           2.1722e-02,  6.9369e-02],
         [-3.5237e-02, -1.9385e-01,  3.1466e-02,  ...,  4.6411e-01,
           2.6132e-02, -4.4617e-02],
         [-2.3722e-02, -8.0524e-02,  1.6880e-01,  ...,  3.0429e-01,
           9.1576e-02,  5.4797e-03]],

        [[ 6.7057e-03, -3.1623e-01, -4.1325e-02,  ...,  3.6815e-01,
          -2.3378e-01,  1.3721e-02],
         [ 8.8644e-02,  1.7966e-01,  3.2649e-02,  ...,  2.4582e-01,
          -2.4653e-01,  5.8716e-02],
         [ 3.7146e-02, -2.0139e-01,  4.5349e-02,  ...,  3.3639e-01,
          -7.3543e-03, -3.2903e-02],
         ...,
         [ 1.6376e-01, -1

tensor([[[ 8.9678e-02, -2.5186e-01, -2.8169e-02,  ...,  9.9845e-02,
          -7.9158e-02, -7.4135e-02],
         [ 3.4517e-02, -1.6770e-02,  6.8779e-02,  ...,  7.4081e-02,
           2.6150e-02, -1.9396e-01],
         [ 5.5029e-02,  9.8692e-03,  1.2726e-01,  ...,  2.4975e-01,
           5.3833e-02, -2.5595e-02],
         ...,
         [ 4.4052e-02, -4.0955e-01,  1.1662e-01,  ...,  2.0502e-01,
           2.0751e-01,  5.9564e-02],
         [ 2.2845e-03, -2.7265e-01,  6.0587e-02,  ...,  1.5735e-01,
           1.6453e-01,  1.1029e-01],
         [ 1.0935e-01, -2.0995e-01,  1.1345e-01,  ...,  6.8103e-02,
           4.9308e-03,  1.3388e-01]],

        [[ 1.7018e-01, -2.6636e-01, -3.8279e-02,  ...,  1.4632e-01,
          -6.7554e-02, -3.0696e-02],
         [-8.1935e-02, -5.8233e-02,  1.0698e-01,  ...,  1.1281e-01,
          -1.8870e-01,  7.2502e-02],
         [ 1.7452e-01, -1.4618e-01,  1.8228e-01,  ...,  2.0147e-01,
          -3.7581e-04,  1.4779e-01],
         ...,
         [ 2.9345e-02, -4

tensor([[[ 1.7357e-01, -2.9550e-01,  1.3575e-01,  ...,  3.7634e-01,
          -6.7866e-03, -5.5596e-02],
         [ 4.3886e-02,  1.3712e-01,  2.0952e-01,  ...,  1.4929e-01,
          -8.9045e-02, -1.3790e-02],
         [ 3.2564e-02, -1.1868e-01,  2.6791e-01,  ...,  1.7263e-01,
           8.7725e-02, -3.9332e-03],
         ...,
         [ 4.0239e-02, -2.8971e-01,  2.7828e-01,  ...,  2.3615e-01,
           1.7392e-01, -4.4211e-02],
         [ 7.3958e-02, -2.3421e-01,  1.9842e-01,  ...,  3.5337e-01,
           1.2364e-01, -1.0852e-01],
         [-6.7200e-03, -1.2722e-01,  3.0498e-01,  ..., -1.3301e-02,
           1.7779e-01,  6.7034e-03]],

        [[ 1.8720e-01, -2.9156e-01,  1.7957e-01,  ...,  4.1547e-01,
          -7.3709e-03, -1.1419e-01],
         [ 1.7742e-01,  8.7827e-02,  2.2484e-01,  ...,  2.7492e-01,
          -1.2166e-01, -1.2069e-01],
         [-1.1784e-01, -2.4216e-01,  1.4255e-01,  ...,  2.5134e-01,
          -6.8290e-02,  1.1047e-01],
         ...,
         [ 4.8697e-02, -2

tensor([[[ 0.1825, -0.0640,  0.0134,  ...,  0.1569, -0.1042, -0.2014],
         [-0.0010,  0.1064,  0.1306,  ...,  0.0044,  0.0958, -0.4034],
         [ 0.0506,  0.0367,  0.0064,  ...,  0.2281, -0.0896, -0.0492],
         ...,
         [ 0.0803, -0.0012, -0.0544,  ...,  0.3536,  0.0918, -0.0713],
         [-0.0780, -0.2012,  0.0471,  ...,  0.2117,  0.0215, -0.1117],
         [-0.0830, -0.0743,  0.0434,  ...,  0.1161, -0.1015,  0.0958]],

        [[ 0.1644, -0.0514, -0.0036,  ...,  0.2023, -0.1688, -0.2021],
         [ 0.0239,  0.1416,  0.1264,  ...,  0.0381,  0.0615, -0.4043],
         [ 0.0716,  0.0586,  0.0026,  ...,  0.2450, -0.1425, -0.0419],
         ...,
         [ 0.0765, -0.0027, -0.0634,  ...,  0.3543,  0.0667, -0.0548],
         [-0.0753, -0.1696,  0.0299,  ...,  0.2414,  0.0096, -0.1333],
         [-0.0729, -0.0512,  0.0073,  ...,  0.1266, -0.1062,  0.0743]],

        [[ 0.0919,  0.0315,  0.0019,  ...,  0.1568, -0.1863, -0.1417],
         [-0.0484,  0.4164, -0.0452,  ...,  0

tensor([[[ 0.0400, -0.0847,  0.0813,  ...,  0.3402,  0.0160, -0.1803],
         [-0.0681,  0.1779,  0.0367,  ...,  0.3487, -0.0699, -0.1988],
         [-0.0052, -0.1771,  0.0204,  ...,  0.5076, -0.0860,  0.0299],
         ...,
         [-0.0186, -0.0211,  0.1272,  ...,  0.5351,  0.0665, -0.0679],
         [-0.0709, -0.1163,  0.0905,  ...,  0.3868,  0.0262, -0.0652],
         [-0.0777,  0.0731,  0.1799,  ...,  0.2215, -0.0198, -0.0704]],

        [[ 0.1011, -0.1111,  0.0960,  ...,  0.3360,  0.0184, -0.1337],
         [-0.2030, -0.0232,  0.0009,  ...,  0.3006, -0.0451, -0.3206],
         [-0.1634, -0.0448,  0.3347,  ...,  0.4551,  0.0463, -0.0718],
         ...,
         [-0.0186, -0.0265,  0.1778,  ...,  0.5340,  0.0610, -0.0348],
         [-0.0804, -0.0825,  0.1298,  ...,  0.3828,  0.0523, -0.0065],
         [-0.0930,  0.0769,  0.1944,  ...,  0.2238,  0.0025, -0.0257]],

        [[ 0.0749, -0.1115,  0.0669,  ...,  0.3637,  0.0206, -0.1853],
         [-0.0147,  0.2005,  0.1444,  ...,  0

tensor([[[ 0.1176, -0.2827,  0.1669,  ...,  0.1896,  0.0168,  0.0143],
         [-0.1059,  0.0276,  0.1259,  ...,  0.2118, -0.0663,  0.0622],
         [-0.2267, -0.2648,  0.2848,  ...,  0.2018,  0.1341, -0.0800],
         ...,
         [-0.0083, -0.2543,  0.1826,  ...,  0.3068,  0.1999,  0.1472],
         [-0.0288, -0.2363,  0.1812,  ...,  0.1513,  0.0939,  0.1191],
         [-0.1216, -0.0371,  0.2116,  ...,  0.0431,  0.0404,  0.1600]],

        [[ 0.1181, -0.2245,  0.1084,  ...,  0.1985,  0.1122, -0.0280],
         [-0.0639,  0.0172,  0.1159,  ...,  0.1983, -0.0292, -0.0010],
         [ 0.0507, -0.0969, -0.0078,  ...,  0.2825, -0.0546,  0.0016],
         ...,
         [ 0.0523, -0.1742,  0.1504,  ...,  0.3305,  0.2324,  0.1372],
         [ 0.0254, -0.1579,  0.1081,  ...,  0.1638,  0.1011,  0.1339],
         [-0.1092,  0.0626,  0.2024,  ...,  0.0782,  0.0515,  0.1465]],

        [[ 0.1378, -0.2551,  0.0905,  ...,  0.1524,  0.0178, -0.0692],
         [-0.1604, -0.0411,  0.0809,  ...,  0

tensor([[[ 1.4662e-01, -4.4133e-02, -7.7793e-02,  ...,  2.8879e-01,
           6.7084e-02, -8.0186e-02],
         [-3.4942e-02,  2.7230e-01, -1.4950e-01,  ...,  4.2539e-01,
          -6.1032e-02, -1.3525e-01],
         [ 2.1102e-02, -1.4729e-01,  1.0306e-01,  ...,  1.8936e-01,
           2.6822e-01, -9.9433e-02],
         ...,
         [ 1.4838e-02,  5.7254e-03, -1.1201e-01,  ...,  4.4065e-01,
           2.2920e-01,  7.2608e-02],
         [ 6.5957e-02,  3.3453e-02, -5.4109e-02,  ...,  3.9790e-01,
           2.3709e-01,  4.8336e-02],
         [ 2.1851e-02,  3.9272e-02,  2.7555e-02,  ...,  1.8518e-01,
           3.2522e-01,  4.8216e-02]],

        [[ 1.4164e-01, -7.5752e-02, -7.9971e-02,  ...,  2.7583e-01,
           5.0496e-02, -5.6417e-02],
         [-9.1229e-02,  2.3692e-02, -3.6880e-02,  ...,  2.6362e-01,
          -1.2706e-01,  1.0040e-02],
         [ 2.2926e-02, -1.4497e-01,  8.4289e-02,  ...,  2.1585e-01,
           2.5057e-01, -7.7204e-02],
         ...,
         [ 1.3544e-02, -1

tensor([[[-0.1363, -0.1307, -0.1649,  ...,  0.1312, -0.0693,  0.1049],
         [-0.1671,  0.0697, -0.0088,  ...,  0.1483,  0.0277, -0.1915],
         [-0.1154, -0.0832,  0.2040,  ...,  0.3231, -0.0652, -0.0377],
         ...,
         [-0.2810, -0.3180,  0.0676,  ...,  0.3814,  0.1425,  0.1331],
         [-0.1892, -0.1858,  0.0885,  ...,  0.3995,  0.0334,  0.0792],
         [-0.1770, -0.2766,  0.1427,  ...,  0.2035,  0.1439,  0.1314]],

        [[-0.1090, -0.1676, -0.2086,  ...,  0.1512, -0.1176,  0.1404],
         [-0.1973,  0.2140, -0.1245,  ...,  0.4593, -0.1481,  0.1852],
         [-0.0738, -0.1412,  0.1208,  ..., -0.0299, -0.0147,  0.0236],
         ...,
         [-0.2200, -0.3626, -0.0047,  ...,  0.3821,  0.0673,  0.1615],
         [-0.1440, -0.2502,  0.0149,  ...,  0.4221,  0.0033,  0.0938],
         [-0.1503, -0.3239,  0.0478,  ...,  0.2234,  0.0764,  0.1282]],

        [[-0.1791, -0.1791, -0.1776,  ...,  0.1368, -0.1075,  0.1206],
         [-0.1626,  0.0117, -0.0055,  ...,  0

tensor([[[ 3.8013e-01, -2.2707e-01,  9.4579e-02,  ...,  2.6360e-01,
           8.2759e-02, -1.2340e-02],
         [ 1.1191e-01,  1.2898e-01,  1.3921e-01,  ...,  2.3688e-01,
           1.2009e-01, -3.9988e-01],
         [ 1.3353e-01,  1.1277e-01,  1.4758e-01,  ...,  3.1040e-01,
          -2.7943e-02, -1.0512e-01],
         ...,
         [ 2.6814e-01, -1.6198e-01,  7.8005e-02,  ...,  4.0076e-01,
           1.8446e-01, -5.8179e-02],
         [ 1.7931e-01, -1.2327e-01,  7.8945e-03,  ...,  2.6512e-01,
           1.5714e-01, -1.3634e-01],
         [ 9.8388e-02, -4.3411e-02,  6.5361e-02,  ...,  2.2515e-01,
           1.8838e-01,  5.2162e-02]],

        [[ 4.0459e-01, -2.2932e-01,  7.6914e-02,  ...,  2.8284e-01,
           1.4502e-01, -3.0298e-02],
         [ 9.9764e-02,  1.3485e-01,  1.2315e-01,  ...,  2.6237e-01,
           1.5858e-01, -3.8685e-01],
         [ 1.4834e-01,  1.4381e-01,  1.3167e-01,  ...,  3.1880e-01,
          -1.1247e-02, -9.6369e-02],
         ...,
         [ 2.4995e-01, -1

tensor([[[ 0.0557, -0.2670,  0.0739,  ...,  0.2070,  0.0118,  0.0615],
         [ 0.0419, -0.0804,  0.1410,  ...,  0.1239,  0.1898, -0.4000],
         [-0.1864,  0.0343,  0.0905,  ...,  0.2138,  0.0145, -0.0139],
         ...,
         [-0.1795, -0.2522,  0.0864,  ...,  0.4193,  0.1196,  0.0272],
         [-0.0630, -0.2292, -0.0220,  ...,  0.1935,  0.0920, -0.0304],
         [-0.2206, -0.1135,  0.1375,  ...,  0.1082,  0.0821, -0.0625]],

        [[ 0.0361, -0.2602,  0.0369,  ...,  0.2390,  0.0453,  0.0237],
         [-0.1499,  0.0282,  0.0958,  ...,  0.2374, -0.1159, -0.1387],
         [ 0.0326, -0.0402, -0.2177,  ...,  0.2232, -0.1002,  0.1132],
         ...,
         [-0.1895, -0.2511,  0.0742,  ...,  0.4680,  0.1389,  0.0057],
         [-0.0717, -0.2001, -0.0495,  ...,  0.2498,  0.0934, -0.0937],
         [-0.2520, -0.0883,  0.1146,  ...,  0.1754,  0.0831, -0.0877]],

        [[ 0.0444, -0.3380,  0.0403,  ...,  0.1680,  0.0575,  0.0107],
         [-0.0221,  0.0990, -0.0509,  ...,  0

tensor([[[ 0.2040, -0.1030,  0.0427,  ...,  0.3409, -0.0773, -0.2375],
         [-0.0811,  0.1993, -0.0228,  ...,  0.5659, -0.1991, -0.1796],
         [ 0.2053,  0.0223,  0.0996,  ...,  0.0763,  0.1032, -0.2128],
         ...,
         [ 0.1236, -0.0662, -0.0057,  ...,  0.3128,  0.0216, -0.2247],
         [ 0.1237, -0.1951,  0.0328,  ...,  0.2801,  0.0429, -0.1975],
         [ 0.1669, -0.0024, -0.0615,  ...,  0.1524, -0.1067, -0.1605]],

        [[ 0.1898, -0.0711,  0.0007,  ...,  0.3373, -0.1093, -0.2130],
         [ 0.0987,  0.0733,  0.1305,  ...,  0.2088,  0.0497, -0.4428],
         [-0.0332,  0.0988,  0.1210,  ...,  0.2522,  0.0172, -0.2311],
         ...,
         [ 0.1722, -0.0773, -0.0713,  ...,  0.2492,  0.0577, -0.2500],
         [ 0.1275, -0.1936, -0.0139,  ...,  0.2781,  0.0909, -0.2728],
         [ 0.1306,  0.0314, -0.1551,  ...,  0.1295, -0.1302, -0.1794]],

        [[ 0.1996, -0.0885,  0.0599,  ...,  0.4061, -0.0657, -0.2547],
         [ 0.1450,  0.1044,  0.0297,  ...,  0

tensor([[[ 0.2354, -0.1202,  0.1172,  ...,  0.3016, -0.0331, -0.1696],
         [ 0.0855,  0.0487,  0.1950,  ...,  0.0929,  0.0450, -0.3013],
         [-0.0352, -0.1788,  0.2131,  ...,  0.3164, -0.0980, -0.0013],
         ...,
         [ 0.1086, -0.0936,  0.1696,  ...,  0.3554,  0.1873, -0.0814],
         [-0.0037, -0.2362,  0.1606,  ...,  0.4148,  0.1161,  0.0217],
         [-0.0521, -0.0907,  0.2212,  ...,  0.0505,  0.0038,  0.0307]],

        [[ 0.2535, -0.0502,  0.1179,  ...,  0.2555, -0.0493, -0.1591],
         [ 0.0696,  0.0938,  0.1993,  ...,  0.1208,  0.0365, -0.2911],
         [-0.0183, -0.1006,  0.2357,  ...,  0.3097, -0.1237,  0.0443],
         ...,
         [ 0.1131, -0.0325,  0.1550,  ...,  0.3389,  0.1750, -0.0804],
         [-0.0144, -0.1848,  0.1553,  ...,  0.4005,  0.1075,  0.0124],
         [-0.0268, -0.0357,  0.2209,  ...,  0.0500, -0.0168,  0.0463]],

        [[ 0.2350, -0.1092,  0.1143,  ...,  0.2933, -0.0590, -0.1311],
         [-0.0093, -0.0692,  0.1918,  ...,  0

tensor([[[ 0.1628, -0.1870, -0.0570,  ...,  0.0825, -0.0364, -0.0112],
         [ 0.1109, -0.0131,  0.1054,  ..., -0.1368,  0.0896, -0.4806],
         [-0.1092, -0.0032,  0.1857,  ...,  0.2669,  0.0065, -0.0879],
         ...,
         [-0.0294, -0.0957,  0.0060,  ...,  0.2639,  0.1623, -0.0016],
         [-0.0474, -0.1669,  0.0205,  ...,  0.2054,  0.0392,  0.0495],
         [-0.1478, -0.0237, -0.0273,  ...,  0.0147, -0.0085,  0.0432]],

        [[ 0.2161, -0.1928, -0.0376,  ...,  0.0347, -0.0166, -0.0360],
         [-0.1052,  0.2297, -0.0939,  ...,  0.1365, -0.1665, -0.1777],
         [ 0.0101, -0.1258,  0.3756,  ...,  0.0193,  0.0642, -0.1839],
         ...,
         [-0.0329, -0.1200,  0.0138,  ...,  0.2552,  0.2084,  0.0025],
         [-0.0244, -0.1990,  0.0763,  ...,  0.1810,  0.0805,  0.0540],
         [-0.1372, -0.0326,  0.0732,  ..., -0.0008, -0.0157,  0.0266]],

        [[ 0.1861, -0.2047, -0.0633,  ...,  0.0067, -0.0282, -0.1160],
         [ 0.0215,  0.0131,  0.1062,  ..., -0

tensor([[[ 0.0131, -0.2125,  0.0091,  ...,  0.2986,  0.1077, -0.1338],
         [-0.0789, -0.0500,  0.1958,  ...,  0.0870, -0.0508, -0.1801],
         [-0.0421, -0.2051, -0.0711,  ...,  0.2772,  0.0718, -0.0640],
         ...,
         [-0.0771, -0.3633,  0.0722,  ...,  0.3361,  0.1697, -0.1658],
         [-0.1348, -0.2911,  0.0890,  ...,  0.2242,  0.0554, -0.2045],
         [-0.1622, -0.0345,  0.2453,  ...,  0.2572,  0.1158, -0.1127]],

        [[-0.0220, -0.2108,  0.0160,  ...,  0.3400,  0.0688, -0.0805],
         [-0.0927, -0.0358,  0.2108,  ...,  0.0712, -0.0394, -0.2325],
         [-0.0205, -0.1593,  0.2562,  ...,  0.2406,  0.1350, -0.2254],
         ...,
         [-0.1046, -0.2648,  0.0899,  ...,  0.2911,  0.1417, -0.1302],
         [-0.1586, -0.2421,  0.0997,  ...,  0.2119,  0.0903, -0.1388],
         [-0.1376,  0.0227,  0.2881,  ...,  0.2344,  0.0829, -0.0682]],

        [[-0.0159, -0.2562,  0.0285,  ...,  0.3104,  0.0518, -0.0727],
         [-0.0429,  0.2574,  0.0683,  ...,  0

tensor([[[ 1.7051e-01, -2.6279e-01,  2.1660e-01,  ...,  3.5423e-01,
          -9.0251e-02,  1.9767e-03],
         [ 7.9921e-02, -1.2902e-02,  1.4512e-01,  ...,  2.7553e-01,
          -1.7157e-01,  5.4638e-02],
         [ 2.7195e-01, -2.2567e-01, -1.3566e-02,  ...,  1.6457e-01,
           4.4415e-02,  2.1649e-02],
         ...,
         [ 9.2497e-02, -1.1371e-01,  3.9809e-02,  ...,  3.5339e-01,
           1.5106e-01,  1.8507e-01],
         [ 2.3677e-02, -1.2965e-01,  6.5483e-02,  ...,  2.7617e-01,
           1.1787e-02,  2.2128e-01],
         [ 7.4142e-02, -8.2841e-02,  1.8440e-01,  ...,  1.6562e-01,
           1.0400e-01,  1.8421e-01]],

        [[ 2.0133e-01, -2.9227e-01,  2.0254e-01,  ...,  2.8965e-01,
          -5.8761e-02,  1.6785e-02],
         [ 8.8508e-02, -6.3810e-02,  6.1299e-02,  ...,  2.1797e-01,
           3.4030e-02, -1.8094e-01],
         [ 2.1309e-02, -1.1997e-01,  9.9054e-02,  ...,  2.8925e-01,
           5.6529e-02, -1.0306e-01],
         ...,
         [ 1.0418e-01, -1

tensor([[[ 1.1488e-01, -1.4879e-01, -9.3396e-02,  ...,  2.3809e-01,
           1.4074e-01, -1.1980e-01],
         [ 7.5277e-02,  1.4661e-01, -3.3355e-03,  ...,  2.8097e-01,
           2.6748e-01, -1.1210e-01],
         [-2.5688e-02,  2.7545e-02,  6.0435e-02,  ...,  3.4906e-01,
           1.6923e-01, -5.1644e-02],
         ...,
         [ 6.7879e-02, -2.1438e-01,  5.7661e-02,  ...,  4.2972e-01,
           4.0509e-01,  3.8023e-02],
         [-1.1591e-01, -2.0284e-01,  1.0505e-01,  ...,  3.7937e-01,
           2.5967e-01,  4.2389e-02],
         [-1.7524e-02, -1.6894e-01,  1.5422e-01,  ...,  3.0477e-01,
           2.7564e-01,  1.2558e-02]],

        [[ 1.2633e-01, -1.7224e-01, -9.0552e-02,  ...,  2.5724e-01,
           1.9922e-01, -1.0260e-01],
         [ 4.0688e-02,  6.1843e-02, -1.0902e-01,  ...,  2.9293e-01,
          -3.0755e-02, -3.1096e-03],
         [ 1.8011e-01, -8.3832e-02, -1.5109e-01,  ...,  3.7774e-01,
           2.2442e-01,  1.2185e-01],
         ...,
         [ 6.9943e-02, -1

tensor([[[ 0.0779, -0.1027,  0.1024,  ...,  0.1389, -0.0114,  0.1035],
         [-0.1351,  0.2120,  0.1087,  ..., -0.0443,  0.0337, -0.1463],
         [-0.2475, -0.0215,  0.1035,  ...,  0.1812, -0.1321,  0.1440],
         ...,
         [-0.0407, -0.0608,  0.0460,  ...,  0.3082,  0.0858, -0.0593],
         [-0.1242, -0.0162,  0.0491,  ...,  0.1973,  0.1327,  0.0434],
         [-0.0891,  0.1140,  0.1205,  ...,  0.0915,  0.0262,  0.0951]],

        [[ 0.1213, -0.0849,  0.1241,  ...,  0.2300, -0.0643,  0.1666],
         [-0.0349,  0.3490,  0.0308,  ...,  0.2998, -0.1317,  0.2345],
         [-0.1227, -0.0342,  0.1140,  ...,  0.0848,  0.0506,  0.1641],
         ...,
         [-0.0270, -0.0016,  0.0864,  ...,  0.3678,  0.1071, -0.0501],
         [-0.0966,  0.0063,  0.0508,  ...,  0.2537,  0.1200,  0.0230],
         [-0.0420,  0.1279,  0.1104,  ...,  0.1171, -0.0129,  0.0693]],

        [[ 0.1102, -0.1124,  0.1219,  ...,  0.2141, -0.1074,  0.1603],
         [-0.1211,  0.1379,  0.1729,  ...,  0

tensor([[[ 3.7159e-02, -2.5422e-01, -8.1282e-03,  ...,  4.2135e-02,
          -1.1171e-01, -3.5862e-02],
         [ 3.2606e-02, -1.4915e-01,  1.0337e-01,  ..., -9.3041e-02,
          -6.0072e-02, -2.6198e-01],
         [-1.9542e-01, -1.2542e-02,  1.8344e-01,  ...,  8.8851e-02,
          -1.4024e-01, -1.8647e-03],
         ...,
         [-8.2166e-02, -1.3072e-01,  5.9921e-02,  ...,  2.1195e-01,
          -3.6652e-02,  1.4097e-01],
         [-3.9948e-02, -3.0230e-01,  1.5339e-01,  ...,  6.2706e-02,
          -9.3021e-02,  1.8439e-01],
         [-7.7201e-02, -9.9446e-02,  2.8715e-02,  ..., -1.4350e-01,
          -1.5289e-01,  1.1603e-01]],

        [[ 6.6725e-02, -1.8014e-01, -8.0491e-02,  ...,  1.7641e-02,
          -9.7558e-02, -6.5325e-02],
         [ 5.5151e-02, -1.6096e-01, -2.9018e-02,  ...,  1.2413e-01,
          -2.5502e-01, -9.2490e-02],
         [ 2.9110e-02, -1.2103e-01, -1.0433e-01,  ...,  1.1697e-01,
          -2.6156e-01,  1.6310e-01],
         ...,
         [-4.1771e-02, -1

tensor([[[ 0.1401, -0.1469,  0.0350,  ...,  0.1739,  0.0471, -0.1545],
         [-0.0723,  0.0724,  0.1067,  ...,  0.2358, -0.0758, -0.1187],
         [ 0.1956, -0.0125, -0.0274,  ...,  0.2020, -0.0557,  0.0680],
         ...,
         [-0.0104, -0.1174,  0.0148,  ...,  0.3855,  0.0941,  0.0192],
         [-0.0014, -0.2259,  0.1236,  ...,  0.3310,  0.0652,  0.0231],
         [ 0.0101, -0.0424,  0.1862,  ...,  0.1568,  0.1046,  0.0844]],

        [[ 0.1216, -0.1314, -0.0053,  ...,  0.1700,  0.0665, -0.1087],
         [ 0.0490,  0.1912,  0.0807,  ...,  0.1292,  0.2144, -0.3625],
         [ 0.0158,  0.1728,  0.2334,  ...,  0.3578, -0.0603, -0.0685],
         ...,
         [ 0.0034, -0.1090, -0.0387,  ...,  0.4312,  0.0779,  0.0097],
         [-0.0218, -0.2386,  0.0922,  ...,  0.3507,  0.0360,  0.0222],
         [ 0.0037, -0.0316,  0.1883,  ...,  0.1931,  0.0941,  0.0572]],

        [[ 0.1187, -0.1400, -0.0207,  ...,  0.1848,  0.0692, -0.0782],
         [ 0.2741,  0.1671,  0.0563,  ...,  0

tensor([[[ 0.1950, -0.0747, -0.1065,  ...,  0.3287, -0.1243, -0.0020],
         [ 0.0976,  0.1368, -0.1588,  ...,  0.2873, -0.2334, -0.1374],
         [-0.0200, -0.1200, -0.0408,  ...,  0.1336, -0.0313, -0.0438],
         ...,
         [ 0.3015, -0.0583, -0.0667,  ...,  0.3225, -0.0172, -0.0103],
         [-0.0510, -0.0993, -0.0118,  ...,  0.3421,  0.0240,  0.0711],
         [ 0.0205, -0.0727, -0.1206,  ...,  0.1482, -0.0699,  0.0378]],

        [[ 0.2764, -0.1355, -0.1379,  ...,  0.3838, -0.0923, -0.0637],
         [ 0.0975,  0.0678, -0.1577,  ...,  0.1294, -0.0291, -0.3000],
         [-0.0891, -0.1001, -0.0970,  ...,  0.2201, -0.2290, -0.1247],
         ...,
         [ 0.3184, -0.0890, -0.1083,  ...,  0.3706,  0.0014, -0.0175],
         [ 0.0213, -0.1336, -0.0569,  ...,  0.3105,  0.0122,  0.0752],
         [ 0.1084, -0.1302, -0.1316,  ...,  0.1412, -0.0879, -0.0093]],

        [[ 0.2286, -0.1263, -0.1295,  ...,  0.3715, -0.1322, -0.0841],
         [ 0.0254,  0.0721, -0.1170,  ...,  0

tensor([[[ 8.5558e-02, -1.8009e-01, -7.1627e-02,  ...,  3.4469e-01,
          -1.5599e-01, -5.9102e-02],
         [ 1.9304e-02,  2.8025e-01, -1.8902e-01,  ...,  4.9894e-01,
          -1.3073e-01, -7.6054e-02],
         [-8.8839e-02, -1.0097e-01, -3.6823e-02,  ...,  2.2965e-01,
          -9.7956e-02, -1.6922e-01],
         ...,
         [-1.4037e-01, -7.3385e-02, -8.9241e-03,  ...,  5.0836e-01,
          -6.6828e-02,  5.2726e-02],
         [-1.1398e-01, -2.0577e-01, -8.3688e-03,  ...,  3.4501e-01,
          -4.3064e-03, -6.0380e-02],
         [-2.2479e-01,  3.2870e-02,  1.5724e-02,  ...,  2.5074e-01,
          -1.2204e-01,  3.6664e-02]],

        [[ 1.1376e-01, -2.1941e-01, -9.4769e-02,  ...,  3.3931e-01,
          -1.4762e-01, -7.8635e-02],
         [-8.0502e-04, -6.5673e-03, -1.7716e-01,  ...,  3.7378e-01,
          -4.8029e-02, -3.3558e-01],
         [-1.2419e-01,  5.8789e-03, -1.7429e-01,  ...,  3.8398e-01,
          -2.1025e-01, -1.7920e-01],
         ...,
         [-9.3278e-02, -1

tensor([[[ 7.8545e-02, -1.3941e-01,  1.1067e-01,  ...,  2.6443e-01,
           1.0003e-01, -2.6292e-02],
         [-7.5320e-02,  3.0348e-02,  1.2469e-01,  ...,  2.9751e-01,
           6.8416e-02,  9.0976e-02],
         [ 1.0209e-03, -1.5567e-01,  1.8733e-01,  ...,  1.8137e-01,
           2.5773e-01,  7.9994e-03],
         ...,
         [ 4.1887e-02, -1.5469e-01,  1.9494e-01,  ...,  4.9116e-01,
           1.5815e-01,  1.4922e-01],
         [-4.2887e-02, -3.6034e-01,  2.8442e-01,  ...,  3.9591e-01,
           2.3044e-01,  9.5248e-02],
         [ 7.9074e-03, -1.5307e-01,  3.1317e-01,  ...,  1.9273e-01,
           1.9726e-01,  1.8006e-01]],

        [[ 2.0277e-01, -1.7859e-01,  1.0870e-01,  ...,  2.7807e-01,
           4.2132e-02, -2.1048e-02],
         [ 2.2105e-02, -2.1402e-03,  1.3721e-01,  ...,  3.8017e-01,
           2.9894e-02,  1.0991e-01],
         [ 1.2017e-01, -2.1751e-01,  1.7675e-02,  ...,  3.0233e-01,
           1.2192e-01,  9.6491e-02],
         ...,
         [ 1.5973e-01, -1

tensor([[[-0.1877,  0.0515, -0.0762,  ...,  0.2618, -0.0522, -0.0036],
         [-0.2198,  0.1636, -0.1538,  ...,  0.1454, -0.3140,  0.0474],
         [-0.3697, -0.1356, -0.2579,  ...,  0.0349, -0.1346, -0.0240],
         ...,
         [-0.2652, -0.0228,  0.0808,  ...,  0.1790,  0.0380,  0.0153],
         [-0.2116, -0.0364,  0.1177,  ...,  0.1621,  0.0149,  0.0041],
         [-0.1914,  0.1514,  0.0884,  ..., -0.0562, -0.0141,  0.0785]],

        [[-0.1451,  0.0028, -0.1076,  ...,  0.2908, -0.1061,  0.0195],
         [-0.3576,  0.2989, -0.2133,  ...,  0.3166, -0.1488, -0.0449],
         [-0.1724,  0.0692,  0.1073,  ...,  0.0104,  0.0180, -0.0525],
         ...,
         [-0.2185, -0.0454,  0.0743,  ...,  0.2194,  0.0807, -0.0598],
         [-0.1579, -0.0605,  0.1006,  ...,  0.2042,  0.0283,  0.0045],
         [-0.1876,  0.1187,  0.1082,  ..., -0.0125,  0.0213,  0.0563]],

        [[-0.1520,  0.0542, -0.0486,  ...,  0.2608, -0.0669,  0.0096],
         [-0.1802,  0.1737, -0.0468,  ...,  0

In [28]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]
print(text)
print('================================')
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
Where are you going today?
I am going shopping. What about you?
I am going to visit my grandmother. she is not very well
['[CLS]', '[MASK]', 'how', 'are', 'you', 'i', 'am', 'romeo', '[SEP]', 'i', 'am', 'going', 'to', 'visit', 'my', '[MASK]', 'she', 'is', 'not', 'very', '[MASK]', '[SEP]']
tensor([[[-0.0594, -0.3794,  0.0031,  ...,  0.2240, -0.3219, -0.0732],
         [-0.2404,  0.0142,  0.0810,  ...,  0.1536, -0.2299, -0.1594],
         [-0.0397, -0.2211, -0.2294,  ...,  0.0962, -0.2445, -0.0959],
         ...,
         [-0.2348, -0.3640, -0.0164,  ...,  0.3233, -0.0098, -0.1435],
         [-0.1622, -0.3274,  0.0663,  ...,  0.2726, -0.1130, -0.1256],
         [-0.2451, -0.3558,  0.1276,  ...,  0.0038, -0.1550, -0.1105]]],
       grad_fn=<AddBackward0>)
tensor([[[20, 20, 20,  ..., 20

In [29]:
#应用量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(40, 768)
    (pos_embed): Embedding(30, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0): EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_K): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_V): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): DynamicQuantizedLinear(in_features=768, out_features=3072, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (fc2): DynamicQuantizedLinear(in_features=3072, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
    )
    (1): Enc

In [30]:
for epoch in range(50):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
        logits_lm, logits_clsf = quantized_model(input_ids, segment_ids, masked_pos)
        loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
        loss_lm = (loss_lm.float()).mean()
        loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
        loss = loss_lm + loss_clsf
        loss = loss.requires_grad_()
        if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

tensor([[[ 2.7664e-01, -3.5962e-01,  5.9782e-02,  ...,  3.5997e-01,
          -1.1840e-01,  4.7731e-02],
         [ 1.5979e-01, -3.3200e-02,  1.0703e-01,  ...,  3.8254e-01,
          -2.9675e-01, -1.6625e-01],
         [-1.3866e-01, -2.0457e-01, -6.6440e-03,  ...,  3.0492e-01,
          -1.8744e-01,  1.8672e-02],
         ...,
         [ 5.1000e-02, -2.3836e-01,  5.4464e-03,  ...,  4.6986e-01,
           8.9828e-02, -2.8349e-02],
         [-3.2916e-03, -2.5011e-01,  1.6300e-01,  ...,  3.0020e-01,
           1.1705e-01,  1.6779e-02],
         [-6.9355e-02, -1.0153e-01,  2.6717e-01,  ...,  3.1584e-01,
           3.0278e-02,  1.8750e-02]],

        [[ 3.9820e-01, -3.3301e-01,  4.4598e-02,  ...,  3.1002e-01,
          -1.0646e-01, -3.4093e-02],
         [-2.8541e-02,  2.7216e-01, -7.1069e-02,  ...,  4.5060e-01,
          -1.4216e-01,  6.1946e-02],
         [-3.0057e-02, -1.1560e-02,  3.0888e-01,  ...,  2.1811e-01,
          -4.7499e-02, -1.9151e-01],
         ...,
         [ 7.6677e-02, -2

tensor([[[ 0.1489, -0.1954,  0.1647,  ...,  0.2633, -0.0367, -0.0882],
         [ 0.0934, -0.0199,  0.0582,  ...,  0.3560, -0.2491, -0.0027],
         [ 0.0992, -0.3164, -0.0192,  ...,  0.2579, -0.0774,  0.1527],
         ...,
         [ 0.1098, -0.1974,  0.1797,  ...,  0.5246,  0.0829,  0.0338],
         [ 0.0033, -0.2478,  0.1671,  ...,  0.4649,  0.1018, -0.0253],
         [ 0.0012, -0.0950,  0.1198,  ...,  0.3320,  0.0167,  0.0714]],

        [[ 0.0818, -0.2131,  0.1533,  ...,  0.2255, -0.0224, -0.0690],
         [ 0.0359,  0.0488,  0.1182,  ...,  0.3882,  0.0096, -0.3993],
         [-0.0507, -0.2725,  0.2834,  ...,  0.3651, -0.1424, -0.1851],
         ...,
         [ 0.0452, -0.2178,  0.1465,  ...,  0.5364,  0.0636,  0.0314],
         [-0.0731, -0.2751,  0.1428,  ...,  0.4189,  0.0864, -0.0253],
         [-0.0460, -0.1246,  0.0665,  ...,  0.2675,  0.0128, -0.0034]],

        [[ 0.0830, -0.2257,  0.1729,  ...,  0.1961, -0.0066, -0.0435],
         [ 0.0683,  0.0006,  0.0440,  ...,  0

tensor([[[ 0.1677, -0.3820, -0.0492,  ...,  0.1233,  0.0060,  0.2301],
         [-0.0283, -0.1474, -0.0972,  ...,  0.2821, -0.0830,  0.2872],
         [-0.0407, -0.2770, -0.0287,  ...,  0.0991,  0.0136,  0.0887],
         ...,
         [ 0.0183, -0.3213,  0.0704,  ...,  0.1022,  0.1266,  0.2126],
         [-0.0991, -0.4379,  0.0442,  ...,  0.1198,  0.0782,  0.1792],
         [-0.0964, -0.2736, -0.0284,  ..., -0.0335,  0.1967,  0.0839]],

        [[ 0.1056, -0.3397, -0.1945,  ...,  0.0865, -0.0295,  0.2547],
         [ 0.1143,  0.0812, -0.1228,  ...,  0.2431, -0.1526,  0.1660],
         [-0.1570, -0.4456, -0.3144,  ...,  0.1128, -0.1313,  0.2264],
         ...,
         [-0.0605, -0.3113, -0.0284,  ...,  0.1239,  0.1447,  0.2374],
         [-0.1450, -0.4679, -0.0364,  ...,  0.1303, -0.0275,  0.1990],
         [-0.1205, -0.2969, -0.0778,  ...,  0.0111,  0.1807,  0.0770]],

        [[ 0.0714, -0.3420, -0.1227,  ...,  0.0689, -0.0481,  0.2323],
         [-0.1376, -0.2106, -0.0366,  ...,  0

tensor([[[ 0.1411, -0.1372, -0.0314,  ...,  0.2695, -0.1070, -0.0109],
         [-0.0295,  0.0999, -0.0739,  ...,  0.1489, -0.1684,  0.0626],
         [ 0.0470, -0.0950,  0.1614,  ..., -0.1175,  0.0393,  0.0353],
         ...,
         [ 0.1613, -0.1208, -0.1941,  ...,  0.2290,  0.0303,  0.0192],
         [ 0.0613, -0.2023,  0.0450,  ...,  0.2741,  0.0334, -0.0413],
         [ 0.0306, -0.0550,  0.0762,  ..., -0.0017,  0.0216,  0.0534]],

        [[ 0.1229, -0.1203, -0.0161,  ...,  0.2823, -0.0764, -0.0635],
         [-0.0558, -0.0546,  0.0865,  ...,  0.2122, -0.1531,  0.1290],
         [ 0.2678, -0.0022, -0.0853,  ...,  0.0629, -0.1963,  0.1138],
         ...,
         [ 0.0978, -0.1460, -0.1963,  ...,  0.2401,  0.0699, -0.0072],
         [ 0.0594, -0.1828,  0.0373,  ...,  0.2952,  0.1083,  0.0017],
         [-0.0241, -0.0464,  0.0295,  ..., -0.0479,  0.0185,  0.0632]],

        [[ 0.1647, -0.0566, -0.0571,  ...,  0.2441, -0.1315, -0.0068],
         [ 0.0615,  0.0299,  0.1377,  ...,  0

tensor([[[-0.0796, -0.1377, -0.0220,  ..., -0.0006, -0.0880, -0.1025],
         [-0.0618,  0.3538,  0.0750,  ...,  0.1033, -0.1925, -0.1647],
         [-0.2664, -0.0503, -0.2670,  ..., -0.0218,  0.0182, -0.0491],
         ...,
         [-0.0972, -0.0414,  0.0989,  ...,  0.2245,  0.0828, -0.0796],
         [-0.1834, -0.0739,  0.0489,  ...,  0.1377, -0.0304, -0.0140],
         [-0.1074,  0.0219,  0.1377,  ...,  0.0282,  0.0650,  0.0747]],

        [[-0.0540, -0.2212, -0.0318,  ...,  0.0526, -0.1047, -0.1295],
         [-0.1501,  0.2774, -0.1594,  ...,  0.0679, -0.0937, -0.0457],
         [-0.0802,  0.1135, -0.0797,  ..., -0.0621,  0.0824, -0.0972],
         ...,
         [-0.1062, -0.0670,  0.0251,  ...,  0.2813, -0.0169, -0.1020],
         [-0.2464, -0.1051, -0.0283,  ...,  0.1882, -0.0931, -0.0487],
         [-0.1361, -0.0219,  0.1007,  ...,  0.0886,  0.0146,  0.0377]],

        [[-0.0410, -0.1742,  0.0135,  ...,  0.0038, -0.1182, -0.1756],
         [-0.1483,  0.1480,  0.0547,  ...,  0

tensor([[[-0.0064, -0.1976, -0.0508,  ...,  0.1895, -0.1761, -0.0539],
         [-0.0804,  0.2218,  0.0522,  ...,  0.2906, -0.0928, -0.1462],
         [-0.1498,  0.0622,  0.1115,  ...,  0.4122,  0.0083,  0.0810],
         ...,
         [-0.0414, -0.2400,  0.0853,  ...,  0.3527,  0.0613,  0.1510],
         [-0.1240, -0.1707, -0.0008,  ...,  0.4049,  0.1369,  0.1373],
         [-0.1200, -0.0365,  0.0537,  ...,  0.2137,  0.0332,  0.1777]],

        [[-0.0231, -0.1696, -0.0270,  ...,  0.1849, -0.1894, -0.0620],
         [-0.1150,  0.2228,  0.0594,  ...,  0.3113, -0.1243, -0.1165],
         [-0.1054,  0.0117,  0.1254,  ...,  0.3592, -0.0118,  0.0171],
         ...,
         [-0.0381, -0.2335,  0.1139,  ...,  0.3796,  0.0794,  0.1371],
         [-0.0946, -0.1465,  0.0072,  ...,  0.3492,  0.1498,  0.1217],
         [-0.0887,  0.0244,  0.0854,  ...,  0.2122,  0.0457,  0.1868]],

        [[ 0.0095, -0.0950, -0.0226,  ...,  0.1664, -0.1362, -0.0228],
         [ 0.1224,  0.2639,  0.0514,  ...,  0

tensor([[[ 0.0761, -0.0776, -0.1184,  ...,  0.0834, -0.1833,  0.0674],
         [-0.1349,  0.2050,  0.0013,  ...,  0.1101, -0.2160, -0.1052],
         [ 0.0838,  0.0227,  0.1466,  ..., -0.2395, -0.0506, -0.0339],
         ...,
         [-0.1286, -0.0944,  0.0360,  ...,  0.2917,  0.0611,  0.0667],
         [-0.0853, -0.1793,  0.0249,  ...,  0.2980,  0.0517,  0.0929],
         [-0.1090,  0.0060,  0.0915,  ...,  0.0262, -0.1156,  0.1100]],

        [[ 0.0760, -0.0670, -0.0691,  ...,  0.0461, -0.1539,  0.0258],
         [-0.1697,  0.1194,  0.0956,  ...,  0.0984, -0.4713, -0.0552],
         [-0.2063, -0.1033, -0.1606,  ...,  0.0149, -0.2235, -0.0071],
         ...,
         [-0.1693, -0.0903,  0.1169,  ...,  0.3013,  0.0390,  0.0403],
         [-0.1068, -0.1695,  0.0596,  ...,  0.2375,  0.0517,  0.0435],
         [-0.1758,  0.0277,  0.1146,  ...,  0.0442, -0.1303,  0.0252]],

        [[ 0.0453, -0.0956, -0.0691,  ...,  0.0939, -0.0858,  0.1160],
         [-0.2191,  0.0781,  0.0976,  ...,  0

tensor([[[-6.8274e-02, -3.2082e-01, -1.4895e-02,  ...,  3.0697e-01,
          -4.2640e-02,  1.2016e-01],
         [-1.2217e-01, -4.9744e-02,  5.0575e-02,  ...,  4.2570e-01,
          -1.2324e-01, -2.6881e-02],
         [ 1.0089e-01, -7.8158e-02, -8.1725e-02,  ...,  3.0066e-01,
          -1.4350e-01,  2.9826e-02],
         ...,
         [-1.0182e-01, -2.0731e-01, -1.7366e-02,  ...,  4.6255e-01,
           1.7844e-01,  1.9467e-02],
         [-3.3152e-02, -2.9145e-01, -6.2501e-03,  ...,  3.0331e-01,
           5.1746e-02,  4.2804e-02],
         [-6.3194e-02, -3.4436e-02,  2.1674e-02,  ...,  2.7399e-01,
          -2.4804e-02,  9.2161e-02]],

        [[ 2.2895e-02, -2.5796e-01, -1.5967e-02,  ...,  3.0287e-01,
          -2.0217e-02,  9.4671e-02],
         [ 8.4048e-02,  2.6598e-02,  1.2824e-01,  ...,  2.5619e-01,
          -3.6444e-01, -9.5309e-02],
         [-1.3775e-01, -1.9209e-01, -1.5498e-01,  ...,  1.3277e-01,
          -3.5652e-01,  9.1380e-02],
         ...,
         [-7.9090e-02, -2

tensor([[[ 0.0132, -0.2890,  0.0948,  ...,  0.1585, -0.1522, -0.1963],
         [-0.2463,  0.0857,  0.0664,  ..., -0.0141, -0.3573, -0.2604],
         [-0.1395, -0.1468,  0.1171,  ...,  0.1262, -0.0986, -0.1463],
         ...,
         [-0.1981, -0.1810,  0.1621,  ...,  0.1890,  0.0512, -0.1827],
         [-0.1720, -0.1450,  0.0808,  ...,  0.1567,  0.0308, -0.1492],
         [-0.2357,  0.0668,  0.1830,  ...,  0.0319,  0.0022, -0.0702]],

        [[ 0.0328, -0.2987,  0.0709,  ...,  0.1171, -0.1673, -0.0806],
         [-0.0466,  0.0376,  0.0438,  ...,  0.0275,  0.0335, -0.3328],
         [-0.1997, -0.0642,  0.0097,  ...,  0.1383, -0.1184, -0.0558],
         ...,
         [-0.1129, -0.2190,  0.1278,  ...,  0.1759,  0.0963, -0.0956],
         [-0.1687, -0.2590,  0.0270,  ...,  0.1043,  0.0409, -0.0606],
         [-0.2433, -0.0077,  0.1321,  ..., -0.0061,  0.0032, -0.0145]],

        [[ 0.0081, -0.2428,  0.0573,  ...,  0.0910, -0.0982, -0.0709],
         [-0.0153,  0.0936,  0.0868,  ...,  0

tensor([[[ 0.2179, -0.1513, -0.0337,  ...,  0.0844,  0.2017, -0.3174],
         [ 0.1400,  0.0892, -0.0190,  ...,  0.1722,  0.2107, -0.5027],
         [ 0.0247, -0.0089,  0.1306,  ...,  0.3167,  0.1277, -0.2703],
         ...,
         [ 0.0257, -0.1359,  0.1676,  ...,  0.4087,  0.3017, -0.1448],
         [ 0.0198, -0.1463,  0.0601,  ...,  0.2701,  0.1452, -0.1454],
         [-0.0319, -0.0179,  0.1892,  ...,  0.2144,  0.2187, -0.1279]],

        [[ 0.2179, -0.1298, -0.0656,  ...,  0.1117,  0.1981, -0.2871],
         [ 0.1368,  0.0711, -0.0268,  ...,  0.1877,  0.2423, -0.5035],
         [ 0.0430,  0.0033,  0.1206,  ...,  0.3359,  0.1185, -0.2895],
         ...,
         [ 0.0190, -0.1431,  0.1729,  ...,  0.4083,  0.2908, -0.1140],
         [ 0.0693, -0.1689,  0.1015,  ...,  0.2881,  0.1860, -0.1080],
         [-0.0020, -0.0301,  0.1573,  ...,  0.2333,  0.2264, -0.0691]],

        [[ 0.2549, -0.2043,  0.0036,  ...,  0.1252,  0.2510, -0.2503],
         [ 0.1871,  0.3031, -0.0483,  ...,  0

tensor([[[ 0.0543, -0.2626,  0.1237,  ...,  0.1840,  0.0424, -0.0425],
         [-0.3165,  0.1135, -0.0239,  ...,  0.3232, -0.0833,  0.0032],
         [-0.0574, -0.2438,  0.2758,  ...,  0.1790,  0.0669, -0.0557],
         ...,
         [-0.0408, -0.3138,  0.1867,  ...,  0.2901,  0.1347, -0.0535],
         [-0.1177, -0.2898,  0.2610,  ...,  0.2866,  0.1062, -0.1293],
         [-0.2311, -0.1763,  0.2788,  ...,  0.1036,  0.1399, -0.0868]],

        [[ 0.0204, -0.2163,  0.0932,  ...,  0.2209,  0.0778,  0.0057],
         [-0.1138, -0.1413,  0.0924,  ...,  0.1419, -0.1356,  0.0236],
         [-0.0355, -0.1733,  0.2487,  ...,  0.2184,  0.0832, -0.0437],
         ...,
         [-0.0143, -0.3116,  0.1898,  ...,  0.3165,  0.1840, -0.0506],
         [-0.0810, -0.2470,  0.2554,  ...,  0.3171,  0.1089, -0.0871],
         [-0.2327, -0.1537,  0.2531,  ...,  0.0950,  0.1528, -0.0334]],

        [[ 0.0116, -0.2121,  0.0868,  ...,  0.1679,  0.0568, -0.0294],
         [-0.0301,  0.0443,  0.2164,  ...,  0

tensor([[[ 0.0466, -0.1964,  0.1279,  ...,  0.0996, -0.0910, -0.1481],
         [-0.1336, -0.0921,  0.2153,  ...,  0.1773, -0.1294, -0.2495],
         [-0.0039, -0.1736, -0.1072,  ...,  0.2478, -0.0373, -0.1234],
         ...,
         [-0.0270, -0.1299,  0.0554,  ...,  0.3843, -0.0981,  0.0407],
         [-0.0253, -0.1749,  0.1145,  ...,  0.2397,  0.0481, -0.0845],
         [-0.0199, -0.0077,  0.2677,  ..., -0.0066, -0.0932, -0.0771]],

        [[ 0.0139, -0.2322,  0.1221,  ...,  0.1667, -0.0410, -0.1030],
         [-0.2005,  0.1084,  0.0740,  ...,  0.3660, -0.0170, -0.1463],
         [-0.1535, -0.1370,  0.2042,  ...,  0.0788,  0.1216, -0.2015],
         ...,
         [-0.1062, -0.1289,  0.0142,  ...,  0.3909,  0.0070, -0.0011],
         [-0.1124, -0.1550,  0.1352,  ...,  0.3026,  0.0971, -0.1318],
         [-0.0744,  0.0104,  0.2335,  ...,  0.0377, -0.0133, -0.1052]],

        [[ 0.0206, -0.1675,  0.1575,  ...,  0.0756, -0.0614, -0.1134],
         [-0.1301,  0.0797,  0.2169,  ..., -0

tensor([[[ 0.1131, -0.3152,  0.2735,  ...,  0.1556,  0.0267, -0.0254],
         [ 0.1222,  0.0615,  0.1830,  ...,  0.2340, -0.0020, -0.1682],
         [-0.2182, -0.2904, -0.0342,  ...,  0.2082, -0.1400, -0.0525],
         ...,
         [ 0.0224, -0.3025,  0.0349,  ...,  0.3775,  0.1065, -0.0016],
         [-0.1093, -0.3953,  0.1523,  ...,  0.2873,  0.1068, -0.0546],
         [-0.1364, -0.1688,  0.2950,  ...,  0.0181,  0.1558,  0.0478]],

        [[ 0.0662, -0.3067,  0.2337,  ...,  0.1061, -0.0333, -0.0280],
         [-0.1408, -0.0941,  0.1275,  ...,  0.2038, -0.0921, -0.0751],
         [ 0.0709, -0.2088,  0.2334,  ...,  0.1054,  0.0330,  0.0779],
         ...,
         [ 0.0419, -0.2678,  0.0161,  ...,  0.3664,  0.1048,  0.0512],
         [-0.0823, -0.3693,  0.1182,  ...,  0.2732,  0.1114, -0.0457],
         [-0.0987, -0.1466,  0.2762,  ..., -0.0403,  0.1420,  0.0307]],

        [[ 0.1393, -0.2264,  0.1819,  ...,  0.1383,  0.0167,  0.0323],
         [-0.0971, -0.0928,  0.1176,  ...,  0

tensor([[[ 9.2673e-02, -2.6195e-01,  5.1705e-03,  ...,  2.3907e-01,
          -1.2316e-01, -2.2633e-01],
         [-2.0902e-01,  1.1541e-01, -4.7725e-02,  ...,  3.4712e-01,
          -1.0372e-01, -3.7092e-02],
         [-8.3082e-03, -1.5274e-01,  1.4915e-01,  ...,  2.8036e-01,
           6.4933e-02, -1.8998e-04],
         ...,
         [ 1.1597e-01, -2.6088e-01,  9.9757e-02,  ...,  4.2978e-01,
           1.4194e-01, -2.2763e-02],
         [-9.0227e-03, -2.7086e-01,  5.6613e-02,  ...,  3.1089e-01,
           9.6667e-02, -7.9355e-02],
         [-4.9407e-02, -1.4334e-01,  1.5360e-01,  ...,  1.7438e-01,
           7.1416e-02,  3.1772e-02]],

        [[ 1.2488e-01, -2.0879e-01,  1.9218e-02,  ...,  2.4137e-01,
          -1.6913e-01, -2.1739e-01],
         [-5.4697e-02, -3.7541e-02,  1.7050e-02,  ...,  1.6674e-01,
          -1.5448e-01, -1.5799e-01],
         [ 2.1387e-01, -1.8576e-01, -5.2280e-02,  ...,  3.1974e-01,
          -1.9781e-01,  1.2985e-01],
         ...,
         [ 1.2134e-01, -2

tensor([[[ 0.2260, -0.0739,  0.0547,  ...,  0.0698, -0.0018, -0.1256],
         [ 0.2684,  0.3490,  0.1671,  ...,  0.1235, -0.2265, -0.1530],
         [ 0.0946, -0.1553, -0.1151,  ...,  0.2434, -0.1447,  0.1129],
         ...,
         [-0.0156, -0.0720,  0.0961,  ...,  0.3600, -0.1281,  0.0834],
         [-0.0392, -0.2160,  0.1334,  ...,  0.1560,  0.0158,  0.0255],
         [ 0.0567,  0.1385,  0.1081,  ..., -0.0127,  0.0130,  0.1460]],

        [[ 0.2496, -0.1005,  0.0563,  ...,  0.1009,  0.0143, -0.1157],
         [ 0.0994,  0.4104, -0.0225,  ...,  0.2817, -0.0634, -0.1341],
         [ 0.0747, -0.0036,  0.2176,  ..., -0.0845,  0.1761,  0.0317],
         ...,
         [-0.0272, -0.1668,  0.1067,  ...,  0.3127, -0.0628,  0.1142],
         [-0.0470, -0.2207,  0.1470,  ...,  0.1435,  0.0989,  0.0012],
         [ 0.0359,  0.1174,  0.1427,  ..., -0.0729,  0.0716,  0.1362]],

        [[ 0.1819, -0.0791,  0.1002,  ...,  0.0330,  0.0308, -0.1391],
         [ 0.1374,  0.2403,  0.1827,  ...,  0

tensor([[[ 0.0617, -0.0989,  0.1726,  ...,  0.1107, -0.2600,  0.2069],
         [-0.1646, -0.0498,  0.2059,  ...,  0.1010, -0.0601, -0.3312],
         [-0.0921, -0.2114,  0.2496,  ...,  0.1340, -0.1207,  0.0533],
         ...,
         [-0.1473, -0.0429,  0.1920,  ...,  0.2604,  0.0195,  0.0847],
         [-0.0810, -0.2570,  0.2142,  ...,  0.2221, -0.0065,  0.0452],
         [-0.1502, -0.0646,  0.2937,  ...,  0.0333, -0.1440,  0.0929]],

        [[ 0.0583, -0.1435,  0.1939,  ...,  0.1007, -0.2757,  0.2968],
         [-0.1591,  0.0408,  0.2813,  ...,  0.3067, -0.3101,  0.0293],
         [-0.2421, -0.5430,  0.0997,  ...,  0.1078, -0.2788,  0.1041],
         ...,
         [-0.1530, -0.0832,  0.1663,  ...,  0.2761, -0.0244,  0.1362],
         [-0.1008, -0.2377,  0.2033,  ...,  0.2176, -0.0412,  0.0760],
         [-0.1042, -0.0496,  0.2936,  ...,  0.0420, -0.1472,  0.1661]],

        [[ 0.0701, -0.1342,  0.2242,  ...,  0.0751, -0.3042,  0.2293],
         [-0.1446, -0.0249,  0.2424,  ...,  0

tensor([[[ 0.1648, -0.2727,  0.0036,  ...,  0.0646, -0.1054,  0.1036],
         [ 0.0205, -0.1141,  0.0898,  ...,  0.0680, -0.2057,  0.0166],
         [-0.0029, -0.2423,  0.2587,  ...,  0.0414, -0.2036,  0.0279],
         ...,
         [ 0.0435, -0.2246,  0.1854,  ...,  0.4358,  0.1669,  0.0203],
         [-0.0163, -0.2006,  0.1881,  ...,  0.3148,  0.0005,  0.1323],
         [-0.0407, -0.1356,  0.1462,  ...,  0.0865,  0.0067,  0.2025]],

        [[ 0.2479, -0.2244,  0.0043,  ...,  0.0576, -0.0893,  0.0822],
         [-0.1178,  0.1808,  0.0324,  ...,  0.2059, -0.2300, -0.0536],
         [ 0.0789, -0.1485,  0.2524,  ...,  0.0451, -0.1658, -0.0192],
         ...,
         [ 0.0988, -0.1996,  0.1373,  ...,  0.4381,  0.2129, -0.0245],
         [ 0.0142, -0.1323,  0.1871,  ...,  0.3543,  0.0120,  0.1548],
         [ 0.0215, -0.0728,  0.1267,  ...,  0.0870,  0.0346,  0.1511]],

        [[ 0.1681, -0.2488,  0.0512,  ...,  0.0247, -0.1116,  0.0561],
         [ 0.0095, -0.1352,  0.1440,  ...,  0

tensor([[[-8.7632e-03, -1.6017e-01,  8.1358e-02,  ...,  9.7872e-02,
          -1.0130e-01,  1.1957e-02],
         [-8.4359e-02,  1.7903e-01,  9.7752e-02,  ...,  1.0827e-01,
          -1.8640e-01, -1.9352e-01],
         [ 4.8461e-02, -1.3649e-01,  3.0322e-01,  ..., -1.2828e-01,
          -9.4421e-02, -1.3642e-01],
         ...,
         [-3.6557e-02, -7.1078e-03, -1.1275e-02,  ...,  2.1074e-01,
           1.1065e-02,  2.2199e-02],
         [-1.8535e-01, -5.8573e-02,  6.1404e-02,  ...,  1.2760e-01,
           8.2304e-02,  2.1840e-03],
         [-3.0428e-02,  8.1287e-02,  9.8858e-02,  ..., -2.9546e-02,
          -7.8558e-02,  3.8159e-02]],

        [[ 3.8293e-02, -1.9995e-01,  1.9744e-01,  ...,  2.1463e-01,
          -2.0967e-03,  8.3197e-03],
         [-4.6095e-02,  1.4669e-01,  2.1485e-01,  ...,  5.1816e-02,
          -5.3858e-02, -5.0172e-01],
         [-4.8523e-02,  2.1150e-03,  1.4885e-01,  ...,  6.4637e-02,
          -5.6076e-02, -2.6437e-01],
         ...,
         [ 5.5598e-03, -3

tensor([[[ 2.5633e-03,  7.2281e-02, -5.8094e-02,  ...,  1.4214e-01,
          -2.7562e-02, -1.3139e-02],
         [-1.5733e-01,  5.3319e-02,  1.9053e-02,  ...,  2.0346e-01,
          -1.1817e-03, -3.0318e-01],
         [-1.9130e-01,  2.3100e-01,  1.8625e-01,  ...,  2.7121e-01,
          -2.0302e-02, -1.1086e-01],
         ...,
         [-1.8614e-01, -3.3860e-02, -1.0118e-01,  ...,  3.6201e-01,
          -6.1526e-02, -2.1538e-02],
         [-1.6624e-01, -1.2877e-01, -7.6986e-03,  ...,  2.7588e-01,
           4.3461e-02, -1.3253e-01],
         [-1.0439e-01, -1.4590e-03,  1.4926e-01,  ...,  1.7357e-01,
           1.2248e-01,  7.6126e-04]],

        [[-6.1758e-02,  7.2042e-02, -6.0211e-02,  ...,  1.9822e-01,
           8.1102e-04,  2.0063e-02],
         [-8.3781e-02,  2.1316e-01, -7.9377e-02,  ...,  3.6137e-01,
          -1.3201e-01, -1.1228e-01],
         [-2.3147e-01, -1.0091e-01, -1.6624e-01,  ...,  1.7264e-01,
          -7.2025e-02,  9.0625e-02],
         ...,
         [-2.4478e-01, -9

tensor([[[ 0.0460, -0.1356, -0.0612,  ...,  0.4352,  0.0543,  0.2508],
         [-0.2067,  0.0244,  0.0523,  ...,  0.2544, -0.1843,  0.3368],
         [-0.1297, -0.0642, -0.0804,  ...,  0.3387, -0.2025,  0.2722],
         ...,
         [-0.1653, -0.2078,  0.1289,  ...,  0.5194,  0.1080,  0.3717],
         [-0.1410, -0.3811,  0.1130,  ...,  0.2396,  0.1415,  0.2535],
         [-0.0545, -0.1844,  0.2115,  ...,  0.1259,  0.0476,  0.3074]],

        [[-0.0165, -0.2216, -0.0338,  ...,  0.3738,  0.0303,  0.3137],
         [-0.1035,  0.0479, -0.0171,  ...,  0.2662,  0.0586, -0.0046],
         [-0.3325,  0.0154,  0.1202,  ...,  0.2913, -0.0425,  0.2897],
         ...,
         [-0.1430, -0.2461,  0.1130,  ...,  0.4500,  0.0884,  0.3987],
         [-0.1319, -0.4137,  0.0918,  ...,  0.2417,  0.1416,  0.3350],
         [-0.0735, -0.2050,  0.2082,  ...,  0.1198,  0.0318,  0.3171]],

        [[ 0.0290, -0.2115, -0.0323,  ...,  0.4306,  0.0700,  0.3280],
         [-0.1288,  0.0696, -0.0119,  ...,  0

tensor([[[-0.0782, -0.1008, -0.1271,  ...,  0.1056,  0.0543, -0.0363],
         [-0.0815,  0.1191,  0.1008,  ...,  0.1939,  0.2176, -0.1776],
         [-0.2148,  0.1560,  0.0015,  ...,  0.3817, -0.0493,  0.1195],
         ...,
         [-0.1749, -0.0303,  0.0568,  ...,  0.5326,  0.1834,  0.0846],
         [-0.1299, -0.0113,  0.0056,  ...,  0.3840,  0.1816,  0.1026],
         [-0.1935,  0.0484,  0.1649,  ...,  0.1284,  0.1671,  0.1768]],

        [[-0.0337, -0.0793, -0.1163,  ...,  0.0950,  0.0821, -0.0233],
         [-0.0923,  0.1211,  0.1015,  ...,  0.2910, -0.0601,  0.1101],
         [-0.0267,  0.0352, -0.2182,  ...,  0.3038, -0.0759,  0.2099],
         ...,
         [-0.1475,  0.0261,  0.0043,  ...,  0.4892,  0.2533,  0.0763],
         [-0.1842, -0.0012, -0.0534,  ...,  0.3918,  0.1615,  0.0208],
         [-0.2439,  0.0909,  0.1436,  ...,  0.1299,  0.1856,  0.1512]],

        [[-0.0116, -0.0480, -0.1602,  ...,  0.1412,  0.0985, -0.0627],
         [-0.0903,  0.1371,  0.0924,  ...,  0

tensor([[[-1.0651e-01, -8.8798e-02, -2.4643e-01,  ...,  1.1754e-02,
          -1.6532e-01, -1.6231e-01],
         [-1.0491e-02,  7.3798e-02,  8.3683e-04,  ..., -3.1797e-02,
          -1.8008e-02, -2.7690e-01],
         [ 3.9659e-02, -1.0480e-01, -2.7016e-01,  ..., -1.2353e-02,
           1.2772e-02,  1.0905e-01],
         ...,
         [-1.3035e-01, -7.8216e-02, -1.4404e-01,  ...,  2.6211e-01,
           1.2661e-01, -1.6372e-01],
         [-1.6769e-01, -1.4377e-01, -1.1074e-01,  ...,  8.3100e-02,
           8.6463e-02, -2.1527e-01],
         [-1.1456e-01,  5.0513e-02, -5.0107e-02,  ..., -3.9367e-02,
          -4.3474e-02, -2.4885e-02]],

        [[-5.9823e-02, -7.1338e-02, -2.0537e-01,  ...,  2.6316e-02,
          -1.2417e-01, -1.4198e-01],
         [ 2.8619e-02,  1.6930e-01, -7.5020e-02,  ..., -8.0031e-02,
           2.1987e-01, -4.6094e-01],
         [ 1.1650e-02,  1.1247e-01,  9.1824e-02,  ...,  7.5888e-02,
           6.2468e-02, -4.3828e-02],
         ...,
         [-1.0930e-01, -9

tensor([[[ 0.2348, -0.0569,  0.0329,  ...,  0.3297, -0.0354,  0.1160],
         [ 0.0584,  0.3238,  0.0955,  ...,  0.1192,  0.1471, -0.2081],
         [ 0.0476,  0.1331,  0.0379,  ...,  0.3692,  0.0198,  0.0089],
         ...,
         [ 0.1121,  0.0198,  0.0180,  ...,  0.4342,  0.2477, -0.0093],
         [ 0.0081,  0.0029,  0.0516,  ...,  0.2765,  0.2513, -0.0176],
         [ 0.0431,  0.2031,  0.0015,  ...,  0.1296,  0.0925,  0.0906]],

        [[ 0.1937, -0.0963,  0.0848,  ...,  0.2408, -0.0237,  0.0973],
         [-0.0012,  0.2785,  0.1057,  ...,  0.0893,  0.1052, -0.2150],
         [ 0.0463,  0.0688,  0.0473,  ...,  0.3090,  0.0114, -0.0235],
         ...,
         [ 0.0814, -0.0348,  0.0385,  ...,  0.4147,  0.2191,  0.0124],
         [-0.0516, -0.0290,  0.0789,  ...,  0.2606,  0.2556, -0.0048],
         [ 0.0128,  0.1383,  0.0461,  ...,  0.1091,  0.0846,  0.0261]],

        [[ 0.1755, -0.1098,  0.0512,  ...,  0.3547, -0.0277,  0.0789],
         [ 0.0370,  0.2546,  0.0207,  ...,  0

tensor([[[ 0.2560, -0.1887,  0.1807,  ...,  0.2776,  0.1126,  0.1499],
         [ 0.1247, -0.0029,  0.1035,  ...,  0.1349,  0.1147, -0.3050],
         [ 0.0577,  0.0747,  0.2968,  ...,  0.2632,  0.1710,  0.0561],
         ...,
         [ 0.1085, -0.0378,  0.2449,  ...,  0.4581,  0.2337,  0.2805],
         [ 0.1371, -0.0924,  0.1857,  ...,  0.2887,  0.1769,  0.2249],
         [ 0.0555, -0.0471,  0.2212,  ...,  0.1414,  0.2097,  0.2702]],

        [[ 0.2613, -0.1208,  0.1919,  ...,  0.3074,  0.1165,  0.0638],
         [ 0.1173,  0.0393,  0.1717,  ...,  0.2948,  0.1221,  0.1216],
         [ 0.1597,  0.0229,  0.0513,  ...,  0.3463,  0.0996,  0.2581],
         ...,
         [ 0.1254, -0.0112,  0.2815,  ...,  0.4660,  0.2555,  0.2151],
         [ 0.1466, -0.0320,  0.2444,  ...,  0.3226,  0.2392,  0.1953],
         [ 0.0828, -0.0064,  0.2795,  ...,  0.1317,  0.2207,  0.2036]],

        [[ 0.2436, -0.1300,  0.1778,  ...,  0.2470,  0.1275,  0.1154],
         [ 0.1231,  0.0693,  0.1106,  ...,  0

tensor([[[ 2.0798e-01, -2.6018e-01,  5.4183e-02,  ...,  3.8158e-01,
          -1.9068e-01, -1.9797e-01],
         [ 1.5611e-01,  3.8748e-02,  3.2835e-02,  ...,  2.3012e-01,
          -1.3268e-01, -3.6332e-01],
         [ 6.2250e-03, -5.3776e-03,  1.9196e-01,  ...,  3.0967e-01,
          -1.2338e-01, -1.4690e-01],
         ...,
         [ 1.8054e-01, -9.2885e-02,  1.8180e-01,  ...,  3.8511e-01,
           1.8601e-02,  1.1206e-01],
         [ 6.6723e-02, -2.2268e-01,  6.2635e-02,  ...,  4.5882e-01,
          -1.6275e-01,  1.5735e-02],
         [ 1.3116e-01, -3.6556e-02,  1.9475e-01,  ...,  1.2971e-01,
          -5.2045e-02,  1.0782e-01]],

        [[ 2.1957e-01, -2.5796e-01,  2.9608e-02,  ...,  4.0469e-01,
          -2.3202e-01, -2.3282e-01],
         [ 1.8120e-01,  2.9476e-01,  1.5050e-01,  ...,  3.2137e-01,
          -3.2692e-01, -1.1573e-01],
         [-1.2418e-02, -1.0271e-01, -8.6283e-02,  ...,  2.8006e-01,
          -3.2672e-01, -4.6599e-02],
         ...,
         [ 1.3169e-01, -1

tensor([[[-0.1083, -0.2182, -0.1177,  ...,  0.1141,  0.0130, -0.1704],
         [-0.0433, -0.0292,  0.0031,  ...,  0.0691,  0.1025, -0.4166],
         [-0.2037, -0.1899,  0.0326,  ...,  0.1426,  0.0199, -0.1203],
         ...,
         [-0.0769, -0.2080, -0.0418,  ...,  0.3477,  0.0533, -0.0094],
         [-0.1125, -0.2281, -0.0958,  ...,  0.2939,  0.1122, -0.0127],
         [-0.0579, -0.1429,  0.0222,  ...,  0.1403,  0.1062,  0.1237]],

        [[-0.1297, -0.2382, -0.1021,  ...,  0.1514, -0.0441, -0.1769],
         [-0.0936,  0.0090,  0.0811,  ...,  0.2359,  0.0082, -0.0901],
         [-0.0556, -0.2277,  0.2669,  ..., -0.0532,  0.1507, -0.0044],
         ...,
         [-0.1341, -0.2193, -0.0350,  ...,  0.3146,  0.0937, -0.0320],
         [-0.0910, -0.2710, -0.1213,  ...,  0.2792,  0.1271, -0.0314],
         [-0.0795, -0.1724,  0.0088,  ...,  0.1344,  0.0619,  0.1350]],

        [[-0.1080, -0.2391, -0.1330,  ...,  0.1127, -0.0392, -0.1819],
         [-0.0622,  0.2465, -0.1075,  ...,  0

In [36]:
#输出对比测试结果
import os
import time

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def print_time_of_model(model):    
    eval_start_time = time.time()
    model(input_ids,segment_ids, masked_pos)
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

    
print_size_of_model(model)
print_size_of_model(quantized_model)

print_time_of_model(model)
print_time_of_model(quantized_model)

Size (MB): 160.844839
Size (MB): 40.572431
tensor([[[ 0.1075, -0.2140,  0.2019,  ...,  0.0748,  0.0040, -0.0091],
         [-0.2135,  0.0146,  0.2145,  ...,  0.3203,  0.0187,  0.0302],
         [-0.0475, -0.0472,  0.4097,  ...,  0.0062,  0.1150,  0.1236],
         ...,
         [ 0.0762, -0.1692,  0.3838,  ...,  0.2508,  0.1579, -0.0314],
         [ 0.0115, -0.2257,  0.4501,  ...,  0.0608,  0.1215, -0.0351],
         [ 0.0366,  0.0608,  0.4477,  ...,  0.0994,  0.0856,  0.1066]],

        [[ 0.1073, -0.2480,  0.2603,  ...,  0.0284, -0.0204, -0.0872],
         [-0.0607,  0.0438,  0.3346,  ...,  0.1798,  0.3206, -0.3581],
         [-0.0566, -0.0643,  0.4379,  ...,  0.1063, -0.0826, -0.0430],
         ...,
         [ 0.0905, -0.1859,  0.3882,  ...,  0.1841,  0.2086, -0.0687],
         [ 0.0583, -0.2357,  0.4723,  ...,  0.0506,  0.1448, -0.0780],
         [ 0.0781,  0.0565,  0.4837,  ...,  0.0984,  0.1535,  0.0639]],

        [[ 0.0547, -0.1746,  0.1964,  ...,  0.0904,  0.0060, -0.0171],
  