# utils

In [1]:
from nltk import Tree
from tqdm.auto import tqdm

def progress(iter_, verbose, **kwargs):
    if not verbose:
        return iter_
    return tqdm(iter_, **kwargs)

def is_terminal(tree: Tree):
    return isinstance(tree[0], str)

def is_preterminal(tree: Tree):
    return isinstance(tree[0], str)

def convert_index(i, j, n):
    """tuple to combination's index"""
    assert 0 <= i < j <= n
    return int(i * (n + (n - i - 1)) / 2 + j - 1)

# processing pure tree

## read trees

In [2]:
from pathlib import Path
import re

from nltk.corpus import BracketParseCorpusReader


def read_parsed_corpus(root, dir_numbers, max_len, verbose=True, with_path=False):
    dir_numbers = [str(num).zfill(2) for num in dir_numbers]
    reader = BracketParseCorpusReader(root, r"wsj_.*\.mrg", tagset="wsj")
    for childdir in progress(sorted(Path(root).iterdir()), verbose):
        if childdir.name not in dir_numbers:
            continue

        for path in sorted(childdir.glob("wsj_*.mrg")):
            for tree in reader.parsed_sents(f"{childdir.name}/{path.name}"):
                if len(tree.leaves()) > max_len:
                    continue

                if with_path:
                    yield path, tree
                else:
                    yield tree

In [3]:
MAX_LEN = 40

raw_train = list(read_parsed_corpus("/home/corpus/PTB3/treebank_3/parsed/mrg/wsj", range(2, 21+1), MAX_LEN))
raw_valid = list(read_parsed_corpus("/home/corpus/PTB3/treebank_3/parsed/mrg/wsj", [22], MAX_LEN))
raw_test = list(read_parsed_corpus("/home/corpus/PTB3/treebank_3/parsed/mrg/wsj", [23], MAX_LEN))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=26.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=26.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=26.0), HTML(value='')))




## make chomsky 

In [4]:
def clean(tree: Tree):
    tree = trim_info(tree)
    tree.chomsky_normal_form(vertMarkov=0, horzMarkov=0)
    tree = replace_dummy_label(tree)
    tree.collapse_unary(collapsePOS=True, collapseRoot=True)
    return tree


def trim_info(tree: Tree) -> Tree:
    if is_preterminal(tree):
        return tree

    new_children = []
    for subtree in tree:
        label = subtree.label()
        if label[0] == label[-1] == "-":
            continue

        label, *_ = re.split('-|=', label, 2)
        new_children.append(Tree(label, subtree))

    new_children = [trim_info(child) for child in new_children if child]
    new_children = [child for child in new_children if child]
    if len(new_children) == 0:
        return None

    label, *_ = re.split('-|=', tree.label(), 2)
    return Tree(label, new_children)


def replace_dummy_label(tree: Tree, dummy_label="DUMMY"):
    if is_preterminal(tree):
        return tree
    label = dummy_label if "|" in tree.label() else tree.label()
    return Tree(label, [replace_dummy_label(child) for child in tree])

In [5]:
train = [clean(tree) for tree in tqdm(raw_train)]
valid = [clean(tree) for tree in tqdm(raw_valid)]
test = [clean(tree) for tree in tqdm(raw_test)]

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=35475.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1525.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2160.0), HTML(value='')))




## save

In [6]:
from pathlib import Path

def save(trees, path):
    Path(f"data/{path}").write_text(
        "\n".join([tree.pformat(float("inf")) for tree in tqdm(trees)]))

Path("data").mkdir(exist_ok=True)
save(train, "train")
save(valid, "valid")
save(test, "test")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=35475.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1525.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2160.0), HTML(value='')))




In [10]:
!du -sh data

259M	data


# label → index

## stats label

In [11]:
from collections import defaultdict
from nltk import Tree

def get_all_label(trees: list[Tree]):
    counter = defaultdict(int)

    for tree in tqdm(trees):
        for rule in tree.productions():
            if rule.is_lexical():
                continue

            for label in [rule.lhs(), *rule.rhs()]:
                counter[str(label)] += 1
    return counter

In [12]:
trees = train + valid + test
counter = get_all_label(trees)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=39160.0), HTML(value='')))




In [13]:
import pandas as pd

labels, counts = zip(*counter.items())
dist = pd.DataFrame({"count": counts}, index=labels).sort_values("count", ascending=False)

In [14]:
all_count = dist["count"].sum()
dist["proba"] = dist["count"] / all_count 
dist["cdf"] = dist["proba"].cumsum()

print("ラベル全種類数:", len(dist))
print("10 個以上のラベル:", len(dist.query("count >= 10")))
print("累計確率が 99 %になるまでのラベル:", len(dist.query("cdf <= .99")))
print("累計確率が 99.9 %になるまでのラベル:", len(dist.query("cdf <= .999")))

ラベル全種類数: 394
10 個以上のラベル: 195
累計確率が 99 %になるまでのラベル: 57
累計確率が 99.9 %になるまでのラベル: 121


In [15]:
LABEL_COUNTS = 60

In [16]:
min_labels = list(dist.index[:LABEL_COUNTS])
index = range(LABEL_COUNTS)

index2label = dict(zip(range(LABEL_COUNTS), min_labels))
label2index = dict(zip(min_labels, range(LABEL_COUNTS)))

In [17]:
import json

with open("data/index2label.json", "w") as f:
    json.dump(index2label, f)

with open("data/label2index.json", "w") as f:
    json.dump(label2index, f)

# processing chart

## convert chart

In [18]:
import torch


def tree2chart(tree, max_len=MAX_LEN, empty_label="EMPTY", oneline=True):
    chart = [[empty_label for _ in range(max_len+1)] for _ in range(max_len+1)]

    def return_index(tree: Tree, position=0) -> tuple[int, int]:
        if is_preterminal(tree):
            left = position
            right = position + 1
            try:
                chart[left][right] = tree.label()
            except Exception as e:
                print(tree)
                print(left, right, tree.label())
                raise e
            return left, right

        try:
            left_tree, right_tree = tree
        except ValueError:
            raise ValueError("this tree is not chomsky.")
        left, position = return_index(left_tree, position)
        _, right = return_index(right_tree, position)

        chart[left][right] = tree.label()

        return left, right

    return_index(tree)
    if oneline:
        return chart2oneline(chart)
    else:
        return chart


def chart2oneline(chart, max_len=MAX_LEN):
    oneline = []
    for i, row in enumerate(chart, 1):
        # if i < max_len:
        oneline.extend(row[i:])
    return oneline


def charts_as_tensor(trees):
    charts = []
    for tree in tqdm(trees):
        charts.append([label2index.get(label, -1) for label in tree2chart(tree)])
    return torch.tensor(charts, requires_grad=False)

In [19]:
train_chart = charts_as_tensor(train)
valid_chart = charts_as_tensor(valid)
test_chart = charts_as_tensor(test)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=35475.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1525.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2160.0), HTML(value='')))




In [20]:
test_chart.shape

torch.Size([2160, 820])

In [23]:
torch.save(train_chart, "data/train_chart.pt")
torch.save(valid_chart, "data/valid_chart.pt")
torch.save(test_chart, "data/test_chart.pt")

In [24]:
!du -sh data

260M	data


In [237]:
tmp = torch.load("data/test_chart.pt")
tmp

tensor([[-1, -1, -1,  ..., -1, -1, -1],
        [ 7, -1, -1,  ..., -1, -1, -1],
        [ 7, -1, -1,  ..., -1, -1, -1],
        ...,
        [ 8,  1, -1,  ..., -1, -1, -1],
        [ 7, -1,  1,  ..., -1, -1, -1],
        [-1, -1, -1,  ..., -1, -1, -1]])