In [None]:
from core.configure import get_rerank_config
from core.models import TypeCLSModel
from core.datasets.selection import get_typecls_datasets
from core.evaluate import evalacc,evalf1,evalNumAcc
import torch
from tqdm import tqdm

In [None]:
args,logger = get_rerank_config()
data_list=get_typecls_datasets()

In [None]:
def save_model(save_path, epoch, best_accuracy, optimizer, model):
    torch.save({'epoch': epoch+1,
                'best_accuracy': best_accuracy,
                'optimizer_dict': optimizer.state_dict(),
                'model_dict': model.state_dict()}, save_path)


def cleanup_state_dict(state_dict):
    new_state = {}
    for k, v in state_dict.items():
        if "module" in k:
            new_name = k[7:]
        else:
            new_name = k
        new_state[new_name] = v
    return new_state


def load_model(save_name, model,optimizer):
    model_data = torch.load(save_name)
#     print(model_data['model_dict'])
    model.load_state_dict(cleanup_state_dict(model_data['model_dict']))
    optimizer.load_state_dict(cleanup_state_dict(model_data['optimizer_dict']))
    best_accuracy = model_data['best_accuracy']
    print(best_accuracy)
    print("model load success")
    return model,optimizer,best_accuracy

In [None]:
def train_n_evaluate(model,loss_fn, optimizer, evaluate, tokenizer, best_f1=0,best_acc = 0):
    step = 0
    best_acc = best_acc
    best_f1 = best_f1
    train_dataset = data_list["train"]
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size)
    # dev_dataset = Dataset(X_dev, y_dev, flag='dev')
    # dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size)
    # test_dataset = Dataset(X_test,y_test,flag='test')
    # test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)
    for epoch in range(args.epoch_nums):
        model.train()
        logger.info('#'*20 + 'Epoch{}'.format(epoch) + '#'*20)
        loop = tqdm(train_dataloader, desc=f'Training Epoch {epoch}')
        epoch_loss = []
        for X1,X2,X3,X4, y in loop:
            # print(y)
            X1_cuda = {}
            for i in X1:
                X1_cuda[i] = X1[i].cuda()
            X2_cuda = {}
            for i in X2:
                X2_cuda[i] = X2[i].cuda()
            X3_cuda = {}
            for i in X3:
                X3_cuda[i] = X3[i].cuda()
            X4_cuda = {}
            for i in X4:
                X4_cuda[i] = X4[i].cuda()
            y_cuda = y.cuda()
            # X_cuda = [x.cuda(5) for x in X]
            # input_ids, attention_mask, position_ids = X

            # # print(X_cuda['input_ids'])
            y_pred = model(X1_cuda,X2_cuda,X3_cuda,X4_cuda).squeeze()
            loss = loss_fn(y_pred, y_cuda.float())
            epoch_loss.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loop.set_postfix(loss=f'Train loss: {loss:.6f}')
            # writer.add_scalar(tag='Train step Loss', scalar_value=loss, global_step=step)
            step += 1
        # writer.add_scalar(tag='Train Loss', scalar_value=torch.tensor(epoch_loss).mean(), global_step=epoch)
            # # print(y_pred,y)
        save_model(args.model_path+"last_"+str(args.version)+".pth",
                   epoch, best_acc, optimizer, model)
        acc, f1 = evaluate(model,tokenizer)
        if acc > best_acc:
            best_acc = acc
            save_model(args.model_path+"best_acc_"+str(args.version)+".pth",
                       epoch, best_acc, optimizer, model)
        if f1 > best_f1:
            best_f1 = f1
            save_model(args.model_path+"best_f1_"+str(args.version)+".pth",
                       epoch, best_f1, optimizer, model)
        # write_result(epoch)

In [None]:
def evaluate(model, tokenizer):
    # gold = json.load(open("/mnt/data/smart_health_02/zhuyansha/data/CHIP-CDN-SR/recall_top20/dev.json","r"))
    test_dataset = data_list["dev"]
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)
    model.eval()
    pred_dict = {}
    with torch.no_grad():
        loop = tqdm(test_dataloader, desc='Evaluating')
        y_preds = []
        y_trues = []
        # tmp_y_pred = []
        # current_item = ""
        Xs = []
        ys = []
        for X1,X2,X3,X4, y in loop:
            X1_cuda = {}
            for i in X1:
                X1_cuda[i] = X1[i].cuda()
            X2_cuda = {}
            for i in X2:
                X2_cuda[i] = X2[i].cuda()
            X3_cuda = {}
            for i in X3:
                X3_cuda[i] = X3[i].cuda()
            X4_cuda = {}
            for i in X4:
                X4_cuda[i] = X4[i].cuda()
            # y_cuda = y.cuda()
            # X_cuda = [x.cuda(5) for x in X]
            # input_ids, attention_mask, position_ids = X

            # # print(X_cuda['input_ids'])
            y_pred = model(X1_cuda,X2_cuda,X3_cuda,X4_cuda).squeeze()
            Xs.extend(y_pred)
            ys.extend(y)

        for X, y in zip(Xs,ys):
            mention = y.split("###")[0]
            candidate = y.split("###")[1]
            # stds = y.split("###")[2]
            if mention not in pred_dict.keys():
                pred_dict[mention] = []
            if X > args.threshold:
                pred_dict[mention].append(candidate)
        # print(pred_dict)
        for item in gold:
            
            tmp = []
            for sr in item["SR"]:
                tmp.extend(pred_dict[sr])
            tmp.extend(pred_dict[item["text"]])
            y_preds.append(list(set(tmp)))
            y_trues.append(item["normalized_result"])
        acc, mul_acc, uni_acc = evalacc(y_preds, y_trues)
        f1, p, r = evalf1(y_preds, y_trues)
        num_acc=evalNumAcc(y_preds, y_trues)
        logger.info('#'*20 + 'acc{}'.format(acc) + '#'*20)
        logger.info('#'*20 + 'mul_acc{}'.format(mul_acc) + '#'*20)
        logger.info('#'*20 + 'uni_acc{}'.format(uni_acc) + '#'*20)
        logger.info('#'*20 + 'f1{}'.format(f1) + '#'*20)
        logger.info('#'*20 + 'p{}'.format(p) + '#'*20)
        logger.info('#'*20 + 'r{}'.format(r) + '#'*20)
        logger.info('#'*20 + 'NumAcc{}'.format(num_acc) + '#'*20)
    return acc, f1

In [None]:
def predict_write(model, tokenizer):
    test_dataset = data_list["test"]
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size)
    model.eval()
    pred_dict = {}
    with torch.no_grad():
        loop = tqdm(test_dataloader, desc='Evaluating')
        y_preds = []
        y_trues = []
        Xs = []
        ys = []
        for X1,X2,X3,X4, y in loop:
            X1_cuda = {}
            for i in X1:
                X1_cuda[i] = X1[i].cuda()
            X2_cuda = {}
            for i in X2:
                X2_cuda[i] = X2[i].cuda()
            X3_cuda = {}
            for i in X3:
                X3_cuda[i] = X3[i].cuda()
            X4_cuda = {}
            for i in X4:
                X4_cuda[i] = X4[i].cuda()
            # y_cuda = y.cuda()
            # X_cuda = [x.cuda(5) for x in X]
            # input_ids, attention_mask, position_ids = X

            # # print(X_cuda['input_ids'])
            y_pred = model(X1_cuda,X2_cuda,X3_cuda,X4_cuda).squeeze()
            Xs.extend(y_pred)
            ys.extend(y)

        for X, y in zip(Xs,ys):
            mention = y.split("###")[0]
            candidate = y.split("###")[1]
            # stds = y.split("###")[2]
            if mention not in pred_dict.keys():
                pred_dict[mention] = []
            if X > args.threshold:
                pred_dict[mention].append(candidate)
        return pred_dict
def write_result(i):
    pred_dict = predict_write(model, tokenizer)
    with open("data/origin_data/CHIP-CDN_test.json","r",encoding='utf-8') as f:
        original_data = json.load(f)
    result_list = []
    for item in original_data:
        tmp = []
        for sr in item["SR"]:
            tmp.extend(pred_dict[re.sub('\[.*?\]','',sr)])
        tmp.extend(pred_dict[item["text"]])
        
        result_list.append({
        "text":item['text'],
        "normalized_result":"##".join(list(set(tmp)))
        })
        # item['normalized_result'] = mentions2preds[item['text'].replace('"','').replace('?','')] if item['text'].replace('"','').replace('?','') in mentions2preds.keys() else ""
    with open(args.output_path+"CHIP-CDN_test"+str(i)+".json","w",encoding='utf-8') as f:
        f.write(json.dumps(result_list,ensure_ascii=False, indent=4))

In [None]:
model = TypeCLSModel()
# model,optimizer,best_f1 = load_model(args.model_path+"best_f1_"+str(args.version)+".pth",model,optimizer)
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
model.cuda()
tokenizer = AutoTokenizer.from_pretrained(args.plm_name)
best_f1 = 0
train_n_evaluate(model,
                 loss_fn, optimizer, evaluate, tokenizer,best_f1=best_f1)
model,optimizer,best_f1 = load_model(args.model_path+"best_f1_"+str(args.version)+".pth",model,optimizer)
write_result(args.version)
