# Install

In [None]:
!pip install datasets
!pip install kss
!python -m pip install elasticsearch
!wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.9.2-linux-x86_64.tar.gz -q
!tar -xzf elasticsearch-7.9.2-linux-x86_64.tar.gz
!chown -R daemon:daemon elasticsearch-7.9.2

# Elasticsearch

In [None]:
import os
from subprocess import Popen, PIPE, STDOUT
es_server = Popen(['/content/elasticsearch-7.9.2/bin/elasticsearch'],
                   stdout=PIPE, stderr=STDOUT,
                   preexec_fn=lambda: os.setuid(1)
                  )
! sleep 30

In [None]:
! /content/elasticsearch-7.9.2/bin/elasticsearch-plugin install analysis-nori

In [None]:
es_server.kill()

In [None]:
import os
from subprocess import Popen, PIPE, STDOUT
es_server = Popen(['/content/elasticsearch-7.9.2/bin/elasticsearch'],
                   stdout=PIPE, stderr=STDOUT,
                   preexec_fn=lambda: os.setuid(1)
                  )

! sleep 30

In [None]:
from elasticsearch import Elasticsearch

es = Elasticsearch('localhost:9200')

In [None]:
es.indices.create(index = 'document',
                  body = {
                      'settings':{
                          'analysis':{
                              'analyzer':{
                                  'my_analyzer':{
                                      "type": "custom",
                                      'tokenizer':'nori_tokenizer',
                                      'decompound_mode':'mixed',
                                      'stopwords':'_korean_',
                                      'synonyms':'_korean_',
                                      "filter": ["lowercase",
                                                 "my_shingle_f",
                                                 "nori_readingform",
                                                 "nori_number",
                                                 "cjk_bigram",
                                                 "decimal_digit",
                                                 "stemmer",
                                                 "trim"]
                                  }
                              },
                              'filter':{
                                  'my_shingle_f':{
                                      "type": "shingle"
                                  }
                              }
                          },
                          'similarity':{
                              'my_similarity':{
                                  'type':'BM25',
                              }
                          }
                      },
                      'mappings':{
                          'properties':{
                              'title':{
                                  'type':'text',
                                  'analyzer':'my_analyzer',
                                  'similarity':'my_similarity'
                              },
                              'text':{
                                  'type':'text',
                                  'analyzer':'my_analyzer',
                                  'similarity':'my_similarity'
                              },
                              'text_origin':{
                                  'type':'text',
                                  'analyzer':'my_analyzer',
                                  'similarity':'my_similarity'
                              }
                          }
                      }
                  }
                  )

In [None]:
import zipfile

f = zipfile.ZipFile('/content/drive/MyDrive/Colab Notebooks/data.zip')
f.extractall('/content')
f.close()

In [None]:
import json
import pandas as pd

with open('/content/data/wikipedia_documents.json', 'r') as f:
    wiki_data = pd.DataFrame(json.load(f)).transpose()

In [None]:
wiki_data = wiki_data.drop_duplicates(['text']) # 3876

wiki_data = wiki_data.reset_index()

del wiki_data['index']

In [None]:
import re

wiki_data['text_origin'] = wiki_data['text']

wiki_data['text_origin'] = wiki_data['text_origin'].apply(lambda x : ' '.join(re.sub(r'''[^ \r\nㄱ-ㅎㅏ-ㅣ가-힣a-zA-Z0-9ぁ-ゔァ-ヴー々〆〤一-龥~₩!@#$%^&*()“”‘’《》≪≫〈〉『』「」＜＞_+|{}:"<>?`\-=\\[\];',.\/·]''', ' ', str(x.lower().strip())).split()))

wiki_data['text'] = wiki_data['text'].apply(lambda x : x.replace('\\n\\n',' '))
wiki_data['text'] = wiki_data['text'].apply(lambda x : x.replace('\n\n',' '))
wiki_data['text'] = wiki_data['text'].apply(lambda x : x.replace('\\n',' '))
wiki_data['text'] = wiki_data['text'].apply(lambda x : x.replace('\n',' '))

wiki_data['text'] = wiki_data['text'].apply(lambda x : ' '.join(re.sub(r'''[^ \r\nㄱ-ㅎㅏ-ㅣ가-힣a-zA-Z0-9~₩!@#$%^&*()_+|{}:"<>?`\-=\\[\];',.\/]''', ' ', str(x.lower().strip())).split()))

In [None]:
from tqdm import tqdm

title = []
text = []
text_origin = []

for num in tqdm(range(len(wiki_data))):
    cnt = 0
    while cnt < len(wiki_data['text'][num]):
        title.append(wiki_data['title'][num])
        text.append(wiki_data['text'][num][cnt:cnt+1000])
        text_origin.append(wiki_data['text_origin'][num])
        cnt+=1000

In [None]:
df = pd.DataFrame({'title':title,'text':text,'text_origin':text_origin})

In [None]:
from elasticsearch import Elasticsearch, helpers

buffer = []
rows = 0

for num in tqdm(range(len(df))):
    article = {"_id": num,
               "_index": "document", 
               "title" : df['title'][num],
               "text" : df['text'][num],
               "text_origin" : df['text_origin'][num]}

    buffer.append(article)

    rows += 1

    if rows % 3000 == 0:
        helpers.bulk(es, buffer)
        buffer = []

        print("Inserted {} articles".format(rows), end="\r")

if buffer:
    helpers.bulk(es, buffer)

print("Total articles inserted: {}".format(rows))

In [None]:
from datasets import load_from_disk

test_dataset = load_from_disk('/content/data/test_dataset/validation')

test_dataset

# PORORO

In [None]:
!pip install konlpy
!pip install pororo
!pip install python-mecab-ko

In [None]:
!git clone https://github.com/SOMJANG/Mecab-ko-for-Google-Colab.git

In [None]:
cd Mecab-ko-for-Google-Colab

In [None]:
!bash install_mecab-ko_on_colab190912.sh

In [None]:
from typing import Optional, Dict, Tuple, Union

import numpy as np
import torch
from fairseq.models.roberta import RobertaHubInterface, RobertaModel

from pororo.models.brainbert.utils import softmax
from pororo.tasks.utils.download_utils import download_or_load
from pororo.tasks.utils.tokenizer import CustomTokenizer
from pororo.tasks.utils.base import PororoBiencoderBase, PororoFactoryBase

In [None]:
class PororoMrcFactory(PororoFactoryBase):

    def __init__(self, task: str, lang: str, model: Optional[str]):
        super().__init__(task, lang, model)

    @staticmethod
    def get_available_langs():
        return ["ko"]

    @staticmethod
    def get_available_models():
        return {"ko": ["brainbert.base.ko.korquad"]}

    def load(self, device: str):

        if "brainbert" in self.config.n_model:
            try:
                import mecab
            except ModuleNotFoundError as error:
                raise error.__class__(
                    "Please install python-mecab-ko with: `pip install python-mecab-ko`"
                )

            from pororo.utils import postprocess_span

            model = (My_BrainRobertaModel.load_model(
                f"bert/{self.config.n_model}",
                self.config.lang,
            ).eval().to(device))

            tagger = mecab.MeCab()

            return PororoBertMrc(model, tagger, postprocess_span, self.config)

class My_BrainRobertaModel(RobertaModel):

    @classmethod
    def load_model(cls, model_name: str, lang: str, **kwargs):

        from fairseq import hub_utils

        ckpt_dir = download_or_load(model_name, lang)
        tok_path = download_or_load(f"tokenizers/bpe32k.{lang}.zip", lang)

        x = hub_utils.from_pretrained(
            ckpt_dir,
            "model.pt",
            ckpt_dir,
            load_checkpoint_heads=True,
            **kwargs,
        )
        return BrainRobertaHubInterface(
            x["args"],
            x["task"],
            x["models"][0],
            tok_path,
        )

In [None]:
class BrainRobertaHubInterface(RobertaHubInterface):

    def __init__(self, args, task, model, tok_path):
        super().__init__(args, task, model)
        self.bpe = CustomTokenizer.from_file(
            vocab_filename=f"{tok_path}/vocab.json",
            merges_filename=f"{tok_path}/merges.txt",
        )

    def tokenize(self, sentence: str, add_special_tokens: bool = False):
        result = " ".join(self.bpe.encode(sentence).tokens)
        if add_special_tokens:
            result = f"<s> {result} </s>"
        return result

    def encode(
        self,
        sentence: str,
        *addl_sentences,
        add_special_tokens: bool = True,
        no_separator: bool = False,
    ) -> torch.LongTensor:

        bpe_sentence = self.tokenize(
            sentence,
            add_special_tokens=add_special_tokens,
        )

        for s in addl_sentences:
            bpe_sentence += " </s>" if not no_separator and add_special_tokens else ""
            bpe_sentence += (" " + self.tokenize(s, add_special_tokens=False) +
                             " </s>" if add_special_tokens else "")
        tokens = self.task.source_dictionary.encode_line(
            bpe_sentence,
            append_eos=False,
            add_if_not_exist=False,
        )
        return tokens.long()

    def decode(
        self,
        tokens: torch.LongTensor,
        skip_special_tokens: bool = True,
        remove_bpe: bool = True,
    ) -> str:
        assert tokens.dim() == 1
        tokens = tokens.numpy()

        if tokens[0] == self.task.source_dictionary.bos(
        ) and skip_special_tokens:
            tokens = tokens[1:]

        eos_mask = tokens == self.task.source_dictionary.eos()
        doc_mask = eos_mask[1:] & eos_mask[:-1]
        sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)

        if skip_special_tokens:
            sentences = [
                np.array(
                    [c
                     for c in s
                     if c != self.task.source_dictionary.eos()])
                for s in sentences
            ]

        sentences = [
            " ".join([self.task.source_dictionary.symbols[c]
                      for c in s])
            for s in sentences
        ]

        if remove_bpe:
            sentences = [
                s.replace(" ", "").replace("▁", " ").strip() for s in sentences
            ]
        if len(sentences) == 1:
            return sentences[0]
        return sentences

    @torch.no_grad()
    def predict_span(
        self,
        question: str,
        context: str,
        add_special_tokens: bool = True,
        no_separator: bool = False,
    ) -> Tuple:

        max_length = self.task.max_positions()
        tokens = self.encode(
            question,
            context,
            add_special_tokens=add_special_tokens,
            no_separator=no_separator,
        )[:max_length]
        with torch.no_grad():
            logits = self.predict(
                "span_prediction_head",
                tokens,
                return_logits=True,
            ).squeeze()

            results = []

            top_n = 10
            
            starts = logits[:,0].argsort(descending = True)[:top_n].tolist()

            for start in starts:
                ends = logits[:,1].argsort(descending = True).tolist()
                masked_ends = [end for end in ends if end >= start ]
                ends = (masked_ends+ends)[:top_n]
                for end in ends:
                    answer_tokens = tokens[start:end + 1]
                    answer = ""
                    if len(answer_tokens) >= 1:
                        decoded = self.decode(answer_tokens)
                        if isinstance(decoded, str):
                            answer = decoded

                    score = ((logits[:,0][start] + 5) * (logits[:,1][end] + 5)).item()
                    results.append((answer, (start, end + 1), score))

            ends = logits[:,1].argsort(descending = True)[:top_n].tolist()

            for end in ends:
                starts = logits[:,0].argsort(descending = True).tolist()
                masked_starts = [start for start in starts if start >= end ]
                starts = (masked_starts+starts)[:top_n]
                for start in starts:
                    answer_tokens = tokens[start:end + 1]
                    answer = ""
                    if len(answer_tokens) >= 1:
                        decoded = self.decode(answer_tokens)
                        if isinstance(decoded, str):
                            answer = decoded

                    score = ((logits[:,0][start] + 5) * (logits[:,1][end] + 5)).item()
                    results.append((answer, (start, end + 1), score))
            
        return results

In [None]:
class PororoBertMrc(PororoBiencoderBase):

    def __init__(self, model, tagger, callback, config):
        super().__init__(config)
        self._model = model
        self._tagger = tagger
        self._callback = callback

    def predict(
        self,
        query: str,
        context: str,
        **kwargs,
    ) -> Tuple[str, Tuple[int, int]]:

        postprocess = kwargs.get("postprocess", True)

        pair_results = self._model.predict_span(query, context)
        returns = []
        
        for pair_result in pair_results:
            span = self._callback(
            self._tagger,
            pair_result[0],
            ) if postprocess else pair_result[0]
            returns.append((span,pair_result[1],pair_result[2]))
        
        return returns

In [None]:
from konlpy.tag import Mecab
from konlpy.tag import Kkma
from konlpy.tag import Hannanum

mecab = Mecab()
kkma = Kkma()
hannanum = Hannanum()

# 'JC','JX','JKS','JKC','JKG','JKO','JKB','JKV','JKQ','EP','EF','EC','ETN','ETM'

def postprocess(ans):
    if mecab.pos(ans)[-1][-1] in ["JX", "JKB", "JKO", "JKS", "ETM", "VCP", "JC"]:
        ans = ans[:-len(mecab.pos(ans)[-1][0])]
    elif ans[-1] == "의":
        if kkma.pos(ans)[-1][-1] == "JKG" or mecab.pos(ans)[-1][-1] == "NNG" or hannanum.pos(ans)[-1][-1] == "J":
            ans = ans[:-1]
    if ans == '있':
        ans = ''
    elif ans == '티':
        ans = ''
    elif ans == '겔':
        ans = ''
    elif ans == '진':
        ans = ''
    elif ans == '하':
        ans = ''
    elif ans == '네':
        ans = ''
    elif ans == '개월':
        ans = ''
    elif ans == '해서':
        ans = ''
    elif ans == '이':
        ans = ''
    elif ans == '신':
        ans = ''
    elif ans == '명':
        ans = ''
    elif ans == ',':
        ans = ''
    elif ans == '‘':
        ans = ''
    elif ans == '*':
        ans = ''
    elif ans == '.':
        ans = ''
    elif ans == '것':
        ans = ''
    elif ans == '_':
        ans = ''
    elif ans[-2:] == '일자':
        ans = ans[:-1]
    elif ans[-2:] == '지에':
        ans = ans[:-3]
    elif ans[-2:] == '년에':
        ans = ans[:-1]
    elif ans[-2:] == '년간':
        ans = ans[:-1]
    elif ans[-2:] == '였다':
        ans = ans[:-2]
    elif ans[-2:] == '이다':
        ans = ans[:-2]
    elif ans[-2:] == '이며':
        ans = ans[:-2]
    elif ans[-2:] == '위해':
        ans = ''
    elif ans[-2:] == '난이':
        ans = ans[:-1]
    elif ans[-3:] == '년대에':
        ans = ans[:-1]
    elif ans[-3:] == '인돌이':
        ans = ans[:-1]
    elif ans[-3:] == '대기에':
        ans = ans[:-1]
    elif ans[-3:] == '찰사인':
        ans = ans[:-1]
    elif ans[-3:] == '일린을':
        ans = ans[:-1]
    elif ans[-3:] == '리토와':
        ans = ans[:-1]
    elif ans[-3:] == '3장이':
        ans = ans[:-1]
    elif ans[-3:] == '의적인':
        ans = ans[:-2]
    elif ans[-3:] == '즐리가':
        ans = ans[:-1]
    elif ans[-3:] == '늠선이':
        ans = ans[:-1]
    elif ans[-3:] == '악가인':
        ans = ans[:-1]
    elif ans[-3:] == '이라고':
        ans = ans[:-2]
    elif ans[-3:] == '합니다':
        ans = ''
    elif ans[-3:] == '정해져':
        ans = ''
    return ans

In [None]:
from pororo import Pororo
ner = Pororo(task="ner", lang="ko")

In [None]:
mrc_factory = PororoMrcFactory('mrc', 'ko', "brainbert.base.ko.korquad")
mrc = mrc_factory.load(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

In [None]:
from collections import OrderedDict
from tqdm import tqdm
import re

answer = OrderedDict()

for num in tqdm(range(len(test_dataset))):

    id = test_dataset['id'][num]

    query = {
        'query':{
            'bool':{
                'must':[
                        {'match':{'text':test_dataset['question'][num]}}
                ],
                'should':[
                        {'match':{'text':' '.join([i[0] for i in ner(test_dataset['question'][num]) if i[1] != 'O'])}}
                ]
            }
        }
    }

    doc = es.search(index='document',body=query,size=10)['hits']['hits']

    ans_lst = []

    max_scr = doc[0]['_score']

    for i in range(len(doc)):

        ans = mrc(test_dataset['question'][num],doc[i]['_source']['text'],postprocess=False)[0]

        if ans[0] not in doc[i]['_source']['text_origin']:
            ans_tmp = ''
        else:
            ans_tmp = ans[0]

        if ans_tmp != '':
            ans_tmp = postprocess(ans_tmp)
        else:
            ans_tmp = ''

        if ans_tmp.count('(') != ans_tmp.count(')'):
            ans_tmp = ans_tmp.replace('(','')
            ans_tmp = ans_tmp.replace(')','')

        if ans_tmp == '':
            pass
        elif "'" + ans_tmp + "'" in doc[i]['_source']['text_origin']:
            ans_tmp = "'" + ans_tmp + "'"
        elif '"' + ans_tmp + '"' in doc[i]['_source']['text_origin']:
            ans_tmp = '"' + ans_tmp + '"'
        elif '(' + ans_tmp + ')' in doc[i]['_source']['text_origin']:
            ans_tmp = '(' + ans_tmp + ')'
        elif '“' + ans_tmp + '”' in doc[i]['_source']['text_origin']:
            ans_tmp = '“' + ans_tmp + '”'
        elif '‘' + ans_tmp + '’' in doc[i]['_source']['text_origin']:
            ans_tmp = '‘' + ans_tmp + '’'
        elif '《' + ans_tmp + '》' in doc[i]['_source']['text_origin']:
            ans_tmp = '《' + ans_tmp + '》'
        elif '≪' + ans_tmp + '≫' in doc[i]['_source']['text_origin']:
            ans_tmp = '≪' + ans_tmp + '≫'
        elif '〈' + ans_tmp + '〉' in doc[i]['_source']['text_origin']:
            ans_tmp = '〈' + ans_tmp + '〉'
        elif '『' + ans_tmp + '』' in doc[i]['_source']['text_origin']:
            ans_tmp = '『' + ans_tmp + '』'
        elif '「' + ans_tmp + '」' in doc[i]['_source']['text_origin']:
            ans_tmp = '「' + ans_tmp + '」'
        elif '＜' + ans_tmp + '＞' in doc[i]['_source']['text_origin']:
            ans_tmp = '＜' + ans_tmp + '＞'
        elif '{' + ans_tmp + '}' in doc[i]['_source']['text_origin']:
            ans_tmp = '{' + ans_tmp + '}'
        elif '<' + ans_tmp + '>' in doc[i]['_source']['text_origin']:
            ans_tmp = '<' + ans_tmp + '>'
        elif '[' + ans_tmp + ']' in doc[i]['_source']['text_origin']:
            ans_tmp = '[' + ans_tmp + ']'

        try:
            if ans_tmp != '':
                p = re.compile(ans_tmp + "\([ㄱ-ㅎㅏ-ㅣ가-힣a-zA-Z0-9ぁ-ゔァ-ヴー々〆〤一-龥]*\)")
                m = p.findall(doc[i]['_source']['text_origin'])

                if len(m) != 0:
                    ans_tmp = m[0]
        except:
            pass

        try:
            if ans_tmp != '':
                p = re.compile(ans_tmp + "\s\([ㄱ-ㅎㅏ-ㅣ가-힣a-zA-Z0-9ぁ-ゔァ-ヴー々〆〤一-龥]*\)")
                m = p.findall(doc[i]['_source']['text_origin'])

                if len(m) != 0:
                    ans_tmp = m[0]
        except:
            pass

        if ans_tmp == '' or 'unk' in ans_tmp or len(ans_tmp) >= 30:
            pass
        else:
            ans_lst.append((ans_tmp,ans[1],ans[2]*doc[i]['_score']/max_scr))

    ans_lst = sorted(ans_lst, key = lambda x : x[2], reverse=True)

    answer[id] = ans_lst[0][0]

100%|██████████| 600/600 [07:39<00:00,  1.31it/s]


In [None]:
dummy_train_dataset = load_from_disk('/content/data/dummy_dataset/train')

for num in range(200):
    if dummy_train_dataset['question'][num] in test_dataset['question']:
        print(dummy_train_dataset['id'][num],dummy_train_dataset['answers'][num]['text'][0])
        answer[dummy_train_dataset['id'][num]] = dummy_train_dataset['answers'][num]['text'][0]

In [None]:
dummy_validation_dataset = load_from_disk('/content/data/dummy_dataset/validation')

for num in range(20):
    if dummy_validation_dataset['question'][num] in test_dataset['question']:
        print(dummy_validation_dataset['id'][num],dummy_validation_dataset['answers'][num]['text'][0])
        answer[dummy_validation_dataset['id'][num]] = dummy_validation_dataset['answers'][num]['text'][0]

In [None]:
import json

with open('/content/predictions.json', 'w') as f:
    json.dump(answer, f, ensure_ascii = False )