In [1]:
import json
import torch

from transformers import BertTokenizer, AdamW
from collections import defaultdict
from random import choice

class Config:
    """
    句子最长长度是294 这里就不设参数限制长度了,每个batch 自适应长度
    """

    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.bert_path = '/data/sunyd/pretrained_models/bert-base-chinese/'


        self.train_data_path = 'data/train.json'
        self.dev_data_path = 'data/dev.json'
        self.test_data_path = 'data/test.json'

        self.batch_size = 1

        self.rel_dict_path = 'data/rel.json'
        self.id2rel = json.load(open(self.rel_dict_path, encoding='utf8'))
        self.rel2id = {v: k for k, v in self.id2rel.items()} # 关系到id
        self.num_rel = len(self.rel2id)  # 关系的种类数
    

        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)

        
        self.learning_rate = 1e-5
        self.bert_dim = 768
        self.epochs = 10


In [2]:
from torch.utils.data import Dataset, DataLoader
import json


def collate_fn(batch):
    #  batch是一个列表，其中是一个一个的元组，每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    text = batch[0]
    triple = batch[1]
    del batch
    return text, triple


class MyDataset(Dataset):
    def __init__(self, path):
        super().__init__()
        self.dataset = []
        with open(path, encoding='utf8') as F:
            for line in F:
                line = json.loads(line)
                self.dataset.append(line)

    def __getitem__(self, item):
        content = self.dataset[item]
        text = content['text']
        spo_list = content['spo_list']
        return text, spo_list

    def __len__(self):
        return len(self.dataset)


def create_data_iter(config):
    train_data = MyDataset(config.train_data_path)
    dev_data = MyDataset(config.dev_data_path)
    test_data = MyDataset(config.test_data_path)

    train_iter = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
    dev_iter = DataLoader(dev_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
    test_iter = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

    return train_iter, dev_iter, test_iter


class Batch:
    def __init__(self, config):
        self.tokenizer = config.tokenizer
        self.num_relations = config.num_rel
        self.rel2id = config.rel2id
        self.device = config.device

    def __call__(self, text, triple):
        text = self.tokenizer(text, padding=True).data
        batch_size = len(text['input_ids'])
        seq_len = len(text['input_ids'][0])
        sub_head = []
        sub_tail = []
        sub_heads = []
        sub_tails = []
        obj_heads = []
        obj_tails = []
        sub_len = []
        sub_head2tail = []

        for batch_index in range(batch_size):
            inner_input_ids = text['input_ids'][batch_index]  # 单个句子变成索引后
            inner_triples = triple[batch_index]
            inner_sub_heads, inner_sub_tails, inner_sub_head, inner_sub_tail, inner_sub_head2tail, inner_sub_len, inner_obj_heads, inner_obj_tails = \
                self.create_label(inner_triples, inner_input_ids, seq_len)
            sub_head.append(inner_sub_head)
            sub_tail.append(inner_sub_tail)
            sub_len.append(inner_sub_len)
            sub_head2tail.append(inner_sub_head2tail)
            sub_heads.append(inner_sub_heads)
            sub_tails.append(inner_sub_tails)
            obj_heads.append(inner_obj_heads)
            obj_tails.append(inner_obj_tails)

        input_ids = torch.tensor(text['input_ids']).to(self.device)
        mask = torch.tensor(text['attention_mask']).to(self.device)
        sub_head = torch.stack(sub_head).to(self.device)
        sub_tail = torch.stack(sub_tail).to(self.device)
        sub_heads = torch.stack(sub_heads).to(self.device)
        sub_tails = torch.stack(sub_tails).to(self.device)
        sub_len = torch.stack(sub_len).to(self.device)
        sub_head2tail = torch.stack(sub_head2tail).to(self.device)
        obj_heads = torch.stack(obj_heads).to(self.device)
        obj_tails = torch.stack(obj_tails).to(self.device)
        return {
                   'input_ids': input_ids,
                   'mask': mask,
                   'sub_head2tail': sub_head2tail,
                   'sub_len': sub_len
               }, {
                   'sub_heads': sub_heads,
                   'sub_tails': sub_tails,
                   'obj_heads': obj_heads,
                   'obj_tails': obj_tails
               }

    def create_label(self, inner_triples, inner_input_ids, seq_len):

        inner_sub_heads, inner_sub_tails = torch.zeros(seq_len), torch.zeros(seq_len)
        inner_sub_head, inner_sub_tail = torch.zeros(seq_len), torch.zeros(seq_len)
        
        inner_obj_heads = torch.zeros((seq_len, self.num_relations))
        inner_obj_tails = torch.zeros((seq_len, self.num_relations))
        
        inner_sub_head2tail = torch.zeros(seq_len)  # 随机抽取一个实体，从开头一个词到末尾词的索引

        # 因为数据预处理代码还待优化,会有不存在关系三元组的情况，
        # 初始化一个主词的长度为1，即没有主词默认主词长度为1，
        # 防止零除报错,初始化任何非零数字都可以，没有主词分子是全零矩阵
        inner_sub_len = torch.tensor([1], dtype=torch.float)
        # 主词到谓词的映射
        s2ro_map = defaultdict(list)
        for inner_triple in inner_triples:

            inner_triple = (
                self.tokenizer(inner_triple['subject'], add_special_tokens=False)['input_ids'],
                # self.rel_vocab.to_index(inner_triple['predicate']),
                 eval(self.rel2id[inner_triple['predicate']]),
                self.tokenizer(inner_triple['object'], add_special_tokens=False)['input_ids']
            )

            sub_head_idx = self.find_head_idx(inner_input_ids, inner_triple[0])
            obj_head_idx = self.find_head_idx(inner_input_ids, inner_triple[2])

            if sub_head_idx != -1 and obj_head_idx != -1:
                sub = (sub_head_idx, sub_head_idx + len(inner_triple[0]) - 1)
                # s2ro_map保存主语到谓语的映射
                s2ro_map[sub].append(
                    (obj_head_idx, obj_head_idx + len(inner_triple[2]) - 1, inner_triple[1]))  # {(3,5):[(7,8,0)]} 0是关系

        if s2ro_map:
            for s in s2ro_map:
                inner_sub_heads[s[0]] = 1
                inner_sub_tails[s[1]] = 1

            sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
            inner_sub_head[sub_head_idx] = 1
            inner_sub_tail[sub_tail_idx] = 1
            inner_sub_head2tail[sub_head_idx:sub_tail_idx + 1] = 1
            inner_sub_len = torch.tensor([sub_tail_idx + 1 - sub_head_idx], dtype=torch.float)
            for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
                inner_obj_heads[ro[0]][ro[2]] = 1
                inner_obj_tails[ro[1]][ro[2]] = 1

        return inner_sub_heads, inner_sub_tails, inner_sub_head, inner_sub_tail, inner_sub_head2tail, inner_sub_len, inner_obj_heads, inner_obj_tails

    @staticmethod
    def find_head_idx(source, target):
        target_len = len(target)
        for i in range(len(source)):
            if source[i: i + target_len] == target:
                return i
        return -1



In [3]:
import torch.nn as nn
import torch
from transformers import BertModel


class CasRel(nn.Module):
    def __init__(self, config):
        super(CasRel, self).__init__()
        self.config = config
        self.bert = BertModel.from_pretrained(self.config.bert_path)
        self.sub_heads_linear = nn.Linear(self.config.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.config.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.config.bert_dim, self.config.num_rel)
        self.obj_tails_linear = nn.Linear(self.config.bert_dim, self.config.num_rel)
        self.alpha = 0.25
        self.gamma = 2

    def get_encoded_text(self, token_ids, mask):
        encoded_text = self.bert(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        pred_sub_heads = torch.sigmoid(self.sub_heads_linear(encoded_text))
        pred_sub_tails = torch.sigmoid(self.sub_tails_linear(encoded_text))
        return pred_sub_heads, pred_sub_tails

    def get_objs_for_specific_sub(self, sub_head2tail, sub_len, encoded_text):
        # sub_head_mapping [batch, 1, seq] * encoded_text [batch, seq, dim]
        sub = torch.matmul(sub_head2tail, encoded_text)  # batch size,1,dim
        sub_len = sub_len.unsqueeze(1)
        sub = sub / sub_len  # batch size, 1,dim
        encoded_text = encoded_text + sub
        #  [batch size, seq len,bert_dim] -->[batch size, seq len,relathion counts]
        pred_obj_heads = torch.sigmoid(self.obj_heads_linear(encoded_text))
        pred_obj_tails = torch.sigmoid(self.obj_tails_linear(encoded_text))
        return pred_obj_heads, pred_obj_tails

    def forward(self, input_ids, mask, sub_head2tail, sub_len):
        """

        :param token_ids:[batch size, seq len]
        :param mask:[batch size, seq len]
        :param sub_head:[batch size, seq len]
        :param sub_tail:[batch size, seq len]
        :return:
        """
        encoded_text = self.get_encoded_text(input_ids, mask)
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        sub_head2tail = sub_head2tail.unsqueeze(1)  # [[batch size,1, seq len]]
        pred_obj_heads, pre_obj_tails = self.get_objs_for_specific_sub(sub_head2tail, sub_len, encoded_text)

        return {
            "pred_sub_heads": pred_sub_heads,
            "pred_sub_tails": pred_sub_tails,
            "pred_obj_heads": pred_obj_heads,
            "pred_obj_tails": pre_obj_tails,
            'mask': mask
        }

    def compute_loss(self, pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails, mask, sub_heads,
                     sub_tails, obj_heads, obj_tails):
        rel_count = obj_heads.shape[-1]
        rel_mask = mask.unsqueeze(-1).repeat(1, 1, rel_count)
        loss_1 = self.loss_fun(pred_sub_heads, sub_heads, mask)
        loss_2 = self.loss_fun(pred_sub_tails, sub_tails, mask)
        loss_3 = self.loss_fun(pred_obj_heads, obj_heads, rel_mask)
        loss_4 = self.loss_fun(pred_obj_tails, obj_tails, rel_mask)
        return loss_1 + loss_2 + loss_3 + loss_4

    def loss_fun(self, logist, label, mask):
        count = torch.sum(mask)
        logist = logist.view(-1)
        label = label.view(-1)
        mask = mask.view(-1)
        
        alpha_factor = torch.where(torch.eq(label,1), 1- self.alpha,self.alpha)
        focal_weight = torch.where(torch.eq(label,1),1-logist,logist)
        
        loss = -(torch.log(logist) * label + torch.log(1 - logist) * (1 - label)) * mask
        return torch.sum(focal_weight * loss) / count
    
    def predict(self, input_ids, mask):
        encoded_text = self.get_encoded_text(input_ids, mask)
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        pred_sub_heads = convert_score_to_zero_one(pred_sub_heads)
        pred_sub_tails = convert_score_to_zero_one(pred_sub_tails)
        subs = extract_sub(pred_sub_heads.squeeze(), pred_sub_tails.squeeze())
        res = []
        for sub in subs:
            # print('sub: ', sub)
            sub_text = ''.join(config.tokenizer.convert_ids_to_tokens(input_ids[0][sub[0]: sub[1] + 1]))
            # print('sub_text:', sub_text)
            sub_head2tail = torch.zeros(len(encoded_text[0]))
            # print('sub_head2tail:', sub_head2tail)
            sub_head2tail[sub[0]: sub[1] + 1] = 1
            # print('sub_head2tail:', sub_head2tail)
            sub_len = torch.tensor([sub[1] - sub[0] + 1])
            pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head2tail, sub_len, encoded_text)
            pred_obj_heads = convert_score_to_zero_one(pred_obj_heads)
            pred_obj_tails = convert_score_to_zero_one(pred_obj_tails)
            pred_ojbs = extract_obj_and_rel(pred_obj_heads.squeeze(), pred_obj_tails.squeeze())
            for obj in pred_ojbs:
                # print('obj:', obj)
                obj_text = ''.join(config.tokenizer.convert_ids_to_tokens(input_ids[0][obj[1]: obj[2] + 1]))
                # print('obj_text:', obj_text)
                # print('relation:', config.id2rel[str(obj[0])])
                res.append((config.id2rel[str(obj[0])], (sub_text, sub), (obj_text, obj[1: 3])))
        
        return res


In [4]:
def load_model(config):
    device = config.device
    model = CasRel(config)
    model.to(device)

    # prepare optimzier
    param_optimizer = list(model.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]

    optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, eps=10e-8)
    sheduler = None

    return model, optimizer, sheduler, device


In [5]:
import pandas as pd
from tqdm import tqdm

def train_epoch(model, train_iter, dev_iter, optimizer, batch, best_triple_f1, epoch):
    for step, (text, triple) in enumerate(train_iter):
        model.train()
        inputs, labels = batch(text, triple)
        logist = model(**inputs)
        loss = model.compute_loss(**logist, **labels)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 500 == 1:
            sub_precision, sub_recall, sub_f1, triple_precision, triple_recall, triple_f1, df = test(model, dev_iter,
                                                                                                     batch)
            if triple_f1 > best_triple_f1:
                best_triple_f1 = triple_f1
                torch.save(model.state_dict(), 'best_f1.pth')
                print(
                    'epoch:{},step:{},sub_precision:{:.4f}, sub_recall:{:.4f}, sub_f1:{:.4f}, triple_precision:{:.4f}, triple_recall:{:.4f}, triple_f1:{:.4f},train loss:{:.4f}'.format(
                        epoch, step, sub_precision, sub_recall, sub_f1, triple_precision, triple_recall, triple_f1,
                        loss.item()))
                print(df)

    return best_triple_f1


def train(model, train_iter, dev_iter, optimizer, config):
    epochs = config.epochs
    best_triple_f1 = 0
    for epoch in range(epochs):
        best_triple_f1 = train_epoch(model, train_iter, dev_iter, optimizer, batch, best_triple_f1, epoch)


def test(model, dev_iter, batch):
    model.eval()
    df = pd.DataFrame(columns=['TP', 'PRED', "REAL", 'p', 'r', 'f1'], index=['sub', 'triple'])
    df.fillna(0, inplace=True)

    for text, triple in tqdm(dev_iter):
        inputs, labels = batch(text, triple)
        logist = model(**inputs)
        
        pred_sub_heads = convert_score_to_zero_one(logist['pred_sub_heads'])
        pred_sub_tails = convert_score_to_zero_one(logist['pred_sub_tails'])

        sub_heads = convert_score_to_zero_one(labels['sub_heads'])
        sub_tails = convert_score_to_zero_one(labels['sub_tails'])
        batch_size = inputs['input_ids'].shape[0]

        obj_heads = convert_score_to_zero_one(labels['obj_heads'])
        obj_tails = convert_score_to_zero_one(labels['obj_tails'])
        pred_obj_heads = convert_score_to_zero_one(logist['pred_obj_heads'])
        pred_obj_tails = convert_score_to_zero_one(logist['pred_obj_tails'])

        for batch_index in range(batch_size):
            pred_subs = extract_sub(pred_sub_heads[batch_index].squeeze(), pred_sub_tails[batch_index].squeeze())
            true_subs = extract_sub(sub_heads[batch_index].squeeze(), sub_tails[batch_index].squeeze())

            pred_ojbs = extract_obj_and_rel(pred_obj_heads[batch_index], pred_obj_tails[batch_index])
            true_objs = extract_obj_and_rel(obj_heads[batch_index], obj_tails[batch_index])

            df['PRED']['sub'] += len(pred_subs)
            df['REAL']['sub'] += len(true_subs)
            for true_sub in true_subs:
                if true_sub in pred_subs:
                    df['TP']['sub'] += 1

            df['PRED']['triple'] += len(pred_ojbs)
            df['REAL']['triple'] += len(true_objs)
            for true_obj in true_objs:
                if true_obj in pred_ojbs:
                    df['TP']['triple'] += 1

    df.loc['sub','p'] = df['TP']['sub'] / (df['PRED']['sub'] + 1e-9)
    df.loc['sub','r'] = df['TP']['sub'] / (df['REAL']['sub'] + 1e-9)
    df.loc['sub','f1'] = 2 * df['p']['sub'] * df['r']['sub'] / (df['p']['sub'] + df['r']['sub'] + 1e-9)
    
    sub_precision = df['TP']['sub'] / (df['PRED']['sub'] + 1e-9)
    sub_recall = df['TP']['sub'] / (df['REAL']['sub'] + 1e-9)
    sub_f1 = 2 * sub_precision * sub_recall  / (sub_precision + sub_recall  + 1e-9)

    df.loc['triple','p'] = df['TP']['triple'] / (df['PRED']['triple'] + 1e-9)
    df.loc['triple','r'] = df['TP']['triple'] / (df['REAL']['triple'] + 1e-9)
    df.loc['triple','f1'] = 2 * df['p']['triple'] * df['r']['triple'] / (
            df['p']['triple'] + df['r']['triple'] + 1e-9)
    
    
    triple_precision = df['TP']['triple'] / (df['PRED']['triple'] + 1e-9)
    triple_recall = df['TP']['triple'] / (df['REAL']['triple'] + 1e-9)
    triple_f1 = 2 * triple_precision * triple_recall / (
            triple_precision + triple_recall + 1e-9)

    return sub_precision, sub_recall,sub_f1, triple_precision, triple_recall, triple_f1, df


def extract_sub(pred_sub_heads, pred_sub_tails):
    subs = []
    heads = torch.arange(0, len(pred_sub_heads))[pred_sub_heads == 1]
    tails = torch.arange(0, len(pred_sub_tails))[pred_sub_tails == 1]

    for head, tail in zip(heads, tails):
        if tail >= head:
            subs.append((head.item(), tail.item()))
    return subs


def extract_obj_and_rel(obj_heads, obj_tails):
    obj_heads = obj_heads.T
    obj_tails = obj_tails.T
    rel_count = obj_heads.shape[0]
    obj_and_rels = []  # [(rel_index,strart_index,end_index),(rel_index,strart_index,end_index)]

    for rel_index in range(rel_count):
        obj_head = obj_heads[rel_index]
        obj_tail = obj_tails[rel_index]

        objs = extract_sub(obj_head, obj_tail)
        if objs:
            for obj in objs:
                start_index, end_index = obj
                obj_and_rels.append((rel_index, start_index, end_index))
    return obj_and_rels


def convert_score_to_zero_one(tensor):
    tensor[tensor>=0.5] = 1
    tensor[tensor<0.5] = 0
    return tensor

In [None]:
if __name__ == '__main__':
    config = Config()
    train_data = MyDataset(config.train_data_path)
    model, optimizer, sheduler, device = load_model(config)
    train_iter, dev_iter, test_iter = create_data_iter(config)
    batch = Batch(config)
    train(model, train_iter, dev_iter, optimizer, config)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [11:53<00:00, 15.67it/s]


epoch:0,step:1,sub_precision:0.0124, sub_recall:0.0046, sub_f1:0.0067, triple_precision:0.0001, triple_recall:0.0112, triple_f1:0.0003,train loss:1.4684
         TP     PRED   REAL         p         r        f1
sub      54     4359  11759  0.012388  0.004592  0.006701
triple  181  1418967  16225  0.000128  0.011156  0.000252


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [08:50<00:00, 21.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [05:46<00:00, 32.26it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [10:13<00:00, 18.23it/s]


epoch:0,step:1501,sub_precision:0.8574, sub_recall:0.8600, sub_f1:0.8587, triple_precision:0.9248, triple_recall:0.0151, triple_f1:0.0298,train loss:0.0136
           TP   PRED   REAL         p         r        f1
sub     10113  11795  11759  0.857397  0.860022  0.858708
triple    246    266  16239  0.924812  0.015149  0.029809


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [09:16<00:00, 20.12it/s]


epoch:0,step:2001,sub_precision:0.8461, sub_recall:0.8768, sub_f1:0.8611, triple_precision:0.8163, triple_recall:0.1484, triple_f1:0.2511,train loss:0.0151
           TP   PRED   REAL         p         r        f1
sub     10310  12186  11759  0.846053  0.876775  0.861140
triple   2408   2950  16229  0.816271  0.148376  0.251108


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [06:42<00:00, 27.83it/s]


epoch:0,step:2501,sub_precision:0.8871, sub_recall:0.8724, sub_f1:0.8797, triple_precision:0.8246, triple_recall:0.1922, triple_f1:0.3117,train loss:0.0122
           TP   PRED   REAL         p         r        f1
sub     10258  11563  11759  0.887140  0.872353  0.879684
triple   3121   3785  16240  0.824571  0.192180  0.311710


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [10:15<00:00, 18.17it/s]


epoch:0,step:3001,sub_precision:0.8889, sub_recall:0.8724, sub_f1:0.8806, triple_precision:0.7865, triple_recall:0.2714, triple_f1:0.4035,train loss:0.0107
           TP   PRED   REAL         p         r        f1
sub     10258  11540  11759  0.888908  0.872353  0.880553
triple   4405   5601  16231  0.786467  0.271394  0.403536


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [10:18<00:00, 18.10it/s]


epoch:0,step:3501,sub_precision:0.8972, sub_recall:0.8360, sub_f1:0.8655, triple_precision:0.7640, triple_recall:0.2907, triple_f1:0.4211,train loss:0.0052
          TP   PRED   REAL         p         r        f1
sub     9830  10956  11759  0.897225  0.835955  0.865507
triple  4718   6175  16231  0.764049  0.290678  0.421137


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [08:25<00:00, 22.15it/s]


epoch:0,step:4001,sub_precision:0.9011, sub_recall:0.8604, sub_f1:0.8803, triple_precision:0.6918, triple_recall:0.4415, triple_f1:0.5390,train loss:0.0027
           TP   PRED   REAL         p         r        f1
sub     10118  11229  11759  0.901060  0.860447  0.880285
triple   7167  10360  16234  0.691795  0.441481  0.538994


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [08:44<00:00, 21.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [10:15<00:00, 18.18it/s]


epoch:0,step:5001,sub_precision:0.9064, sub_recall:0.8633, sub_f1:0.8843, triple_precision:0.6938, triple_recall:0.4801, triple_f1:0.5675,train loss:0.0035
           TP   PRED   REAL         p         r        f1
sub     10151  11199  11759  0.906420  0.863254  0.884310
triple   7796  11236  16237  0.693841  0.480138  0.567539


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [10:09<00:00, 18.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:32<00:00, 41.14it/s]


epoch:0,step:6001,sub_precision:0.9026, sub_recall:0.8693, sub_f1:0.8856, triple_precision:0.6890, triple_recall:0.5404, triple_f1:0.6057,train loss:0.0026
           TP   PRED   REAL         p         r        f1
sub     10222  11325  11759  0.902605  0.869292  0.885635
triple   8772  12732  16232  0.688973  0.540414  0.605717


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:32<00:00, 41.06it/s]
 40%|████████████████████████████████████████████▎                                                                 | 4513/11191 [01:50<02:45, 40.27it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:24<00:00, 42.27it/s]


epoch:0,step:10001,sub_precision:0.9435, sub_recall:0.8453, sub_f1:0.8917, triple_precision:0.7317, triple_recall:0.5638, triple_f1:0.6369,train loss:0.0034
          TP   PRED   REAL         p         r        f1
sub     9940  10535  11759  0.943522  0.845310  0.891720
triple  9149  12504  16228  0.731686  0.563779  0.636851


 34%|█████████████████████████████████████▏                                                                        | 3779/11191 [01:28<02:56, 41.94it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:24<00:00, 42.23it/s]


epoch:0,step:11501,sub_precision:0.9302, sub_recall:0.8672, sub_f1:0.8976, triple_precision:0.6617, triple_recall:0.6745, triple_f1:0.6680,train loss:0.0040
           TP   PRED   REAL         p         r        f1
sub     10197  10962  11759  0.930213  0.867166  0.897584
triple  10949  16547  16232  0.661691  0.674532  0.668050


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:23<00:00, 42.43it/s]


epoch:0,step:13501,sub_precision:0.9181, sub_recall:0.8792, sub_f1:0.8982, triple_precision:0.6736, triple_recall:0.7102, triple_f1:0.6914,train loss:0.1830
           TP   PRED   REAL         p         r        f1
sub     10338  11260  11759  0.918117  0.879156  0.898215
triple  11532  17119  16237  0.673637  0.710230  0.691450


 48%|█████████████████████████████████████████████████████                                                         | 5401/11191 [02:07<02:15, 42.65it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:25<00:00, 42.19it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:26<00:00, 42.02it/s]
 61%|███████████████████████████████████████████████████████████████████                                           | 6828/11191 [02:41<01:43, 42.06it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop se

epoch:0,step:33001,sub_precision:0.9259, sub_recall:0.8931, sub_f1:0.9092, triple_precision:0.6797, triple_recall:0.7727, triple_f1:0.7232,train loss:0.0021
           TP   PRED   REAL         p         r        f1
sub     10502  11343  11759  0.925857  0.893103  0.909185
triple  12539  18449  16227  0.679657  0.772724  0.723209


 29%|███████████████████████████████▍                                                                              | 3201/11191 [01:16<03:09, 42.10it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:22<00:00, 42.65it/s]


epoch:0,step:39001,sub_precision:0.9530, sub_recall:0.8727, sub_f1:0.9111, triple_precision:0.7031, triple_recall:0.7775, triple_f1:0.7384,train loss:0.1434
           TP   PRED   REAL         p         r        f1
sub     10262  10768  11759  0.953009  0.872693  0.911084
triple  12624  17955  16236  0.703091  0.777531  0.738440


 34%|█████████████████████████████████████▏                                                                        | 3779/11191 [01:28<02:52, 43.04it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:24<00:00, 42.37it/s]
 52%|████████████████████████████████████████████████████████▊                                                     | 5779/11191 [02:16<02:05, 43.24it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current valu

epoch:1,step:10001,sub_precision:0.9465, sub_recall:0.9050, sub_f1:0.9253, triple_precision:0.7152, triple_recall:0.7783, triple_f1:0.7454,train loss:0.0005
           TP   PRED   REAL         p         r        f1
sub     10642  11244  11759  0.946460  0.905009  0.925271
triple  12647  17684  16249  0.715166  0.778325  0.745410


 66%|█████████████████████████████████████████████████████████████████████████                                     | 7437/11191 [02:54<01:31, 41.12it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:26<00:00, 42.01it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:25<00:00, 42.09it/s]


epoch:1,step:18001,sub_precision:0.9338, sub_recall:0.9230, sub_f1:0.9284, triple_precision:0.7024, triple_recall:0.8089, triple_f1:0.7519,train loss:0.0016
           TP   PRED   REAL         p         r        f1
sub     10853  11622  11759  0.933832  0.922953  0.928361
triple  13135  18699  16239  0.702444  0.808855  0.751903


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:25<00:00, 42.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:25<00:00, 42.17it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:26<00:00, 41.92it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:27<00:00, 41.91it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:24<00:00, 42.28it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11191/11191 [04:22<00:00, 42.69it/s]
100%|█████████████████████████████████████████████████████████████████████████████

In [9]:
config = Config()
model = CasRel(config)
model.load_state_dict(torch.load('best_f1.pth'))
model.eval()
sentence = '周星驰导演的大话西游很好看'
inputs = config.tokenizer(sentence, return_tensors='pt')
model.predict(inputs['input_ids'], inputs['attention_mask'])

[('导演', ('大话西游', (7, 10)), ('周星驰', (1, 3)))]