In [1]:
from core.configure import get_recall_config
from core.models import SiameseClassificationModel
from core.datasets.recall import get_recall_datasets
from core.evaluate import compute_metrics
from transformers import AutoConfig,TrainerCallback,Trainer
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args,training_args = get_recall_config()
config = AutoConfig.from_pretrained(args.plm_name)

In [3]:
recall_datasets,standard_name_list,tokenizer = get_recall_datasets()

Map (num_proc=8): 100%|██████████| 14828/14828 [00:16<00:00, 907.92 examples/s] 
Map (num_proc=8): 100%|██████████| 4868/4868 [00:00<00:00, 16977.67 examples/s]


In [4]:
model = SiameseClassificationModel(config=config, args=args, code=standard_name_list)

100%|██████████| 8/8 [00:22<00:00,  2.78s/it]


In [5]:
class ModifyDatasetCallback(TrainerCallback):
    def __init__(self, trainer, train_list, args) -> None:
        super().__init__()
        self._trainer = trainer
        self.train_list = train_list
        self.args = args
    
    def gen_training_dataset(self, train_list, neg_num=args.neg_num):
        train_term = [i[0] for i in train_list]
        train_embed = self._trainer.model.get_term_embedding(train_term)
        index_list, _ = self._trainer.model.faiss_distance(train_embed,self._trainer.model.code_embedding)
        
        input_data = []
        label_list = []
        for t,idx in zip(train_list, index_list):
            item_neg_list = [self._trainer.model.code[i] for i in idx]
            item_neg_list = [i for i in item_neg_list if i not in t[1]][:neg_num]
            
            for pos in t[1]:
                input_data.append([t[0],pos]+item_neg_list)
                label_list.append(len(t[1])-1)
        
        df = pd.DataFrame({'input':input_data,'labels':label_list})
        dataset = datasets.Dataset.from_pandas(df)
        encoded_dataset = dataset.map(preprocess_function,num_proc=8)
        return encoded_dataset

    def on_epoch_end(self, args, state, control, **kwargs):
        self._trainer.model.update_code_embedding()
        self._trainer.train_dataset = self.gen_training_dataset(self.train_list)
        return control

class SavePretrainedCallback(TrainerCallback):
    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer

    def on_epoch_end(self, args, state, control, **kwargs):
        self._trainer.save_predtrained
        return control

In [7]:
from transformers import Trainer

In [11]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=recall_datasets["train"],
    eval_dataset=recall_datasets["dev"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
a = ModifyDatasetCallback(trainer,recall_datasets["train"],args)
trainer.add_callback(a)

Detected kernel version 4.19.91, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
trainer.train()

In [None]:
def gen_cand_dataset(data_list,model=model,train_flag=False):
    train_embed = model.get_term_embedding([i[0] for i in data_list])
    index_list, _ = model.faiss_distance(train_embed, model.code_embedding)

    cand_list = []
    for t,idx in zip(data_list, index_list):
        mention = t[0]
        o_mention = t[1][0]
        stds = t[1][1]
        cand = [model.code[i] for i in idx[:10]]

        if train_flag:
            golden = t[1]
            golden_add = [i for i in golden if i not in cand]
            cand = golden_add + cand
            cand = cand[:20]

        tmp = []
        for std in stds.split("##"):
            if std in cand:
                tmp.append(std)  
        cand_list.append({
            "text":mention,
            "candidates":cand,
            "normalized_result":tmp,
            "origin_mention":o_mention
        })

    return cand_dict

In [None]:
import json

In [None]:
for todo_set in ["train","dev","test"]:
    candidates = gen_cand_dataset(recall_datasets[todo_set],train_flag=True if todo_set=='train' else False)
    with open("data/3_recall_result/"+todo_set+".json", 'w', encoding='utf-8') as f:
        json.dump(candidates,f,indent=2,ensure_ascii=False)