In [1]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
         https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

token_list = list()
for sentence in sentences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)


In [2]:
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [3]:
# sample IsNext and NotNext to be same in small batch size
def make_data():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # MASK LM
        n_pred =  min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word2idx['[MASK]'] # make mask
            elif random() > 0.9:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                while index < 4: # can't involve 'CLS', 'SEP', 'PAD'
                  index = randint(0, vocab_size - 1)
                input_ids[pos] = index # replace

        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch
# Proprecessing Finished

In [14]:
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)

class MyDataSet(Data.Dataset):
  def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
    self.input_ids = input_ids
    self.segment_ids = segment_ids
    self.masked_tokens = masked_tokens
    self.masked_pos = masked_pos
    self.isNext = isNext
  
  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx):
    return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]

loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

In [5]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # [batch_size, 1, seq_len]
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]

def gelu(x):
    """
      Implementation of the gelu activation function.
      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
      Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # [seq_len] -> [batch_size, seq_len]
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]
        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size, seq_len, n_heads, d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual) # output: [batch_size, seq_len, d_model]

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
        return self.fc2(gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]
        return enc_outputs

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(d_model, 2)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]
        for layer in self.layers:
            # output: [batch_size, max_len, d_model]
            output = layer(output, enc_self_attn_mask)
        # it will be decided by first token(CLS)
        h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext

        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
        print(output)
        print(masked_pos)
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        print(h_masked)
        h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]
        logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]
        return logits_lm, logits_clsf
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)

In [6]:
for epoch in range(50):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
      logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
      loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
      loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

tensor([[[-0.1890, -0.2047, -0.2048,  ..., -0.1148,  0.0296, -0.0562],
         [-0.2405, -0.3417, -0.0919,  ..., -0.3263, -0.0249, -0.1323],
         [-0.0160, -0.2199, -0.1080,  ..., -0.2453,  0.1712,  0.1507],
         ...,
         [-0.2875, -0.0638,  0.0447,  ..., -0.3648,  0.2224, -0.1167],
         [-0.2470, -0.2487,  0.0178,  ..., -0.3012,  0.2738, -0.2479],
         [-0.2400, -0.0523,  0.0175,  ..., -0.2968,  0.2168, -0.1821]],

        [[-0.1725, -0.1996, -0.1884,  ..., -0.1130,  0.0326, -0.0812],
         [-0.1072, -0.1677, -0.1199,  ..., -0.2078,  0.0545, -0.2205],
         [-0.2399, -0.2136, -0.0255,  ..., -0.2398,  0.1503, -0.1368],
         ...,
         [-0.2905, -0.0583,  0.0849,  ..., -0.3584,  0.2173, -0.1282],
         [-0.2474, -0.2421,  0.0303,  ..., -0.2975,  0.2728, -0.2805],
         [-0.2062, -0.0566,  0.0550,  ..., -0.2699,  0.2090, -0.1790]],

        [[-0.2201, -0.2065, -0.2235,  ..., -0.1209, -0.0117, -0.0681],
         [ 0.0215, -0.1838, -0.1555,  ..., -0

tensor([[[-1.9830e-01, -1.5718e-01, -1.1011e-01,  ..., -1.8848e-01,
          -1.0198e-01,  6.4003e-02],
         [-1.3842e-01, -4.1150e-02, -5.3377e-02,  ..., -2.9528e-01,
           6.0719e-02,  1.3173e-01],
         [-3.7655e-01, -3.5033e-02,  2.8300e-01,  ..., -4.2525e-01,
           1.9320e-01,  7.1725e-02],
         ...,
         [ 6.5367e-03,  1.2483e-01,  2.2110e-01,  ..., -4.6466e-01,
           1.4870e-01,  1.6419e-01],
         [-1.9873e-01, -6.8287e-02,  3.1001e-02,  ..., -3.1050e-01,
           1.2702e-01,  1.6161e-02],
         [-1.2139e-01,  8.2297e-02,  3.9865e-02,  ..., -3.5388e-01,
           1.1662e-01, -7.7283e-02]],

        [[-2.1384e-01, -1.5208e-01, -1.5504e-01,  ..., -1.5914e-01,
          -8.7663e-02,  7.9434e-02],
         [ 6.2923e-02, -4.0285e-02, -1.7200e-01,  ..., -3.4454e-01,
           5.0803e-02,  1.7967e-01],
         [-1.9627e-01,  3.6522e-02,  6.8572e-02,  ..., -3.9959e-01,
           2.3291e-01,  2.3896e-01],
         ...,
         [ 3.3989e-02,  1

tensor([[[-0.1477, -0.1404, -0.0245,  ..., -0.0996,  0.0876,  0.2169],
         [-0.0323, -0.1452,  0.0961,  ..., -0.2407, -0.0916,  0.0592],
         [-0.2040,  0.0046,  0.1274,  ..., -0.2908,  0.1965,  0.3312],
         ...,
         [-0.1218,  0.1213,  0.1609,  ..., -0.1837,  0.1424,  0.2412],
         [-0.1253,  0.0248,  0.0837,  ..., -0.0895,  0.2290,  0.2060],
         [-0.0606,  0.1323,  0.0519,  ..., -0.3462,  0.0172,  0.1034]],

        [[-0.1362, -0.1076, -0.0134,  ..., -0.1032,  0.0891,  0.2882],
         [-0.1729, -0.2731,  0.0501,  ..., -0.2094,  0.0215,  0.1072],
         [ 0.1400, -0.0529,  0.0329,  ..., -0.0675,  0.0782,  0.3856],
         ...,
         [-0.1375,  0.1609,  0.1891,  ..., -0.1806,  0.1707,  0.2948],
         [-0.0939,  0.0222,  0.0971,  ..., -0.0743,  0.2670,  0.2531],
         [-0.0381,  0.1625,  0.0418,  ..., -0.3359,  0.0820,  0.1273]],

        [[-0.1603, -0.1438, -0.0045,  ..., -0.0973,  0.0994,  0.2157],
         [-0.0188, -0.1363,  0.1210,  ..., -0

tensor([[[ 6.9922e-02, -3.0296e-01,  2.0565e-03,  ..., -1.7813e-01,
          -2.0459e-01,  2.9132e-01],
         [-9.4232e-02, -5.0865e-01,  1.7347e-01,  ..., -1.6538e-01,
          -7.9225e-02,  3.9345e-02],
         [ 5.4752e-02, -2.8704e-01,  5.9012e-02,  ..., -1.1269e-01,
           1.2715e-02,  2.7836e-01],
         ...,
         [ 6.0927e-02,  3.3208e-02,  2.0071e-01,  ..., -1.3122e-01,
           7.4629e-02,  2.4390e-01],
         [-3.8553e-02, -1.0023e-02,  4.8741e-02,  ...,  1.9256e-02,
           1.5721e-01,  2.2546e-01],
         [-3.5276e-02, -1.3512e-01, -3.0656e-02,  ..., -4.2986e-02,
           1.6184e-01,  1.6710e-01]],

        [[ 5.7057e-02, -2.5897e-01,  8.8584e-03,  ..., -1.9942e-01,
          -1.8799e-01,  2.9907e-01],
         [-1.5540e-01, -2.8811e-01,  3.9999e-02,  ..., -1.4086e-01,
           1.0090e-01, -6.5531e-03],
         [-2.4203e-01, -2.4990e-01,  5.5162e-02,  ..., -2.0493e-01,
           2.2449e-02,  1.2748e-01],
         ...,
         [ 6.1451e-02,  6

tensor([[[ 7.6813e-04, -1.9081e-01, -4.4405e-02,  ...,  2.4953e-03,
           5.4715e-02,  8.3438e-02],
         [ 1.6854e-01, -1.6881e-01,  5.9063e-02,  ...,  7.9896e-02,
          -3.9219e-02, -6.3611e-02],
         [-9.0341e-03, -1.1810e-01,  2.8803e-01,  ..., -9.7909e-02,
           2.1439e-01, -3.0443e-02],
         ...,
         [ 1.5825e-01,  1.3871e-02,  4.1992e-01,  ...,  1.7410e-01,
           3.8398e-02,  6.3399e-02],
         [ 7.1310e-02, -2.9374e-02,  1.6768e-01,  ...,  2.0313e-01,
           2.0948e-01, -8.6528e-02],
         [ 1.7504e-01, -3.9283e-02,  1.9454e-01,  ...,  9.8100e-02,
           1.5326e-01, -2.6984e-02]],

        [[ 1.4732e-03, -2.3269e-01, -1.0755e-02,  ..., -1.8513e-03,
          -7.2802e-03,  1.0304e-01],
         [ 1.3736e-01, -2.8581e-01,  1.1084e-01,  ...,  5.5271e-03,
          -1.9314e-03, -8.6556e-02],
         [ 2.5679e-01, -1.6934e-02,  1.7235e-01,  ...,  5.8336e-02,
           1.1586e-01,  6.6626e-02],
         ...,
         [ 1.5223e-01,  4

tensor([[[-0.0875, -0.1747,  0.0854,  ..., -0.0399, -0.2761,  0.2265],
         [-0.0754, -0.0561,  0.1113,  ..., -0.1867, -0.1855,  0.1894],
         [-0.1908, -0.3011,  0.1749,  ..., -0.3622,  0.2075,  0.2237],
         ...,
         [-0.0869, -0.0009,  0.1595,  ..., -0.0563,  0.0918,  0.0587],
         [-0.1013,  0.0007,  0.2102,  ...,  0.0043,  0.2034,  0.0348],
         [ 0.0148,  0.0429,  0.0347,  ..., -0.1359,  0.0841,  0.0472]],

        [[-0.0391, -0.1958,  0.0435,  ..., -0.0686, -0.2702,  0.2061],
         [-0.1163, -0.0783, -0.0496,  ..., -0.1843, -0.2591, -0.0579],
         [-0.1347, -0.1799,  0.2032,  ..., -0.3591,  0.1094,  0.1988],
         ...,
         [-0.0800,  0.0333,  0.1483,  ..., -0.0458,  0.0852,  0.0416],
         [-0.0892,  0.0185,  0.2048,  ...,  0.0097,  0.1721,  0.0363],
         [ 0.0091,  0.0615,  0.0265,  ..., -0.1214,  0.0651,  0.0348]],

        [[-0.0273, -0.1425,  0.0909,  ..., -0.1118, -0.2736,  0.1921],
         [-0.1241, -0.0491, -0.0459,  ..., -0

tensor([[[-0.0397, -0.0912,  0.0024,  ..., -0.2220,  0.0108,  0.1123],
         [ 0.1615, -0.0672, -0.0303,  ..., -0.1917, -0.0800,  0.1081],
         [ 0.0929, -0.1159,  0.0246,  ..., -0.3138,  0.3289,  0.0999],
         ...,
         [ 0.2292, -0.0257,  0.1657,  ..., -0.1294,  0.0477,  0.2110],
         [ 0.0869, -0.1118,  0.0966,  ..., -0.1043,  0.2123,  0.0279],
         [ 0.2478,  0.0690, -0.0292,  ..., -0.1585,  0.1847,  0.1260]],

        [[-0.0727, -0.0996,  0.0210,  ..., -0.2837,  0.0069,  0.1111],
         [ 0.1610, -0.0441, -0.1651,  ..., -0.2043,  0.0019,  0.0937],
         [ 0.0558, -0.1331,  0.0064,  ..., -0.2651,  0.1849,  0.1248],
         ...,
         [ 0.2097, -0.0297,  0.1325,  ..., -0.1528,  0.0518,  0.1899],
         [ 0.0601, -0.1548,  0.0672,  ..., -0.1479,  0.2311,  0.0109],
         [ 0.2305,  0.0425, -0.0740,  ..., -0.1640,  0.2046,  0.1183]],

        [[-0.0116, -0.0813,  0.0121,  ..., -0.2279, -0.0228,  0.0926],
         [ 0.1965, -0.2169, -0.0912,  ..., -0

tensor([[[-1.2474e-02, -2.4623e-01, -2.3811e-02,  ...,  1.4242e-02,
          -1.7052e-01,  2.0166e-02],
         [ 2.8852e-02, -5.2255e-02, -7.4380e-02,  ..., -1.0605e-01,
          -7.8145e-02, -7.9392e-02],
         [-1.3281e-01, -2.8398e-01,  1.1454e-02,  ..., -1.8962e-01,
           1.0183e-01, -2.5263e-02],
         ...,
         [-2.8351e-02, -4.7295e-02,  1.8322e-01,  ..., -6.2098e-02,
           8.1863e-03,  1.0854e-01],
         [-2.2497e-03, -2.3328e-01, -5.2009e-02,  ..., -1.7653e-01,
           2.3660e-01, -4.9563e-02],
         [ 5.7140e-02, -9.2424e-02, -6.6250e-02,  ..., -9.6865e-02,
           6.7854e-02,  3.6410e-03]],

        [[ 1.4354e-03, -2.8867e-01,  3.7062e-02,  ..., -1.4684e-02,
          -1.1302e-01,  6.6258e-02],
         [ 2.8003e-01, -1.8329e-01, -5.6926e-02,  ..., -2.2256e-01,
           8.0600e-02,  6.5266e-02],
         [ 4.9588e-02, -1.2900e-01,  5.2066e-02,  ..., -8.9896e-02,
           1.7505e-02,  1.6516e-01],
         ...,
         [ 7.0881e-03, -9

tensor([[[-0.2089, -0.1718,  0.0821,  ..., -0.1632,  0.0445,  0.1529],
         [-0.1292, -0.3270,  0.0008,  ..., -0.2003, -0.0447,  0.0112],
         [ 0.1421, -0.1677,  0.0276,  ..., -0.1240,  0.1236,  0.2571],
         ...,
         [-0.0419, -0.1004,  0.0957,  ..., -0.1113,  0.1174,  0.1238],
         [-0.0917, -0.1935,  0.0122,  ..., -0.1439,  0.3015, -0.0292],
         [ 0.0719, -0.0725, -0.0782,  ..., -0.2054,  0.2786,  0.1144]],

        [[-0.2219, -0.2050,  0.1167,  ..., -0.1937,  0.0469,  0.1429],
         [-0.1804, -0.1192,  0.0931,  ..., -0.1576, -0.1105,  0.2021],
         [-0.4257, -0.1642,  0.1968,  ..., -0.3211,  0.2059,  0.0308],
         ...,
         [-0.0608, -0.1061,  0.1361,  ..., -0.1319,  0.1042,  0.1408],
         [-0.1091, -0.1717,  0.0635,  ..., -0.1150,  0.2777, -0.0329],
         [ 0.0610, -0.0681, -0.0504,  ..., -0.2324,  0.2819,  0.1079]],

        [[-0.2084, -0.1742,  0.0576,  ..., -0.1322,  0.0323,  0.1442],
         [-0.1061, -0.2095,  0.0271,  ..., -0

tensor([[[-0.3421, -0.0841,  0.0770,  ..., -0.0714, -0.2595,  0.3970],
         [-0.1833, -0.0210,  0.0817,  ..., -0.1423, -0.0834,  0.2006],
         [-0.0786,  0.0392, -0.0414,  ..., -0.1718,  0.0351,  0.2613],
         ...,
         [-0.2502,  0.1383,  0.2548,  ..., -0.2665, -0.0239,  0.2583],
         [-0.3395,  0.1045,  0.1589,  ..., -0.1813,  0.1348,  0.1568],
         [-0.1802,  0.2177,  0.1670,  ..., -0.3007,  0.0156,  0.0864]],

        [[-0.3550, -0.0440,  0.0876,  ..., -0.1045, -0.2587,  0.4116],
         [-0.2201,  0.2503,  0.0799,  ..., -0.1789, -0.0933,  0.2162],
         [-0.2744,  0.0926,  0.1678,  ..., -0.3183,  0.0418,  0.2334],
         ...,
         [-0.2226,  0.1445,  0.2314,  ..., -0.3139, -0.0271,  0.2820],
         [-0.3127,  0.1328,  0.1461,  ..., -0.2102,  0.1448,  0.1463],
         [-0.1736,  0.2514,  0.1562,  ..., -0.3164,  0.0394,  0.1219]],

        [[-0.3209, -0.1109,  0.1089,  ..., -0.1215, -0.2553,  0.4156],
         [-0.0669,  0.2479,  0.0481,  ..., -0

tensor([[[-0.0189, -0.3012, -0.1237,  ..., -0.2120, -0.0522,  0.0058],
         [ 0.0706, -0.0592, -0.0825,  ..., -0.1844, -0.0026, -0.0827],
         [-0.1397, -0.0077, -0.1830,  ..., -0.2581,  0.0057, -0.0493],
         ...,
         [-0.2040, -0.1613,  0.1188,  ..., -0.0005, -0.0771, -0.1905],
         [-0.1539, -0.1526,  0.0443,  ..., -0.1147,  0.1138, -0.4064],
         [-0.0400, -0.0842, -0.0508,  ..., -0.2606,  0.0492, -0.3038]],

        [[-0.0199, -0.2392, -0.1286,  ..., -0.2092, -0.0430,  0.0010],
         [-0.0949, -0.1906, -0.0465,  ..., -0.1989,  0.0164, -0.1114],
         [ 0.0237, -0.1214, -0.1026,  ..., -0.2115,  0.0211, -0.0246],
         ...,
         [-0.1586, -0.1315,  0.0902,  ...,  0.0332, -0.0450, -0.1789],
         [-0.1440, -0.1245,  0.0447,  ..., -0.0839,  0.1137, -0.3783],
         [-0.0071, -0.0654, -0.0585,  ..., -0.2378,  0.0650, -0.3044]],

        [[-0.0271, -0.2560, -0.1137,  ..., -0.2008, -0.0622, -0.0066],
         [-0.0957, -0.2087, -0.0327,  ..., -0

tensor([[[ 4.2066e-02, -1.8263e-01,  1.4823e-02,  ..., -6.8522e-02,
          -3.3402e-01,  4.8651e-02],
         [ 1.7768e-02, -2.4828e-01, -3.9047e-02,  ..., -7.1438e-02,
          -8.6986e-02,  8.6671e-02],
         [-2.3925e-01, -1.5622e-01,  2.0610e-01,  ..., -2.8063e-01,
           1.3418e-02,  3.7504e-02],
         ...,
         [ 9.0804e-02, -5.3319e-02,  2.8907e-01,  ..., -1.3551e-01,
          -1.4483e-01,  1.5905e-01],
         [-8.7196e-02, -2.1938e-01,  2.5405e-01,  ..., -4.4384e-02,
          -6.4297e-03, -5.4122e-02],
         [ 1.0216e-01,  1.3620e-01,  5.4578e-02,  ..., -1.5236e-01,
          -1.6002e-01,  6.2898e-02]],

        [[ 8.5969e-02, -1.7583e-01,  4.8249e-02,  ..., -6.8223e-02,
          -3.5412e-01,  6.1751e-02],
         [ 6.0207e-03, -2.1252e-01,  1.5304e-01,  ..., -7.7524e-02,
          -2.9727e-01,  1.2902e-01],
         [-1.0550e-01, -2.4242e-01,  1.7196e-01,  ..., -3.3736e-01,
          -1.7756e-01,  6.5865e-02],
         ...,
         [ 1.1120e-01, -1

tensor([[[-1.9096e-02,  9.6399e-02,  6.8821e-02,  ..., -1.2666e-01,
          -7.3327e-02,  1.0140e-01],
         [ 1.7565e-01,  7.4609e-02, -4.3921e-03,  ..., -2.4316e-01,
           3.5674e-02,  2.8609e-02],
         [ 7.3003e-02,  1.2550e-01,  1.0424e-01,  ..., -1.4818e-01,
           1.8449e-01,  2.5037e-01],
         ...,
         [-1.1932e-01, -2.4497e-02,  1.0490e-01,  ..., -1.6723e-01,
           5.6965e-02, -1.4083e-02],
         [-1.3417e-01, -5.8150e-02,  1.4129e-01,  ..., -1.0330e-01,
           2.3587e-01, -1.4186e-01],
         [-1.1339e-01,  4.4422e-04,  1.0256e-01,  ..., -1.4270e-01,
           1.0405e-01, -3.6034e-02]],

        [[-4.2901e-02,  1.3013e-01,  5.1507e-02,  ..., -1.5766e-01,
          -5.3802e-02,  1.3695e-01],
         [-1.0088e-01,  5.2755e-02,  2.2882e-02,  ..., -1.0572e-01,
          -4.9491e-02,  2.3370e-03],
         [-1.5387e-01, -1.4837e-01,  2.5222e-01,  ..., -2.6201e-01,
           8.3660e-02,  4.9130e-02],
         ...,
         [-1.0602e-01,  2

tensor([[[-1.2536e-01, -3.4185e-01,  2.1958e-02,  ..., -1.3211e-01,
          -9.8865e-02, -1.1105e-01],
         [-5.3830e-02, -2.3655e-01, -4.5002e-02,  ..., -1.1897e-01,
           1.8698e-01, -9.3135e-02],
         [-4.0899e-02, -3.4121e-02, -2.5336e-02,  ..., -8.2486e-02,
           8.9818e-02,  8.3954e-03],
         ...,
         [-1.9401e-01, -4.6627e-02,  3.8764e-02,  ..., -2.9861e-02,
           5.5104e-02, -6.5088e-02],
         [-1.8790e-01, -9.1148e-02, -2.4330e-02,  ..., -9.3477e-02,
           2.5267e-01, -1.1566e-01],
         [-9.6810e-02, -6.0387e-02, -1.6884e-02,  ..., -1.7433e-01,
           7.6957e-02, -1.5873e-01]],

        [[-1.0890e-01, -3.1573e-01, -2.0752e-02,  ..., -1.5849e-01,
          -1.5106e-01, -1.3729e-01],
         [ 6.0595e-02,  6.4750e-02, -2.9316e-01,  ..., -2.6263e-01,
           4.4323e-02,  2.0789e-02],
         [-6.3963e-02,  1.6591e-01, -1.0890e-01,  ..., -1.1328e-01,
           2.3683e-01,  1.2414e-01],
         ...,
         [-1.9074e-01,  2

tensor([[[-9.9124e-02, -3.8391e-02, -7.8220e-03,  ..., -8.6779e-02,
          -1.9533e-02,  2.5432e-01],
         [-9.9272e-02,  1.0052e-01, -8.9254e-02,  ..., -9.0228e-02,
           1.2470e-01,  8.2169e-02],
         [-3.5118e-01,  6.1032e-02,  1.2810e-01,  ..., -2.5998e-01,
           1.7182e-01,  2.1179e-01],
         ...,
         [-5.5579e-02,  1.0963e-01,  1.6497e-01,  ..., -1.2956e-01,
           4.1315e-02,  4.0611e-01],
         [-1.0245e-01, -6.7696e-02,  3.1929e-03,  ..., -1.0839e-01,
           2.7852e-01,  2.3618e-01],
         [-1.4938e-01,  1.2972e-01,  4.5240e-02,  ..., -3.2183e-01,
           7.7248e-02,  3.0059e-01]],

        [[-4.7073e-02, -8.6185e-02, -5.6863e-04,  ..., -1.2798e-01,
          -8.2319e-03,  2.3267e-01],
         [-6.3396e-02,  1.7028e-02, -1.5006e-02,  ..., -1.3341e-01,
           6.0206e-02,  2.2761e-01],
         [-1.9426e-01,  4.5931e-02,  7.3081e-02,  ..., -4.2561e-01,
           2.9681e-01,  1.2225e-01],
         ...,
         [-4.1674e-02,  8

tensor([[[-2.0345e-01, -1.4176e-01, -1.5154e-02,  ..., -5.7865e-02,
          -3.1438e-02,  7.4507e-02],
         [-1.3723e-01, -1.8628e-01,  6.6593e-02,  ..., -1.9351e-01,
           1.9416e-02, -1.3714e-01],
         [-1.0321e-01,  9.9979e-03,  2.6201e-04,  ..., -1.2337e-01,
           1.1067e-01,  1.4480e-02],
         ...,
         [-5.5581e-02, -1.8395e-03,  1.0012e-01,  ..., -7.8077e-02,
           1.0575e-01,  5.6057e-02],
         [-7.1734e-02, -6.2849e-02,  1.0764e-01,  ...,  4.7148e-02,
           2.1344e-01, -9.6001e-02],
         [-1.0057e-01, -1.2061e-02, -4.7348e-02,  ..., -7.0609e-02,
           5.6569e-02,  3.3517e-02]],

        [[-1.9205e-01, -1.5139e-01, -7.1581e-03,  ..., -5.4034e-02,
          -3.2861e-02,  8.0587e-02],
         [-1.2958e-01, -1.9073e-01,  7.7031e-02,  ..., -1.8165e-01,
           1.2409e-02, -1.3174e-01],
         [-8.9665e-02, -1.1911e-02,  9.7973e-03,  ..., -1.1339e-01,
           9.3022e-02,  1.3020e-02],
         ...,
         [-4.9253e-02, -1

tensor([[[-8.3762e-02, -3.5138e-01,  1.9411e-02,  ..., -3.9285e-02,
          -2.8167e-03,  2.5499e-01],
         [-6.8193e-02, -1.7124e-01,  8.7868e-02,  ..., -1.1022e-01,
          -3.0235e-02,  2.7533e-01],
         [-1.9318e-01, -2.8467e-01,  1.7522e-01,  ..., -5.3078e-01,
           1.9344e-01,  2.1831e-01],
         ...,
         [-5.1625e-02, -1.6167e-01,  8.3178e-02,  ..., -5.6329e-02,
           1.0136e-01,  8.8356e-02],
         [-1.1942e-01, -2.6212e-01,  7.1868e-02,  ..., -1.9843e-01,
           1.0536e-01,  1.1895e-01],
         [ 1.6921e-02, -1.9341e-01,  8.4799e-03,  ..., -9.9586e-02,
           1.0401e-01,  5.6711e-02]],

        [[-3.0992e-02, -3.2068e-01,  6.8140e-02,  ..., -6.1164e-02,
           1.0899e-02,  2.5543e-01],
         [-1.0300e-02, -2.0551e-01,  1.6310e-02,  ..., -9.6956e-02,
           3.6028e-02,  1.9759e-01],
         [-2.6157e-01, -1.9666e-01,  2.1116e-01,  ..., -3.2161e-01,
           1.8850e-01,  1.2495e-01],
         ...,
         [-5.0784e-02, -1

tensor([[[-1.3014e-01, -2.6416e-01, -4.8497e-02,  ...,  2.7329e-02,
          -6.5582e-02,  1.8898e-01],
         [ 2.6883e-02, -1.3533e-01,  5.4669e-02,  ..., -7.8593e-02,
          -2.1283e-02, -1.8644e-02],
         [-2.1972e-01, -1.9937e-01,  1.8585e-01,  ..., -2.1916e-01,
           7.7030e-02,  1.3149e-01],
         ...,
         [-2.1629e-01, -4.9328e-02,  3.0600e-01,  ..., -1.4477e-01,
           1.0235e-01,  1.4213e-01],
         [-1.1819e-01, -9.9771e-02,  3.0037e-01,  ..., -1.6119e-02,
           2.0853e-01, -1.0228e-02],
         [-7.1658e-02, -2.5244e-02,  1.0086e-01,  ..., -1.3331e-01,
           1.1993e-01,  4.6432e-02]],

        [[-1.6114e-01, -2.4406e-01, -8.9579e-02,  ...,  2.3318e-03,
          -5.8635e-02,  1.5934e-01],
         [-2.1394e-01, -3.1390e-01,  7.6532e-02,  ..., -2.2083e-01,
           5.2824e-02,  9.1479e-02],
         [-3.8285e-02, -1.1016e-01,  1.0332e-02,  ..., -1.3494e-01,
           1.2093e-02,  2.7339e-01],
         ...,
         [-2.2718e-01, -1

tensor([[[ 0.0208, -0.1437,  0.0815,  ..., -0.1547, -0.1285,  0.1595],
         [ 0.0657,  0.0641,  0.0062,  ..., -0.1647, -0.0361,  0.1948],
         [-0.0563, -0.0280,  0.1139,  ...,  0.0233,  0.1711,  0.1702],
         ...,
         [-0.0658,  0.1616,  0.1918,  ..., -0.1089, -0.0478,  0.0214],
         [-0.1328,  0.0908,  0.1629,  ..., -0.0597,  0.2544,  0.0347],
         [-0.1167,  0.2585,  0.1140,  ..., -0.2025,  0.1024,  0.0027]],

        [[ 0.0135, -0.1206,  0.1124,  ..., -0.1342, -0.1003,  0.1408],
         [-0.0469, -0.0863,  0.0809,  ..., -0.0522, -0.0730,  0.1599],
         [-0.2003,  0.0016,  0.2061,  ..., -0.2405,  0.2563,  0.0445],
         ...,
         [-0.0845,  0.1726,  0.2354,  ..., -0.1528, -0.0538,  0.0214],
         [-0.1474,  0.0731,  0.2007,  ..., -0.0916,  0.2644,  0.0186],
         [-0.0712,  0.2203,  0.1472,  ..., -0.2065,  0.1111,  0.0079]],

        [[ 0.0289, -0.1154,  0.0987,  ..., -0.1333, -0.1138,  0.1119],
         [-0.1143, -0.0610,  0.0949,  ..., -0

tensor([[[ 0.0342, -0.1712,  0.1517,  ..., -0.1777, -0.1456,  0.1891],
         [ 0.2890, -0.0839, -0.0270,  ..., -0.2289, -0.2477,  0.1762],
         [ 0.0792, -0.1216,  0.1798,  ..., -0.1232,  0.0518,  0.2026],
         ...,
         [-0.0903,  0.0035,  0.3010,  ..., -0.1535, -0.1836,  0.1173],
         [-0.1511, -0.0517,  0.2509,  ..., -0.0741, -0.0384,  0.0739],
         [-0.1469, -0.0150,  0.2000,  ..., -0.3034, -0.1303, -0.0217]],

        [[ 0.0268, -0.2059,  0.1199,  ..., -0.1709, -0.2107,  0.1501],
         [ 0.2481, -0.1174, -0.0625,  ..., -0.1843, -0.2842,  0.0581],
         [-0.2434, -0.0602,  0.1517,  ..., -0.1778,  0.0501,  0.0214],
         ...,
         [-0.0679,  0.0308,  0.3187,  ..., -0.1543, -0.1827,  0.1591],
         [-0.1175, -0.0507,  0.3116,  ..., -0.0588, -0.0562,  0.0721],
         [-0.1436,  0.0416,  0.2270,  ..., -0.3309, -0.0928, -0.0318]],

        [[ 0.0299, -0.1869,  0.1520,  ..., -0.1831, -0.1768,  0.1861],
         [ 0.0776, -0.0402,  0.1211,  ..., -0

tensor([[[-1.9166e-02, -1.9422e-01, -3.7973e-02,  ..., -8.8577e-03,
          -6.5482e-02,  1.3947e-01],
         [ 4.9121e-02, -1.1404e-02, -6.9418e-02,  ..., -1.0484e-01,
          -4.1598e-02,  1.8552e-02],
         [-2.7077e-01,  5.3302e-02,  6.4856e-02,  ..., -2.9384e-01,
           2.0197e-01,  1.5910e-01],
         ...,
         [-1.0771e-01,  1.6807e-02,  1.4208e-01,  ..., -1.3273e-01,
           1.3621e-02,  1.7101e-01],
         [-6.9202e-03, -8.9603e-02,  1.3315e-01,  ..., -3.2016e-02,
           1.3293e-01,  1.1963e-01],
         [-1.7399e-02, -2.3020e-02, -7.6976e-02,  ..., -2.2119e-01,
           2.7308e-02,  1.6907e-02]],

        [[-1.3460e-02, -1.9295e-01, -5.7450e-02,  ..., -2.7454e-02,
          -1.4867e-02,  1.1390e-01],
         [ 3.5609e-01,  1.1811e-01, -1.9223e-01,  ..., -1.5308e-01,
           4.3914e-02,  8.3301e-02],
         [-2.7029e-02,  9.9771e-02, -5.1681e-02,  ..., -2.3393e-01,
           1.6353e-01,  2.8723e-01],
         ...,
         [-6.7843e-02,  4

tensor([[[-0.0134, -0.1950, -0.0029,  ..., -0.0458, -0.2054,  0.1123],
         [ 0.0141, -0.1390,  0.0866,  ..., -0.1056, -0.0445, -0.0761],
         [ 0.2642, -0.0108,  0.0244,  ..., -0.1278, -0.0532, -0.0826],
         ...,
         [-0.0712, -0.0473,  0.1691,  ..., -0.1529,  0.0153, -0.1332],
         [ 0.0440, -0.0065,  0.1383,  ..., -0.1209,  0.1085, -0.2389],
         [ 0.1006,  0.0468,  0.0608,  ..., -0.1195,  0.0548, -0.1772]],

        [[ 0.0188, -0.2611, -0.0120,  ..., -0.0092, -0.2130,  0.0701],
         [ 0.2847, -0.0102,  0.0196,  ..., -0.1588, -0.1158, -0.0906],
         [ 0.1089,  0.0401,  0.0782,  ..., -0.0825,  0.0352, -0.0055],
         ...,
         [-0.0412, -0.0530,  0.1744,  ..., -0.1110, -0.0135, -0.1499],
         [ 0.0425, -0.0377,  0.1420,  ..., -0.0521,  0.0882, -0.2498],
         [ 0.0931,  0.0291,  0.0555,  ..., -0.0671,  0.0450, -0.2060]],

        [[ 0.0073, -0.2162, -0.0169,  ..., -0.0271, -0.1757,  0.1032],
         [-0.0791,  0.0556, -0.0504,  ..., -0

tensor([[[-7.8048e-02, -9.9583e-02,  6.0976e-02,  ..., -2.0099e-01,
          -6.0113e-02,  1.0071e-01],
         [-4.9042e-02,  9.4270e-03,  1.3168e-01,  ..., -2.5171e-01,
           7.3133e-02,  1.4212e-01],
         [-4.3278e-02,  5.2418e-03,  3.1778e-01,  ..., -5.1160e-01,
           3.9794e-01,  3.2447e-02],
         ...,
         [-4.4760e-02, -1.4573e-02,  2.7102e-01,  ..., -2.7006e-01,
           1.1688e-01,  1.1351e-01],
         [-4.1822e-02, -4.9548e-02,  1.6327e-01,  ..., -2.2021e-01,
           2.4983e-01, -9.2748e-02],
         [ 1.6885e-02, -1.5692e-02,  1.2288e-01,  ..., -3.6370e-01,
           2.1703e-01, -7.0287e-03]],

        [[-9.4861e-02, -1.0328e-01,  6.2118e-02,  ..., -1.8473e-01,
          -1.1480e-01,  1.0379e-01],
         [-5.7338e-02,  1.4346e-01, -7.2836e-02,  ..., -2.0052e-01,
           1.4623e-01,  5.9063e-02],
         [-2.1563e-01,  1.6570e-01,  2.2039e-01,  ..., -3.8384e-01,
           3.1713e-01,  5.1892e-02],
         ...,
         [-6.7732e-02, -1

tensor([[[ 0.0064, -0.1316, -0.1033,  ..., -0.0151, -0.2202,  0.0680],
         [ 0.0979, -0.0321, -0.2463,  ...,  0.0616, -0.0600, -0.1576],
         [-0.0829, -0.0785, -0.0261,  ..., -0.0702,  0.0617, -0.0920],
         ...,
         [ 0.1074, -0.0735,  0.0673,  ..., -0.0687, -0.1118, -0.0036],
         [-0.0587, -0.1041,  0.0106,  ..., -0.0334,  0.1509, -0.1080],
         [ 0.1933,  0.0274, -0.1833,  ..., -0.1958, -0.0837, -0.1144]],

        [[ 0.0350, -0.1437, -0.0586,  ..., -0.0323, -0.2554,  0.1106],
         [ 0.1078, -0.0169, -0.1521,  ..., -0.1217, -0.0651, -0.1478],
         [-0.0144, -0.1890,  0.0516,  ..., -0.3086,  0.1830,  0.0061],
         ...,
         [ 0.1224, -0.0952,  0.1629,  ..., -0.0261, -0.0988,  0.0635],
         [-0.0566, -0.1039,  0.0408,  ...,  0.0235,  0.1216, -0.0641],
         [ 0.2064,  0.0307, -0.1506,  ..., -0.1667, -0.0646, -0.0782]],

        [[ 0.0372, -0.1175, -0.0767,  ..., -0.0478, -0.2524,  0.0937],
         [ 0.0953, -0.1197, -0.1170,  ..., -0

tensor([[[ 0.0012, -0.1154, -0.0089,  ..., -0.1302, -0.2517,  0.0956],
         [ 0.0732, -0.3041,  0.2137,  ..., -0.1249,  0.0226, -0.0626],
         [ 0.2291, -0.1791,  0.1931,  ..., -0.0626, -0.1043,  0.1047],
         ...,
         [ 0.0069,  0.0249,  0.3591,  ..., -0.0991,  0.0194,  0.1687],
         [-0.0752, -0.1588,  0.2907,  ..., -0.0504,  0.1680, -0.0417],
         [-0.0287, -0.0054,  0.1306,  ..., -0.2675,  0.0419, -0.0781]],

        [[ 0.0325, -0.0789,  0.0097,  ..., -0.1704, -0.1865,  0.1037],
         [ 0.0576, -0.1286,  0.2426,  ..., -0.1044, -0.0939, -0.0301],
         [-0.1313, -0.0622,  0.3575,  ..., -0.2985,  0.0686,  0.1242],
         ...,
         [ 0.0234,  0.0119,  0.3889,  ..., -0.1351,  0.0039,  0.2234],
         [-0.1027, -0.1495,  0.3346,  ..., -0.0677,  0.1724, -0.0018],
         [-0.0232, -0.0313,  0.1796,  ..., -0.2698,  0.0460, -0.0430]],

        [[ 0.0089, -0.1072, -0.0102,  ..., -0.1510, -0.2434,  0.1019],
         [ 0.0865, -0.2962,  0.2050,  ..., -0

tensor([[[ 0.0189, -0.2331,  0.0735,  ..., -0.0527, -0.0735,  0.3549],
         [-0.0944, -0.1423,  0.0650,  ..., -0.1319,  0.0023,  0.1295],
         [-0.2135, -0.3154,  0.2211,  ..., -0.2387,  0.1041,  0.1999],
         ...,
         [-0.0048, -0.0416,  0.3030,  ..., -0.1560,  0.2279,  0.2397],
         [-0.1341, -0.0810,  0.1428,  ...,  0.0463,  0.2497,  0.1384],
         [ 0.0309,  0.1201,  0.2619,  ..., -0.1999,  0.2455,  0.1644]],

        [[-0.0249, -0.2032,  0.0752,  ..., -0.0953, -0.0724,  0.3625],
         [ 0.0076, -0.0732,  0.0461,  ..., -0.1231, -0.0725,  0.1448],
         [-0.2419, -0.1632,  0.2591,  ..., -0.1469,  0.1543,  0.0687],
         ...,
         [-0.0042, -0.0328,  0.2857,  ..., -0.1463,  0.2045,  0.2092],
         [-0.1000, -0.0815,  0.1553,  ...,  0.0464,  0.2506,  0.0942],
         [ 0.0415,  0.1293,  0.2564,  ..., -0.2116,  0.2314,  0.1299]],

        [[ 0.0035, -0.2159,  0.0965,  ..., -0.1235, -0.0753,  0.3777],
         [ 0.3690, -0.1061, -0.0032,  ..., -0

tensor([[[ 2.3803e-03, -1.0679e-01,  6.3592e-02,  ...,  4.7589e-02,
          -1.3268e-01,  1.2087e-01],
         [ 1.6512e-01,  7.7961e-03,  8.1979e-03,  ...,  6.8996e-02,
          -2.3766e-01,  2.3244e-03],
         [-1.2624e-01,  5.1367e-02,  1.4520e-01,  ..., -1.0439e-01,
           7.5600e-03, -1.0471e-02],
         ...,
         [ 5.4015e-02,  8.6398e-02,  4.6966e-01,  ...,  1.1132e-01,
           3.1848e-02,  1.5952e-01],
         [-3.8553e-02, -1.3031e-01,  2.5901e-01,  ...,  4.5021e-02,
           1.7793e-01,  1.3894e-02],
         [-6.3381e-02,  1.1831e-01,  7.4949e-02,  ..., -1.2194e-01,
           3.3186e-02,  4.0262e-02]],

        [[-2.8110e-02, -1.1430e-01,  2.9421e-02,  ...,  7.4369e-02,
          -1.4078e-01,  1.5278e-01],
         [ 1.4800e-01, -1.2734e-01,  1.2456e-01,  ...,  3.5820e-02,
          -2.0979e-01,  3.2311e-02],
         [ 1.0425e-01,  7.1398e-02,  1.2799e-01,  ..., -4.1028e-03,
          -9.8901e-02,  2.9544e-01],
         ...,
         [ 2.0712e-02,  8

In [7]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]
print(text)
print('================================')
print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thank you Romeo
Where are you going today?
I am going shopping. What about you?
I am going to visit my grandmother. she is not very well
['[CLS]', 'i', 'am', 'going', 'shopping', 'what', 'about', 'you', '[SEP]', '[MASK]', 'meet', 'you', 'too', '[MASK]', 'are', 'you', 'today', '[SEP]']
tensor([[[-0.0524, -0.3062,  0.1610,  ..., -0.0787, -0.0590,  0.0542],
         [ 0.0260, -0.0610,  0.0867,  ..., -0.1049,  0.0397,  0.0261],
         [-0.2035, -0.1793,  0.2307,  ..., -0.3219,  0.1046,  0.0444],
         ...,
         [ 0.0889,  0.0687,  0.1409,  ..., -0.1033,  0.0302,  0.2241],
         [-0.0332, -0.1121,  0.1001,  ...,  0.0480,  0.1896,  0.0707],
         [ 0.0794, -0.0790, -0.0026,  ..., -0.0914,  0.1828,  0.0604]]],
       grad_fn=<AddBackward0>)
tensor([[[13, 13, 13,  ..., 13, 13, 13],
        

In [18]:
#应用量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(40, 768)
    (pos_embed): Embedding(30, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0): EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_K): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (W_V): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): DynamicQuantizedLinear(in_features=768, out_features=3072, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
        (fc2): DynamicQuantizedLinear(in_features=3072, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      )
    )
    (1): Enc

In [20]:
for epoch in range(50):
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
        logits_lm, logits_clsf = quantized_model(input_ids, segment_ids, masked_pos)
        loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
        loss_lm = (loss_lm.float()).mean()
        loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
        loss = loss_lm + loss_clsf
        loss = loss.requires_grad_()
        if (epoch + 1) % 10 == 0:
          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

tensor([[[-4.5361e-02, -2.6862e-01,  1.2328e-01,  ...,  3.0652e-02,
          -3.0078e-01,  4.1041e-01],
         [ 5.7225e-02, -1.7477e-01, -1.0097e-01,  ..., -4.1559e-02,
          -1.4546e-01,  2.8240e-01],
         [-2.5026e-02, -5.9221e-02,  3.3736e-01,  ..., -8.0557e-02,
          -4.7665e-02,  5.0360e-01],
         ...,
         [ 9.7117e-02,  3.7378e-02,  3.5947e-01,  ..., -2.7847e-02,
          -9.0566e-02,  3.3642e-01],
         [-8.6474e-02, -1.2759e-01,  1.7857e-01,  ...,  9.4139e-03,
           1.1239e-01,  5.5147e-02],
         [ 6.7228e-02, -1.1400e-04,  1.4150e-01,  ..., -1.1729e-01,
           3.9413e-02,  2.1787e-01]],

        [[-3.6517e-02, -2.3113e-01,  1.7809e-01,  ..., -3.2067e-02,
          -3.4148e-01,  4.2978e-01],
         [-7.3748e-02, -2.7809e-01, -1.1410e-01,  ..., -1.8960e-01,
          -3.7664e-01,  1.3209e-01],
         [-1.6435e-01, -8.4231e-02,  3.7981e-01,  ...,  4.6508e-02,
          -8.1816e-02,  3.0883e-01],
         ...,
         [ 7.1829e-02,  1

tensor([[[-7.8861e-02,  5.8903e-03,  2.7070e-01,  ..., -1.8538e-01,
           3.9787e-02,  4.4764e-01],
         [-1.9361e-01, -1.1296e-01,  1.8048e-02,  ..., -4.4199e-01,
           4.5542e-04,  2.8669e-02],
         [-2.2868e-01,  7.5635e-02,  2.2954e-01,  ..., -3.0984e-01,
           9.6575e-02,  1.7695e-01],
         ...,
         [-1.1106e-01,  1.7068e-01,  2.8166e-01,  ..., -4.1894e-01,
           7.8964e-02,  9.9363e-02],
         [-1.7448e-01,  7.5842e-02,  2.6021e-01,  ..., -2.5776e-01,
           2.9554e-01,  1.5747e-01],
         [-1.4401e-01,  2.5757e-01,  1.5581e-01,  ..., -4.5316e-01,
           2.2580e-01,  9.8509e-02]],

        [[-7.2553e-02, -2.8941e-02,  1.2555e-01,  ..., -2.4239e-01,
           5.5442e-02,  3.8937e-01],
         [-1.2507e-01,  1.3276e-01, -9.8665e-02,  ..., -2.2166e-01,
           7.0628e-02, -4.6646e-02],
         [-5.4482e-02, -2.1045e-02,  2.5838e-01,  ..., -3.6251e-01,
           1.4757e-01,  1.8119e-01],
         ...,
         [-3.7713e-02,  1

tensor([[[-0.1346, -0.0445, -0.0639,  ...,  0.0309, -0.2569,  0.1471],
         [ 0.0074, -0.2104,  0.0026,  ..., -0.2077, -0.2321, -0.0557],
         [-0.2568,  0.0141,  0.3042,  ..., -0.0575, -0.1919,  0.1884],
         ...,
         [ 0.0291,  0.0703,  0.1652,  ..., -0.1938, -0.2448,  0.0712],
         [-0.2446,  0.0470,  0.1844,  ..., -0.1311,  0.0919,  0.0158],
         [-0.0471,  0.1718, -0.0103,  ..., -0.3085, -0.0540, -0.0669]],

        [[-0.1965, -0.0353, -0.1173,  ...,  0.0503, -0.2130,  0.0531],
         [-0.0006, -0.2287, -0.0284,  ..., -0.2457, -0.2090, -0.0821],
         [-0.2846,  0.0484,  0.2618,  ..., -0.0585, -0.1715,  0.1254],
         ...,
         [-0.0489,  0.0698,  0.1476,  ..., -0.2508, -0.2142,  0.0210],
         [-0.2935,  0.0835,  0.1480,  ..., -0.1129,  0.0988, -0.0578],
         [-0.0747,  0.1332,  0.0106,  ..., -0.3303, -0.0130, -0.0488]],

        [[-0.1554, -0.0115, -0.0930,  ...,  0.0699, -0.2703,  0.1051],
         [ 0.0086, -0.0300,  0.0974,  ..., -0

tensor([[[-0.2328, -0.0053,  0.0056,  ..., -0.1266, -0.0343,  0.2050],
         [ 0.0924, -0.2653,  0.0536,  ..., -0.2413,  0.2893, -0.0344],
         [-0.1007, -0.0770,  0.0267,  ..., -0.3767,  0.1600,  0.1487],
         ...,
         [-0.1285, -0.0102,  0.2208,  ..., -0.0900,  0.0643,  0.1528],
         [-0.1600, -0.1514,  0.0453,  ..., -0.1401,  0.1812,  0.0805],
         [-0.1789,  0.0610,  0.0507,  ..., -0.1061,  0.1829,  0.0392]],

        [[-0.2678, -0.0007,  0.1068,  ..., -0.1214, -0.0713,  0.0965],
         [-0.0525, -0.2534,  0.1184,  ..., -0.1025,  0.2152, -0.1975],
         [-0.0550, -0.1138,  0.0483,  ..., -0.1190, -0.0034,  0.0638],
         ...,
         [-0.1308, -0.0401,  0.2613,  ..., -0.0735,  0.0465,  0.1015],
         [-0.1616, -0.1280,  0.0952,  ..., -0.0900,  0.1244,  0.0311],
         [-0.1763,  0.0400,  0.1326,  ..., -0.0793,  0.1288, -0.0479]],

        [[-0.2577,  0.0617,  0.0421,  ..., -0.1329, -0.1109,  0.1832],
         [ 0.0480, -0.1740,  0.0655,  ..., -0

tensor([[[-0.3287, -0.1151, -0.0987,  ..., -0.1951, -0.3350,  0.2946],
         [-0.1463, -0.0082, -0.1155,  ..., -0.1302, -0.2121,  0.1453],
         [-0.1571, -0.0465,  0.1659,  ..., -0.2199, -0.1179,  0.3634],
         ...,
         [-0.2576,  0.0196,  0.0766,  ..., -0.2742, -0.3533,  0.2345],
         [-0.2683, -0.1478,  0.0413,  ..., -0.1580, -0.1322,  0.0981],
         [-0.1256,  0.0448, -0.0712,  ..., -0.2430, -0.2455,  0.0066]],

        [[-0.3434, -0.1849, -0.1070,  ..., -0.1429, -0.3622,  0.2973],
         [-0.0615, -0.2754, -0.2497,  ..., -0.1149, -0.0534,  0.1870],
         [-0.1411,  0.0670, -0.1077,  ..., -0.1126, -0.0729,  0.2749],
         ...,
         [-0.2101, -0.0116, -0.0232,  ..., -0.2624, -0.3056,  0.2666],
         [-0.2436, -0.1283,  0.0268,  ..., -0.1743, -0.1148,  0.1052],
         [-0.1047, -0.0227, -0.0938,  ..., -0.2498, -0.2123,  0.0428]],

        [[-0.2967, -0.0961, -0.0599,  ..., -0.1771, -0.3502,  0.3004],
         [ 0.1938, -0.1119, -0.0192,  ..., -0

tensor([[[-7.6993e-02, -2.1990e-01,  1.3536e-01,  ...,  1.0981e-03,
          -2.1304e-01,  1.4499e-01],
         [ 9.5283e-02, -8.9285e-02,  7.4873e-02,  ..., -1.3717e-01,
           3.5355e-02,  3.3482e-02],
         [ 1.4970e-02, -4.2437e-02,  2.2615e-01,  ..., -1.8632e-01,
           2.1069e-02,  1.8337e-01],
         ...,
         [ 7.4966e-03, -1.1196e-01,  2.9901e-01,  ..., -1.8988e-01,
          -1.3610e-01,  8.5647e-02],
         [-1.1524e-01, -1.7824e-01,  2.1561e-01,  ..., -8.4961e-02,
           7.1682e-02, -4.7030e-02],
         [ 4.3955e-02, -9.1464e-02,  2.4394e-01,  ..., -3.4209e-01,
           4.0921e-02, -1.2196e-01]],

        [[-7.2211e-02, -2.1056e-01,  1.2672e-01,  ..., -3.1233e-02,
          -1.8593e-01,  1.1373e-01],
         [ 4.2744e-02, -3.8612e-01, -1.1694e-01,  ..., -2.1716e-01,
          -6.0107e-02,  4.0077e-02],
         [-1.2384e-01, -1.0761e-01,  3.4286e-01,  ..., -1.5611e-01,
           9.8844e-02,  1.2125e-01],
         ...,
         [ 2.3952e-02, -4

tensor([[[-0.0643,  0.0522, -0.0446,  ..., -0.1055, -0.0958,  0.0929],
         [ 0.0102, -0.2109, -0.1621,  ..., -0.2529, -0.1491,  0.0187],
         [-0.0973,  0.1218,  0.3568,  ..., -0.2076,  0.0667,  0.1589],
         ...,
         [-0.0218,  0.0696,  0.2266,  ..., -0.1318, -0.0624,  0.1818],
         [-0.0734, -0.0923,  0.0932,  ..., -0.0315,  0.2039,  0.0056],
         [ 0.0513,  0.0449, -0.0528,  ..., -0.1583, -0.0234,  0.0145]],

        [[-0.0553,  0.0493, -0.0355,  ..., -0.1015, -0.0901,  0.1046],
         [ 0.3174, -0.0202, -0.0827,  ..., -0.2201, -0.1679,  0.0952],
         [ 0.0688,  0.0421,  0.0024,  ..., -0.2369,  0.1741,  0.1528],
         ...,
         [-0.0842,  0.0818,  0.2390,  ..., -0.1045, -0.0229,  0.1873],
         [-0.0273, -0.0679,  0.0521,  ..., -0.0432,  0.1619,  0.0056],
         [ 0.0582,  0.0501,  0.0005,  ..., -0.1921, -0.0039,  0.0576]],

        [[-0.1015,  0.0471,  0.0064,  ..., -0.1017, -0.0791,  0.1169],
         [ 0.0889, -0.0827, -0.0117,  ..., -0

tensor([[[-4.6397e-02, -7.4024e-02,  3.9684e-02,  ..., -1.8319e-02,
           4.7021e-02,  1.6050e-01],
         [ 1.5857e-03, -1.3568e-01, -7.6978e-02,  ..., -1.0229e-01,
           1.2747e-01, -3.4636e-02],
         [ 2.5961e-02,  8.3676e-02,  5.3084e-02,  ..., -9.9314e-02,
          -2.2302e-03,  1.3663e-01],
         ...,
         [ 2.1039e-02,  1.3979e-01,  4.4893e-02,  ..., -9.9615e-02,
           1.5361e-03,  1.4774e-01],
         [-1.5680e-01, -7.1803e-03,  6.8600e-02,  ...,  8.8648e-02,
           1.0748e-01,  7.0700e-02],
         [-9.0116e-02,  3.9780e-02,  5.4334e-02,  ..., -1.9882e-01,
           1.2958e-01,  1.2514e-01]],

        [[ 1.3940e-02, -1.5378e-01,  8.0290e-02,  ..., -4.1606e-02,
           8.0182e-02,  1.9013e-01],
         [ 2.2531e-01, -1.6347e-01, -7.9570e-02,  ..., -2.8246e-01,
           6.6654e-02,  3.5950e-02],
         [ 7.5718e-02, -5.7135e-03,  9.5849e-02,  ..., -3.6819e-01,
           1.7064e-01,  3.3689e-01],
         ...,
         [ 2.6199e-02,  1

tensor([[[-1.6286e-01, -2.0575e-01, -1.3084e-02,  ..., -3.0991e-02,
          -1.4958e-01,  1.6804e-01],
         [-3.9925e-02, -1.3142e-03, -1.2145e-02,  ..., -5.4543e-02,
          -6.5902e-02, -2.2095e-02],
         [ 1.3099e-01,  3.7053e-02,  1.2544e-01,  ..., -8.7321e-02,
           3.1012e-02,  3.5157e-01],
         ...,
         [ 1.1448e-01,  5.3757e-02,  1.0903e-01,  ..., -1.6131e-01,
           4.1199e-02,  2.2127e-01],
         [-3.3328e-03, -1.1520e-03,  1.8119e-01,  ..., -7.6688e-02,
           2.4983e-01, -9.6485e-03],
         [ 1.3227e-01, -3.0337e-02,  1.1569e-02,  ..., -2.2431e-03,
           1.7350e-01,  4.6121e-02]],

        [[-1.4373e-01, -1.5137e-01,  1.7958e-04,  ...,  1.3820e-02,
          -1.3382e-01,  1.4083e-01],
         [ 2.2022e-01, -1.8603e-02, -1.4407e-01,  ..., -1.8658e-01,
          -6.0026e-02,  2.9750e-02],
         [ 5.5622e-02, -5.5893e-02, -1.1436e-01,  ..., -3.1627e-01,
           1.8485e-01,  2.5377e-01],
         ...,
         [ 3.8655e-02,  2

tensor([[[-3.9779e-02, -1.5618e-02,  3.4112e-02,  ..., -1.3549e-01,
          -3.7591e-02,  2.9862e-01],
         [ 9.6094e-02, -7.6128e-02,  1.3297e-02,  ..., -3.8313e-01,
           4.9722e-02,  2.3396e-01],
         [-3.8877e-05, -9.1838e-02,  7.4245e-02,  ..., -3.2583e-01,
           1.0163e-01,  3.0983e-01],
         ...,
         [-5.0423e-02,  1.1671e-01,  3.1608e-01,  ..., -2.0913e-01,
           5.5499e-02,  3.1570e-01],
         [-5.6732e-02,  3.2951e-03,  1.2042e-01,  ..., -2.1905e-01,
           1.1928e-01,  1.2287e-01],
         [ 1.1265e-01,  6.7148e-02,  7.5526e-02,  ..., -3.0054e-01,
           7.0378e-02,  1.8904e-01]],

        [[-7.7484e-02, -4.7723e-02,  2.5853e-02,  ..., -1.6188e-01,
          -4.2188e-02,  3.1086e-01],
         [ 6.7311e-02, -2.2806e-01, -1.7761e-02,  ..., -3.9004e-01,
           7.5301e-03,  1.5321e-01],
         [ 3.4149e-02,  5.6612e-02,  5.0975e-02,  ..., -1.5679e-01,
           6.7212e-02,  3.2317e-01],
         ...,
         [-7.3060e-02,  4

tensor([[[-0.2244, -0.3630,  0.1041,  ..., -0.0884, -0.1901,  0.1079],
         [-0.0502, -0.1083, -0.1124,  ..., -0.1498, -0.3660, -0.1369],
         [-0.2398,  0.0207,  0.2149,  ..., -0.1234,  0.0361,  0.0858],
         ...,
         [-0.0991, -0.0013,  0.3279,  ..., -0.0299, -0.2188,  0.1522],
         [ 0.0038, -0.0779,  0.1522,  ...,  0.0419, -0.0007, -0.0025],
         [-0.0065,  0.0034,  0.0356,  ..., -0.0681, -0.0462,  0.1138]],

        [[-0.2464, -0.4324,  0.0665,  ..., -0.1268, -0.2235, -0.0009],
         [-0.0371, -0.1379,  0.0587,  ..., -0.2621, -0.3976,  0.0192],
         [-0.0981, -0.0098,  0.0567,  ..., -0.2087,  0.0209,  0.0266],
         ...,
         [-0.1268, -0.0203,  0.3210,  ..., -0.0899, -0.2643,  0.1184],
         [-0.0423, -0.1369,  0.1702,  ...,  0.0254, -0.0579, -0.0246],
         [-0.0339, -0.0131,  0.0459,  ..., -0.1187, -0.0569,  0.0939]],

        [[-0.2671, -0.4651,  0.0820,  ..., -0.0994, -0.1677,  0.0216],
         [-0.0583, -0.1789, -0.0922,  ..., -0

tensor([[[-1.4036e-01, -1.7772e-01, -3.7137e-03,  ..., -2.5990e-01,
          -8.7595e-03,  2.2953e-01],
         [ 4.9947e-02, -3.1191e-01, -1.0434e-01,  ..., -5.6602e-01,
           6.0029e-02, -3.4696e-02],
         [-1.2756e-01,  1.2136e-02, -1.6628e-02,  ..., -3.6400e-01,
           1.6900e-01,  1.9367e-01],
         ...,
         [-2.1869e-01,  1.6078e-01,  3.0861e-01,  ..., -2.6700e-01,
           4.2522e-02,  1.6309e-02],
         [-3.0653e-01, -7.7265e-02,  1.0176e-01,  ..., -2.0325e-01,
           1.7100e-01, -2.9731e-02],
         [-5.5702e-02,  1.0547e-01,  1.2851e-01,  ..., -3.3523e-01,
           1.3955e-01,  1.4133e-02]],

        [[-6.9586e-02, -1.2368e-01,  1.6688e-02,  ..., -2.3532e-01,
          -4.0266e-02,  2.5516e-01],
         [ 4.6969e-02, -3.2198e-01, -6.2966e-02,  ..., -3.6379e-01,
          -8.5606e-02,  9.9123e-02],
         [-2.2566e-01,  1.3281e-01,  1.0938e-01,  ..., -2.2738e-01,
           1.5970e-02,  1.7028e-01],
         ...,
         [-1.7631e-01,  1

tensor([[[ 0.0544, -0.1037, -0.2371,  ..., -0.0262, -0.0626,  0.0084],
         [ 0.0912, -0.2579, -0.2194,  ..., -0.3353,  0.0124, -0.0706],
         [-0.0783,  0.0784, -0.0413,  ..., -0.0946,  0.1149,  0.0796],
         ...,
         [-0.0022,  0.0802,  0.0230,  ..., -0.1592,  0.0714, -0.0782],
         [-0.0169, -0.0166, -0.0279,  ..., -0.0846,  0.1988, -0.0257],
         [-0.0103,  0.0466, -0.1892,  ..., -0.1164,  0.0311, -0.0081]],

        [[ 0.0567, -0.0744, -0.1755,  ..., -0.0497, -0.0468,  0.0704],
         [ 0.0881, -0.0038, -0.0593,  ..., -0.1912,  0.0090, -0.0390],
         [-0.1119,  0.1135, -0.0528,  ..., -0.2560,  0.0999,  0.1989],
         ...,
         [ 0.0209,  0.0739,  0.1159,  ..., -0.2523,  0.1202, -0.0508],
         [-0.0148,  0.0352,  0.0937,  ..., -0.1661,  0.2195, -0.0310],
         [-0.0013,  0.0368, -0.0943,  ..., -0.1770,  0.0852, -0.0637]],

        [[ 0.0499, -0.1140, -0.2474,  ..., -0.0963, -0.0826, -0.0250],
         [-0.0513, -0.1647, -0.2196,  ..., -0

tensor([[[-0.1059, -0.2237,  0.0218,  ..., -0.0871, -0.3864,  0.1186],
         [-0.1440, -0.1178,  0.1005,  ..., -0.0729, -0.2255,  0.0193],
         [ 0.0043, -0.1083,  0.2410,  ..., -0.0418,  0.0564,  0.3523],
         ...,
         [-0.0603, -0.0355,  0.2258,  ..., -0.1025, -0.0796,  0.1164],
         [-0.1747, -0.1240,  0.1835,  ..., -0.0646,  0.1587, -0.0456],
         [-0.0519, -0.0881,  0.0737,  ..., -0.2086, -0.0049, -0.0417]],

        [[-0.1224, -0.1766,  0.0491,  ..., -0.1611, -0.4673,  0.1880],
         [-0.0901, -0.3589, -0.0461,  ..., -0.0436, -0.3140,  0.0208],
         [ 0.0018, -0.0780,  0.3811,  ..., -0.0731, -0.0488,  0.2241],
         ...,
         [-0.0707, -0.0785,  0.2302,  ..., -0.1696, -0.1530,  0.1386],
         [-0.2726, -0.1602,  0.1442,  ..., -0.1037,  0.0351, -0.0271],
         [-0.0572, -0.1054,  0.1161,  ..., -0.3065, -0.0688, -0.0409]],

        [[-0.1363, -0.1259,  0.0067,  ..., -0.0985, -0.4000,  0.1856],
         [ 0.1796, -0.1214, -0.0638,  ..., -0

tensor([[[ 0.0868, -0.2036, -0.1606,  ..., -0.1266, -0.2077,  0.1667],
         [-0.1754, -0.1476, -0.1709,  ...,  0.0334, -0.2362,  0.0551],
         [ 0.0058, -0.0022, -0.0095,  ..., -0.0731, -0.2063,  0.3042],
         ...,
         [-0.0040,  0.0150,  0.0765,  ...,  0.0005, -0.2010,  0.1753],
         [-0.1219, -0.1840,  0.0177,  ..., -0.0647, -0.1327,  0.0435],
         [-0.0796,  0.0787, -0.0988,  ..., -0.0457, -0.1113,  0.0358]],

        [[ 0.0743, -0.2422, -0.1173,  ..., -0.0886, -0.1452,  0.1591],
         [ 0.2880, -0.0532, -0.0198,  ..., -0.1924, -0.1260,  0.1050],
         [-0.1537, -0.1058, -0.0733,  ..., -0.1959, -0.0813,  0.0846],
         ...,
         [ 0.0325, -0.0513,  0.0751,  ..., -0.0251, -0.1149,  0.1257],
         [-0.1483, -0.2063,  0.0438,  ..., -0.0972, -0.0763, -0.0788],
         [-0.1295, -0.0455, -0.0750,  ..., -0.0726, -0.1396,  0.0255]],

        [[ 0.0392, -0.1857, -0.0662,  ..., -0.1111, -0.1880,  0.1952],
         [-0.0568, -0.1699, -0.1074,  ..., -0

tensor([[[ 0.0321, -0.2475, -0.0819,  ..., -0.2091, -0.0097,  0.0751],
         [ 0.0188,  0.0991, -0.2194,  ..., -0.1293, -0.1371, -0.0614],
         [ 0.0643,  0.1051,  0.1065,  ..., -0.1865,  0.2015,  0.3184],
         ...,
         [ 0.0571,  0.0405,  0.2806,  ..., -0.0563,  0.1007,  0.2336],
         [-0.0164, -0.1079,  0.1336,  ..., -0.1886,  0.2853,  0.0312],
         [ 0.0177,  0.0707,  0.0099,  ..., -0.3717,  0.2166,  0.0206]],

        [[ 0.0514, -0.2182,  0.0075,  ..., -0.1119,  0.0490,  0.0403],
         [ 0.2590,  0.0244, -0.2824,  ..., -0.2274, -0.0502,  0.0308],
         [ 0.0803, -0.0884,  0.0665,  ..., -0.2749,  0.2689,  0.2219],
         ...,
         [ 0.0770,  0.0279,  0.3447,  ..., -0.0655,  0.1237,  0.1890],
         [ 0.0184, -0.0937,  0.2408,  ..., -0.0831,  0.2155,  0.0333],
         [ 0.0752,  0.1024,  0.0894,  ..., -0.3554,  0.1713,  0.0388]],

        [[ 0.0938, -0.2591,  0.0235,  ..., -0.1474,  0.0749,  0.0388],
         [-0.0368, -0.0798, -0.0895,  ..., -0

tensor([[[-0.0010, -0.0727,  0.1739,  ..., -0.2432, -0.1108,  0.2797],
         [-0.0458, -0.2393,  0.1916,  ..., -0.2769, -0.1185,  0.0786],
         [-0.1563, -0.0423,  0.3463,  ..., -0.2419, -0.0434,  0.3102],
         ...,
         [-0.0653,  0.0423,  0.3404,  ..., -0.2952,  0.0288,  0.2871],
         [-0.1693, -0.0275,  0.1833,  ..., -0.2765,  0.1251,  0.1322],
         [-0.1062,  0.0323,  0.1294,  ..., -0.3425,  0.0028,  0.3041]],

        [[ 0.0420, -0.1228,  0.1980,  ..., -0.1749, -0.0718,  0.2735],
         [ 0.0187, -0.2748,  0.0943,  ..., -0.3047,  0.1674,  0.2651],
         [-0.2113, -0.0621,  0.0992,  ..., -0.3546,  0.0078,  0.2472],
         ...,
         [-0.0920, -0.0052,  0.3450,  ..., -0.2113,  0.0094,  0.2947],
         [-0.1540, -0.0655,  0.1299,  ..., -0.1653,  0.1618,  0.0805],
         [-0.0779, -0.0271,  0.1538,  ..., -0.2938, -0.0139,  0.2955]],

        [[ 0.0543, -0.1761,  0.1305,  ..., -0.1572, -0.1078,  0.2505],
         [-0.0527, -0.2146, -0.0091,  ..., -0

tensor([[[-0.0254, -0.0935,  0.2888,  ..., -0.1248, -0.0810,  0.0425],
         [ 0.4271,  0.0193,  0.0382,  ..., -0.1081, -0.0233, -0.1153],
         [-0.0241, -0.0436,  0.0936,  ..., -0.1790,  0.2092,  0.0149],
         ...,
         [ 0.1291,  0.0998,  0.4418,  ..., -0.1112,  0.0279, -0.0488],
         [-0.0322,  0.0754,  0.2988,  ..., -0.0585,  0.2815,  0.0580],
         [ 0.1027,  0.1675,  0.1840,  ..., -0.2610, -0.0607, -0.1132]],

        [[-0.0195, -0.1452,  0.2940,  ..., -0.1414, -0.1652,  0.0519],
         [ 0.2442, -0.1711,  0.1414,  ..., -0.3683,  0.0247, -0.1903],
         [ 0.0318, -0.0290,  0.0800,  ..., -0.1859,  0.1395,  0.0368],
         ...,
         [ 0.0568,  0.0693,  0.4471,  ..., -0.1043, -0.0250, -0.0424],
         [-0.0726,  0.0551,  0.2847,  ..., -0.0728,  0.1964,  0.0873],
         [ 0.0575,  0.0951,  0.2068,  ..., -0.3076, -0.0683, -0.0660]],

        [[-0.0348, -0.0803,  0.2396,  ..., -0.0813, -0.1237,  0.0913],
         [ 0.0827, -0.1187, -0.0334,  ..., -0

tensor([[[-0.0287, -0.2005, -0.0822,  ..., -0.0939, -0.1217,  0.1027],
         [ 0.0667, -0.4018, -0.0195,  ..., -0.1087, -0.0158, -0.0172],
         [ 0.1768, -0.0440, -0.1440,  ..., -0.0695,  0.1254,  0.0868],
         ...,
         [ 0.0296, -0.0378,  0.2684,  ...,  0.0145,  0.1157, -0.0287],
         [ 0.0768, -0.1212,  0.0237,  ..., -0.0376,  0.2330, -0.0071],
         [ 0.1730, -0.0536,  0.0366,  ..., -0.0747,  0.0550, -0.0413]],

        [[ 0.0085, -0.2694, -0.0638,  ..., -0.0893, -0.1677,  0.1181],
         [ 0.0977, -0.2636, -0.0710,  ..., -0.2756,  0.0048, -0.0285],
         [ 0.0119, -0.1797, -0.1681,  ..., -0.1732,  0.1132,  0.1886],
         ...,
         [ 0.0448, -0.0868,  0.2386,  ...,  0.0091,  0.0381, -0.0043],
         [ 0.0709, -0.1587,  0.0468,  ..., -0.0284,  0.1445,  0.0171],
         [ 0.1526, -0.1776,  0.0199,  ..., -0.0575,  0.0147,  0.0222]],

        [[ 0.0032, -0.2540, -0.0563,  ..., -0.0957, -0.1626,  0.0985],
         [ 0.2448,  0.0630, -0.1033,  ..., -0

tensor([[[ 0.0087, -0.0741, -0.1156,  ..., -0.2295, -0.2517,  0.1010],
         [-0.1284, -0.1251, -0.1893,  ..., -0.4252, -0.2155, -0.1400],
         [-0.1806,  0.0623,  0.1899,  ..., -0.3228, -0.1547,  0.0149],
         ...,
         [ 0.0616,  0.0384,  0.0038,  ..., -0.2342,  0.0057, -0.0137],
         [-0.0536, -0.0588,  0.0314,  ..., -0.2438,  0.1114, -0.1716],
         [-0.0168,  0.0639, -0.0504,  ..., -0.3528,  0.1454, -0.0782]],

        [[ 0.0066, -0.0305, -0.0845,  ..., -0.2223, -0.2075,  0.1205],
         [-0.0924, -0.0071, -0.1201,  ..., -0.2760, -0.1128, -0.0037],
         [-0.1214,  0.1229,  0.0796,  ..., -0.3611, -0.1066,  0.1638],
         ...,
         [ 0.0793,  0.0960,  0.0367,  ..., -0.2107, -0.0403,  0.0367],
         [-0.0265, -0.0129,  0.0208,  ..., -0.2577,  0.0730, -0.1312],
         [-0.0024,  0.0898, -0.0578,  ..., -0.2724,  0.1086, -0.0897]],

        [[ 0.0043, -0.0967, -0.1100,  ..., -0.2502, -0.2060,  0.0819],
         [-0.1069, -0.0923, -0.2615,  ..., -0

tensor([[[-0.2390, -0.0674,  0.0525,  ..., -0.0756, -0.2482,  0.0329],
         [-0.2126, -0.0776, -0.1426,  ..., -0.2763,  0.0609, -0.1783],
         [-0.2132,  0.1924, -0.0894,  ..., -0.1361, -0.1831,  0.1250],
         ...,
         [-0.1325,  0.2935,  0.2044,  ..., -0.1123, -0.1045,  0.2319],
         [-0.1262,  0.0956,  0.0423,  ..., -0.0506,  0.1074,  0.0831],
         [-0.0814,  0.1750,  0.0812,  ..., -0.0677, -0.0948, -0.0405]],

        [[-0.2227, -0.0738,  0.0957,  ..., -0.1199, -0.2438,  0.0269],
         [-0.1533, -0.0316, -0.1187,  ..., -0.2707, -0.0999, -0.0294],
         [-0.2566,  0.2642,  0.2659,  ..., -0.1574, -0.0206,  0.2021],
         ...,
         [-0.1900,  0.3246,  0.1614,  ..., -0.1475, -0.0947,  0.1845],
         [-0.0637,  0.1632, -0.0302,  ..., -0.1573,  0.0503,  0.0310],
         [-0.0653,  0.2018,  0.1066,  ..., -0.1416, -0.0670, -0.0709]],

        [[-0.2231, -0.0736,  0.0891,  ..., -0.0234, -0.2526,  0.0644],
         [-0.1365,  0.0352,  0.0058,  ..., -0

tensor([[[ 1.3214e-01, -5.4318e-02,  5.1706e-02,  ..., -6.8993e-02,
          -6.4121e-02,  5.5526e-02],
         [ 3.5787e-01,  1.2809e-02, -1.1902e-01,  ..., -1.9262e-02,
          -5.1885e-02,  6.0760e-03],
         [ 2.2401e-01, -1.3825e-01, -7.7769e-02,  ..., -9.7204e-04,
           7.5425e-02,  2.4357e-01],
         ...,
         [ 7.5653e-02,  7.7197e-02,  1.0566e-01,  ..., -1.6826e-02,
           1.8034e-01,  1.5889e-01],
         [ 1.0462e-01, -1.2012e-01,  8.0519e-03,  ..., -4.4394e-02,
           1.6293e-01,  4.3087e-03],
         [ 1.4978e-01,  1.4247e-01, -2.8019e-02,  ..., -3.3214e-02,
           9.0474e-02, -1.1291e-01]],

        [[ 7.4560e-02, -6.0263e-02,  8.9569e-02,  ..., -1.4452e-01,
          -1.0218e-02,  4.8827e-02],
         [ 1.6179e-01, -2.1929e-01, -1.4691e-01,  ..., -2.1503e-01,
          -8.0092e-02, -5.3124e-02],
         [ 6.0228e-02, -1.1529e-01,  7.7350e-02,  ...,  2.2745e-01,
          -5.8256e-02,  2.5247e-01],
         ...,
         [ 8.3326e-02,  6

tensor([[[-6.5557e-02, -2.0110e-01, -1.0981e-01,  ..., -6.0361e-02,
          -2.2253e-01,  2.7277e-01],
         [ 8.0152e-02, -2.3183e-01, -2.3966e-01,  ..., -1.6539e-01,
          -1.6436e-01,  2.0333e-02],
         [-6.2312e-02, -1.2381e-01, -1.4818e-02,  ..., -1.3260e-01,
           4.5685e-02,  3.6016e-01],
         ...,
         [-1.3178e-01,  3.0816e-02,  1.6719e-01,  ..., -2.2408e-01,
           1.1338e-01,  1.8584e-01],
         [-1.6642e-01, -1.3750e-01,  8.5831e-03,  ..., -1.3518e-01,
           2.2169e-01,  2.1512e-01],
         [-6.0564e-02, -7.4185e-03, -3.0646e-02,  ..., -2.4849e-01,
           3.6508e-02,  2.0141e-01]],

        [[-2.1106e-01, -2.6583e-01, -1.1607e-02,  ..., -9.9097e-02,
          -2.8504e-01,  1.9755e-01],
         [-8.6660e-02, -1.5857e-01, -2.0408e-01,  ..., -1.0270e-01,
           7.8510e-02, -7.3425e-02],
         [-1.0360e-01,  3.4580e-02, -1.1430e-01,  ..., -6.7908e-02,
          -1.1535e-01,  2.9646e-01],
         ...,
         [-1.9724e-01,  6

tensor([[[-0.1986, -0.4907,  0.0353,  ..., -0.2287, -0.1444,  0.0299],
         [-0.0579, -0.4232, -0.1628,  ..., -0.1678, -0.0597, -0.0291],
         [-0.2629, -0.2477,  0.2901,  ..., -0.0747, -0.0178,  0.2145],
         ...,
         [-0.2073, -0.1418,  0.2125,  ..., -0.0824, -0.0568,  0.1592],
         [-0.1718, -0.3667, -0.0407,  ..., -0.0136,  0.0380, -0.0487],
         [-0.0876, -0.1954,  0.0059,  ..., -0.1270,  0.0781,  0.0652]],

        [[-0.2153, -0.5091,  0.0654,  ..., -0.2164, -0.2129,  0.0845],
         [ 0.2382, -0.3513, -0.0955,  ..., -0.1535, -0.1078,  0.0965],
         [-0.2106, -0.1245,  0.1061,  ..., -0.2925,  0.0307,  0.2588],
         ...,
         [-0.1974, -0.1123,  0.2165,  ..., -0.1663, -0.0491,  0.2010],
         [-0.2134, -0.3492, -0.0077,  ..., -0.0264,  0.0269, -0.0065],
         [-0.1059, -0.1735,  0.0409,  ..., -0.1355,  0.0633,  0.1097]],

        [[-0.2182, -0.4441,  0.0229,  ..., -0.1872, -0.2270,  0.0311],
         [-0.0991, -0.4064, -0.1182,  ..., -0

tensor([[[-1.0678e-01, -1.7020e-01,  9.9393e-02,  ..., -5.0486e-02,
          -3.1164e-02,  2.7956e-01],
         [ 2.1715e-02, -2.0063e-01, -2.2930e-01,  ..., -2.1506e-01,
          -4.7115e-02,  3.7824e-02],
         [ 8.0500e-02,  5.7872e-02,  2.0028e-01,  ..., -8.0905e-04,
           2.1893e-01,  1.6405e-01],
         ...,
         [ 6.7831e-02,  9.5958e-02,  1.0843e-01,  ...,  1.3349e-02,
           3.3987e-01,  2.2393e-01],
         [ 2.0899e-02,  1.7156e-02, -2.5643e-02,  ..., -8.6159e-03,
           3.4667e-01,  7.9283e-02],
         [ 9.8312e-02,  1.6584e-01, -6.2019e-02,  ..., -1.4649e-01,
           2.7048e-01,  9.5084e-02]],

        [[-2.1997e-02, -1.1562e-01,  1.5396e-02,  ..., -1.0971e-01,
           6.2690e-02,  1.7194e-01],
         [ 9.5476e-02, -2.1652e-01, -2.3563e-01,  ..., -2.3363e-01,
           2.8711e-01,  5.9265e-03],
         [ 5.2253e-02,  3.9302e-02, -1.8159e-01,  ..., -6.4795e-02,
           2.0214e-01,  1.9184e-02],
         ...,
         [ 1.5221e-01,  1

tensor([[[-4.8496e-02, -5.3446e-02,  1.4614e-01,  ..., -3.6586e-01,
          -1.5976e-01, -3.8269e-02],
         [ 5.1885e-02, -7.2440e-02, -1.7434e-02,  ..., -5.1509e-01,
           4.1137e-02, -1.2272e-01],
         [-1.0107e-01,  1.9424e-01,  2.1478e-03,  ..., -5.1318e-01,
           1.6277e-01,  1.8634e-01],
         ...,
         [-1.0974e-01,  1.6562e-01,  2.2119e-01,  ..., -3.5374e-01,
          -2.3994e-02,  1.1636e-01],
         [-4.2417e-02,  1.1729e-01,  1.6526e-01,  ..., -2.4555e-01,
           3.1540e-02,  7.1546e-02],
         [ 3.1126e-02,  1.5160e-01,  1.0695e-01,  ..., -3.9925e-01,
           8.3536e-02,  4.8366e-02]],

        [[-5.6911e-02,  3.4640e-02,  7.4069e-02,  ..., -3.2995e-01,
          -1.8203e-01,  8.0709e-03],
         [ 7.2530e-02, -5.4314e-02, -1.2166e-01,  ..., -4.1576e-01,
           3.1568e-02, -1.3659e-01],
         [-8.2678e-02,  3.2998e-01, -2.3085e-02,  ..., -2.4887e-01,
          -2.1753e-02,  1.6434e-02],
         ...,
         [-8.6150e-02,  2

tensor([[[-0.0213,  0.0616, -0.1021,  ..., -0.1512, -0.0934,  0.0888],
         [-0.0008, -0.0235, -0.1631,  ..., -0.3180,  0.2197, -0.1029],
         [ 0.0478,  0.1819, -0.2076,  ..., -0.1377,  0.1240,  0.1259],
         ...,
         [-0.0913,  0.2002,  0.1970,  ..., -0.2526,  0.0491,  0.1427],
         [-0.2130,  0.1312,  0.0377,  ..., -0.2041,  0.1769, -0.0445],
         [-0.0051,  0.2271, -0.0252,  ..., -0.3665,  0.1445, -0.0049]],

        [[-0.0283,  0.0866, -0.0970,  ..., -0.0879, -0.0210,  0.0973],
         [ 0.0528,  0.2139, -0.1347,  ..., -0.2851,  0.1516, -0.0570],
         [ 0.0037,  0.2518,  0.1657,  ..., -0.1669,  0.2155,  0.1492],
         ...,
         [-0.1106,  0.2683,  0.2543,  ..., -0.2591,  0.0367,  0.0799],
         [-0.2066,  0.1944,  0.0806,  ..., -0.2246,  0.1469, -0.1179],
         [-0.0382,  0.2652,  0.0319,  ..., -0.3230,  0.0605, -0.0462]],

        [[-0.0852,  0.1517, -0.1084,  ..., -0.0869, -0.0656,  0.1494],
         [ 0.2428,  0.3217, -0.0971,  ..., -0

In [14]:
#输出对比测试结果
import os
import time

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def print_time_of_model(model):    
    eval_start_time = time.time()
    model(torch.LongTensor([input_ids]), \
                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
    eval_end_time = time.time()
    eval_duration_time = eval_end_time - eval_start_time
    print("Evaluate total time (seconds): {0:.1f}".format(eval_duration_time))

    
print_size_of_model(model)
print_size_of_model(quantized_model)

print_time_of_model(model)
print_time_of_model(quantized_model)

Size (MB): 160.844839
Size (MB): 40.572431
tensor([[[-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [ 0.1013,  0.3634,  0.0050,  ..., -0.1247, -0.0733, -0.1555],
         [-0.0312,  0.2100, -0.0643,  ..., -0.1972, -0.0326, -0.1339],
         ...,
         [-0.0090,  0.1465,  0.0373,  ..., -0.2756, -0.2431, -0.1753],
         [-0.0873,  0.1474,  0.2928,  ..., -0.3138, -0.1397, -0.0847],
         [ 0.0250,  0.1036,  0.1591,  ..., -0.1132, -0.0359, -0.0897]]],
       grad_fn=<AddBackward0>)
tensor([[[7, 7, 7,  ..., 7, 7, 7],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]])
tensor([[[ 0.0538,  0.1581,  0.0941,  ..., -0.4250,  0.0572,  0.0348],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3327,  0.0380,  ..., -0.3048, -0.0289, -0.1779],
         [-0.0738,  0.3