## 测试用

In [1]:
import os
import time
import json
import argparse
import torch
import torch.optim as optim

# from models import Model
from models import save_model
from dataset import get_dataset, get_dataloader
from metrics import recall,mean_rank, mean_reciprocal_rank

In [2]:
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "gpu")

In [3]:
print(torch.__version__)
print(device)

1.1.0
cuda


In [14]:
'''
 @Date  : 
 @Author: liuyouyuan
 @mail  : liuyouyuan@lizhi.fm
'''
'''
定义模型主体结构
'''
import math
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable

device = torch.device("cuda" if torch.cuda.is_available() else "gpu")


class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return self.proj(x)


class TextEncoder(nn.Module):

    def __init__(self, d_model, d_ff, n_head, dropout, n_block):
        super(TextEncoder, self).__init__()
        self.layers = nn.ModuleList([TextBlock(d_model, d_ff, n_head, dropout) for _ in range(n_block)])
        self.norm = LayerNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


class CommentDecoder(nn.Module):

    def __init__(self, d_model, d_ff, n_head, dropout, n_block):
        super(CommentDecoder, self).__init__()
        self.layers = nn.ModuleList([DecoderBlock(d_model, d_ff, n_head, dropout) for _ in range(n_block)])
        self.norm = LayerNorm(d_model)

    def forward(self, x, m, mask):
        for layer in self.layers:
            x = layer(x, m, mask)
        return self.norm(x)


class TextBlock(nn.Module):

    def __init__(self, d_model, d_ff, n_head, dropout):
        super(TextBlock, self).__init__()
        self.self_attn = MultiHeadedAttention(n_head, d_model)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)])

    def forward(self, x):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x))
        return self.sublayer[1](x, self.feed_forward)


class DecoderBlock(nn.Module):

    def __init__(self, d_model, d_ff, n_head, dropout):
        super(DecoderBlock, self).__init__()
        self.self_attn = MultiHeadedAttention(n_head, d_model)
        self.text_attn = MultiHeadedAttention(n_head, d_model)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)])

    def forward(self, x, m, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        x = self.sublayer[1](x, lambda x: self.text_attn(x, m, m))
        return self.sublayer[2](x, self.feed_forward)


class LayerNorm(nn.Module):

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(4)])
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def attention(self, query, key, value, mask=None, dropout=None):
        "Compute 'Scaled Dot Product Attention'"
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = self.attention(query, key, value, mask=mask,
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x)


class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)],requires_grad=False).cuda()
        return self.dropout(x)


class PositionalEmb(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEmb, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        self.pe = torch.nn.Embedding(max_len, d_model)

    def forward(self, x):
        x = x + self.pe(Variable(torch.range(1,x.size(1))).long().cuda()).unsqueeze(0)
        return self.dropout(x)


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


def subsequent_mask(batch, size):
    "Mask out subsequent positions."
    attn_shape = (batch, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


class Model(nn.Module):
    def __init__(self, n_emb, n_hidden, vocab_size, dropout, d_ff, n_head, n_block):
        super(Model, self).__init__()
        self.n_emb = n_emb
        self.n_hidden = n_hidden
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.embedding = nn.Sequential(Embeddings(n_hidden, vocab_size), PositionalEncoding(n_hidden, dropout))
        self.text_encoder = TextEncoder(n_hidden, d_ff, n_head, dropout, n_block)
        self.comment_decoder = CommentDecoder(n_hidden, d_ff, n_head, dropout, n_block)
        self.output_layer = nn.Linear(self.n_hidden, self.vocab_size)
        self.criterion = nn.CrossEntropyLoss(reduce=False, size_average=False, ignore_index=0)

    def encode_text(self, T):
        x = self.embedding(T)
        out = self.text_encoder(x)
        return out

    def decode(self, Y, T, mask):
        embs = self.embedding(Y)
        out = self.comment_decoder(embs, T, mask)
        out = self.output_layer(out)
        return out
    
    def forward(self, Y, T):
        enc_text = self.encode_text(T)
        print("encode text:", enc_text)
        print(enc_text.size())
        mask = Variable(subsequent_mask(Y.size(0), Y.size(1)-1), requires_grad=False).cuda()
        print("mask:", mask, mask.size())
        print(Y.size(0), Y.size(1)-1)
        outs = self.decode(Y[:,:-1], enc_text, mask)
        print("outs:", outs, outs.size())
        Y = Y.t()
        print("Y.t()", Y)
        print(Y.size())
        print("-"*20)
        outs = outs.transpose(0, 1)
        print("outs.tr", outs)
        print(outs.size())
        print("-"*20)
        print("OSIZE:", outs.contiguous().view(-1, self.vocab_size).size())
        print("YSIZE:", Y[1:].contiguous().view(-1).size())
        loss = self.criterion(outs.contiguous().view(-1, self.vocab_size),
                              Y[1:].contiguous().view(-1))
        print("Loss:", loss, loss.size())
        return torch.mean(loss)

    def ranking(self, Y, T):
        nums = len(Y)
        out_text = self.encode_text(T.unsqueeze(0))
        out_text = out_text.repeat(nums, 1, 1)

        mask = Variable(subsequent_mask(Y.size(0), Y.size(1) - 1), requires_grad=False).cuda()
        outs = self.decode(Y[:, :-1], out_text, mask)

        Y = Y.t()
        outs = outs.transpose(0, 1)

        loss = self.criterion(outs.contiguous().view(-1, self.vocab_size),
                              Y[1:].contiguous().view(-1))

        loss = loss.view(-1, nums).sum(0)
        return torch.sort(loss, dim=0, descending=True)[1]


def save_model(path, model):
    model_state_dict = model.state_dict()
    torch.save(model_state_dict, path)


In [17]:
def set_parser():
    parser = argparse.ArgumentParser(description='train.py')

    parser.add_argument('--n_emb', type=int, default=512, help="Embedding size")
    parser.add_argument('--n_hidden', type=int, default=512, help="Hidden size")
    parser.add_argument('--d_ff', type=int, default=2048, help="Hidden size of Feedforward")
    parser.add_argument('--n_head', type=int, default=8, help="Number of head")
    parser.add_argument('--n_block', type=int, default=6, help="Number of block")
    parser.add_argument('--batch_size', type=int, default=64, help="Batch size")
    parser.add_argument('--vocab_size', type=int, default=30000, help="Vocabulary size")
    parser.add_argument('--epoch', type=int, default=50, help="Number of epoch")
    parser.add_argument('--report', type=int, default=500, help="Number of report interval")
    parser.add_argument('--lr', type=float, default=3e-4, help="Learning rate")
    parser.add_argument('--dropout', type=float, default=0.1, help="Dropout rate")
    parser.add_argument('--restore', type=str, default='', help="Restoring model path")
    parser.add_argument('--mode', type=str, default='train', help="Train or test")
    parser.add_argument('--dir', type=str, default='ckpt', help="Checkpoint directory")
    parser.add_argument('--max_len', type=int, default=20, help="Limited length for text")
    parser.add_argument('--n_com', type=int, default=5, help="Number of input comments")
    return parser.parse_args()

class Args(object):
    def __init__(self, mode="train"):
        self.n_emb = 512
        self.n_hidden = 512
        self.d_ff = 2048
        self.n_head = 8
        self.n_block = 6
        self.batch_size = 64
        self.vocab_size = 30000
        self.epoch = 50
        self.report = 500
        self.lr = 3e-4
        self.dropout = 0.1
        self.restore = ''
        self.mode = mode
        self.dir = "./test_dir"
        self.max_len = 20
        self.n_com = 5

class Config(object):

    def __init__(self, args, data_path="data"):
        self.args = args
        self.data_path = data_path
        self.train_path = os.path.join(data_path, "train-context.json")
        self.dev_path = os.path.join(data_path, "dev-candidate.json")
        self.test_path = os.path.join(data_path, "test-candidate.json")
        self.vocab_path = os.path.join(data_path, "dicts-30000.json")
        self.w2i_vocabs = json.load(open(self.vocab_path, 
                                        'r', encoding='utf8'))['word2id']
        self.i2v_vocabs = json.load(open(self.vocab_path, 
                                        'r', encoding='utf8'))['id2word']
        self.args.vocab_size = len(self.w2i_vocabs)
        if not os.path.exists(self.args.dir):
            os.mkdir(self.args.dir)

def train(config):
    # train_path:train-context.json
    args = config.args
    train_set = get_dataset(config.train_path, config.w2i_vocabs, config, is_train=True)
    dev_set = get_dataset(config.dev_path, config.w2i_vocabs, config, is_train=False)
    # X:img,torch.stack;
    train_batch = get_dataloader(train_set, args.batch_size, is_train=True)
    model = Model(n_emb=args.n_emb, n_hidden=args.n_hidden, vocab_size=args.vocab_size,
                  dropout=args.dropout, d_ff=args.d_ff, n_head=args.n_head, n_block=args.n_block)
    if args.restore != '':
        model_dict = torch.load(args.restore)
        model.load_state_dict(model_dict)
    model.to(device)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=args.lr)
    best_score = -1000000

    for i in range(args.epoch):
        model.train()
        report_loss, start_time, n_samples = 0, time.time(), 0
        count, total = 0, len(train_set) // args.batch_size + 1
        n = 0
        for batch in train_batch:
            Y, T = batch
            print("-------Y:", Y)
            print("Y[63]:", Y[63])
            print("-------T:", T)
            print("T[63]:", T[63])
            Y = Y.to(device)
            T = T.to(device)
            print("Y size:", Y.size())
            print("T size:", T.size())
            optimizer.zero_grad()
            loss = model(Y, T)
            print("-----loss:", loss)
            print("*"*30)
            loss.backward()
            optimizer.step()
            report_loss += loss.item()
            #break
            n_samples += len(Y.data)
            print("n_samples:", n_samples)
            n += 1
            if n ==10:
                break
            count += 1
            if count % args.report == 0 or count == total:
                print('%d/%d, epoch: %d, report_loss: %.3f, time: %.2f'
                      % (count, total, i+1, report_loss / n_samples, time.time() - start_time))
                score = eval(model, dev_set, args.batch_size)
                model.train()
                if score > best_score:
                    best_score = score
                    save_model(os.path.join(args.dir, 'best_checkpoint.pt'), model)
                else:
                    save_model(os.path.join(args.dir, 'checkpoint.pt'), model)
                report_loss, start_time, n_samples = 0, time.time(), 0
        if i == 0:
            break

    return model


def eval(model, dev_set, batch_size):
    print("starting evaluating...")
    start_time = time.time()
    model.eval()
    # predictions, references = [], []
    dev_batch = get_dataloader(dev_set, batch_size, is_train=False)

    loss = 0
    with torch.no_grad():
        for batch in dev_batch:
            Y, T = batch
            Y = Y.to(device)
            T = T.to(device)
            loss += model(Y, T).item()
    print(loss)
    print("evaluting time:", time.time() - start_time)

    return -loss


def test(test_set, model):
    print("starting testing...")
    start_time = time.time()
    model.eval()
    predictions, references = [], []
    with torch.no_grad():
        for i in range(len(test_set)):
            Y, T, data = test_set.get_candidate(i)
            print("-------Y:", Y)
            print("Y[10]:", Y[10])
            print("-------T:", T)
            print("T[10]:", T[10])
            print("data:", data)
            print("*"*40)
            Y = Y.to(device)
            T = T.to(device)
            ids = model.ranking(Y, T).data
            print("ids:", ids)
            print("ids size", ids.size())

            candidate = []
            comments = list(data['candidate'].keys())
            print("comments:", comments)
            for id in ids:
                print("comments[id]:", comments[id])
                candidate.append(comments[id])
            predictions.append(candidate)
            print("prediction candidate:", candidate)
            references.append(data['candidate'])
            print("data candidate:", data['candidate'])
            break
            if i % 100 == 0:
                print(i)

    recall_1 = recall(predictions, references, 1)
    recall_5 = recall(predictions, references, 5)
    recall_10 = recall(predictions, references, 10)
    mr = mean_rank(predictions, references)
    mrr = mean_reciprocal_rank(predictions, references)
    s = "r1={}, r5={}, r10={}, mr={}, mrr={}"
    print(s.format(recall_1, recall_5, recall_10, mr, mrr))

    print("testing time:", time.time() - start_time)
    # for ref, pre in zip(references, predictions):
    #     print(ref)
    #     print("-"*20)  
    #     print(pre)
    #     print("*"*100)

In [16]:
# args = set_parser()
args = Args("train")
config = Config(args, data_path="../data")

print("mode:", args.mode)

if args.mode == 'train':
    train(config)
else:
    test_set = get_dataset(config.test_path, config.w2i_vocabs, config, is_train=False)
    model = Model(n_emb=args.n_emb, n_hidden=args.n_hidden, vocab_size=args.vocab_size,
              dropout=args.dropout, d_ff=args.d_ff, n_head=args.n_head, n_block=args.n_block)
    model_dict = torch.load(args.restore)
    model.load_state_dict(model_dict)
    model.to(device)
    test(test_set, model)

mode: train
starting load...
loading time: 3.001723051071167
starting load...
loading time: 0.10674905776977539
-------Y: tensor([[   1, 1301, 3641,  ...,    0,    0,    0],
        [   1,    3,   16,  ...,    0,    0,    0],
        [   1, 4890,  340,  ...,    0,    0,    0],
        ...,
        [   1,  676, 3329,  ...,    0,    0,    0],
        [   1, 6615,  285,  ...,    0,    0,    0],
        [   1,    9,   48,  ...,    0,    0,    0]])
Y[63]: tensor([   1,    9,   48, 4483,    5,   78, 2083,  227,  227,  227,    2,    0,
           0,    0,    0,    0,    0,    0,    0,    0])
-------T: tensor([[   1,   71, 7314,  ...,    0,    0,    0],
        [   1,  501,    4,  ...,    0,    0,    0],
        [   1,    3,    4,  ...,    0,    0,    0],
        ...,
        [   1,  173,  434,  ...,    0,    0,    0],
        [   1,    3,  145,  ...,    0,    0,    0],
        [   1,  225,   13,  ...,    0,    0,    0]])
T[63]: tensor([    1,   225,    13,   146,    89,     7,     4,  4483, 1

outs.tr tensor([[[ 0.7030, -0.4805,  0.5400,  ..., -0.1906, -0.7496,  0.3430],
         [ 0.3052,  0.0879,  0.2838,  ..., -0.6523, -0.9702,  0.1908],
         [ 0.0373, -0.1900,  0.1614,  ..., -0.3116, -0.5760,  0.2814],
         ...,
         [-0.1124,  0.0360,  0.0865,  ...,  0.0020, -0.3929,  0.7360],
         [-0.0458, -0.1647,  0.7215,  ..., -0.3561, -0.8176,  0.0514],
         [ 0.3224,  0.1452,  0.1754,  ..., -0.2537,  0.0020,  0.1559]],

        [[ 0.6695, -0.1954, -0.5768,  ...,  1.0173,  0.2118,  0.8079],
         [ 0.1219,  0.0592, -0.0607,  ...,  1.2508, -0.4211,  0.1326],
         [ 0.6548,  0.1961, -0.8053,  ...,  0.2787, -0.3840,  0.6212],
         ...,
         [-0.1409, -0.3290, -0.3193,  ...,  0.4420, -0.4849,  0.0032],
         [ 0.6757, -0.4460, -0.4227,  ...,  0.3650,  0.2310, -0.3567],
         [ 0.2435,  0.4576, -0.4429,  ..., -0.1383,  0.6133,  0.4758]],

        [[ 0.6593, -0.4256,  0.9042,  ...,  0.8581, -0.0634, -0.0973],
         [ 0.8777,  0.5459, -0.1095, 

mask: tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
      

n_samples: 128
-------Y: tensor([[    1,   455,   455,  ...,     2,     0,     0],
        [    1,  2070,  4025,  ...,     0,     0,     0],
        [    1, 11583,     5,  ...,     0,     0,     0],
        ...,
        [    1,  1451,   276,  ...,     0,     0,     0],
        [    1,    33,    12,  ...,     0,     0,     0],
        [    1,     3,     3,  ...,     0,     0,     0]])
Y[63]: tensor([1, 3, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
-------T: tensor([[   1,  455,  455,  ...,  455,  455,    2],
        [   1,   24,  339,  ...,    0,    0,    0],
        [   1,  902, 1766,  ...,    0,    0,    0],
        ...,
        [   1,   85,   83,  ...,    0,    0,    0],
        [   1,  119,   85,  ...,    0,    0,    0],
        [   1,  257, 4717,  ...,    0,    0,    0]])
T[63]: tensor([    1,   257,  4717,   479,     5,     4,   303,    26,  8920,    24,
          743,   163,     4,   912,     4,   629,  1924,    30,    29,    10,
            4,   317,  1851,  1241,   

outs.tr tensor([[[-0.5599, -0.0340,  8.0108,  ..., -0.3983, -0.0898, -0.7315],
         [-0.6652, -0.2180,  8.4186,  ..., -0.4036, -0.2755, -0.5956],
         [-0.6804, -0.1674,  8.2724,  ..., -0.2529, -0.4348, -0.4464],
         ...,
         [-0.5859, -0.1681,  8.3683,  ..., -0.3181, -0.1877, -0.7049],
         [-0.8621, -0.1218,  8.4116,  ..., -0.6865, -0.3321, -0.5952],
         [-0.5632, -0.0843,  8.2205,  ..., -0.5197, -0.0678, -0.8492]],

        [[-0.5007, -0.5252,  8.5991,  ...,  0.0204, -0.2445, -0.5814],
         [-0.4123, -0.1236,  8.8284,  ..., -0.3509,  0.0114, -0.4810],
         [-0.3814, -0.0496,  8.7528,  ..., -0.5115, -0.1267, -0.6292],
         ...,
         [-0.3221, -0.3641,  8.5654,  ..., -0.1025, -0.1975, -0.7960],
         [-0.4785, -0.3198,  8.7589,  ..., -0.2232, -0.1424, -0.4703],
         [-0.6885, -0.3132,  8.9689,  ...,  0.1262, -0.2491, -0.1828]],

        [[-0.4441, -0.5829,  8.7403,  ...,  0.1422, -0.3530, -0.6025],
         [-0.5871, -0.0339,  8.7270, 

outs.tr tensor([[[-7.0789e-01, -3.3763e-01,  7.4684e+00,  ..., -4.7808e-02,
          -2.3631e-01, -9.0275e-01],
         [-5.1336e-01, -3.0271e-02,  7.5239e+00,  ..., -2.8536e-01,
          -3.8818e-01, -9.1280e-01],
         [-3.9732e-01, -3.7480e-01,  7.2771e+00,  ..., -1.6193e-01,
          -1.7415e-01, -1.0615e+00],
         ...,
         [-5.9846e-01, -3.2175e-01,  7.2506e+00,  ..., -2.2809e-01,
          -2.2195e-01, -1.1279e+00],
         [-6.2322e-01, -3.2996e-01,  7.2651e+00,  ..., -2.8500e-01,
          -3.4086e-01, -9.9605e-01],
         [-6.5649e-01, -2.0827e-01,  7.1392e+00,  ..., -2.1996e-01,
          -9.6384e-02, -9.4005e-01]],

        [[-8.9303e-01, -5.2885e-01,  7.8553e+00,  ...,  5.4307e-01,
          -4.3712e-01, -8.2020e-01],
         [-3.2141e-01, -3.7350e-01,  7.9409e+00,  ...,  4.7193e-02,
           1.5336e-01, -1.1159e+00],
         [-6.3934e-01, -3.8216e-01,  7.8821e+00,  ..., -1.9809e-02,
           9.5603e-02, -6.4535e-01],
         ...,
         [-3.8625

mask: tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
      

n_samples: 320
-------Y: tensor([[   1,   23, 1764,  ...,    0,    0,    0],
        [   1, 1962, 2329,  ...,    0,    0,    0],
        [   1,  597,   23,  ...,    0,    0,    0],
        ...,
        [   1,    3,    2,  ...,    0,    0,    0],
        [   1,  103,    5,  ...,    0,    0,    0],
        [   1,    3,    2,  ...,    0,    0,    0]])
Y[63]: tensor([1, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
-------T: tensor([[   1,   23, 1764,  ...,    0,    0,    0],
        [   1,  999,    3,  ...,    0,    0,    0],
        [   1,  261,    7,  ...,    0,    0,    0],
        ...,
        [   1,   33,    5,  ...,    0,    0,    0],
        [   1,  818,  453,  ...,    0,    0,    0],
        [   1,    3,    4,  ...,    0,    0,    0]])
T[63]: tensor([   1,    3,    4, 6669,    4,    3,    4, 2435,    4,    3,    4,  328,
           4,  151,    4,  125,    2,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,   

outs.tr tensor([[[-0.7017, -0.5310,  7.2711,  ..., -0.1230, -0.2783, -1.2039],
         [-0.5522, -0.4688,  7.2623,  ..., -0.2889, -0.3102, -1.3352],
         [-0.6510, -0.4771,  7.2141,  ..., -0.3753, -0.3174, -1.3882],
         ...,
         [-0.7982, -0.4820,  7.2374,  ..., -0.1853, -0.1657, -1.3591],
         [-0.5521, -0.4899,  7.1893,  ..., -0.1516, -0.0212, -1.1911],
         [-0.7913, -0.3478,  7.2788,  ..., -0.2848, -0.1305, -1.3208]],

        [[-0.5966, -0.5082,  7.7879,  ...,  0.0402, -0.0161, -1.1516],
         [-0.7037, -0.6040,  7.9192,  ..., -0.1253, -0.0630, -1.1228],
         [-0.6123, -0.3809,  7.7536,  ...,  0.1259,  0.1516, -0.8212],
         ...,
         [-0.7939, -0.5953,  7.8853,  ...,  0.0904, -0.2879, -0.9200],
         [-0.5299, -0.4509,  7.8303,  ...,  0.1261, -0.0236, -0.9113],
         [-0.8235, -0.4682,  8.0287,  ...,  0.1389, -0.2113, -0.9835]],

        [[-0.6172, -0.6087,  7.8301,  ..., -0.2962, -0.0505, -1.4283],
         [-0.6447, -0.5318,  7.9596, 

outs.tr tensor([[[-7.1115e-01, -5.8468e-01,  7.5689e+00,  ..., -3.8915e-01,
          -3.2984e-01, -1.4253e+00],
         [-8.1196e-01, -5.1145e-01,  7.6192e+00,  ..., -2.6243e-01,
          -2.5271e-01, -1.2620e+00],
         [-9.1993e-01, -4.9035e-01,  7.5953e+00,  ..., -2.0172e-01,
          -3.6427e-01, -1.3054e+00],
         ...,
         [-7.7187e-01, -4.7820e-01,  7.6495e+00,  ..., -2.1083e-01,
          -3.6777e-01, -1.3836e+00],
         [-8.3145e-01, -5.3505e-01,  7.6414e+00,  ..., -1.3356e-01,
          -4.2843e-01, -1.3168e+00],
         [-7.3792e-01, -6.7506e-01,  7.5336e+00,  ..., -2.0542e-01,
          -4.3589e-01, -1.4636e+00]],

        [[-8.2277e-01, -7.7271e-01,  8.1723e+00,  ..., -4.8633e-02,
          -3.0011e-01, -1.3747e+00],
         [-8.6409e-01, -7.9642e-01,  8.0776e+00,  ..., -1.7210e-01,
          -4.2621e-01, -1.0708e+00],
         [-6.9008e-01, -5.8346e-01,  8.0915e+00,  ..., -9.7044e-02,
          -5.3622e-03, -1.3579e+00],
         ...,
         [-6.2337

mask: tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
      

mask: tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
      

mask: tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 0, 0],
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
      

- Y 是真是值在训练和测试的时候都只有一句话；
- T 是上下文评论，训练4-8句话；测试只有4句话？

In [18]:
# args = set_parser()
args = Args("test")
args.restore = "../ckpt/best_checkpoint.pt"
config = Config(args, data_path="../data")

print("mode:", args.mode)

if args.mode == 'train':
    train(config)
else:
    test_set = get_dataset(config.test_path, config.w2i_vocabs, config, is_train=False)
    model = Model(n_emb=args.n_emb, n_hidden=args.n_hidden, vocab_size=args.vocab_size,
              dropout=args.dropout, d_ff=args.d_ff, n_head=args.n_head, n_block=args.n_block)
    model_dict = torch.load(args.restore)
    model.load_state_dict(model_dict)
    model.to(device)
    test(test_set, model)

mode: test
starting load...
loading time: 0.14416289329528809




starting testing...
-------Y: tensor([[   1,  238,    2,  ...,    0,    0,    0],
        [   1,  260,    2,  ...,    0,    0,    0],
        [   1,  351,    2,  ...,    0,    0,    0],
        ...,
        [   1,   72, 2685,  ...,    0,    0,    0],
        [   1, 2509,  102,  ...,    0,    0,    0],
        [   1,    3,   13,  ...,    0,    0,    0]])
Y[10]: tensor([    1, 22693,   175, 19758, 12896, 12291,   175,   807,     3,  3143,
            2,     0,     0,     0,     0,     0,     0,     0,     0,     0])
-------T: tensor([   1, 2743,  137, 1152,    6,    9,  126,  703,    8,    4,  137,    4,
        2310,  489, 5968,    4,    3,    2,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,   