In [None]:
import sys
sys.path.append('../src/utils')

In [None]:
from pathlib import Path
base_path = Path("/data1/ujiie/shinra/tohoku_bert/")

In [None]:
import json

for path in base_path.glob("**/"):
    if path.stem in ["tohoku_bert", "_temp_files", "tokens"]:
        continue
    if path.parent.stem == "tohoku_bert":
        continue

    if not (path / "tokens").exists():
        continue

    category = path.stem

    diff_cnt = 0
    total = 0

    vocab = load_vocab(path / "vocab.txt")
    anns = load_annotation(path / f"{category}_dist.json")

    for token_file in path.glob("tokens/*.txt"):
        page_id = int(token_file.stem)
        if page_id not in anns:
            continue
        ann = anns[page_id]

        tokens, text_offsets = load_tokens(token_file, vocab)
        print(tokens)
        for a in ann:
            if "token_offset" not in a:
                continue
            total += 1
            start_line = int(a["token_offset"]["start"]["line_id"])
            start_offset = int(a["token_offset"]["start"]["offset"])

            end_line = int(a["token_offset"]["end"]["line_id"])
            end_offset = int(a["token_offset"]["end"]["offset"])

            diff_cnt += int(tokens[start_line][start_offset].startswith("##")
                or (end_offset < len(tokens[end_line]) 
                    and tokens[end_line][end_offset].startswith("##")))
            
        break
    break
    
    print(f"{category}: {diff_cnt}/{total}")

In [None]:
import pickle

attributes = {}
attribute_path = Path("/data1/ujiie/shinra/tohoku_bert/attributes")
for path in attribute_path.glob("*.txt"):
    category = path.stem
    with open(path, "r") as f:
        attributes[category] = [l for l in f.read().split("\n") if l != ""]

with open("/data1/ujiie/shinra/tohoku_bert/attributes.pickle", "wb") as f:
    pickle.dump(attributes, f)

In [46]:
from collections import defaultdict

def load_tokens(path, vocab):
    tokens = []
    text_offsets = []
    with open(path, "r") as f:
        for line in f:
            line = line.rstrip().split()
            line = [l.split(",") for l in line]
            tokens.append([vocab[int(l[0])] for l in line])
            text_offsets.append([[l[1], l[2]] for l in line])

    return tokens, text_offsets

def load_vocab(path):
    vocab = []
    with open(path, "r") as f:
        for line in f:
            line = line.rstrip()
            if not line:
                continue
            vocab.append(line)
    return vocab

def load_annotation(path):
    ann = defaultdict(list)
    with open(path, "r") as f:
        for line in f:
            line = line.rstrip()
            if not line:
                continue
            line = json.loads(line)
            line["page_id"] = int(line["page_id"])
            ann[line["page_id"]].append(line)
    return ann

def find_word_alignment(tokens):
    word_idxs = []
    for idx, token in enumerate(tokens):
        if not token.startswith("##"):
            word_idxs.append(idx)

    return word_idxs

In [47]:
from pathlib import Path
import pickle
import json
from tqdm import tqdm

def find_word_alignment(tokens):
    word_idxs = []
    sub2word = {}
    for idx, token in enumerate(tokens):
        if not token.startswith("##"):
            word_idxs.append(idx)
        sub2word[idx] = len(word_idxs) - 1

    return word_idxs, sub2word


class ShinraData(object):
    def __init__(self, tokenizer, attributes_path, params={}):
        self.tokenizer = tokenizer
        with open(attributes_path, "rb") as f:
            self.attributes = pickle.load(f)
        self.attr2idx = {}
        for key, value in self.attributes.items():
            self.attr2idx[key] = {word: idx for idx, word in enumerate(value)}

        self.page_id = None
        self.page_title = None
        self.category = None
        self.plain_text = None
        self.tokens = None
        self.word_alignments = None
        self.sub2word = None
        self.text_offsets = None
        self.nes = None

        for key, value in params.items():
            setattr(self, key, value)

    @classmethod
    def from_shinra2020_format(
        cls, 
        tokenizer=None,
        attributes_path=None,
        input_path=None):

        input_path = Path(input_path)
        category = input_path.stem

        anns = load_annotation(input_path / f"{category}_dist.json")
        vocab = load_vocab(input_path / "vocab.txt")

        docs = []
        for token_file in tqdm(input_path.glob("tokens/*.txt")):
            page_id = int(token_file.stem)
            tokens, text_offsets = load_tokens(token_file, vocab)

            # find title
            title = "".join([t[2:] if t.startswith("##") else t for t in tokens[4]])
            pos = title.find("-jawiki")
            title = title[:pos]

            # find word alignments = start positions of words
            word_alignments = [find_word_alignment(t) for t in tokens]
            sub2word = [w[1] for w in word_alignments]
            word_alignments = [w[0] for w in word_alignments]

            data = {
                "page_id": page_id, 
                "page_title": title,
                "category": category,
                "tokens": tokens,
                "text_offsets": text_offsets,
                "word_alignments": word_alignments,
                "sub2word": sub2word,
            }

            if page_id in anns:
                data["nes"] = anns[page_id]
            else:
                continue

            docs.append(cls(tokenizer, attributes_path, params=data))

        return docs

    @property
    def words(self):
        all_words = []
        for tokens, word_alignments in zip(self.tokens, self.word_alignments):
            words = []
            prev_idx = 0
            for idx in word_alignments[1:] + [-1]:
                inword_subwords = tokens[prev_idx:idx]
                inword_subwords = [s[2:] if s.startswith("##") else s for s in inword_subwords]
                words.append("".join(inword_subwords))
                prev_idx = idx
            all_words.append(words)
        return all_words

    @property
    def iob(self):
        """
        %%% IOB for ** only word-level iob2 tag **
        iobs = [sent, sent, ...]
        sent = [[Token1_attr1_iob, Token2_attr1_iob, ...], [Token1_attr2_iob, Token2_attr2_iob, ...], ...]

        {"O": 0, "B": 1, "I": 2}
        """
        iobs = [[["O" for _ in range(len(tokens))] for _ in range(len(self.attributes[self.category]))] for tokens in self.word_alignments]
        for ne in self.nes:
            start_line = int(ne["token_offset"]["start"]["line_id"])
            start_offset = int(ne["token_offset"]["start"]["offset"])

            end_line = int(ne["token_offset"]["end"]["line_id"])
            end_offset = int(ne["token_offset"]["end"]["offset"])

            if start_line != end_line:
                continue

            # 正解となるsubwordを含むwordまでタグ付
            attr_idx = self.attr2idx[self.category][ne["attribute"]]
            ne_start = self.sub2word[start_line][start_offset]
            ne_end = self.sub2word[end_line][end_offset]
            for idx in range(ne_start, ne_end):
                iobs[start_line][attr_idx][idx] = "B" if idx == ne_start else "I"

        return iobs


In [48]:
dataset = ShinraData.from_shinra2020_format(None, "/data1/ujiie/shinra/tohoku_bert/attributes.pickle", Path("/data1/ujiie/shinra/tohoku_bert/Event/Event_Other"))

525it [00:10, 50.94it/s]


In [55]:
outputs = []
with open("event_other.iob", "w") as f:
    for tokens, iobs in zip(dataset[5].words, dataset[5].iob):
        output = []
        flatten_iobs = [["O" for j in range(len(iobs))] for i in range(len(tokens))]
        print(len(flatten_iobs), len(tokens), len(iobs))
        for attr_idx, iob in enumerate(iobs):
            for token_idx, tag in enumerate(iob):
                if tag != "O":
                    tag = tag + "-" + dataset[5].attributes[dataset[5].category][attr_idx]
                flatten_iobs[token_idx][attr_idx] = tag
        for token, iob in zip(tokens, flatten_iobs):
            iob = "\t".join(iob)
            output.append(f"{token}\t{iob}")
        print(output)
        outputs.append("\n".join(output))

    attributes = ["token"] + dataset[5].attributes[dataset[5].category]
    attributes = "\t".join(attributes)
    f.write(f"{attributes}\n")
    f.write("\n\n".join(outputs))

\tO\tO\tO\tO\tO\tO\tO', 'の\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '結論\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'で\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'ある\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'と\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'し\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'て\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'いる\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO']
60 60 11
['なお\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '、\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '宮沢\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '説\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'によって\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'も\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '、\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'そもそも\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '主権\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'を\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '制約\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'する\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '原理\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'が\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'ある\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', 'の\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO\tO', '