In [25]:
import re
import sys
import json
from collections import Counter, defaultdict
from statistics import median

#from dee.event_types import get_event_template


def load_line_json_iterator(filepath):
    with open(filepath, "rt", encoding="utf-8") as fin:
        for line in fin:
            d = json.loads(line.strip())
            yield d


def load_json(filepath):
    with open(filepath, "rt", encoding="utf-8") as fin:
        return json.load(fin)


def sent_seg(
    text,
    special_seg_indicators=None,
    lang="zh",
    punctuations=None,
    quotation_seg_mode=True,
) -> list:
    """
    cut texts into sentences (in chinese language).
    Args:
        text <str>: texts ready to be cut
        special_seg_indicators <list>: some special segment indicators and
            their replacement ( [indicator, replacement] ), in baike data,
            this argument could be `[('###', '\n'), ('%%%', ' '), ('%%', ' ')]`
        lang <str>: languages that your corpus is, support `zh` for Chinese
            and `en` for English now.
        punctuations <set>: you can split the texts by specified punctuations.
            texts will not be splited by `;`, so you can specify them by your own.
        quotation_seg_mode <bool>: if True, the quotations will be regarded as a
            part of the former sentence.
            e.g. `我说：“翠花，上酸菜。”，她说：“欸，好嘞。”`
            the text will be splited into
            ['我说：“翠花，上酸菜。”，', '她说：“欸，好嘞。”'], other than
            ['我说：“翠花，上酸菜。', '”，她说：“欸，好嘞。”']
    Rrturns:
        <list>: a list of strings, which are splited sentences.
    """
    # if texts are not in string format, raise an error
    if not isinstance(text, str):
        raise ValueError

    # if the text is empty, return a list with an empty string
    if len(text) == 0:
        return []

    text_return = text

    # segment on specified indicators
    # special indicators standard, like [('###', '\n'), ('%%%', '\t'), ('\s', '')]
    if special_seg_indicators:
        for indicator in special_seg_indicators:
            text_return = re.sub(indicator[0], indicator[1], text_return)

    if lang == "zh":
        punkt = {"。", "？", "！", "…"}
    elif lang == "en":
        punkt = {".", "?", "!"}
    if punctuations:
        punkt = punkt | punctuations

    if quotation_seg_mode:
        text_return = re.sub(
            "([%s]+[’”`'\"]*)" % ("".join(punkt)), "\\1\n", text_return
        )
    else:
        text_return = re.sub("([{}])".format("".join(punkt)), "\\1\n", text_return)

    # drop sentences with no length
    return [
        s.strip()
        for s in filter(
            lambda x: len(x.strip()) == 1
            and x.strip() not in punkt
            or len(x.strip()) > 0,
            text_return.split("\n"),
        )
    ]


def stat_sent_len(filepath):
    num_sents = []
    sent_len = []
    for d in load_line_json_iterator(filepath):
        sents = sent_seg(d["text"])
        num_sents.append(len(sents))
        lens = [len(sent) for sent in sents]
        sent_len.extend(lens)
        # if min(lens) < 5:
        #     print("================= raw text =================")
        #     print(d["text"])
        #     print("================= processed text =================")
        #     print("\n".join(filter(lambda x: len(x) < 5, sents)))
        #     breakpoint()
    sent_len_counter = Counter(sent_len)
    print(
        (
            f"num_sents: min: {min(num_sents)}, median: {median(num_sents)}, max: {max(num_sents)}\n"
            f"sent_len: min: {min(sent_len)}, median: {median(sent_len)}, max: {max(sent_len)}"
            f"{sent_len_counter.most_common()}"
        )
    )

# qy: get ranges of a given word "span"
def get_span_drange(sents, span):
    drange = []
    common_span = (
        span.replace("*", "\*")
        .replace("?", "\?")
        .replace("+", "\+")
        .replace("[", "\[")
        .replace("]", "\]")
        .replace("(", "\(")
        .replace(")", "\)")
        .replace(".", "\.")
        .replace("-", "\-")
    )  # noqa: W605
    for sent_idx, sent in enumerate(sents):
        # qy: word to be found shorter than the sentence
        if len(sent) < len(common_span):
            continue
        for ocurr in re.finditer(common_span, sent):
            span_pos = ocurr.span()
            if (
                (
                    "0" <= span[0] <= "9"
                    and "0" <= sents[sent_idx][span_pos[0] - 1] <= "9"
                    and span_pos[0] - 1 > -1
                )
                or (
                    "0" <= span[0] <= "9"
                    and "0" <= sents[sent_idx][span_pos[0] - 2]
                    and sents[sent_idx][span_pos[0] - 1] == "."
                    and span_pos[0] - 2 > -1
                )
                or (
                    "0" <= span[-1] <= "9"
                    and span_pos[1] < len(sents[sent_idx])
                    and "0" <= sents[sent_idx][span_pos[1]] <= "9"
                )
                or (
                    "0" <= span[-1] <= "9"
                    and span_pos[1] + 1 < len(sents[sent_idx])
                    and sents[sent_idx][span_pos[1]] == "."
                    and "0" <= sents[sent_idx][span_pos[1] + 1] <= "9"
                )
            ):
                continue
            drange.append([sent_idx, *span_pos]) # qy: 第几句 从几到几
    return drange

# qy:将短句子合并为每句总长不超过128
def reorganise_sents(sents, max_seq_len, concat=False, final_cut=False, concat_str=" "):
    new_sents = []
    group = ""
    for sent in sents:
        if len(sent) + len(group) < max_seq_len:
            if concat:
                if len(group) > 1 and "\u4e00" <= group[-1] <= "\u9fa5":
                    group += concat_str + sent
                else:
                    group += sent
            else:
                new_sents.append(sent)
        else:
            if len(group) > 0:
                new_sents.append(group)
                group = ""
            if len(sent) > max_seq_len:
                if final_cut:
                    group = sent[:max_seq_len]
                else:
                    sent_splits = sent_seg(sent, punctuations={"，", "、", "|", ","})
                    reorg_sent_splits = reorganise_sents(
                        sent_splits, max_seq_len, concat=True, final_cut=True
                    )
                    new_sents.extend(reorg_sent_splits)
            else:
                group = sent
    if len(group) > 0:
        new_sents.append(group)
    return [s.strip() for s in filter(lambda x: len(x) > 0, new_sents)]


def build(
    event_type2event_class,
    filepath,
    dump_filepath,
    max_seq_len=128,
    inference=False,
    add_trigger=False,
):
    not_valid = 0
    data = []
    for d in load_line_json_iterator(filepath): # qy:for each document
        sents = sent_seg(d["text"], punctuations={"；"}) # qy:sentence segmentation
        sents = reorganise_sents(sents, max_seq_len, concat=True) # qy:合并短句
        # sents = d['map_sentences']
        # sentence length filtering
        sents = list(filter(lambda x: len(x) >= 5, sents)) # qy:去除<5个字的句子
        if(len(d['title'])>0):
            sents.insert(0, d["title"])
        # sents.insert(0, d['map_title'])
        ann_valid_mspans = []
        ann_valid_dranges = []
        ann_mspan2dranges = defaultdict(list)
        ann_mspan2guess_field = {}
        recguid_eventname_eventdict_list = [] # qy:event lists

        event_types = []
        if not inference:
            # qy: no events given -> invalid
            if "event_list" not in d or len(d["event_list"]) == 0:
                not_valid += 1
                continue

            for event_idx, ins in enumerate(d["event_list"]):
                event_types.append(ins["event_type"])

                roles = event_type2event_class[ins["event_type"]].FIELDS
                role2arg = {x: None for x in roles}
                # take trigger into consideration
                trigger = ins["trigger"]
                trigger_ocurr = get_span_drange(sents, trigger)

                if len(trigger_ocurr) <= 0:
                    continue
                if add_trigger:
                    role2arg["Trigger"] = trigger
                    ann_mspan2guess_field[trigger] = "Trigger"
                    ann_valid_mspans.append(trigger)
                    ann_mspan2dranges[trigger] = trigger_ocurr
                for arg_pair in ins["arguments"]:
                    ocurr = get_span_drange(sents, arg_pair["argument"])
                    if len(ocurr) <= 0:
                        continue
                    role2arg[arg_pair["role"]] = arg_pair["argument"] # qy: each role only assigned one argument, cover previous ones
                    ann_valid_mspans.append(arg_pair["argument"])
                    ann_mspan2guess_field[arg_pair["argument"]] = arg_pair["role"]
                    ann_mspan2dranges[arg_pair["argument"]] = ocurr
                ann_valid_dranges = list(ann_mspan2dranges.values())
                recguid_eventname_eventdict_list.append(
                    [event_idx, ins["event_type"], role2arg]
                )

        doc_type = "unk"
        if len(event_types) > 0:
            et_counter = Counter(event_types).most_common()
            if len(et_counter) == 1 and et_counter[0][1] == 1:
                doc_type = "o2o"
            elif len(et_counter) == 1 and et_counter[0][1] > 1:
                doc_type = "o2m"
            elif len(et_counter) > 1:
                doc_type = "m2m"

        data.append(
            [
                d["id"],
                {
                    "doc_type": doc_type,
                    "sentences": sents,
                    "ann_valid_mspans": ann_valid_mspans,
                    "ann_valid_dranges": ann_valid_dranges,
                    "ann_mspan2dranges": dict(ann_mspan2dranges),
                    "ann_mspan2guess_field": ann_mspan2guess_field,
                    "recguid_eventname_eventdict_list": recguid_eventname_eventdict_list,
                },
            ]
        )
    print("not valid:", not_valid)
    with open(dump_filepath, "wt", encoding="utf-8") as fout:
        json.dump(data, fout, ensure_ascii=False)


def build_m2m(
    event_type2event_class,
    filepath,
    dump_filepath,
    max_seq_len=128,
    inference=False,
    add_trigger=False,
):
    not_valid = 0
    data = []
    for d in load_line_json_iterator(filepath):
        sents = sent_seg(d["text"], punctuations={"；"})
        sents = reorganise_sents(sents, max_seq_len, concat=True)
        # sents = d['map_sentences']
        # sentence length filtering
        sents = list(filter(lambda x: len(x) >= 5, sents))
        sents.insert(0, d["title"])
        # sents.insert(0, d['map_title'])
        ann_valid_mspans = []
        ann_valid_dranges = []
        ann_mspan2dranges = defaultdict(list)
        ann_mspan2guess_field = {}
        recguid_eventname_eventdict_list = []

        event_types = []
        if not inference:
            if "event_list" not in d or len(d["event_list"]) == 0:
                not_valid += 1
                continue

            for event_idx, ins in enumerate(d["event_list"]):
                event_types.append(ins["event_type"])

                roles = event_type2event_class[ins["event_type"]].FIELDS
                role2arg = {x: [] for x in roles}
                # take trigger into consideration
                trigger = ins["trigger"]
                trigger_ocurr = get_span_drange(sents, trigger)

                if len(trigger_ocurr) <= 0:
                    continue
                if add_trigger:
                    role2arg["Trigger"].append(trigger)
                    ann_mspan2guess_field[trigger] = "Trigger"
                    ann_valid_mspans.append(trigger)
                    ann_mspan2dranges[trigger] = trigger_ocurr

                for arg_pair in ins["arguments"]:
                    ocurr = get_span_drange(sents, arg_pair["argument"])
                    if len(ocurr) <= 0:
                        continue
                    role2arg[arg_pair["role"]].append(arg_pair["argument"])
                    ann_valid_mspans.append(arg_pair["argument"])
                    ann_mspan2guess_field[arg_pair["argument"]] = arg_pair["role"]
                    ann_mspan2dranges[arg_pair["argument"]] = ocurr
                ann_valid_dranges = list(ann_mspan2dranges.values())
                new_role2arg = {x: None for x in roles}
                for role, args in role2arg.items():
                    if len(args) <= 0:
                        new_role2arg[role] = None
                    else:
                        new_role2arg[role] = args

                recguid_eventname_eventdict_list.append(
                    [event_idx, ins["event_type"], new_role2arg]
                )

        et_counter = Counter(event_types).most_common()
        if len(et_counter) == 1 and et_counter[0][1] == 1:
            doc_type = "o2o"
        elif len(et_counter) == 1 and et_counter[0][1] > 1:
            doc_type = "o2m"
        elif len(et_counter) > 0:
            doc_type = "m2m"
        else:
            doc_type = "unk"

        data.append(
            [
                d["id"],
                {
                    "doc_type": doc_type,
                    "sentences": sents,
                    "ann_valid_mspans": ann_valid_mspans,
                    "ann_valid_dranges": ann_valid_dranges,
                    "ann_mspan2dranges": dict(ann_mspan2dranges),
                    "ann_mspan2guess_field": ann_mspan2guess_field,
                    "recguid_eventname_eventdict_list": recguid_eventname_eventdict_list,
                },
            ]
        )
    print("not valid:", not_valid)
    with open(dump_filepath, "wt", encoding="utf-8") as fout:
        json.dump(data, fout, ensure_ascii=False)


def stat_roles(filepath):
    type2roles = defaultdict(set)
    for d in load_line_json_iterator(filepath):
        if "event_list" not in d:
            continue
        for event_idx, ins in enumerate(d["event_list"]):
            for arg_pair in ins["arguments"]:
                type2roles[ins["event_type"]].add(arg_pair["role"])

    for event_type in type2roles:
        print(event_type, len(type2roles[event_type]), list(type2roles[event_type]))


def merge_pred_ents_to_inference(pred_filepath, inference_filepath, dump_filepath):
    inference_data = load_json(inference_filepath)
    pred_data = {}
    pred_sents = {}
    pred_titles = {}
    for pred in load_line_json_iterator(pred_filepath):
        pred_data[pred["id"]] = pred["entity_pred"]
        pred_sents[pred["id"]] = pred["map_sentences"]
        pred_titles[pred["id"]] = pred["map_title"]
    for d in inference_data:
        guid = d[0]
        d[1]["sentences"] = pred_sents[guid]
        d[1]["sentences"].insert(0, pred_titles[guid])
        epd = pred_data[guid]
        ann_valid_mspans = []
        ann_valid_dranges = []
        ann_mspan2guess_field = {}
        ann_mspan2dranges = defaultdict(list)
        for ent in epd:
            if "trigger" in ent[1].lower():
                # ent_type = 'Trigger'
                continue
            else:
                ent_type = ent[1].split("-")[-1]
            ann_mspan2guess_field[ent[0]] = ent_type
            ann_mspan2dranges[ent[0]].append([ent[2] + 1, ent[3], ent[4] + 1])
        # for ent, ent_type in ent_pairs:
        #     drange = get_span_drange(d[1]['sentences'], ent)
        #     if len(drange) == 0:
        #         continue
        #     ann_mspan2guess_field[ent] = ent_type
        #     ann_mspan2dranges[ent] = drange
        ann_mspan2dranges = dict(ann_mspan2dranges)
        ann_valid_mspans = list(ann_mspan2dranges.keys())
        ann_valid_dranges = list(ann_mspan2dranges.values())
        d[1]["ann_valid_mspans"] = ann_valid_mspans
        d[1]["ann_valid_dranges"] = ann_valid_dranges
        d[1]["ann_mspan2guess_field"] = ann_mspan2guess_field
        d[1]["ann_mspan2dranges"] = ann_mspan2dranges

    with open(dump_filepath, "wt", encoding="utf-8") as fout:
        json.dump(inference_data, fout, ensure_ascii=False)

    print(json.dumps(inference_data[:2], ensure_ascii=False, indent=2))


def merge_pred_ents_with_pred_format_to_inference(
    pred_filepath, inference_filepath, dump_filepath
):
    inference_data = load_json(inference_filepath)
    pred_data = {}
    for pred in load_line_json_iterator(pred_filepath):
        pred_data[pred["id"]] = pred["new_comments"]
    for d in inference_data:
        guid = d[0]
        d[1]["sentences"] = pred_data[guid]["sentences"]
        ann_valid_mspans = []
        ann_valid_dranges = []
        ann_mspan2guess_field = {}
        ann_mspan2dranges = defaultdict(list)
        for ent in pred_data[guid]["mspans"]:
            if "trigger" in ent["mtype"].lower():
                # ent_type = 'Trigger'
                continue
            else:
                ent_type = ent["mtype"].split("-")[-1]
            ann_mspan2guess_field[ent["msapn"]] = ent_type
            ann_mspan2dranges[ent["msapn"]].append(ent["drange"])
        ann_mspan2dranges = dict(ann_mspan2dranges)
        ann_valid_mspans = list(ann_mspan2dranges.keys())
        ann_valid_dranges = list(ann_mspan2dranges.values())
        d[1]["ann_valid_mspans"] = ann_valid_mspans
        d[1]["ann_valid_dranges"] = ann_valid_dranges
        d[1]["ann_mspan2guess_field"] = ann_mspan2guess_field
        d[1]["ann_mspan2dranges"] = ann_mspan2dranges

    with open(dump_filepath, "wt", encoding="utf-8") as fout:
        json.dump(inference_data, fout, ensure_ascii=False)

    print(json.dumps(inference_data[:2], ensure_ascii=False, indent=2))


def multi_role_stat(filepath):
    num_ins = 0
    num_multi_role_doc = 0
    type2num_multi_role = defaultdict(lambda: 0)
    type2role2num_multi_role = defaultdict(lambda: defaultdict(list))

    for d in load_line_json_iterator(filepath):
        if "event_list" not in d:
            continue
        for ins in d["event_list"]:
            num_ins += 1
            roles = [x["role"] for x in ins["arguments"]]
            role, role_cnt = Counter(roles).most_common(1)[0]
            if role_cnt > 1:
                # if ins['event_type'] == '高管变动' and role == '高管职位':
                #     breakpoint()
                num_multi_role_doc += 1
                type2num_multi_role[ins["event_type"]] += 1
                type2role2num_multi_role[ins["event_type"]][role].append(role_cnt)

    print("num_ins", num_ins)
    print("num_multi_role_doc", num_multi_role_doc)
    print("type2num_multi_role", type2num_multi_role)
    for event_type in type2role2num_multi_role:
        for role in type2role2num_multi_role[event_type]:
            # type2role2num_multi_role[event_type][role] = Counter(type2role2num_multi_role[event_type][role]).most_common()
            type2role2num_multi_role[event_type][role] = sum(
                type2role2num_multi_role[event_type][role]
            )
    print("type2role2num_multi_role", type2role2num_multi_role)


def stat_shared_triggers(filepath):
    # train: 3400 / 9498
    num_records = 0
    num_share_trigger_records = 0
    with open(filepath, "rt", encoding="utf-8") as fin:
        for line in fin:
            trigger2event = defaultdict(list)
            data = json.loads(line)
            for ins in data.get("event_list", []):
                num_records += 1
                trigger2event[ins["trigger"]].append(ins)
            for trigger, inses in trigger2event.items():
                if len(inses) > 1:
                    num_share_trigger_records += len(inses)
    print(
        f"num_records: {num_records}, num_share_trigger_records: {num_share_trigger_records}"
    )



In [6]:
class BaseEvent(object):
    def __init__(self, fields, event_name='Event', key_fields=(), recguid=None):
        self.recguid = recguid
        self.name = event_name
        self.fields = list(fields)
        self.field2content = {f: None for f in fields}
        self.nonempty_count = 0
        self.nonempty_ratio = self.nonempty_count / len(self.fields)

        self.key_fields = set(key_fields)
        for key_field in self.key_fields:
            assert key_field in self.field2content

    def __repr__(self):
        event_str = "\n{}[\n".format(self.name)
        event_str += "  {}={}\n".format("recguid", self.recguid)
        event_str += "  {}={}\n".format("nonempty_count", self.nonempty_count)
        event_str += "  {}={:.3f}\n".format("nonempty_ratio", self.nonempty_ratio)
        event_str += "] (\n"
        for field in self.fields:
            if field in self.key_fields:
                key_str = " (key)"
            else:
                key_str = ""
            event_str += "  " + field + "=" + str(self.field2content[field]) + ", {}\n".format(key_str)
        event_str += ")\n"
        return event_str

    def update_by_dict(self, field2text, recguid=None):
        self.nonempty_count = 0
        self.recguid = recguid

        for field in self.fields:
            if field in field2text and field2text[field] is not None:
                self.nonempty_count += 1
                self.field2content[field] = field2text[field]
            else:
                self.field2content[field] = None

        self.nonempty_ratio = self.nonempty_count / len(self.fields)

    def field_to_dict(self):
        return dict(self.field2content)

    def set_key_fields(self, key_fields):
        self.key_fields = set(key_fields)

    def is_key_complete(self):
        for key_field in self.key_fields:
            if self.field2content[key_field] is None:
                return False

        return True

    def get_argument_tuple(self):
        args_tuple = tuple(self.field2content[field] for field in self.fields)
        return args_tuple

    def is_good_candidate(self, min_match_count=2):
        key_flag = self.is_key_complete()
        if key_flag:
            if self.nonempty_count >= min_match_count:
                return True
        return False
class event_0(BaseEvent):
	NAME = '股票事件'
	FIELDS   = ['股票代码', '股票名称', '股票评级', '评级变化']
	TRIGGERS = {1: ['股票代码'],
   2: ['股票名称','评级变化'],
   3: ['股票代码','股票名称','评级变化'],
   4: ['股票代码','股票名称','股票评级','评级变化']}

	TRIGGERS['all'] = ['股票代码', '股票名称', '股票评级', '评级变化']
	def __init__(self, recguid=None):
		super().__init__(
		self.FIELDS, event_name=self.NAME, recguid=recguid
		 )
		self.set_key_fields(self.TRIGGERS)



event_type2event_class = { event_0.NAME: event_0,}
event_type_fields_list = [(event_0.NAME,event_0.FIELDS,event_0.TRIGGERS,2),]

In [17]:
filepath = 'submit_test.json'
#dump_filepath,
max_seq_len=128
inference=False
add_trigger=False

In [27]:
not_valid = 0
data = []

for d in load_line_json_iterator(filepath): # qy:for each document
    sents = sent_seg(d["text"], punctuations={"；"})
    sents = reorganise_sents(sents, max_seq_len, concat=True) # qy:合并短句
    #print(d['id'])
    if(d['id']==269):
        print(sents)
        for s in sents:
            print(len(s))
        print("len:"+str(len(sents)))

['中信建投证券CHINASECURITIES证券研究报告·A股公司简评报告中药II持续完善品种布局,健康产业引领者起航华润三九(000999)事件维持买入公司拟收购澳诺中国100%股权及引入抗肿瘤创新药物11月27日晚,', '公司公告:1)拟收购誉衡药业持有的澳诺(中国)制药有限公司(以下简称"澳诺制药"或"目标公司")100%股权;2)拟购买沈阳药科大学所持有的QBH-196项目所有技术成果及知识产权,以及在相关专利所涉及的有关国家和地区开发经营该产品的权利。', '简评拟收购澳诺制药,补充儿科产品线公司拟拟收购誉衡药业持有的澳诺制药100%股权,交易价款共计人民币14.2亿元,资金来源为公司自有资金。', '澳诺制药是集生产、研发、销售于一体的高新技术企业,核心产品为葡萄糖酸钙锌口服溶液、维生素C咀嚼片、参芝石斜颗粒,其"澳诺"、"金辛金丐特"是儿童补钙知名品牌。"澳诺"、"金辛金丐特"牌葡萄糖酸钙锌口服溶液是儿童补钙大产品,具备良好的市场规模和成长性。', '近年来,该产品及品牌连续位列零售市场钙补充剂第二位,钙补充剂药品市场第一位。零售渠道及医疗渠道市场份额均持续增长。', '自我诊疗是重要战略方向之一,儿童健康是自我诊疗业务重点发|我们认为,公司致力于成为"大众医药健康产业的引领者",|展的领域,葡萄糖酸钙锌口服溶液是儿童补钙大品种,具备良好预测和比率2018|2019E|2020E|2021E|营业收入(百万)|13,', '427.7|15,098.2|16,831.4|18,767.0|贺菊颖hejuying@csc.com.cn010-86451162执业证书编号:|S1440517050001|', '刘若飞liuruofei@csc.com.cn010-85130388执业证书编号:|S1440519080003|发布日期:|2019年11月28日|当前股价:|29.38元|主要数据股票价格绝对/相对市场表现(%|(%)|1个月|3个月|12个月|', '-9.04/-7.29|-4.55/-5.93|22.31/9.6|12月最高/最低价(元)|33.6/21.93|总股本(万股)|97,890.0|流通A股(万股)|97,839.53|总市值(亿元)|287.6|流通市值(亿元)|287.45|', '近3月日均成交量(万)|524.3