# Setup

## Libraries and parameters

In [1]:
import json
import argparse
import os
import sys
import random
import logging
import time
from tqdm import tqdm
import numpy as np

from shared.data_structures import Dataset
from shared.const import task_ner_labels, get_labelmap
from entity.utils import convert_dataset_to_samples, batchify, NpEncoder
from entity.models import EntityModel

from transformers import AdamW, get_linear_schedule_with_warmup
import torch

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger("root")

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
def setseed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def check_and_mkdirs(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

In [3]:
class ArgClass:
    def __init__(self):
        self.task = "scierc"
        self.data_dir = (
            "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/data/scierc_data/processed_data/json"
        )
        self.output_dir = "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/PURE/models/ent_scierc_01"

        self.use_albert = False  # albert requires different class from the transformers library
        self.model = "allenai/scibert_scivocab_uncased"
        # if args.bert_model_dir is not None, bert_model_name = str(args.bert_model_dir) + "/", elif bert_model_name = args.model
        self.bert_model_dir = None

        self.do_train = True
        self.do_eval = True
        self.eval_test = True
        self.dev_pred_filename = "ent_pred_dev.json"
        self.test_pred_filename = "ent_pred_test.json"

        self.seed = 0
        self.context_window = 300
        self.max_span_length = 8
        self.train_batch_size = 16
        self.eval_batch_size = 32
        self.learning_rate = 1e-05
        self.task_learning_rate = 5e-04
        self.warmup_proportion = 0.1
        self.num_epoch = 100
        self.print_loss_step = 100
        self.eval_per_epoch = 1
        self.bertadam = False
        self.train_shuffle = False

    def __repr__(self):
        return json.dumps(vars(self))

    def __str__(self):
        return self.__repr__()


args = ArgClass()
print(args)

args.train_data = os.path.join(args.data_dir, "train.json")
args.dev_data = os.path.join(args.data_dir, "dev.json")
args.test_data = os.path.join(args.data_dir, "test.json")

if "albert" in args.model:
    logger.info("Use Albert: %s" % args.model)
    args.use_albert = True

setseed(args.seed)

check_and_mkdirs(args.output_dir)

if args.do_train:
    logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), "w"))
else:
    logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), "w"))

logger.info(sys.argv)
logger.info(args)

07/15/2023 04:15:07 - INFO - root - ['/home/yuxiangliao/anaconda3/envs/cxr_graph/lib/python3.10/site-packages/ipykernel_launcher.py', '--ip=127.0.0.1', '--stdin=9008', '--control=9006', '--hb=9005', '--Session.signature_scheme="hmac-sha256"', '--Session.key=b"16d9b86d-5d24-4fed-896e-3b7d0a1e8f1d"', '--shell=9007', '--transport="tcp"', '--iopub=9009', '--f=/home/yuxiangliao/.local/share/jupyter/runtime/kernel-v2-91488E1LCHlPUKEI7.json']
07/15/2023 04:15:07 - INFO - root - {"task": "scierc", "data_dir": "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/data/scierc_data/processed_data/json", "output_dir": "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/PURE/models/ent_scierc_01", "use_albert": false, "model": "allenai/scibert_scivocab_uncased", "bert_model_dir": null, "do_train": true, "do_eval": true, "eval_test": true, "dev_pred_filename": "ent_pred_dev.json", "test_pred_filename": "ent_pred_test.json", "seed": 0, "context_window": 300, "max_spa

{"task": "scierc", "data_dir": "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/data/scierc_data/processed_data/json", "output_dir": "/home/yuxiangliao/PhD/workspace/VSCode_workspace/cxr_graph/resources/PURE/models/ent_scierc_01", "use_albert": false, "model": "allenai/scibert_scivocab_uncased", "bert_model_dir": null, "do_train": true, "do_eval": true, "eval_test": true, "dev_pred_filename": "ent_pred_dev.json", "test_pred_filename": "ent_pred_test.json", "seed": 0, "context_window": 300, "max_span_length": 8, "train_batch_size": 16, "eval_batch_size": 32, "learning_rate": 1e-05, "task_learning_rate": 0.0005, "warmup_proportion": 0.1, "num_epoch": 100, "print_loss_step": 100, "eval_per_epoch": 1, "bertadam": false, "train_shuffle": false}


## Model and label

In [4]:
ner_label2id, ner_id2label = get_labelmap(task_ner_labels[args.task])
num_ner_labels = len(task_ner_labels[args.task]) + 1  # including null

model = EntityModel(args, num_ner_labels=num_ner_labels)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForEntity: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForEntity from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForEntity from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForEntity were not initialized from the model checkpoint at allenai/scibert_

# Model training

## Train and dev data

In [5]:
train_data = Dataset(args.train_data)
train_samples, train_ner = convert_dataset_to_samples(
    train_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window
)
train_batches = batchify(train_samples, args.train_batch_size)

07/15/2023 04:15:11 - INFO - root - # Overlap: 0
07/15/2023 04:15:11 - INFO - root - Extracted 1861 samples from 350 documents, with 5598 NER labels, 140.335 avg input length, 300 max length
07/15/2023 04:15:11 - INFO - root - Max Length: 101, max NER: 13


In [6]:
dev_data = Dataset(args.dev_data)
dev_samples, dev_ner = convert_dataset_to_samples(
    dev_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window
)
dev_batches = batchify(dev_samples, args.eval_batch_size)

07/15/2023 04:15:11 - INFO - root - # Overlap: 0
07/15/2023 04:15:11 - INFO - root - Extracted 275 samples from 50 documents, with 811 NER labels, 138.455 avg input length, 226 max length
07/15/2023 04:15:11 - INFO - root - Max Length: 68, max NER: 11


## Training

In [7]:
def evaluate(model, batches, tot_gold):
    """
    Evaluate the entity model
    """
    logger.info('Evaluating...')
    c_time = time.time()
    cor = 0
    tot_pred = 0
    l_cor = 0
    l_tot = 0

    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        for sample, preds in zip(batches[i], pred_ner):
            for gold, pred in zip(sample['spans_label'], preds):
                l_tot += 1
                if pred == gold:
                    l_cor += 1
                if pred != 0 and gold != 0 and pred == gold:
                    cor += 1
                if pred != 0:
                    tot_pred += 1

    acc = l_cor / l_tot
    logger.info('Accuracy: %5f'%acc)
    logger.info('Cor: %d, Pred TOT: %d, Gold TOT: %d'%(cor, tot_pred, tot_gold))
    p = cor / tot_pred if cor > 0 else 0.0
    r = cor / tot_gold if cor > 0 else 0.0
    f1 = 2 * (p * r) / (p + r) if cor > 0 else 0.0
    logger.info('P: %.5f, R: %.5f, F1: %.5f'%(p, r, f1))
    logger.info('Used time: %f'%(time.time()-c_time))
    return f1

def save_model(model, args):
    """
    Save the model to the output directory
    """
    logger.info('Saving model to %s...'%(args.output_dir))
    model_to_save = model.bert_model.module if hasattr(model.bert_model, 'module') else model.bert_model
    model_to_save.save_pretrained(args.output_dir)
    model.tokenizer.save_pretrained(args.output_dir)

def output_ner_predictions(model, batches, dataset, output_file):
    """
    Save the prediction as a json file
    """
    ner_result = {}
    span_hidden_table = {}
    tot_pred_ett = 0
    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        for sample, preds in zip(batches[i], pred_ner):
            off = sample['sent_start_in_doc'] - sample['sent_start']
            k = sample['doc_key'] + '-' + str(sample['sentence_ix'])
            ner_result[k] = []
            for span, pred in zip(sample['spans'], preds):
                span_id = '%s::%d::(%d,%d)'%(sample['doc_key'], sample['sentence_ix'], span[0]+off, span[1]+off)
                if pred == 0:
                    continue
                ner_result[k].append([span[0]+off, span[1]+off, ner_id2label[pred]])
            tot_pred_ett += len(ner_result[k])

    logger.info('Total pred entities: %d'%tot_pred_ett)

    js = dataset.js
    for i, doc in enumerate(js):
        doc["predicted_ner"] = []
        doc["predicted_relations"] = []
        for j in range(len(doc["sentences"])):
            k = doc['doc_key'] + '-' + str(j)
            if k in ner_result:
                doc["predicted_ner"].append(ner_result[k])
            else:
                logger.info('%s not in NER results!'%k)
                doc["predicted_ner"].append([])
            
            doc["predicted_relations"].append([])

        js[i] = doc

    logger.info('Output predictions to %s..'%(output_file))
    with open(output_file, 'w') as f:
        f.write('\n'.join(json.dumps(doc, cls=NpEncoder) for doc in js))

In [8]:
best_result = 0.0

# Returns an iterator over module parameters,
# yielding both the name of the parameter as well as the parameter itself.
param_optimizer = list(model.bert_model.named_parameters())
optimizer_grouped_parameters = [
    {"params": [p for n, p in param_optimizer if "bert" in n]},
    {"params": [p for n, p in param_optimizer if "bert" not in n], "lr": args.task_learning_rate},
]

# Creating the optimizer and defining hyperparameters
# 文章中提到，原优化器adam它的数学公式中是带bias-correct，
# 而在官方的bert模型中，实现的优化器bertadam是不带bias-correction的。
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=not(args.bertadam))

# Define a scheduler and associate it with an optimiser
t_total = len(train_batches) * args.num_epoch
# 模型迭代前期用较大的lr进行warmup,后期随着迭代，用较小的lr。
# Transformer在训练的初始阶段，输出层附近的期望梯度非常大，
# warmup可以避免前向FC层的不稳定的剧烈改变，所以没有warm-up的话模型优化过程就会非常不稳定
scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*args.warmup_proportion), t_total)

tr_loss = 0
tr_examples = 0
global_step = 0
eval_step = len(train_batches) // args.eval_per_epoch



然后，在每次训练迭代中，你需要执行以下步骤

1. 将输入数据和标签数据传递给模型，计算输出和损失：
   1. 前向传播 outputs = model(inputs)
   2. 计算损失 loss = criterion(outputs, labels)
2. 清零优化器的梯度缓存：optimizer.zero_grad()
3. 反向传播，计算参数的梯度：loss.backward()
4. 更新模型参数：optimizer.step()

In [9]:
if args.do_train:
    for epoch_id in range(args.num_epoch):
        logger.info('Training epoch = %d', epoch_id)
        if args.train_shuffle:
            random.shuffle(train_batches)
        for i in tqdm(range(len(train_batches))):
            # ner_loss, ner_llh = torch.Size([16, 652, 7]) batch_size,span_num,label_num
            output_dict = model.run_batch(train_batches[i], training=True)
            loss = output_dict['ner_loss']
            loss.backward()
            
            tr_loss += loss.item() # total loss for this epoch
            tr_examples += len(train_batches[i]) # total trained examples for this epoch
            global_step += 1
            
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            if global_step % args.print_loss_step == 0:
                logger.info('Epoch=%d, iter=%d, loss=%.5f'%(epoch_id, i, tr_loss / tr_examples))
                tr_loss = 0
                tr_examples = 0

            if global_step % eval_step == 0:
                f1 = evaluate(model, dev_batches, dev_ner)
                if f1 > best_result:
                    best_result = f1
                    logger.info('!!! Best valid (epoch=%d): %.2f' % (epoch_id, f1*100))
                    save_model(model, args)

07/15/2023 04:15:14 - INFO - root - Training epoch = 0
 85%|████████▍ | 99/117 [07:12<01:19,  4.40s/it]07/15/2023 04:22:29 - INFO - root - Epoch=0, iter=99, loss=301.07064
 99%|█████████▉| 116/117 [08:37<00:05,  5.08s/it]07/15/2023 04:23:52 - INFO - root - Evaluating...
07/15/2023 04:24:17 - INFO - root - Accuracy: 0.981875
07/15/2023 04:24:17 - INFO - root - Cor: 0, Pred TOT: 0, Gold TOT: 811
07/15/2023 04:24:17 - INFO - root - P: 0.00000, R: 0.00000, F1: 0.00000
07/15/2023 04:24:17 - INFO - root - Used time: 25.157210
100%|██████████| 117/117 [09:03<00:00,  4.65s/it]
07/15/2023 04:24:17 - INFO - root - Training epoch = 1
 70%|███████   | 82/117 [05:56<02:36,  4.47s/it]07/15/2023 04:30:19 - INFO - root - Epoch=1, iter=82, loss=37.46203
 99%|█████████▉| 116/117 [08:40<00:05,  5.10s/it]07/15/2023 04:32:59 - INFO - root - Evaluating...
07/15/2023 04:33:24 - INFO - root - Accuracy: 0.981875
07/15/2023 04:33:24 - INFO - root - Cor: 0, Pred TOT: 0, Gold TOT: 811
07/15/2023 04:33:24 - INFO -

# Model inference

In [None]:
if args.do_eval:
    args.bert_model_dir = args.output_dir
    model = EntityModel(args, num_ner_labels=num_ner_labels)
    if args.eval_test:
        test_data = Dataset(args.test_data)
        prediction_file = os.path.join(args.output_dir, args.test_pred_filename)
    else:
        test_data = Dataset(args.dev_data)
        prediction_file = os.path.join(args.output_dir, args.dev_pred_filename)
    test_samples, test_ner = convert_dataset_to_samples(test_data, args.max_span_length, ner_label2id=ner_label2id, context_window=args.context_window)
    test_batches = batchify(test_samples, args.eval_batch_size)
    evaluate(model, test_batches, test_ner)
    output_ner_predictions(model, test_batches, test_data, output_file=prediction_file)