# Setup

In [1]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
! pip install fastNLP
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
from torch.utils.data import Dataset, DataLoader
import json
import torch
from fastNLP import Vocabulary
from transformers import BertTokenizer, AdamW
from collections import defaultdict
from random import choice
from transformers import AutoTokenizer, AutoModelForMaskedLM

In [5]:
'cuda' if torch.cuda.is_available() else 'cpu'

'cuda'

In [6]:
class Config:
    """
    句子最长长度是294 这里就不设参数限制长度了,每个batch 自适应长度
    """

    def __init__(self, model_name, epochs):
        self.model_name = model_name
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.bert_path = '../input/huggingface-bert/bert-base-chinese'

        self.train_file = '/content/drive/MyDrive/Final Project/Datafile/train.json'
        self.test_file = '/content/drive/MyDrive/Final Project/Datafile/test.json'
        self.dev_file = '/content/drive/MyDrive/Final Project/Datafile/dev.json'

        self.batch_size = 16

        self.rel_dict_path = '/content/drive/MyDrive/Final Project/Datafile/rel.json'
        id2rel = json.load(open(self.rel_dict_path, encoding='utf8'))
        self.rel_vocab = Vocabulary(unknown=None, padding=None)
        self.rel_vocab.add_word_lst(list(id2rel.values()))  # relation to index
        self.num_rel = 18 

        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        #self.tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
        #self.tokenizer_wwn = BertTokenizer.from_pretrained("hfl/chinese-bert-wwm")

        
        self.learning_rate = 1e-5
        self.bert_dim = 768
        self.epochs = epochs
config = Config("hfl/chinese-bert-wwm",5)

# Data Preprocessing

In [7]:
def collate_fn(batch):
    batch = list(zip(*batch))
    text = batch[0]
    triple = batch[1]
    del batch
    return text, triple

class MyDataset(Dataset):
    def __init__(self, path):
        super().__init__()
        self.dataset = []
        with open(path, encoding='utf8') as F:
            for line in F:
                line = json.loads(line)
                self.dataset.append(line)

    def __getitem__(self, item):
        content = self.dataset[item]
        text = content['text']
        spo_list = content['spo_list']
        return text, spo_list

    def __len__(self):
        return len(self.dataset)

In [8]:
train_data = MyDataset(config.train_file)
test_data = MyDataset(config.test_file)
dev_data = MyDataset(config.dev_file)

In [30]:
dev_data.__len__()

11191

In [9]:
train_data.__len__()

55959

In [10]:
test_data.__len__()

13417

In [11]:
for i in train_data:
  print(i)
  break;

('笔       名：木斧原       名：杨莆曾  用  名：穆新文、牧羊、寒白、洋漾出生日期：1931—职       业：作家、诗人性    别： 男民    族： 回族政治面貌：中共党员 祖       籍：固原县出  生  地：成都', [{'predicate': '民族', 'object_type': '文本', 'subject_type': '人物', 'object': '回族', 'subject': '木斧'}, {'predicate': '出生日期', 'object_type': '日期', 'subject_type': '人物', 'object': '1931', 'subject': '木斧'}, {'predicate': '出生地', 'object_type': '地点', 'subject_type': '人物', 'object': '成都', 'subject': '木斧'}])


In [12]:
def create_data_iter(config):
    train_data = MyDataset(config.train_file)
    dev_data = MyDataset(config.dev_file)
    test_data = MyDataset(config.test_file)
    

    train_iter = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
    dev_iter = DataLoader(dev_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
    test_iter = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

    return train_iter, dev_iter, test_iter

In [13]:
train_iter = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
for i in train_iter:
    text_raw, triple = i
    break

In [14]:
text_raw

('师坦，男，1966年3月出生，汉族，河北怀来人，中共党员，研究生学历，工程师、注册规划师',
 '《汗血宝马》是由吴子牛执导的作品，聂远参加演出',
 '之前第一部吴若希的《越难越爱》也不错，每次到有情感线的地方就放这首，泪流满面',
 '苏杰，国家二级指挥，1978年毕业于北京市艺术学校、曾进修于中央音乐学院指挥系受教于黄飞立、郑晓英、秋里、李德伦',
 '杨力华，男，1962年出生，中山大学数学与计算科学学院教授、博士生导师，副院长',
 '4月，孔维领衔主演的武侠话剧《新龙门客栈》在大连文化中心开演12',
 '《婚礼策划人》是2002年播出的电视剧，由武内英树，羽住英一郎导演',
 '乔恩·海德，男，1977年在美国中西部科罗拉多州北部的柯林斯堡出生，是临床医生詹姆斯·海德六个孩子中的老四',
 '山东浩信集团有限公司于2007年09月10日在昌邑市市场监督管理局登记成立',
 '祁永膺生于1853，去世于1905，字伯福，别字荫杰，号子服',
 '范业展，男，1964年出生，毕业于北京林业大学，学士学位，高级工程师',
 '《花落随》是李维演唱的歌曲，收录在专辑《烟花问》',
 '徐贇：我毕业华东交通大学艺术学院之后一直从事店面设计、商业空间设计和品牌终端形象设计',
 '《物联网应用启示录》是2011年机械工业出版社出版的图书，作者是陈海滢',
 '《电信行业节能减排技术、方法与案例》是2010年1月人民邮电出版社出版的图书，作者是秦廷奎',
 '亚美尼亚革命联盟1890年成立于俄罗斯帝国梯弗里斯（今[格鲁吉亚]首都第比利斯）')

In [15]:
print(triple)

([{'predicate': '出生地', 'object_type': '地点', 'subject_type': '人物', 'object': '河北怀来', 'subject': '师坦'}, {'predicate': '民族', 'object_type': '文本', 'subject_type': '人物', 'object': '汉族', 'subject': '师坦'}, {'predicate': '出生日期', 'object_type': '日期', 'subject_type': '人物', 'object': '1966年3月', 'subject': '师坦'}], [{'predicate': '导演', 'object_type': '人物', 'subject_type': '影视作品', 'object': '吴子牛', 'subject': '汗血宝马'}], [{'predicate': '歌手', 'object_type': '人物', 'subject_type': '歌曲', 'object': '吴若希', 'subject': '越难越爱'}], [{'predicate': '毕业院校', 'object_type': '学校', 'subject_type': '人物', 'object': '北京市艺术学校', 'subject': '苏杰'}], [{'predicate': '出生日期', 'object_type': '日期', 'subject_type': '人物', 'object': '1962年', 'subject': '杨力华'}], [{'predicate': '主演', 'object_type': '人物', 'subject_type': '影视作品', 'object': '孔维', 'subject': '新龙门客栈'}], [{'predicate': '导演', 'object_type': '人物', 'subject_type': '影视作品', 'object': '羽住英一郎', 'subject': '婚礼策划人'}, {'predicate': '导演', 'object_type': '人物', 'subject_type': '影视作品', 'obj

In [16]:
class Batch:
    def __init__(self, config):
        self.tokenizer = config.tokenizer
        self.num_relations = config.num_rel
        self.rel_vocab = config.rel_vocab
        self.device = config.device

    def __call__(self, text, triple):
        ''' Generate inputs and labels for model training
        Parameters: 
          text: raw text sentences with given batch size generated by dataloader
          triple: triple for sentences with given batch size generated by dataloader
        return:
          inputs -> tensor 
          labels -> tensor
        '''
        text = self.tokenizer(text, padding=True).data
        batch_size = len(text['input_ids'])
        seq_len = len(text['input_ids'][0])
        sub_head = []
        sub_tail = []
        sub_heads = []
        sub_tails = []
        obj_heads = []
        obj_tails = []
        sub_len = []
        sub_head2tail = []

        for batch_index in range(batch_size):
            inner_input_ids = text['input_ids'][batch_index]  # for each sentence index in batch 
            inner_triples = triple[batch_index] 
            inner_sub_heads, inner_sub_tails, inner_sub_head, inner_sub_tail, inner_sub_head2tail, inner_sub_len, inner_obj_heads, inner_obj_tails = \
                self.create_label(inner_triples, inner_input_ids, seq_len)
            sub_head.append(inner_sub_head)
            sub_tail.append(inner_sub_tail)
            sub_len.append(inner_sub_len)
            sub_head2tail.append(inner_sub_head2tail)
            sub_heads.append(inner_sub_heads)
            sub_tails.append(inner_sub_tails)
            obj_heads.append(inner_obj_heads)
            obj_tails.append(inner_obj_tails)

        input_ids = torch.tensor(text['input_ids']).to(self.device)
        mask = torch.tensor(text['attention_mask']).to(self.device)
        sub_head = torch.stack(sub_head).to(self.device)
        sub_tail = torch.stack(sub_tail).to(self.device)
        sub_heads = torch.stack(sub_heads).to(self.device)
        sub_tails = torch.stack(sub_tails).to(self.device)
        sub_len = torch.stack(sub_len).to(self.device)
        sub_head2tail = torch.stack(sub_head2tail).to(self.device)
        obj_heads = torch.stack(obj_heads).to(self.device)
        obj_tails = torch.stack(obj_tails).to(self.device)

        return {
                   'input_ids': input_ids,
                   'mask': mask,
                   'sub_head2tail': sub_head2tail,
                   'sub_len': sub_len
               }, {
                   'sub_heads': sub_heads,
                   'sub_tails': sub_tails,
                   'obj_heads': obj_heads,
                   'obj_tails': obj_tails
               }

    def create_label(self, inner_triples, inner_input_ids, seq_len):
        '''generate labels for given text sentence
        Parameters:
          inner_triples: triple for each sentence within the batch
          inner_input_ids: word index for each sentence within the batch
          seq_len: padding size within the batch
        Return:
          inner_sub_heads, inner_sub_tails, inner_sub_head, inner_sub_tail, inner_sub_head2tail -> tensor(seq_len)
          inner_obj_heads, inner_obj_tails -> tensor(seq_len, num_relations)
        '''
        inner_sub_heads, inner_sub_tails = torch.zeros(seq_len), torch.zeros(seq_len)
        inner_sub_head, inner_sub_tail = torch.zeros(seq_len), torch.zeros(seq_len)
        inner_obj_heads = torch.zeros((seq_len, self.num_relations))
        inner_obj_tails = torch.zeros((seq_len, self.num_relations))
        inner_sub_head2tail = torch.zeros(seq_len)  # 随机抽取一个实体，从开头一个词到末尾词的索引

        # 因为数据预处理代码还待优化,会有不存在关系三元组的情况，
        # 初始化一个主词的长度为1，即没有主词默认主词长度为1，
        # 防止零除报错,初始化任何非零数字都可以，没有主词分子是全零矩阵
        inner_sub_len = torch.tensor([1], dtype=torch.float)
        # 主词到谓词的映射
        s2ro_map = defaultdict(list)
        for inner_triple in inner_triples:

            inner_triple = (
                self.tokenizer(inner_triple['subject'], add_special_tokens=False)['input_ids'],
                self.rel_vocab.to_index(inner_triple['predicate']),
                self.tokenizer(inner_triple['object'], add_special_tokens=False)['input_ids']
            )

            sub_head_idx = self.find_head_idx(inner_input_ids, inner_triple[0])
            obj_head_idx = self.find_head_idx(inner_input_ids, inner_triple[2])

            if sub_head_idx != -1 and obj_head_idx != -1:
                sub = (sub_head_idx, sub_head_idx + len(inner_triple[0]) - 1)
                # s2ro_map保存主语到谓语的映射
                s2ro_map[sub].append(
                    (obj_head_idx, obj_head_idx + len(inner_triple[2]) - 1, inner_triple[1]))  # {(3,5):[(7,8,0)]} 0 is relation

        if s2ro_map:
            for s in s2ro_map:
                inner_sub_heads[s[0]] = 1
                inner_sub_tails[s[1]] = 1

            sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
            inner_sub_head[sub_head_idx] = 1
            inner_sub_tail[sub_tail_idx] = 1
            inner_sub_head2tail[sub_head_idx:sub_tail_idx + 1] = 1
            inner_sub_len = torch.tensor([sub_tail_idx + 1 - sub_head_idx], dtype=torch.float)
            for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
                inner_obj_heads[ro[0]][ro[2]] = 1
                inner_obj_tails[ro[1]][ro[2]] = 1

        return inner_sub_heads, inner_sub_tails, inner_sub_head, inner_sub_tail, inner_sub_head2tail, inner_sub_len, inner_obj_heads, inner_obj_tails

    @staticmethod
    def find_head_idx(source, target):
        target_len = len(target)
        for i in range(len(source)):
            if source[i: i + target_len] == target:
                return i
        return -1

# Modeling - CasRel



In [17]:
import torch.nn as nn
import torch
from transformers import BertModel

In [18]:
class CasRel(nn.Module):
    def __init__(self, config):
        super(CasRel, self).__init__()
        self.config = config
        self.model_name = self.config.model_name
        #self.bert = BertModel.from_pretrained("bert-base-chinese")
        self.bert = BertModel.from_pretrained(self.model_name)
        self.sub_heads_linear = nn.Linear(self.config.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.config.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.config.bert_dim, self.config.num_rel)
        self.obj_tails_linear = nn.Linear(self.config.bert_dim, self.config.num_rel)
        self.alpha = 0.25
        self.gamma = 2

    def get_encoded_text(self, token_ids, mask):        
        encoded_text = self.bert(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        pred_sub_heads = torch.sigmoid(self.sub_heads_linear(encoded_text))
        pred_sub_tails = torch.sigmoid(self.sub_tails_linear(encoded_text))
        return pred_sub_heads, pred_sub_tails

    def get_objs_for_specific_sub(self, sub_head2tail, sub_len, encoded_text):
        # sub_head_mapping [batch, 1, seq] * encoded_text [batch, seq, dim]
        sub = torch.matmul(sub_head2tail, encoded_text)  # batch size,1,dim
        sub_len = sub_len.unsqueeze(1)
        sub = sub / sub_len  # batch size, 1,dim
        encoded_text = encoded_text + sub
        #  [batch size, seq len,bert_dim] -->[batch size, seq len,relathion counts]
        pred_obj_heads = torch.sigmoid(self.obj_heads_linear(encoded_text))
        pred_obj_tails = torch.sigmoid(self.obj_tails_linear(encoded_text))
        return pred_obj_heads, pred_obj_tails

    def forward(self, input_ids, mask, sub_head2tail, sub_len):
        """

        :param token_ids:[batch size, seq len]
        :param mask:[batch size, seq len]
        :param sub_head:[batch size, seq len]
        :param sub_tail:[batch size, seq len]
        :return:
        """
         #mat1 and mat2 shapes cannot be multiplied (324x21128 and 768x1)
        encoded_text = self.get_encoded_text(input_ids, mask)
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        sub_head2tail = sub_head2tail.unsqueeze(1)  # [[batch size,1, seq len]]
        pred_obj_heads, pre_obj_tails = self.get_objs_for_specific_sub(sub_head2tail, sub_len, encoded_text)

        return {
            "pred_sub_heads": pred_sub_heads,
            "pred_sub_tails": pred_sub_tails,
            "pred_obj_heads": pred_obj_heads,
            "pred_obj_tails": pre_obj_tails,
            'mask': mask
        }

    def compute_loss(self, pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails, mask, sub_heads,
                     sub_tails, obj_heads, obj_tails):
        rel_count = obj_heads.shape[-1]
        rel_mask = mask.unsqueeze(-1).repeat(1, 1, rel_count)
        loss_1 = self.loss_fun(pred_sub_heads, sub_heads, mask)
        loss_2 = self.loss_fun(pred_sub_tails, sub_tails, mask)
        loss_3 = self.loss_fun(pred_obj_heads, obj_heads, rel_mask)
        loss_4 = self.loss_fun(pred_obj_tails, obj_tails, rel_mask)
        return loss_1 + loss_2 + loss_3 + loss_4

    def loss_fun(self, logist, label, mask):
        count = torch.sum(mask)
        logist = logist.view(-1)
        label = label.view(-1)
        mask = mask.view(-1)
        
        alpha_factor = torch.where(torch.eq(label,1), 1- self.alpha,self.alpha)
        focal_weight = torch.where(torch.eq(label,1),1-logist,logist)
        
        loss = -(torch.log(logist) * label + torch.log(1 - logist) * (1 - label)) * mask
        return torch.sum(focal_weight * loss) / count

In [19]:

def load_model(config):
    device = config.device
    model = CasRel(config)
    model.to(device)

    # prepare optimzier
    param_optimizer = list(model.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]

    optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate, eps=10e-8)
    sheduler = None

    return model, optimizer, sheduler, device

# Model Training

In [20]:
import pandas as pd
from tqdm import tqdm

In [27]:
def train_epoch(model, train_iter, dev_iter, optimizer, batch, best_triple_f1, epoch, save_path):
    for step, (text, triple) in enumerate(train_iter):
        model.train()
        inputs, labels = batch(text, triple)
        logist = model(**inputs)
        loss = model.compute_loss(**logist, **labels)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 500 == 1:
            sub_precision, sub_recall, sub_f1, triple_precision, triple_recall, triple_f1, df = test(model, dev_iter,
                                                                                                     batch)
            if triple_f1 > best_triple_f1:
                best_triple_f1 = triple_f1
                torch.save(model.state_dict(), save_path)
                print(
                    'epoch:{},step:{},sub_precision:{:.4f}, sub_recall:{:.4f}, sub_f1:{:.4f}, triple_precision:{:.4f}, triple_recall:{:.4f}, triple_f1:{:.4f},train loss:{:.4f}'.format(
                        epoch, step, sub_precision, sub_recall, sub_f1, triple_precision, triple_recall, triple_f1,
                        loss.item()))
                print(df)

    return best_triple_f1


def train(model, train_iter, dev_iter, optimizer, batch, config, save_path):
    epochs = config.epochs
    best_triple_f1 = 0
    for epoch in range(epochs):
        best_triple_f1 = train_epoch(model, train_iter, dev_iter, optimizer, batch, best_triple_f1, epoch, save_path)


def test(model, dev_iter, batch):
    model.eval()
    df = pd.DataFrame(columns=['TP', 'PRED', "REAL", 'p', 'r', 'f1'], index=['sub', 'triple'])
    df.fillna(0, inplace=True)

    for text, triple in tqdm(dev_iter):
        inputs, labels = batch(text, triple)
        # print("text: ", text)
        # print("triple: ", triple)
        # print("labels: ", labels)
        logist = model(**inputs)
        
        pred_sub_heads = convert_score_to_zero_one(logist['pred_sub_heads'])
        pred_sub_tails = convert_score_to_zero_one(logist['pred_sub_tails'])

        sub_heads = convert_score_to_zero_one(labels['sub_heads'])
        #print("labels['sub_heads']: ", labels['sub_heads'])
        sub_tails = convert_score_to_zero_one(labels['sub_tails'])
        batch_size = inputs['input_ids'].shape[0]

        obj_heads = convert_score_to_zero_one(labels['obj_heads'])
        obj_tails = convert_score_to_zero_one(labels['obj_tails'])
        pred_obj_heads = convert_score_to_zero_one(logist['pred_obj_heads'])
        pred_obj_tails = convert_score_to_zero_one(logist['pred_obj_tails'])

        for batch_index in range(batch_size):
            pred_subs = extract_sub(pred_sub_heads[batch_index].squeeze(), pred_sub_tails[batch_index].squeeze())
            #print("pred_subs: ", pred_subs)
            true_subs = extract_sub(sub_heads[batch_index].squeeze(), sub_tails[batch_index].squeeze())
            #print("true_subs: ", true_subs)
            pred_ojbs = extract_obj_and_rel(pred_obj_heads[batch_index], pred_obj_tails[batch_index])
            #print("pred_ojbs: ", pred_ojbs)
            true_objs = extract_obj_and_rel(obj_heads[batch_index], obj_tails[batch_index])
            #print("true_objs: ", true_objs)

            df['PRED']['sub'] += len(pred_subs)
            df['REAL']['sub'] += len(true_subs)
            for true_sub in true_subs:
                if true_sub in pred_subs:
                    df['TP']['sub'] += 1

            df['PRED']['triple'] += len(pred_ojbs)
            df['REAL']['triple'] += len(true_objs)
            for true_obj in true_objs:
                if true_obj in pred_ojbs:
                    df['TP']['triple'] += 1

    df.loc['sub','p'] = df['TP']['sub'] / (df['PRED']['sub'] + 1e-9)
    df.loc['sub','r'] = df['TP']['sub'] / (df['REAL']['sub'] + 1e-9)
    df.loc['sub','f1'] = 2 * df['p']['sub'] * df['r']['sub'] / (df['p']['sub'] + df['r']['sub'] + 1e-9)
    
    sub_precision = df['TP']['sub'] / (df['PRED']['sub'] + 1e-9)
    sub_recall = df['TP']['sub'] / (df['REAL']['sub'] + 1e-9)
    sub_f1 = 2 * sub_precision * sub_recall  / (sub_precision + sub_recall  + 1e-9)

    df.loc['triple','p'] = df['TP']['triple'] / (df['PRED']['triple'] + 1e-9)
    df.loc['triple','r'] = df['TP']['triple'] / (df['REAL']['triple'] + 1e-9)
    df.loc['triple','f1'] = 2 * df['p']['triple'] * df['r']['triple'] / (
            df['p']['triple'] + df['r']['triple'] + 1e-9)
    
    
    triple_precision = df['TP']['triple'] / (df['PRED']['triple'] + 1e-9)
    triple_recall = df['TP']['triple'] / (df['REAL']['triple'] + 1e-9)
    triple_f1 = 2 * triple_precision * triple_recall / (
            triple_precision + triple_recall + 1e-9)

    return sub_precision, sub_recall,sub_f1, triple_precision, triple_recall, triple_f1, df


def extract_sub(pred_sub_heads, pred_sub_tails):
    subs = []
    heads = torch.arange(0, len(pred_sub_heads))[pred_sub_heads == 1].to(torch.device('cuda'))
    tails = torch.arange(0, len(pred_sub_tails))[pred_sub_tails == 1].to(torch.device('cuda'))

    for head, tail in zip(heads, tails):
        if tail >= head:
            subs.append((head.item(), tail.item()))
    return subs


def extract_obj_and_rel(obj_heads, obj_tails):
    obj_heads = obj_heads.T
    obj_tails = obj_tails.T
    rel_count = obj_heads.shape[0]
    obj_and_rels = []  # [(rel_index,strart_index,end_index),(rel_index,strart_index,end_index)]

    for rel_index in range(rel_count):
        obj_head = obj_heads[rel_index]
        obj_tail = obj_tails[rel_index]

        objs = extract_sub(obj_head, obj_tail)
        if objs:
            for obj in objs:
                start_index, end_index = obj
                obj_and_rels.append((rel_index, start_index, end_index))
    return obj_and_rels


def convert_score_to_zero_one(tensor):
    tensor[tensor>=0.5] = 1
    tensor[tensor<0.5] = 0
    return tensor

## Bert CasRel

In [None]:
config = Config("bert-base-chinese",8)
train_data = MyDataset(config.train_file)
bert_base_model, optimizer, sheduler, device = load_model(config)
train_iter, dev_iter, test_iter = create_data_iter(config)
batch = Batch(config)

Downloading:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/624 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
train(bert_base_model, train_iter, dev_iter, optimizer, batch, config, 'base_best_f1.pth')
torch.save(bert_base_model, 'base_model')

100%|██████████| 700/700 [02:28<00:00,  4.71it/s]


epoch:0,step:1,sub_precision:0.0005, sub_recall:0.0066, sub_f1:0.0010, triple_precision:0.0000, triple_recall:0.0047, triple_f1:0.0001,train loss:1.3757
        TP     PRED   REAL         p         r        f1
sub     78   142565  11759  0.000547  0.006633  0.001011
triple  77  1929978  16230  0.000040  0.004744  0.000079


100%|██████████| 700/700 [01:52<00:00,  6.20it/s]
100%|██████████| 700/700 [01:53<00:00,  6.15it/s]


epoch:0,step:1001,sub_precision:0.8917, sub_recall:0.8814, sub_f1:0.8865, triple_precision:0.8314, triple_recall:0.1479, triple_f1:0.2511,train loss:0.0427
           TP   PRED   REAL         p         r        f1
sub     10364  11623  11759  0.891680  0.881367  0.886494
triple   2401   2888  16238  0.831371  0.147863  0.251072


100%|██████████| 700/700 [01:55<00:00,  6.07it/s]


epoch:0,step:1501,sub_precision:0.8723, sub_recall:0.8911, sub_f1:0.8816, triple_precision:0.7107, triple_recall:0.4174, triple_f1:0.5259,train loss:0.0170
           TP   PRED   REAL         p         r        f1
sub     10478  12012  11759  0.872294  0.891062  0.881578
triple   6776   9534  16233  0.710720  0.417421  0.525944


100%|██████████| 700/700 [01:56<00:00,  6.02it/s]


epoch:0,step:2001,sub_precision:0.8945, sub_recall:0.8803, sub_f1:0.8873, triple_precision:0.6799, triple_recall:0.5510, triple_f1:0.6087,train loss:0.0152
           TP   PRED   REAL         p         r        f1
sub     10351  11572  11759  0.894487  0.880262  0.887317
triple   8947  13159  16237  0.679915  0.551025  0.608722


100%|██████████| 700/700 [01:55<00:00,  6.04it/s]


epoch:0,step:2501,sub_precision:0.9247, sub_recall:0.8787, sub_f1:0.9011, triple_precision:0.7267, triple_recall:0.5384, triple_f1:0.6185,train loss:0.0064
           TP   PRED   REAL         p         r        f1
sub     10333  11175  11759  0.924653  0.878731  0.901108
triple   8736  12022  16227  0.726668  0.538362  0.618500


100%|██████████| 700/700 [01:56<00:00,  6.00it/s]


epoch:0,step:3001,sub_precision:0.8309, sub_recall:0.9025, sub_f1:0.8652, triple_precision:0.6608, triple_recall:0.6864, triple_f1:0.6733,train loss:0.0083
           TP   PRED   REAL         p         r        f1
sub     10612  12771  11759  0.830945  0.902458  0.865226
triple  11141  16860  16232  0.660795  0.686360  0.673335


100%|██████████| 700/700 [01:56<00:00,  6.02it/s]


epoch:1,step:1,sub_precision:0.8966, sub_recall:0.8794, sub_f1:0.8879, triple_precision:0.7072, triple_recall:0.6445, triple_f1:0.6744,train loss:0.0115
           TP   PRED   REAL         p         r        f1
sub     10341  11534  11759  0.896567  0.879412  0.887906
triple  10466  14799  16240  0.707210  0.644458  0.674377


100%|██████████| 700/700 [01:56<00:00,  5.99it/s]


epoch:1,step:501,sub_precision:0.9109, sub_recall:0.9020, sub_f1:0.9065, triple_precision:0.6687, triple_recall:0.7143, triple_f1:0.6907,train loss:0.0157
           TP   PRED   REAL         p         r        f1
sub     10607  11644  11759  0.910941  0.902032  0.906465
triple  11593  17337  16230  0.668685  0.714295  0.690738


100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]


epoch:1,step:1501,sub_precision:0.9126, sub_recall:0.9099, sub_f1:0.9113, triple_precision:0.6988, triple_recall:0.6935, triple_f1:0.6961,train loss:0.0083
           TP   PRED   REAL         p         r        f1
sub     10700  11725  11759  0.912580  0.909941  0.911259
triple  11257  16110  16231  0.698759  0.693549  0.696144


100%|██████████| 700/700 [01:56<00:00,  6.00it/s]
100%|██████████| 700/700 [01:56<00:00,  6.00it/s]


epoch:1,step:2501,sub_precision:0.8932, sub_recall:0.9192, sub_f1:0.9060, triple_precision:0.6949, triple_recall:0.7361, triple_f1:0.7149,train loss:0.0095
           TP   PRED   REAL         p         r        f1
sub     10809  12101  11759  0.893232  0.919211  0.906035
triple  11946  17192  16229  0.694858  0.736090  0.714880


100%|██████████| 700/700 [01:56<00:00,  5.99it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:56<00:00,  5.99it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:56<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]


epoch:3,step:1001,sub_precision:0.9316, sub_recall:0.9412, sub_f1:0.9364, triple_precision:0.6459, triple_recall:0.8059, triple_f1:0.7171,train loss:0.0047
           TP   PRED   REAL         p         r        f1
sub     11068  11880  11759  0.931650  0.941236  0.936419
triple  13072  20237  16220  0.645946  0.805919  0.717119


100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:56<00:00,  5.99it/s]


epoch:3,step:2001,sub_precision:0.8988, sub_recall:0.9445, sub_f1:0.9211, triple_precision:0.6533, triple_recall:0.8043, triple_f1:0.7210,train loss:0.0091
           TP   PRED   REAL         p         r        f1
sub     11106  12356  11759  0.898835  0.944468  0.921086
triple  13051  19978  16227  0.653269  0.804277  0.720950


100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:56<00:00,  6.00it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]


epoch:4,step:501,sub_precision:0.9287, sub_recall:0.9520, sub_f1:0.9402, triple_precision:0.6584, triple_recall:0.8011, triple_f1:0.7228,train loss:0.0087
           TP   PRED   REAL         p         r        f1
sub     11195  12054  11759  0.928737  0.952037  0.940243
triple  13004  19750  16232  0.658430  0.801134  0.722806


100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]


epoch:4,step:1501,sub_precision:0.9054, sub_recall:0.9656, sub_f1:0.9346, triple_precision:0.6696, triple_recall:0.8101, triple_f1:0.7331,train loss:0.0033
           TP   PRED   REAL         p         r        f1
sub     11355  12541  11759  0.905430  0.965643  0.934568
triple  13144  19631  16226  0.669553  0.810058  0.733134


100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:56<00:00,  5.99it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:58<00:00,  5.93it/s]
100%|██████████| 700/700 [01:57<00:00,  5.95it/s]
100%|██████████| 700/700 [01:57<00:00,  5.97it/s]
100%|██████████| 700/700 [01:57<00:00,  5.98it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]
100%|██████████| 700/700 [01:57<00:00,  5.96it/s]


In [None]:
bert_base_model, optimizer, sheduler, device = load_model(config)
bert_base_model.load_state_dict(torch.load('/content/drive/MyDrive/Final Project/base_best_f1.pth'))

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

### Test Result

In [None]:
sub_precision, sub_recall,sub_f1, triple_precision, triple_recall, triple_f1, df = test(bert_base_model, dev_iter, batch)

100%|██████████| 700/700 [02:40<00:00,  4.37it/s]


In [None]:
df

Unnamed: 0,TP,PRED,REAL,p,r,f1
sub,11174,11709,11759,0.954309,0.950251,0.952275
triple,13108,19446,16247,0.674072,0.806795,0.734486


## Bert Whole Word Mask CasRel

In [28]:
config = Config("hfl/chinese-bert-wwm",8)
train_data = MyDataset(config.train_file)
bert_wwn_model, optimizer, sheduler, device = load_model(config)
train_iter, dev_iter, test_iter = create_data_iter(config)
batch = Batch(config)
train(bert_wwn_model, train_iter, dev_iter, optimizer, batch, config, 'wwm_best_f1.pth')
torch.save(bert_wwn_model, 'wwm_model')

Some weights of the model checkpoint at hfl/chinese-bert-wwm were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Test Result

In [None]:
sub_precision, sub_recall,sub_f1, triple_precision, triple_recall, triple_f1, df = test(bert_wwn_model, dev_iter, batch)

100%|██████████| 700/700 [02:01<00:00,  5.78it/s]


In [None]:
df

Unnamed: 0,TP,PRED,REAL,p,r,f1
sub,11545,12097,11759,0.954369,0.981801,0.967891
triple,12922,18468,16224,0.699697,0.796474,0.744956
