In [None]:
# EDA

import sys, sqlite3, os
import pandas as pd
import pprint
from collections import namedtuple
import ipadic
import MeCab
from math import ceil
import random

# if you want to use chikkarpy library, remove comment
#from chikkarpy import Chikkar
#from chikkarpy.dictionarylib import Dictionary

class EDA:
    def __init__(self):
        pass

    def get_synonyms(self, word):
        pass

    # 分かち書き
    def wakati_text(self, text, hinshi=['名詞', '動詞']):
        """
        分かち書き後のリストと同義語検索用の単語の原型リストを返す

        Args:
            text (str): 分かち書きする文章
            hinshi (list, optional): 原型を取得する品詞. Defaults to ['名詞', '動詞'].

        Returns:
            list: 分かち書き後の単語、指定の品詞に絞った単語の原型リスト
        """
        m = MeCab.Tagger(ipadic.MECAB_ARGS)
        p = m.parse(text)
        p_split = [i.split("\t") for i in p.split("\n")][:-2]

        # 原文の分かち書き
        raw_words = [x[0] for x in p_split]

        # 同義語検索用の単語の原型リスト（品詞を絞る）
        second_half = [x[1].split(",") for x in p_split]
        original_words = [x[6] if x[0] in hinshi else "" for x in second_half]
        original_words = ["" if word in self.stop_words else word for word in original_words]

        return raw_words, original_words

    def synonym_replacement(self, raw_words, original_words, n):
        """
        文章の単語をランダムにn個同義語で置き換える

        Args:
            raw_words (list): 分かち書き済みのリスト
            original_words (list): 単語の原型のリスト
            n (int): 単語を置き換える件数

        Returns:
            list
        """
        new_words = raw_words.copy()

        # 同義語に置き換える単語をランダムに決める
        original_words_idx = [i for i, x in enumerate(original_words) if x != ""]
        random.shuffle(original_words_idx)

        # 指定の件数になるまで置き換え
        num_replaced = 0
        for idx in original_words_idx:
            raw_word = raw_words[idx]
            random_word = original_words[idx]
            synonyms = self.get_synonyms(random_word)
            if len(synonyms) >= 1:
                synonym = random.choice(synonyms)
                new_words = [synonym if word == raw_word else word for word in new_words]
                num_replaced += 1
            if num_replaced >= n:
                break

        return new_words

    def random_insertion(self, raw_words, original_words, n):
        """
        文章の中にランダムに単語をn個挿入

        Args:
            raw_words (list): 分かち書き済みのリスト
            original_words (list): 単語の原型のリスト
            n (int): 挿入する単語数

        Returns:
            list
        """
        new_words = raw_words.copy()
        for _ in range(n):
            self.add_word(new_words, original_words)
        return new_words

    def add_word(self, new_words, original_words):
        synonyms = []
        counter = 0
        insert_word_original = [x for x in original_words if x]
        while len(synonyms) < 1:
            random_word = insert_word_original[random.randint(0, len(insert_word_original)-1)]
            synonyms = self.get_synonyms(random_word)
            counter += 1
            if counter >= 10:
                return
        random_synonym = synonyms[0]
        random_idx = random.randint(0, len(new_words)-1)
        new_words.insert(random_idx, random_synonym)


    def random_deletion(self, words, p):
        """
        文章の各単語を確率pで削除する

        Args:
            words (list): 分かち書き済みのリスト
            p (float): 削除する確率

        Returns:
            list
        """
        # 1文字しかなければ削除しない
        if len(words) == 1:
            return words

        # 確率pでランダムに削除
        new_words = []
        for word in words:
            r = random.uniform(0, 1)
            if r > p:
                new_words.append(word)

        # 全て削除してしまったら、ランダムに1つ単語を返す
        if len(new_words) == 0:
            rand_int = random.randint(0, len(words)-1)
            return [words[rand_int]]

        return new_words

    def random_swap(self, words, n):
        """
        文章の単語の場所をn回入れ替える

        Args:
            words (list): 分かち書き済みのリスト
            n (int): 入れ替える回数

        Returns:
            list
        """
        new_words = words.copy()
        for _ in range(n):
            nwords = self.swap_word(new_words)

        return new_words

    def swap_word(self, new_words):
        random_idx_1 = random.randint(0, len(new_words)-1)
        random_idx_2 = random_idx_1
        counter = 0
        while random_idx_2 == random_idx_1:
            random_idx_2 = random.randint(0, len(new_words)-1)
            counter += 1
            if counter > 3:
                return new_words
        new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]

        return new_words

    # 各手法をまとめて実行
    def eda(self, sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
        """
        EDAの各手法をまとめて実行して、指定の件数分のEDA済み類似文章+原文をリストで返す。
        原文はリストの最後に挿入される。

        Args:
            sentence (str): EDAを実行する文章（原文）
            alpha_sr (float, optional): synonym_replacementのalpha. Defaults to 0.1.
            alpha_ri (float, optional): random_insertionのalpha. Defaults to 0.1.
            alpha_rs (float, optional): random_swapのalpha. Defaults to 0.1.
            p_rd (float, optional): random_deletionのalpha. Defaults to 0.1.
            num_aug (int, optional): EDAで作成する文章数. Defaults to 9.

        Returns:
            [type]: [description]
        """

        # 分かち書き
        raw_words, original_words = self.wakati_text(sentence)
        num_words = len(raw_words)

        augmented_sentences = []
        techniques = ceil(alpha_sr) + ceil(alpha_ri) + ceil(alpha_rs) + ceil(p_rd)
        if techniques == 0:
            return

        num_new_per_technique = int(num_aug/techniques)+1

        #ランダムに単語を同義語でn個置き換える
        if (alpha_sr > 0):
            n_sr = max(1, int(alpha_sr*num_words))
            for _ in range(num_new_per_technique):
                a_words = self.synonym_replacement(raw_words,original_words ,n_sr)
                augmented_sentences.append(''.join(a_words))

        #ランダムに文中に出現する単語の同義語をn個挿入
        if (alpha_ri > 0):
            n_ri = max(1, int(alpha_ri*num_words))
            for _ in range(num_new_per_technique):
                a_words = self.random_insertion(raw_words,original_words, n_ri)
                augmented_sentences.append(''.join(a_words))

        #ランダムに単語の場所をn回入れ替える
        if (alpha_rs > 0):
            n_rs = max(1, int(alpha_rs*num_words))
            for _ in range(num_new_per_technique):
                a_words = self.random_swap(raw_words, n_rs)
                augmented_sentences.append(''.join(a_words))

        #ランダムに単語を確率pで削除する
        if (p_rd > 0):
            for _ in range(num_new_per_technique):
                a_words = self.random_deletion(raw_words, p_rd)
                augmented_sentences.append(''.join(a_words))

        #必要な文章の数だけランダムに抽出
        random.shuffle(augmented_sentences)
        augmented_sentences = augmented_sentences[:num_aug]

        #原文もリストに加える
        augmented_sentences.append(sentence)

        return augmented_sentences


class WordnetEDA(EDA):

    def __init__(self):
        ROOT: str = 'Z:\\' if os.name == 'nt' else '/Z'
        RESOURCE_ROOT: str = os.path.join(ROOT, 'ae_share02', '128_TotalLifePlanningSalesTraining', 'smartphone_nlp', 'dev', 'resources')
        WORDNET_PATH: str = os.path.join(RESOURCE_ROOT, "synonym", "wn_jpn.db")
        STOPWORDS_PATH: str = os.path.join(RESOURCE_ROOT, "stopwords", "slothlib_stopwords.txt")
        # synset(概念ID)とlemma(単語)の組み合わせDataFrameの作成
        conn = sqlite3.connect(WORDNET_PATH)
        q = 'SELECT synset,lemma FROM sense,word USING (wordid) WHERE sense.lang="jpn"'
        self.sense_word = pd.read_sql(q, conn)
        # stop words
        self.stop_words = pd.read_csv(STOPWORDS_PATH, header=None)[0].to_list()

    # 類義語をリストにして返す
    def get_synonyms(self, word):
        """
        入力した単語の類似語を日本語wordnetから検索してリスト化

        Args:
            word (str): 類似語を検索する単語

        Returns:
            list: 入力した単語の類似語
        """
        synsets = self.sense_word.loc[self.sense_word.lemma == word, "synset"]
        synset_words = set(self.sense_word.loc[self.sense_word.synset.isin(synsets), "lemma"])

        if word in synset_words:
            synset_words.remove(word)

        return list(synset_words)




class SudachiEDA(EDA):

    def __init__(self):
        ROOT:str = 'Z:\\' if os.name == 'nt' else '/Z'
        PROJ_ROOT:str = os.path.join(ROOT, 'ae_share02', '128_TotalLifePlanningSalesTraining', 'smartphone_nlp', 'poc')
        SYNONYM_FILE = os.path.join(PROJ_ROOT, 'resources', 'synonym', 'synonyms.txt')
        self.df = pd.read_csv(SYNONYM_FILE, skip_blank_lines=True,
                            names=('group_id', 'type', 'expand', 'vocab_id', 'relation', 'abbreviation', 'spelling', 'domain', 'surface', 'reserve1', 'reserve2'))
        RESOURCE_ROOT: str = os.path.join(ROOT, 'ae_share02', '128_TotalLifePlanningSalesTraining', 'smartphone_nlp', 'poc', 'resources')
        STOPWORDS_PATH: str = os.path.join(RESOURCE_ROOT, "stopwords", "slothlib_stopwords.txt")
        # stop words
        self.stop_words = pd.read_csv(STOPWORDS_PATH, header=None)[0].to_list()

    def get_synonyms(self, word:str)->list:
        """get synonyms using sudachi-dictionary

        Args:
            word (str): The word you want to search

        Returns:
            list: exclude specified word
        """
        cols = ['group_id', 'surface']
        all_df = pd.DataFrame(columns=cols)
        for row in self.df[self.df.surface==word].itertuples():
            all_df = pd.concat([all_df, self.df[self.df.group_id==row.group_id].loc[:, cols]])
        surfaces: list = all_df['surface'].tolist()
        ex_surfaces = [n for n in surfaces if word != n]
        return ex_surfaces

class WordnetProxy:

    def __init__(self):
        ROOT:str = 'Z:\\' if os.name == 'nt' else '/Z'
        PROJ_ROOT:str = os.path.join(ROOT, 'ae_share02', '128_TotalLifePlanningSalesTraining', 'smartphone_nlp', 'poc')
        self.conn = sqlite3.connect(os.path.join(PROJ_ROOT, 'resources', 'synonym', 'wn_jpn.db'))

    def __getWords(self, lemma):
        Word = namedtuple('Word', 'wordid lang lemma pron pos')
        cur = self.conn.execute("select * from word where lemma=?", (lemma,))
        return [Word(*row) for row in cur]

    def __getSenses(self, word):
        Sense = namedtuple('Sense', 'synset wordid lang rank lexid freq src')
        cur = self.conn.execute("select * from sense where wordid=?", (word.wordid,))
        return [Sense(*row) for row in cur]

    def __getSynset(self, synset):
        Synset = namedtuple('Synset', 'synset pos name src')
        cur = self.conn.execute("select * from synset where synset=?", (synset,))
        return Synset(*cur.fetchone())

    def __getWordsFromSynset(self, synset, lang):
        Word = namedtuple('Word', 'wordid lang lemma pron pos')
        cur = self.conn.execute("select word.* from sense, word where synset=? and word.lang=? and sense.wordid = word.wordid;", (synset,lang))
        return [Word(*row) for row in cur]

    def __getWordsFromSenses(self, sense, lang="jpn"):
        synonym = {}
        for s in sense:
            lemmas = []
            syns = self.__getWordsFromSynset(s.synset, lang)
            for sy in syns:
                lemmas.append(sy.lemma)
            synonym[self.__getSynset(s.synset).name] = lemmas
        return synonym

    def get_synonym(self, word:str)->dict:
        """get synonyms using wordnet DB

        Args:
            word (str):The word you want to search

        Raises:
            Exception: if word does not string type

        Returns:
            dict: KEY:english related word, VALUE:japanese word.
        """
        synonym = {}
        if type(word) is not str:
            raise Exception(f'accept only string. current variable type is {type(word)}.')

        words = self.__getWords(word)
        if words:
            for w in words:
                sense = self.__getSenses(w)
                s = self.__getWordsFromSenses(sense)
                synonym = dict(list(synonym.items()) + list(s.items()))
        return synonym