In [1]:
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForTokenClassification,
    HfArgumentParser,
    set_seed,
)
from transformer_utils.data_utils_inference import *
import copy
from typing import Dict
import torch
import argparse
import datetime
import random
from transformer_utils.modeling_moe import *
import numpy as np
import time
import copy

In [2]:
MODEL_CLASSES = {
    'bert': (BertConfig, BertForSequenceClassificationMT, BertTokenizer),
    # 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    'xlm': (XLMConfig, XLMForSequenceClassificationMT, XLMTokenizer),
    'roberta': (RobertaConfig, RobertaForSequenceClassificationMT, RobertaTokenizer),
    # 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
    # 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
    'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassificationMT, XLMRobertaTokenizerFast),
    'xlmrobertafast': (XLMRobertaConfig, XLMRobertaForSequenceClassificationMT, XLMRobertaTokenizerFast),
    'robertatextcnn': (RobertaConfig, RobertaTextCNNForSequenceClassificationMT, RobertaTokenizer),
    'xlmrobertatextcnn': (XLMRobertaConfig, XLMRobertaTextCNNForSequenceClassificationMT, XLMRobertaTokenizer),

    'robertatextlstm': (RobertaConfig, RobertaLSTMForSequenceClassificationMT, RobertaTokenizer),
    'xlmrobertalstm': (XLMRobertaConfig, XLMRobertaLSTMForSequenceClassificationMT, XLMRobertaTokenizer),

    'robertadpcnn': (RobertaConfig, RobertaDPCNNForSequenceClassificationMT, RobertaTokenizer),
    'xlmrobertadpcnn': (XLMRobertaConfig, XLMRobertaDPCNNForSequenceClassificationMT, XLMRobertaTokenizer),
    'deberta': (DebertaConfig, DebertaForSequenceClassificationMT,DebertaTokenizer)
}

In [4]:
class args:
    def __init__(self):
        self.add_visual_features = True
        self.max_bounding_box_size = 200
        self.max_color_size = 100
        self.max_FontWeight_size = 100
        self.max_FontSize_size = 100
        self.equal_size = 2
        self.max_part_size = 10
        self.output_hidden_states = True
        self.num_attention_heads = 12
        self.num_hidden_layers = 12
        self.intermediate_size = 3072
        self.batch_size = 16
        self.hidden_size = 768
        self.model_type = "xlmroberta"
        self.drop_hard_label = False
        self.test_onnx = False
        self.vocab_clip_mapping_file = None
        self.output_path = "result.tsv"
        self.sliding_window_size = -1
        self.cache_dir = None
        self.no_cuda = False
        self.local_rank = -1
        self.inference_mt_all_tasks = False
        self.abandon_visual_features = False

In [5]:
args = args()

In [6]:
def add_config(config):
    config.add_visual_features = args.add_visual_features
    config.max_bounding_box_size = args.max_bounding_box_size
    config.max_color_size = args.max_color_size
    config.max_FontWeight_size = args.max_FontWeight_size
    config.max_FontSize_size = args.max_FontSize_size
    config.equal_size = args.equal_size
    config.max_part_size = args.max_part_size
    config.drop_hard_label = args.drop_hard_label
    config.sliding_window_size = args.sliding_window_size
    # logger.info("PyTorch: setting up devices")
    if args.no_cuda:
        device = torch.device("cpu")
        n_gpu = 0
    elif args.local_rank == -1:
        # if n_gpu is > 1 we'll use nn.DataParallel.
        # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
        n_gpu = torch.cuda.device_count()
        device_ids = list(range(n_gpu))
        random.shuffle(device_ids)
        device = torch.device("cuda:{}".format(device_ids[0]) if torch.cuda.is_available() else "cpu")
    else:
        # Here, we'll use torch.distributed.
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(0, 50000))
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
    config.device = device
    config.n_gpu = n_gpu
    config.device_ids = device_ids
    config.num_attention_heads = args.num_attention_heads
    config.hidden_size = args.hidden_size
    config.intermediate_size = args.intermediate_size
    config.num_hidden_layers = args.num_hidden_layers
    config.output_hidden_states = args.output_hidden_states
    config.batch_size = args.batch_size
    config.test_onnx = args.test_onnx
    config.abandon_visual_features = args.abandon_visual_features
    config.inference_mt_all_tasks = args.inference_mt_all_tasks

In [7]:
def trans_conll_file_to_overlap_file(conll_file_path, nodrop_no_label=False):
    overlap_conll_file = []
    content = []
    split_ = [i * 250 for i in range(10)]
    firstline = False
    number = 0
    conll_file = open(conll_file_path, 'r', encoding='utf-8')
    for line in conll_file:
        line_v = line.split(' ')
        if len(line_v) != 23:
            number += 1
            if number % 1000 == 0:
                print(number)
            if firstline:
                if not content:
                    continue
                has_label = False
                for line_tmp in content:
                    if 'O' != line_tmp.split(' ')[1]:
                        has_label = True
                if has_label:
                    for part, start_index in enumerate(split_):
                        overlap_conll_file.append(xml_path)
                        for tmp_content in content[start_index: start_index + 400]:
                            overlap_conll_file.append(tmp_content + ' ' + str(part) + '\n')
                        if start_index + 400 >= len(content):
                            break
                else:
                    pass
#                     print(xml_path)
            xml_path = line
            firstline = True
            content = []
        else:
            content.append(line.replace('\n', ''))
    if content:
        if nodrop_no_label:
            for part, start_index in enumerate(split_):
                    overlap_conll_file.append(xml_path)
                    for tmp_content in content[start_index: start_index + 400]:
                        overlap_conll_file.append(tmp_content + ' ' + str(part) + '\n')
                    if start_index + 400 >= len(content):
                        break
        elif has_label:
            for part, start_index in enumerate(split_):
                    overlap_conll_file.append(xml_path)
                    for tmp_content in content[start_index: start_index + 400]:
                        overlap_conll_file.append(tmp_content + ' ' + str(part) + '\n')
                    if start_index + 400 >= len(content):
                        break
        else:
            pass
    return overlap_conll_file

In [8]:

@dataclass
class InputExample_pred_mt:
    xml_path: str
    words: List[str]


def get_eval_dataloader_mt(eval_dataset: Dataset, batch_size=20) -> DataLoader:
    data_loader = DataLoader(
        dataset=eval_dataset,
        sampler=None,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=default_data_collator,
    )

    return data_loader


def inference_mt(dataloader, model, task_index, args=None):
    # logger.info("***** Running %s *****")
    # logger.info("  Num examples = %d", len(dataloader.dataset))
    # logger.info("  Batch size = %d", dataloader.batch_size)

    preds: np.ndarray = None
    label_ids: np.ndarray = None
    session = None
    # print(f"Output logits for task {args.task_index}")
    if args.test_onnx:
        session = rt.InferenceSession(model, providers=['CPUExecutionProvider'])
    else:
        if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
            model = torch.nn.DataParallel(model, args.device_ids)
        model.to(args.device)
        model.eval()
    for inputs in tqdm(dataloader, desc='Evaluate'):
        has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"])
        for k in list(inputs.keys()):
            if k not in ['input_ids', 'attention_mask', 'visual_features', 'labels']:
                inputs.pop(k)
        for k, v in inputs.items():
            inputs[k] = v.to(args.device)
        if args.test_onnx:
            onnx_inputs = {session.get_inputs()[0].name: to_numpy(inputs['input_ids']),
                           session.get_inputs()[1].name: to_numpy(inputs['attention_mask']),
                           session.get_inputs()[2].name: to_numpy(inputs['visual_features'])}
            ort_outs = session.run(['output'], onnx_inputs)
            pred = ort_outs[0]
        else:
            with torch.no_grad():
                outputs = model(**inputs, task_index=task_index)
                if has_labels:  
                    step_eval_loss, logits = outputs[:2] 
                else:
                    logits = outputs[0]
                pred = logits.detach().cpu().numpy()
        if preds is None:
            preds = pred
        else:
            preds = np.append(preds, pred, axis=0)
        if inputs.get("labels") is not None:
            if label_ids is None:
                label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
    return label_ids, preds


def get_real_label_mt(predictions: np.ndarray, label_ids: np.ndarray, label_map: Dict, label2id: Dict):
    # predictions = softmax(predictions, axis=2)  
    preds_task = predictions
#     preds_max = np.max(predictions, axis=2)
    preds = np.argmax(predictions, axis=2)
    batch_size, seq_len = preds.shape
    out_label_list, preds_list = [[[] for _ in range(batch_size)] for i in range(2)]
    for i in range(batch_size):
        for j in range(seq_len):
            if label_ids[i, j] != nn.NLLLoss().ignore_index:
                out_label_list[i].append(label_map[label_ids[i][j]])
                preds_list[i].append(label_map[preds[i][j]])
    return preds_list, out_label_list   # delete the useless tokens



def get_real_result_from_real_preds_add_mt(real_preds, examples, real_label_ids):
    xml_dic = {}
    real_preds_ori, real_label_ids_ori, examples_ori, label_real_from_example_ori = [[] for i in range(4)]
    for index, example in enumerate(examples):
        if xml_dic.__contains__(example.xml_path) is False:
            xml_dic[example.xml_path] = [index]
        else:
            xml_dic[example.xml_path].append(index)
    
    for key in xml_dic:
        IsImage = [[0 for i in range(250 * (len(xml_dic[key]) - 1))] for j in range(1)]
        label_pred, label_real_from_example, label_real = [['O' for i in range(250 * (len(xml_dic[key]) - 1))] for j in range(3)]
        words = ['' for i in range(250 * (len(xml_dic[key]) - 1))]
        for part, index in enumerate(xml_dic[key]):
            words[part * 250: part * 250 + len(examples[index].words_ori)] = examples[index].words_ori
            IsImage[part * 250: part * 250 + len(examples[index].isImage)] = examples[index].isImage
            label_real_from_example[part * 250: part * 250 + len(examples[index].labels)] = examples[index].labels
            label_real[part * 250: part * 250 + len(real_label_ids[index])] = real_label_ids[index]
            tmp = list(real_preds[index])
            for index_, lab in enumerate(label_pred[part * 250:]):
                if index_ == len(tmp):
                    break
                if tmp[index_] == "O" and lab != "O":
                    tmp[index_] = lab
            label_pred[part * 250: part * 250 + len(tmp)] = tmp
        examples_ori.append(InputExample_pred_mt(xml_path=key, words=words))
        real_preds_ori.append(label_pred)
        label_real_from_example_ori.append(label_real_from_example)
        real_label_ids_ori.append(label_real)
    return real_preds_ori, examples_ori, label_real_from_example_ori, real_label_ids_ori





def get_entities_bio_mt(seq):
    """Gets entities from sequence.
    note: BIO
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        seq = ['PER', 'PER', 'O', 'LOC', 'PER']
        get_entity_bio(seq)
        #output
        [['PER', 0,1], ['LOC', 3, 3]]
    """
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]
    chunks = []
    chunk = [-1, -1, -1]
    for indx, tag in enumerate(seq):
        if tag != 'O':
            if chunk[2] != -1 and chunk[0] == tag:
                chunk[2] = indx
            elif chunk[2] != -1 and chunk[0] != tag:
                chunks.append(chunk)
                chunk = [-1, -1, -1]
                chunk[1] = indx
                chunk[0] = tag
                chunk[2] = indx
            else:
                chunk = [-1, -1, -1]
                chunk[1] = indx
                chunk[0] = tag
                chunk[2] = indx
            if indx == len(seq) - 1:
                chunks.append(chunk)
        else:
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
    return set([tuple(chunk) for chunk in chunks])



@dataclass
class InputExample_pred_mt:
    xml_path: str
    words: List[str]


In [9]:

def get_entities_bio(seq):
    """Gets entities from sequence.
    note: BIOS
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
        # >>> get_entity_bios(seq)
        [['PER', 0,1], ['LOC', 3, 3]]
    """
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']] # 一定会以o为结尾
    chunks = []
    chunk = [-1, -1, -1]
    for indx, tag in enumerate(seq):
        if tag.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
            chunk[1] = indx
            chunk[0] = tag.split('-')[1]
            chunk[2] = indx
            if indx == len(seq) -1:
                chunks.append(chunk)
        elif tag.startswith('I-') and chunk[1] != -1:
            _type = tag.split('-')[1]
            if _type == chunk[0]:
                chunk[2] = indx
            
            if indx == len(seq) - 1:
                chunks.append(chunk)
        else:
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
    return set([tuple(chunk) for chunk in chunks])

In [10]:
def change_to_bio(list_of_list_seq):
    # ['PER', 'PER', 'O', 'LOC', 'PER'] -> ['B-PER', 'I-PER', 'O', 'B-LOC', 'B-PER']
    new_list = []
    for list_seq in list_of_list_seq:
        list_seq_copy = copy.copy(list_seq)
        ents = get_entities_bio_mt(list_seq)
        for span in ents:
            name, start, end = span
            list_seq_copy[start] = f"B-{name}"
            for left_index in range(start+1,end+1):
                list_seq_copy[left_index] = f"I-{name}"
        assert get_entities_bio(list_seq_copy) == get_entities_bio_mt(list_seq)
        new_list.append(list_seq_copy)
    return new_list


In [11]:
def IsNumberWithOptionalCurrency(s, currency_list):
    try:
        float(s)
        return True
    except ValueError:
        pass
    try:
        import unicodedata
        unicodedata.numeric(s)
        return True
    except (TypeError, ValueError):
        pass
    if len(s) > 0:
        for key in currency_list:
            if s[-len(key):] == key:
                return IsNumberWithOptionalCurrency(s[:-len(key)], currency_list)

    return False

def generate_currency_dic(currencyListFilePath):
    currencyListD = {}
    for line in open(currencyListFilePath, 'r', encoding='utf-8'):
        if currencyListD.__contains__(line.strip()):
            print(line.strip())
        currencyListD[line.strip()] = 1
    return currencyListD

def compare_number_for_price(location1, location2, currency_list):
    number_a = [token for token in location1 if IsNumberWithOptionalCurrency(token, currency_list)]
    number_b = [token for token in location2 if IsNumberWithOptionalCurrency(token, currency_list)]
    return number_a == number_b


def compare(entity, location1, location2, words, currency_list, isprint):
    if entity == 'MainImage' and abs(location2[0] - location1[0]) <= 5:
        if isprint:
            print("_iamge_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
        return True
    if words[location1[0]: location1[1] + 1] == words[location2[0]: location2[1] + 1]:
        if isprint:
            print("_1_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
        return True
    if (min(location1[1], location2[1]) - max(location1[0], location2[0]) + 1) / (location2[1] - location2[0] + 1) > 0.6:
        if isprint:
            print("_2_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
        return True
    Longest = longestCommonSubstring(' '.join(words[location1[0]: location1[1] + 1]), ' '.join(words[location2[0]: location2[1] + 1]))
    if Longest / len(' '.join(words[location2[0]: location2[1] + 1])) > 0.6 or (entity == 'Price' and len(words[location1[0]: location1[1] + 1]) == 2):
        if entity == 'Price' and compare_number_for_price(words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1], currency_list) is False:
            if isprint:
                print("_4_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
            return False
        if isprint:
            print("_3_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
        return True
    if isprint:
        print("_4_", entity, words[location1[0]: location1[1] + 1], words[location2[0]: location2[1] + 1])
    return False

def get_metric(real_preds, real_label_ids, examples, currency_list, isprint=False):
    real_results, real_label_results = {}, {}
    res_pred, res_label = [], []
    for index_i, real_pred in enumerate(real_preds):
        for index_j, label in enumerate(real_pred):
            if 'O' == label:
                continue
            key_v = label.split('-')
            if real_results.__contains__(key_v[1]) is False:
                real_results[key_v[1]] = []
            if 'B' == key_v[0]:
                real_results[key_v[1]].append([index_j, index_j])
            if 'I' == key_v[0]:
                if len(real_results[key_v[1]]) == 0:
                    # real_results[key_v[1]].append([index_j, index_j])
                    continue
                else:
                    real_results[key_v[1]][-1][-1] += 1
        res_pred.append(real_results)
        real_results = {}

        for index_j, label in enumerate(real_label_ids[index_i]):
            if 'O' == label:
                continue
            key_v = label.split('-')
            if real_label_results.__contains__(key_v[1]) is False:
                real_label_results[key_v[1]] = []
            if 'B' == key_v[0]:
                real_label_results[key_v[1]].append([index_j, index_j])
            if 'I' == key_v[0]:
                if not real_label_results[key_v[1]]:
                    real_label_results[key_v[1]].append([index_j, index_j])
                else:
                    real_label_results[key_v[1]][-1][-1] = index_j
        res_label.append(real_label_results)
        real_label_results = {}
    for index, pred in enumerate(res_pred):
        for key in pred:
            if res_label[index].__contains__(key) is False:
                continue
            labels = res_label[index][key][0]
            for number in range(len(pred[key])):
                if pred[key][number] != labels and compare(key, pred[key][number], labels, examples[index].words, currency_list, isprint):
                    for offset in range(pred[key][number][0], pred[key][number][1] + 1):
                        real_preds[index][offset] = 'O'
                    for offset in range(labels[0], labels[1] + 1):
                        if offset == labels[0]:
                            real_preds[index][offset] = 'B-' + key
                        else:
                            real_preds[index][offset] = 'I-' + key
                    pred[key][number] = labels

                    

    return compute_metric_mt_add_qiang(real_preds, real_label_ids)

def f1_score_ee(true_entities, pred_entities):
    """Compute the F1 score for DeepEE."""
    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)

    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0

    return score


def precision_score_ee(true_entities, pred_entities):
    """Compute the precision for DeepEE."""
    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)

    score = nb_correct / nb_pred if nb_pred > 0 else 0

    return score


def recall_score_ee(true_entities, pred_entities):
    """Compute the recall for DeepEE."""
    nb_correct = len(true_entities & pred_entities)
    nb_true = len(true_entities)

    score = nb_correct / nb_true if nb_true > 0 else 0

    return score


def classification_report_mt(true_entities, pred_entities, digits=5):
    """Build a text report showing the main classification metrics."""
    name_width = 0
    d1 = defaultdict(set)
    d2 = defaultdict(set)
    for e in true_entities:
        d1[e[0]].add((e[1], e[2]))
        name_width = max(name_width, len(e[0]))
    for e in pred_entities:
        d2[e[0]].add((e[1], e[2]))


    last_line_heading = 'macro avg'
    width = max(name_width, len(last_line_heading), digits)

    headers = ["precision", "recall", "f1-score", "support"]
    head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
    report = head_fmt.format(u'', *headers, width=width)
    report += u'\n\n'

    row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'

    ps, rs, f1s, s = [], [], [], []
    for type_name, type_true_entities in d1.items():
        type_pred_entities = d2[type_name]
        nb_correct = len(type_true_entities & type_pred_entities)
        nb_pred = len(type_pred_entities)
        nb_true = len(type_true_entities)

        p = nb_correct / nb_pred if nb_pred > 0 else 0
        r = nb_correct / nb_true if nb_true > 0 else 0
        f1 = 2 * p * r / (p + r) if p + r > 0 else 0

        report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)

        ps.append(p)
        rs.append(r)
        f1s.append(f1)
        s.append(nb_true)

    report += u'\n'

 

    # compute averages
    report += row_fmt.format('micro avg',
                             precision_score_ee(true_entities, pred_entities),
                             recall_score_ee(true_entities, pred_entities),
                             f1_score_ee(true_entities, pred_entities),
                             np.sum(s),
                             width=width, digits=digits)
    report += row_fmt.format(last_line_heading,
                             np.average(ps, weights=s),
                             np.average(rs, weights=s),
                             np.average(f1s, weights=s),
                             np.sum(s),
                             width=width, digits=digits)

 

    return report

def compute_metric_mt_add_qiang(prediction_onnx_origin, real_label_ids_ori):
    true_entities = get_entities_bio(real_label_ids_ori)
    pred_entities = get_entities_bio(prediction_onnx_origin)

    bad_pred = [ett for ett in pred_entities if ett not in true_entities]
    print(bad_pred)

    print('-'*10)
    print(true_entities)

    
    results = {
        "f1": f1_score_ee(true_entities, pred_entities),
        'report': classification_report_mt(true_entities, pred_entities),
        'recall':recall_score_ee(true_entities, pred_entities),
        'precision':precision_score_ee(true_entities, pred_entities)
    }
    return results

In [None]:
currency_list = np.load('./data/currency.npy', allow_pickle=True).tolist()

In [13]:
def run_different_model(saved_dir, all_candidate):
    all_candidate = [all_candidate]
    all_cand_res = {}
    for one_model_candidate in all_candidate:
        print(f"[Info] start for candidate {one_model_candidate}")
        args.model_name_or_path = one_model_candidate
        model_path_or_model = args.model_name_or_path
        config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
        model_config = config_class.from_pretrained(args.model_name_or_path,
                                                  finetuning_task="dummy",
                                                  cache_dir=args.cache_dir if args.cache_dir else None)
        add_config(model_config)
        tokenizer = AutoTokenizer.from_pretrained(
                YOUR_MODEL_PATH,
                use_fast=True,
            )
        model_path_or_model = model_class.from_pretrained(
                    model_path_or_model,
                    from_tf=False,
                    config=model_config,
                    cache_dir=None,
                )
        
        saved_lang_result_for_this_candidate = {}
        
        for one_conll_file_path in all_conll_file_path:
            print("*"*10)
            args.conll_file_path = one_conll_file_path
            language_name = one_conll_file_path.split('/')[-2]
            entity_type = os.path.basename(one_conll_file_path).split('_')[0]
            print(f"[Info]: Metrics for language {language_name} entity type {entity_type}")
            label_this = ['O', entity_type]
            label2id_this = {"O":0, entity_type:1}
            label_map_this = {0:"O", 1:entity_type}
            overlap_conll = trans_conll_file_to_overlap_file(args.conll_file_path, True)
            test_dataset = LabelingDataset_mt(overlap_conll, tokenizer, label_this, args.vocab_clip_mapping_file is not None, vocab_clip_mapping, 512, model_config)
            test_dataloader = get_eval_dataloader(test_dataset, batch_size=model_config.batch_size)
            task_index_map = {"MainImage":0, "Price":2, "Name":1}

            
            model_config.task_index = task_index_map[entity_type]

            
            test_label_ids, prediction_onnx = inference_mt(test_dataloader, model_path_or_model, task_index_map[entity_type], model_config)
            real_prediction_onnx, real_label_ids = get_real_label_mt(prediction_onnx, test_label_ids, label_map_this, label2id_this)
            real_prediction_onnx_origin, examples_origin, label_real_from_example_ori, real_label_ids_ori= get_real_result_from_real_preds_add_mt(real_prediction_onnx, test_dataset.__get_examples__(), real_label_ids)
            
            real_prediction_onnx_origin_optimized = change_to_bio(real_prediction_onnx_origin)

            real_label_ids_ori_optimized = change_to_bio(real_label_ids_ori)

            out_res = get_metric(real_prediction_onnx_origin_optimized, real_label_ids_ori_optimized, examples_origin, currency_list)


            print(out_res)
            print('*'*10)
            if language_name in saved_lang_result_for_this_candidate:
                saved_lang_result_for_this_candidate[language_name][entity_type]['recall'] = out_res['recall']
                saved_lang_result_for_this_candidate[language_name][entity_type]['precision'] = out_res['precision']
            else:
                saved_lang_result_for_this_candidate[language_name] = {"Price":{}, "MainImage":{}, "Name":{}}
                saved_lang_result_for_this_candidate[language_name][entity_type]['recall'] = out_res['recall']
                saved_lang_result_for_this_candidate[language_name][entity_type]['precision'] = out_res['precision']

        all_cand_res[one_model_candidate] = saved_lang_result_for_this_candidate
    return all_cand_res

In [14]:
from glob import glob
all_conll_file_full_paths = []
for ll in os.listdir("./test_data"):
    this_ll = os.path.join("./test_data", ll)
    xixi = glob(os.path.join(this_ll, "*_conll.tsv"))
    all_conll_file_full_paths.extend(xixi)

In [None]:
def statistic(res_dict):
    all_ckp = res_dict.keys()
    for one_ckp in all_ckp:
        this_ckp_result = res_dict[one_ckp]
        block_1, block_2, block_3, block_4, all_precision, all_recall = [], [], [], [], [], []
        all_langs = this_ckp_result.keys()
        for one_langs in all_langs:
            this_lang_metrics = this_ckp_result[one_langs]

            try:
                price_precison = this_lang_metrics['Price']['precision']
            except:
                price_precison = 100
            
            try:
                price_recall = this_lang_metrics['Price']['recall']
            except:
                price_recall = 100

            try:
                image_precision = this_lang_metrics['MainImage']['precision']
            except:
                image_precision = 100

            try:
                image_recall = this_lang_metrics['MainImage']['recall']
            except:
                image_recall = 100

            try:
                name_precision = this_lang_metrics['Name']['precision']
            except:
                name_precision = 100

            try:
                name_recall = this_lang_metrics['Name']['recall']
            except:
                name_recall = 100

            
            all_values = [price_precison, price_recall, image_precision, image_recall, name_precision, name_recall]
            min_value = min(all_values)
            if min_value >= 0.9:
                block_4.append(one_langs)
            elif min_value >= 0.8:
                block_3.append(one_langs)
            elif min_value >= 0.7:
                block_2.append(one_langs)
            else:
                block_1.append(one_langs)
            all_precision.extend([price_precison, image_precision, name_precision])
            all_recall.extend([price_recall, image_recall, name_recall])
            all_precision = [i for i in all_precision if i!= 100]
            all_recall = [i for i in all_recall if i!= 100]
     
        print(f"For checkpoint {one_ckp}, <70 markets number: {len(block_1)}, 70-80: {len(block_2)}, 80-90: {len(block_3)}, >90: {len(block_4)}, avg micro precision: {np.mean(all_precision)}, avg micro recall: {np.mean(all_recall)}, avg micro f1: {2 * np.mean(all_precision) * np.mean(all_recall) / (np.mean(all_precision) + np.mean(all_recall))}")

Output the final resuts

In [None]:
vocab_clip_mapping = {}
res = run_different_model("./saved/test", all_conll_file_full_paths)
statistic(res)