# 中文三元组联合抽取

## 介绍

在这个notebook中我们将使用openue库代码来训练我们自己的三元组联合抽取，使用的基础模型是`bert-base-chinese`，训练分为两步，首先训练关系分类模型，其次训练实体抽取模型。之后联合验证。

## 数据集

在这个数据集中，使用ske数据集，具体例子如下。我们使用代码来读取`train.json`来分析一下数据。

In [None]:
import json
with open("../dataset/ske/train.json", "r") as file:
    for line in file.readlines():
        example = json.loads(line)
        break
for k, v in example.items():
    print(f"{k}: {v}")

# 训练

## `seq model`关系分类模型

如我们的模型图所示，我们需要先训练一个关系分类模型，识别出句子中实体的属性。

<div  align="center">
    <img src="./imgs/architecture.png" width = "600" height = "400" alt="图片名称" align=center />
</div>


In [None]:
import argparse
import importlib

import numpy as np
import torch
import pytorch_lightning as pl
import openue.lit_models as lit_models
import yaml
import time
from transformers import AutoConfig
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# 设置一些参数和动态调用包
def _import_class(module_and_class_name: str) -> type:
    module_name, class_name = module_and_class_name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
	
    return class_


def _setup_parser():
    """Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
    parser = argparse.ArgumentParser(add_help=False)

    # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
    # trainer_parser = pl.Trainer.add_argparse_args(parser)
    # trainer_parser._action_groups[1].title = "Trainer Args"  # pylint: disable=protected-access
    # parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])

    # Basic arguments
    parser.add_argument("--wandb", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--litmodel_class", type=str, default="SEQLitModel")
    parser.add_argument("--data_class", type=str, default="REDataset")
    parser.add_argument("--model_class", type=str, default="BertForRelationClassification")
    parser.add_argument("--load_checkpoint", type=str, default=None)

    # Get the data and model classes, so that we can add their specific arguments
    temp_args, _ = parser.parse_known_args()
    data_class = _import_class(f"openue.data.{temp_args.data_class}")
    model_class = _import_class(f"openue.models.{temp_args.model_class}")

    # Get data, model, and LitModel specific arguments
    data_group = parser.add_argument_group("Data Args")
    data_class.add_to_argparse(data_group)

    model_group = parser.add_argument_group("Model Args")
    model_class.add_to_argparse(model_group)

    lit_model_group = parser.add_argument_group("LitModel Args")
    lit_models.BaseLitModel.add_to_argparse(lit_model_group)

    parser.add_argument("--help", "-h", action="help")
    return parser

def _save_model(litmodel, tokenizer, path):
    os.system(f"mkdir -p {path}")
    litmodel.model.save_pretrained(path)
    tokenizer.save_pretrained(path)

In [None]:
parser = _setup_parser()
args = parser.parse_args(args=[])

path = "./config/run_seq.yaml"
# 使用config.yaml 载入超参设置
opt = yaml.load(open(path))
args.__dict__.update(opt)



np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_class = _import_class(f"openue.data.{args.data_class}")
model_class = _import_class(f"openue.models.{args.model_class}")
litmodel_class = _import_class(f"openue.lit_models.{args.litmodel_class}")

data = data_class(args)

lit_model = litmodel_class(args=args, data_config=data.get_config())



logger = pl.loggers.TensorBoardLogger("training/logs")
if args.wandb:
    logger = pl.loggers.WandbLogger(project="openue demo")
    logger.log_hyperparams(vars(args))

early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=5)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max",
    filename='{epoch}-{Eval/f1:.2f}',
    dirpath="output",
    save_weights_only=True
)


callbacks = [early_callback, model_checkpoint]

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")

trainer.fit(lit_model, datamodule=data)

trainer.test(lit_model, datamodule=data)


_save_model(litmodel=lit_model, tokenizer=data.tokenizer, path="seq_model")

In [None]:
parser = _setup_parser()
args = parser.parse_args(args=[])

path = "./config/run_ner.yaml"
# 使用config.yaml 载入超参设置
opt = yaml.load(open(path))
args.__dict__.update(opt)



np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_class = _import_class(f"openue.data.{args.data_class}")
model_class = _import_class(f"openue.models.{args.model_class}")
litmodel_class = _import_class(f"openue.lit_models.{args.litmodel_class}")

data = data_class(args)

lit_model = litmodel_class(args=args, data_config=data.get_config())



logger = pl.loggers.TensorBoardLogger("training/logs")
if args.wandb:
    logger = pl.loggers.WandbLogger(project="openue demo")
    logger.log_hyperparams(vars(args))

early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=5)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max",
    filename='{epoch}-{Eval/f1:.2f}',
    dirpath="output",
    save_weights_only=True
)


callbacks = [early_callback, model_checkpoint]

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")

trainer.fit(lit_model, datamodule=data)

trainer.test(lit_model, datamodule=data)
_save_model(litmodel=lit_model, tokenizer=data.tokenizer, path="ner_model")

In [None]:
type(args.model_name_or_path)

In [None]:
parser = _setup_parser()
args = parser.parse_args(args=[])

path = "./config/run_inter.yaml"
# 使用config.yaml 载入超参设置
opt = yaml.load(open(path))
args.__dict__.update(opt)



np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_class = _import_class(f"openue.data.{args.data_class}")
model_class = _import_class(f"openue.models.{args.model_class}")
litmodel_class = _import_class(f"openue.lit_models.{args.litmodel_class}")

data = data_class(args)

lit_model = litmodel_class(args=args, data_config=data.get_config())



logger = pl.loggers.TensorBoardLogger("training/logs")
if args.wandb:
    logger = pl.loggers.WandbLogger(project="openue demo")
    logger.log_hyperparams(vars(args))

early_callback = pl.callbacks.EarlyStopping(monitor="Eval/f1", mode="max", patience=5)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/f1", mode="max",
    filename='{epoch}-{Eval/f1:.2f}',
    dirpath="output",
    save_weights_only=True
)


callbacks = [early_callback, model_checkpoint]

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")

trainer.fit(lit_model, datamodule=data)

trainer.test(lit_model, datamodule=data)
_save_model(litmodel=lit_model, tokenizer=data.tokenizer, path="ner_model")