In [2]:
import pickle
import json
import os
import gzip
import math
import copy
from tqdm import tqdm
import numpy as np

import multiprocessing
from multiprocessing import Pool, cpu_count

import torch
import torch.nn.functional as F
from torch import nn

from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForMultipleChoice,
    BertJapaneseTokenizer,
    PreTrainedTokenizer,
    AdamW,
    get_linear_schedule_with_warmup,
)

import logging
logger = logging.getLogger(__name__)

from collections import Counter, defaultdict
from typing import List, Dict
Candidate = Dict[str, str] 


import unidic
import MeCab
home_path = os.environ['HOME']
tagger = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
STOP_POSTAGS = ('BOS/EOS',"代名詞","接続詞","感動詞","動詞,非自立可能","助動詞",'助詞',"接頭辞","記号,一般","補助記号","空白")
SEPARATE_TOKEN = '。'

In [6]:
from apex import amp

In [7]:
def list_fm_jsonl(f_jsonl: os.path.abspath) -> List[Candidate]:
    """ jsonl -> List[Dict[str, str]] """
    return [json.loads(line.rstrip()) for line in open(f_jsonl, 'r')]


def list_fm_tsv(f_tsv: os.path.abspath, col=0) -> List[int]:
    """ 2cols (pred, out_label_id) -> List[pred:int] """
    return [int(line.split()[col]) for line in open(f_tsv, 'r')]

def search_entity_ignore_answer(queries, topk=1000,ignore_answers=True):
    query,answer_candidates = queries
    ignore_docid = [entitie2id[answer] for answer in answer_candidates]
    
    avgdl = sum(doc_id2token_count) / len(doc_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    with open('./ir_dump/inverted_index', 'r', encoding='utf-8') as index_file:
        for token in parsed_query:
            if token in token2pointer:
                pointer, offset = token2pointer[token]
                index_file.seek(pointer)
                index_line = index_file.read(offset-pointer).rstrip()
                postings_list = load_index_line(index_line)
                target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(entities)
    doc_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for doc_id, tf in postings_list:
            dl = doc_id2token_count[doc_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            doc_id2tfidf[doc_id] += token_tfidf
    if ignore_answers:
        for ignore_id in ignore_docid:
            doc_id2tfidf[ignore_id] = 0

    docs = [(doc_id, tfidf) for doc_id, tfidf in enumerate(doc_id2tfidf) if tfidf != 0]
    docs = sorted(docs, key=lambda x: x[1], reverse=True)
    return docs[:topk]

def search_entity(queries, topk=1000,ignore_answers=False):
    query,answer_candidates = queries
    ignore_docid = [entitie2id[answer] for answer in answer_candidates]
    
    avgdl = sum(doc_id2token_count) / len(doc_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    with open('./ir_dump/inverted_index', 'r', encoding='utf-8') as index_file:
        for token in parsed_query:
            if token in token2pointer:
                pointer, offset = token2pointer[token]
                index_file.seek(pointer)
                index_line = index_file.read(offset-pointer).rstrip()
                postings_list = load_index_line(index_line)
                target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(entities)
    doc_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for doc_id, tf in postings_list:
            dl = doc_id2token_count[doc_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            doc_id2tfidf[doc_id] += token_tfidf
    if ignore_answers:
        for ignore_id in ignore_docid:
            doc_id2tfidf[ignore_id] = 0

    docs = [(doc_id, tfidf) for doc_id, tfidf in enumerate(doc_id2tfidf) if tfidf != 0]
    docs = sorted(docs, key=lambda x: x[1], reverse=True)
    return docs[:topk]

def get_contexts_bm25_add_answer(sentence_list,query,answer,topk=1000):
    sentence_list = sentence_list.split("。")
    inverted_index = defaultdict(list)
    sentence_id2sentence = [sentence for sentence in sentence_list]
    sentence_id2token_count = []
    for sentence_id, sentence in enumerate(sentence_list):
        tokens = parse_text(sentence)
    
        sentence_id2token_count += [len(tokens)]

        count_tokens = Counter(tokens)
        for token, count in count_tokens.items():
            inverted_index[token] += [(sentence_id, count)]

    avgdl = sum(sentence_id2token_count) / len(sentence_id2token_count)
    parsed_query = parse_text(query)
    parsed_query += parse_text(answer)
    target_posting = {}
    for token in parsed_query:
        if token in inverted_index:
            postings_list = inverted_index[token]
            target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(sentence_list)
    sentence_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for sentence_id, tf in postings_list:
            dl = sentence_id2token_count[sentence_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            sentence_id2tfidf[sentence_id] += token_tfidf

    sentences = [(sentence_id, tfidf) for sentence_id, tfidf in enumerate(sentence_id2tfidf) if tfidf != 0]
    sentences = sorted(sentences, key=lambda x: x[1], reverse=True)
    return "。".join(list(map(lambda x: sentence_id2sentence[x[0]], sentences[:topk])))

def get_contexts_bm25(sentence_list,query,topk=1000):
    sentence_list = sentence_list.split("。")
    inverted_index = defaultdict(list)
    sentence_id2sentence = [sentence for sentence in sentence_list]
    sentence_id2token_count = []
    for sentence_id, sentence in enumerate(sentence_list):
        tokens = parse_text(sentence)
    
        sentence_id2token_count += [len(tokens)]

        count_tokens = Counter(tokens)
        for token, count in count_tokens.items():
            inverted_index[token] += [(sentence_id, count)]

    avgdl = sum(sentence_id2token_count) / len(sentence_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    for token in parsed_query:
        if token in inverted_index:
            postings_list = inverted_index[token]
            target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(sentence_list)
    sentence_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for sentence_id, tf in postings_list:
            dl = sentence_id2token_count[sentence_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            sentence_id2tfidf[sentence_id] += token_tfidf

    sentences = [(sentence_id, tfidf) for sentence_id, tfidf in enumerate(sentence_id2tfidf) if tfidf != 0]
    sentences = sorted(sentences, key=lambda x: x[1], reverse=True)
    return "。".join(list(map(lambda x: sentence_id2sentence[x[0]], sentences[:topk])))


def search_entity_candidates(queries, topk=10):
    query,answer_candidates = queries
    avgdl = sum(doc_id2token_count) / len(doc_id2token_count)
    parsed_query = parse_text(query)
    target_posting = {}
    with open('./ir_dump/inverted_index', 'r', encoding='utf-8') as index_file:
        for token in parsed_query:
            if token in token2pointer:
                pointer, offset = token2pointer[token]
                index_file.seek(pointer)
                index_line = index_file.read(offset-pointer).rstrip()
                postings_list = load_index_line(index_line)
                target_posting[token] = postings_list

    # bm25スコアでor検索
    k1 = 2.0
    b = 0.75
    all_docs = len(entities)
    doc_id2tfidf = [0 for i in range(all_docs)]
    for token, postings_list in target_posting.items():
        idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
        # idfが負になる単語は一般的すぎるので無視
        idf = max(idf, 0)
        if idf == 0:
            continue
        for doc_id, tf in postings_list:
            dl = doc_id2token_count[doc_id]
            token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
            doc_id2tfidf[doc_id] += token_tfidf
    
    # candidateごとの検索
    search_results = []
    with open('./ir_dump/inverted_index', 'r', encoding='utf-8') as index_file:
        for candidate in answer_candidates:
            parsed_candidate = parse_text(candidate)
            
            candidate_target_posting = {}
            for token in parsed_candidate:
                if token in token2pointer:
                    pointer, offset = token2pointer[token]
                    index_file.seek(pointer)
                    index_line = index_file.read(offset-pointer).rstrip()
                    postings_list = load_index_line(index_line)
                    candidate_target_posting[token] = postings_list
                    
            candidate_tfidf = []
            # candidateとなる文字列が含まれるdoc_idの集合
            candidate_doc_ids = set()
            for token_position, (token, postings_list) in enumerate(candidate_target_posting.items()):
                idf = math.log2((all_docs-len(postings_list)+0.5) / (len(postings_list) + 0.5))
                # idfが負になる単語は一般的すぎるので無視
                idf = max(idf, 0)
                if idf == 0:
                    continue
                token_doc_ids = []
                for doc_id, tf in postings_list:
                    dl = doc_id2token_count[doc_id]
                    token_tfidf = idf * ((tf * (k1 + 1))/(tf + k1 * (1-b+b*(dl/avgdl))))
                    doc_id2tfidf[doc_id] += token_tfidf
                    candidate_tfidf += [(doc_id, token_tfidf)]
                    token_doc_ids += [doc_id]
                
                if token_position == 0:
                    candidate_doc_ids |= set(token_doc_ids)
                else:
                    candidate_doc_ids &= set(token_doc_ids)

            docs = [(doc_id, doc_id2tfidf[doc_id]) for doc_id in candidate_doc_ids]
            docs = sorted(docs, key=lambda x: x[1], reverse=True)
            search_results += [docs[:topk]]
            for doc_id, tfidf in candidate_tfidf:
                doc_id2tfidf[doc_id] -= tfidf
            
    return search_results

def parse_text(text):
    node = tagger.parseToNode(text)
    tokens = []
    while node:
        if node.feature.startswith(STOP_POSTAGS):
            pass
        else:
            feature = node.feature.split(",")
            if len(feature) >7:
                tokens += [feature[7].lower()]
            else:
                tokens += [node.surface.lower()]
        node = node.next
    return tokens

In [8]:
class InputExample(object):
    """A single training/test example for multiple choice"""

    def __init__(self, example_id, question, contexts, endings,ctx1,ctx2,ctx3,label=None):
        """Constructs a InputExample.
        Args:
            example_id: Unique id for the example.
            contexts: list of str. The untokenized text of the first sequence
                      (context of corresponding question).
            question: string. The untokenized text of the second sequence
                      (question).
            endings: list of str. multiple choice's options.
                     Its length must be equal to contexts' length.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.example_id = example_id
        self.question = question
        self.contexts = contexts
        self.endings = endings
        self.label = label
        self.ctx1 = ctx1
        self.ctx2 = ctx2
        self.ctx3 = ctx3


class InputFeatures(object):
    def __init__(self, example_id, choices_features1,choices_features2,choices_features3,choices_features4, label):
        self.example_id = example_id
        self.choices_features1 = [
            {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
            }
            for input_ids, input_mask, segment_ids in choices_features1
        ]
        self.choices_features2 = [
            {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
            }
            for input_ids, input_mask, segment_ids in choices_features2
        ]
        self.choices_features3 = [
            {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
            }
            for input_ids, input_mask, segment_ids in choices_features3
        ]
        self.choices_features4 = [
            {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "segment_ids": segment_ids,
            }
            for input_ids, input_mask, segment_ids in choices_features4
        ]
        self.label = label


class DataProcessor(object):
    """Base class for data converters for multiple choice data sets."""

    def get_examples(self, mode, data_dir, fname, entities_fname):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
        
        
class JaqketProcessor(DataProcessor):

    def _get_entities(self, data_dir, entities_fname):
        logger.info("LOOKING AT {} entities".format(data_dir))
        entities = dict()
        for line in self._read_json_gzip(os.path.join(data_dir, entities_fname)):
            entity = json.loads(line.strip())
            entities[entity["title"]] = entity["text"]

        return entities

    def get_examples(self, mode, data_dir, json_data, entities, num_options=20):
        """See base class."""
        logger.info("LOOKING AT {} [{}]".format(data_dir, mode))
        entities = entities
        return self._create_examples(
            json_data,
            mode,
            entities,
            num_options,
        )

    def get_labels(self):
        """See base class."""
        return [
            "0",
            "1",
            "2",
            "3",
            "4",
            "5",
            "6",
            "7",
            "8",
            "9",
            "10",
            "11",
            "12",
            "13",
            "14",
            "15",
            "16",
            "17",
            "18",
            "19",
        ]

    def _read_json(self, input_file):
        return input_file
#         with open(input_file, "r", encoding="utf-8") as fin:
#             lines = fin.readlines()
#             return lines

    def _read_json_gzip(self, input_file):
        with gzip.open(input_file, "rt", encoding="utf-8") as fin:
            lines = fin.readlines()
            return lines

    def _create_examples(self, lines, t_type, entities, num_options):
        """Creates examples for the training and dev sets."""

        examples = []
        skip_examples = 0

        # for line in tqdm.tqdm(
        #    lines, desc="read jaqket data", ascii=True, ncols=80
        # ):
        logger.info("read jaqket data: {}".format(len(lines)))
        for line in lines:
            data_raw = line

            id = data_raw["qid"]
            question = data_raw["question"].replace("_", "")  # "_" は cloze question
            options = data_raw["answer_candidates"][:num_options]  # TODO
            answer = data_raw["answer_entity"]
            ctx1 = data_raw["ctx1"]
            ctx2 = data_raw["ctx2"]
            ctx3 = data_raw["ctx3"]

            if len(options) != num_options:
                skip_examples += 1
                continue

            contexts = [entities[options[i]] for i in range(num_options)]
            truth = 0

            if len(options) == num_options:  # TODO
                examples.append(
                    InputExample(
                        example_id=id,
                        question=question,
                        contexts=contexts,
                        endings=options,
                        ctx1=ctx1,
                        ctx2=ctx2,
                        ctx3=ctx3,
                        label=truth,
                    )
                )

        if t_type == "train":
            assert len(examples) > 1
            assert examples[0].label is not None

        logger.info("len examples: {}".format(len(examples)))
        logger.info("skip examples: {}".format(skip_examples))

        return examples
    
    
def convert_examples_to_features(example):
#     tokenizer: PreTrainedTokenizer,)
    
    
    label_list = [f"{i}" for i in range(20)]
    label_map = {label: i for i, label in enumerate(label_list)}
    pad_token_segment_id=0
    pad_on_left=False
    pad_token=0
    mask_padding_with_zero=True
    max_length = 768
    
    contexts,endings,question,label,example_id,ctx_add1,ctx_add2,ctx_add3 = example
    
    ##top1_ignore-answer
    entity_text1 = "。".join([entities[doc_id2title[s[0]]] for s in ctx_add1[:1]])
    ##top5_in-answer
    entity_text2 = "。".join([entities[doc_id2title[s[0]]] for s in ctx_add2[:5]])
    ##top2_ignore-answer
    entity_text3 = "。".join([entities[doc_id2title[s[0]]] for s in ctx_add1[:2]])
    
    features = []
    context2_1 = get_contexts_bm25(entity_text1,question)
    context2_3 = get_contexts_bm25(entity_text2,question)
    ##正解エンティティの本文 + 正解候補のタイトルを除外したBM25で引っ張ってきた文章(top1)
    choices_features1 = []
    ##選択肢本文のみ
    choices_features2 = []
    ##BM25で引っ張ってきた文章のみ(top5)
    choices_features3 = []
    ##BM25で引っ張ってきた文章のみ(top5)(wikiを検索するときも並び替えの時もqueryに選択肢を追加)
    choices_features4 = []
    for ending_idx, (context, ending) in enumerate(
        zip(contexts,endings)
    ):
        input_ids, attention_mask, token_type_ids = make_bert_input1(ending,question,context2_1,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id)
        choices_features1.append((input_ids, attention_mask, token_type_ids))
        input_ids, attention_mask, token_type_ids = make_bert_input2(ending,question,context2_1,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id)
        choices_features2.append((input_ids, attention_mask, token_type_ids))
        input_ids, attention_mask, token_type_ids = make_bert_input3(ending,question,context2_3,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id)
        choices_features3.append((input_ids, attention_mask, token_type_ids))
        
        
        entity_text = "。".join([entities[doc_id2title[s[0]]] for s in ctx_add3[ending_idx][:5]])
        context2_4 = get_contexts_bm25_add_answer(entity_text,question,ending)
        input_ids, attention_mask, token_type_ids = make_bert_input3(ending,question,context2_4,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id)
        choices_features4.append((input_ids, attention_mask, token_type_ids))


    label = 0

    features.append(
        InputFeatures(
            example_id=example_id,
            choices_features1=choices_features1,
            choices_features2=choices_features2,
            choices_features3=choices_features3,
            choices_features4=choices_features4,
            label=label,
        )
    )

    return features

def make_bert_input1(ending,question,context2,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id):
    context1 = get_contexts_bm25(entities[ending],question)
    text_a = context1[:768]+ tokenizer.sep_token + context2
    text_b = question + tokenizer.sep_token + ending

    inputs = tokenizer.encode_plus(
        text_a,
        text_b,
        add_special_tokens=True,
        max_length=max_length,
        truncation="only_first",  # 常にcontextをtruncate
    )

    input_ids, token_type_ids = (
        inputs["input_ids"],
        inputs["token_type_ids"],
    )

    # The mask has 1 for real tokens and 0 for padding tokens. Only
    # real tokens are attended to.
    attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding_length = max_length - len(input_ids)
    if pad_on_left:
        input_ids = ([pad_token] * padding_length) + input_ids
        attention_mask = (
            [0 if mask_padding_with_zero else 1] * padding_length
        ) + attention_mask
        token_type_ids = (
            [pad_token_segment_id] * padding_length
        ) + token_type_ids
    else:
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + (
            [0 if mask_padding_with_zero else 1] * padding_length
        )
        token_type_ids = token_type_ids + (
            [pad_token_segment_id] * padding_length
        )
    return input_ids, attention_mask, token_type_ids

def make_bert_input2(ending,question,context2,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id):
    context1 = get_contexts_bm25(entities[ending],question)
    text_a = context1
    text_b = question + tokenizer.sep_token + ending

    inputs = tokenizer.encode_plus(
        text_a,
        text_b,
        add_special_tokens=True,
        max_length=max_length,
        truncation="only_first",  # 常にcontextをtruncate
    )

    input_ids, token_type_ids = (
        inputs["input_ids"],
        inputs["token_type_ids"],
    )

    # The mask has 1 for real tokens and 0 for padding tokens. Only
    # real tokens are attended to.
    attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding_length = max_length - len(input_ids)
    if pad_on_left:
        input_ids = ([pad_token] * padding_length) + input_ids
        attention_mask = (
            [0 if mask_padding_with_zero else 1] * padding_length
        ) + attention_mask
        token_type_ids = (
            [pad_token_segment_id] * padding_length
        ) + token_type_ids
    else:
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + (
            [0 if mask_padding_with_zero else 1] * padding_length
        )
        token_type_ids = token_type_ids + (
            [pad_token_segment_id] * padding_length
        )
    return input_ids, attention_mask, token_type_ids

def make_bert_input3(ending,question,context2,mask_padding_with_zero,max_length,pad_on_left,pad_token,pad_token_segment_id):
    context1 = context2
    text_a = context1
    text_b = question + tokenizer.sep_token + ending

    inputs = tokenizer.encode_plus(
        text_a,
        text_b,
        add_special_tokens=True,
        max_length=max_length,
        truncation="only_first",  # 常にcontextをtruncate
    )

    input_ids, token_type_ids = (
        inputs["input_ids"],
        inputs["token_type_ids"],
    )

    # The mask has 1 for real tokens and 0 for padding tokens. Only
    # real tokens are attended to.
    attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding_length = max_length - len(input_ids)
    if pad_on_left:
        input_ids = ([pad_token] * padding_length) + input_ids
        attention_mask = (
            [0 if mask_padding_with_zero else 1] * padding_length
        ) + attention_mask
        token_type_ids = (
            [pad_token_segment_id] * padding_length
        ) + token_type_ids
    else:
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + (
            [0 if mask_padding_with_zero else 1] * padding_length
        )
        token_type_ids = token_type_ids + (
            [pad_token_segment_id] * padding_length
        )
    return input_ids, attention_mask, token_type_ids


def get_qus_answers(input_file):
    with open(input_file, "r", encoding="utf-8") as fin:
        lines = fin.readlines()    
    queries = []
    answers = []
    for line in tqdm(lines):
        data_raw = json.loads(line.strip("\n"))
        question = data_raw["question"].replace("_", "")  # "_" は cloze question
        answer = data_raw['answer_candidates']
        queries += [(question,answer)]
#         answers += [answer]
    return queries

def load_index_line(index_line):
    return list(map(lambda x: tuple(map(int, x.split(':'))), index_line.split(' ')))

def read_json(x):
    with open(x, "r", encoding="utf-8") as fin:
        lines = fin.readlines()
        lines = [eval(line) for line in lines]    
    return lines

In [9]:
def select_field1(features, field):
    return [
        [choice[field] for choice in feature.choices_features1] for feature in features
    ]

def select_field2(features, field):
    return [
        [choice[field] for choice in feature.choices_features2] for feature in features
    ]

def select_field3(features, field):
    return [
        [choice[field] for choice in feature.choices_features3] for feature in features
    ]

def select_field4(features, field):
    return [
        [choice[field] for choice in feature.choices_features4] for feature in features
    ]

def get_inputs(features):
    all_input_ids1 = torch.tensor(select_field1(features, "input_ids"), dtype=torch.long)
    all_input_mask1 = torch.tensor(select_field1(features, "input_mask"), dtype=torch.long)
    all_segment_ids1 = torch.tensor(select_field1(features, "segment_ids"), dtype=torch.long)
    all_label_ids1 = torch.tensor([f.label for f in features], dtype=torch.long)  
    
    all_input_ids2 = torch.tensor(select_field2(features, "input_ids"), dtype=torch.long)
    all_input_mask2 = torch.tensor(select_field2(features, "input_mask"), dtype=torch.long)
    all_segment_ids2 = torch.tensor(select_field2(features, "segment_ids"), dtype=torch.long)
    all_label_ids2 = torch.tensor([f.label for f in features], dtype=torch.long) 
    
    all_input_ids3 = torch.tensor(select_field3(features, "input_ids"), dtype=torch.long)
    all_input_mask3 = torch.tensor(select_field3(features, "input_mask"), dtype=torch.long)
    all_segment_ids3 = torch.tensor(select_field3(features, "segment_ids"), dtype=torch.long)
    all_label_ids3 = torch.tensor([f.label for f in features], dtype=torch.long) 
    
    all_input_ids4 = torch.tensor(select_field4(features, "input_ids"), dtype=torch.long)
    all_input_mask4 = torch.tensor(select_field4(features, "input_mask"), dtype=torch.long)
    all_segment_ids4 = torch.tensor(select_field4(features, "segment_ids"), dtype=torch.long)
    all_label_ids4 = torch.tensor([f.label for f in features], dtype=torch.long) 
    
    return (all_input_ids1,all_input_mask1,all_segment_ids1,all_label_ids1),(all_input_ids2,all_input_mask2,all_segment_ids2,all_label_ids2),(all_input_ids3,all_input_mask3,all_segment_ids3,all_label_ids3),(all_input_ids4,all_input_mask4,all_segment_ids4,all_label_ids4)

In [7]:
import time
t1 = time.time()

device = "cuda:0"
root_path = "./data/"
test_path = "./data/aio_leaderboard.json"

with open('./ir_dump/doc_id2title.pickle', 'rb') as f:
    doc_id2title = pickle.load(f)
with open('./ir_dump/doc_id2token_count.pickle', 'rb') as f:
    doc_id2token_count = pickle.load(f)
with open('./ir_dump/token2pointer.pickle', 'rb') as f:
    token2pointer = pickle.load(f)

input_file = './data/all_entities.json.gz'
entitie2id = {k:v for v,k in enumerate(doc_id2title)}
with gzip.open(input_file, "rt", encoding="utf-8") as fin:
    lines = fin.readlines()
    
entities = dict()
for line in lines:
    entity = json.loads(line.strip())
    entities[entity["title"]] = entity["text"]
del lines

dev1_queries = get_qus_answers(test_path)
with Pool(cpu_count()) as p:
    dev1_results = list(tqdm(p.imap(search_entity, dev1_queries), total=len(dev1_queries)))
    dev1_results_ignore = list(tqdm(p.imap(search_entity_ignore_answer, dev1_queries), total=len(dev1_queries)))
    dev1_results_answer_q = list(tqdm(p.imap(search_entity_candidates, dev1_queries), total=len(dev1_queries)))
    
processor = JaqketProcessor()
label_list = processor.get_labels()
num_labels = len(label_list)
task_name = "jaqket"
MODEL_CLASSES = {"bert": (BertConfig, BertForMultipleChoice, BertJapaneseTokenizer)}
config_class, model_class, tokenizer_class = MODEL_CLASSES["bert"]
path_name = "cl-tohoku/bert-base-japanese-v2"
config = config_class.from_pretrained(path_name,num_labels=num_labels,finetuning_task=task_name,)
tokenizer = tokenizer_class.from_pretrained(path_name)
model = model_class.from_pretrained(path_name,config=config)

param = model.bert.embeddings.position_embeddings.weight.data
param2 = F.interpolate(param.view(1,1,512,768),size=(768,768),mode='bicubic',align_corners=True)[0,0]
model.bert.embeddings.position_embeddings.weight = nn.Parameter(param2)

model1 = copy.deepcopy(model)
model1.load_state_dict(torch.load("./params/mix-model-alldata.pt",map_location="cpu"))
model1 = model1.to(device)
model1 = amp.initialize(model1, opt_level="O2",verbosity=0)

model2 = copy.deepcopy(model)
model2.load_state_dict(torch.load("./params/title_only-model-alldata.pt",map_location="cpu"))
model2 = model2.to(device)
model2 = amp.initialize(model2, opt_level="O2",verbosity=0)

model3 = copy.deepcopy(model)
model3.load_state_dict(torch.load("./params/question_only-model-alldata.pt",map_location="cpu"))
model3 = model3.to(device)
model3 = amp.initialize(model3, opt_level="O2",verbosity=0)

model4 = copy.deepcopy(model)
model4.load_state_dict(torch.load("./params/title_only-model-alldata.pt",map_location="cpu"))
model4 = model4.to(device)
model4 = amp.initialize(model4, opt_level="O2",verbosity=0)

test_data = read_json(test_path)
for data,ctx1,ctx2,ctx3 in zip(test_data,dev1_results,dev1_results_ignore,dev1_results_answer_q):
    data["ctx1"] = ctx2
    data["ctx2"] = ctx1
    data["ctx3"] = ctx3


test_ex  = processor.get_examples("dev",root_path,test_data,entities)
test_values = [(ex.contexts,ex.endings,ex.question,ex.label,ex.example_id,ex.ctx1,ex.ctx2,ex.ctx3) for ex in test_ex]
with Pool(multiprocessing.cpu_count()) as p:
    test_features = list(tqdm(p.imap(convert_examples_to_features,test_values), total=len(test_values)))
    test_features = [f[0] for f in test_features]
    
batch1,batch2,batch3,batch4 = get_inputs(test_features)

batch_size = 4
input_batch1 = batch1[0].split(batch_size)
att_batch1 = batch1[1].split(batch_size)
typeid_batch1 = batch1[2].split(batch_size)
label_batch1 = batch1[3].split(batch_size)

input_batch2 = batch2[0].split(batch_size)
att_batch2 = batch2[1].split(batch_size)
typeid_batch2 = batch2[2].split(batch_size)
label_batch2 = batch2[3].split(batch_size)

input_batch3 = batch3[0].split(batch_size)
att_batch3 = batch3[1].split(batch_size)
typeid_batch3 = batch3[2].split(batch_size)
label_batch3 = batch3[3].split(batch_size)

input_batch4 = batch4[0].split(batch_size)
att_batch4 = batch4[1].split(batch_size)
typeid_batch4 = batch4[2].split(batch_size)
label_batch4 = batch4[3].split(batch_size)

position_ids = torch.LongTensor([i for i in range(768)]).to(device)
with torch.no_grad():
    pred1 = []
    for batch in tqdm(zip(input_batch1,att_batch1,typeid_batch1,label_batch1),total=len(input_batch1)):
        model1.eval()
        batch = tuple(t.to(device) for t in batch)
        inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "position_ids":position_ids,
                    "labels": batch[3],
                }
        outputs = model1(**inputs)
        pred = outputs["logits"].cpu().numpy()
        pred1.extend(pred)
        
with torch.no_grad():
    pred2 = []
    for batch in tqdm(zip(input_batch2,att_batch2,typeid_batch2,label_batch2),total=len(input_batch2)):
        model2.eval()
        batch = tuple(t.to(device) for t in batch)
        inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "position_ids":position_ids,
                    "labels": batch[3],
                }
        outputs = model2(**inputs)
        pred = outputs["logits"].cpu().numpy()
        pred2.extend(pred)
        
with torch.no_grad():
    pred3 = []
    for batch in tqdm(zip(input_batch3,att_batch3,typeid_batch3,label_batch3),total=len(input_batch3)):
        model3.eval()
        batch = tuple(t.to(device) for t in batch)
        inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "position_ids":position_ids,
                    "labels": batch[3],
                }
        outputs = model3(**inputs)
        pred = outputs["logits"].cpu().numpy()
        pred3.extend(pred)
with torch.no_grad():
    pred4 = []
    for batch in tqdm(zip(input_batch4,att_batch4,typeid_batch4,label_batch4),total=len(input_batch4)):
        model4.eval()
        batch = tuple(t.to(device) for t in batch)
        inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "token_type_ids": batch[2],
                    "position_ids":position_ids,
                    "labels": batch[3],
                }
        outputs = model4(**inputs)
        pred = outputs["logits"].cpu().numpy()
        pred4.extend(pred)
        
answers = list_fm_jsonl(test_path)    # List[Candidate]

pred1 = np.array(pred1)
pred2 = np.array(pred2)
pred3 = np.array(pred3)
pred4 = np.array(pred4)
pred_labels = ((pred1+pred2+pred3+pred4)/4).argmax(axis=-1)


fo = open("./submission.json", 'w')

for answer_info, pred_label in zip(answers, pred_labels):
    result = {
              'qid': answer_info['qid'],
              'answer_entity': answer_info['answer_candidates'][pred_label]
             }

    json.dump(result, fo, ensure_ascii=False)
    fo.write('\n')

fo.close()

t2 = time.time() - t1

100%|██████████| 1000/1000 [00:00<00:00, 106381.52it/s]
100%|██████████| 1000/1000 [01:14<00:00, 13.38it/s]
100%|██████████| 1000/1000 [01:25<00:00, 11.73it/s]
100%|██████████| 1000/1000 [02:45<00:00,  6.03it/s]
Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-v2 were not used when initializing BertForMultipleChoice: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a mode

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.8.6/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3427, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-7-a9cea1c0e0b0>", line 117, in <module>
    pred = outputs["logits"].cpu().numpy()
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.8.6/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 2054, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.8.6/lib/python3.8/site-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/root/.py

TypeError: object of type 'NoneType' has no len()

In [77]:
# pred1 = np.array(pred1)
# pred2 = np.array(pred2)
# pred3 = np.array(pred3)
# pred4 = np.array(pred4)

## BEST LB
w1,w2,w3,w4  = np.array([0.30265671, 0.28305784, 0.17957585, 0.2347096 ])
pred1_1,pred1_2,pred1_3,pred1_4 = pred1.copy(),pred2.copy(),pred3.copy(),pred4.copy()

pred1_1 *= w1
pred1_2 *= w2
pred1_3 *= w3
pred1_4 *= w4

pred_labels = (pred1_1+pred1_2+pred1_3+pred1_4).argmax(axis=1)

In [78]:
answers = list_fm_jsonl(test_path)    # List[Candidate]

# pred1 = np.array(pred1)
# pred2 = np.array(pred2)
# pred3 = np.array(pred3)
# pred4 = np.array(pred4)
# pred_labels = ((pred1+pred2+pred3+pred4)/4).argmax(axis=-1)


fo = open("./submission-searchv3-4model-weight-ensemble-softmax-t03.json", 'w')

for answer_info, pred_label in zip(answers, pred_labels):
    result = {
              'qid': answer_info['qid'],
              'answer_entity': answer_info['answer_candidates'][pred_label]
             }

    json.dump(result, fo, ensure_ascii=False)
    fo.write('\n')

fo.close()