In [1]:
MAX_LEN = 40
LABEL_COUNTS = 60

In [2]:
import torch

assert torch.cuda.is_available, "GPU が利用可能ではない。"

device = torch.device("cuda", index=1)
device

device(type='cuda', index=1)

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig, logging

model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

logging.set_verbosity_error()
encoder = AutoModel.from_pretrained(model_name)
logging.set_verbosity_warning()

## dataloader

In [5]:
from pathlib import Path

from nltk import Tree
import torch
from torch.utils.data import Dataset
from tqdm import tqdm


class TreebankDataset(Dataset):
    def __init__(self, tokenizer, tree_path, gold_path):
        self.tokenizer = tokenizer
        trees = [Tree.fromstring(tree) for tree in Path(tree_path).read_text().splitlines()]
        self.trees, token_ids = zip(*self.encode(trees))
        self.token_ids = torch.as_tensor(token_ids).to(device)
        self.gold_labels = torch.load(gold_path).to(device)

        self.token_ids.requires_grad = False
        self.gold_labels.requires_grad = False

    # メソッド名おかしい
    def encode(self, trees, max_len=MAX_LEN):
        for tree in tqdm(trees):
            words = tree.leaves()
            token_ids = []
            for word in words:
                first_token_id, *_ = self.tokenizer.encode(
                    word, add_special_tokens=False)
                token_ids.append(first_token_id)

            # TODO: add attention mask
            padding = [self.tokenizer.pad_token_id] * (max_len - len(token_ids))
            token_ids = [self.tokenizer.cls_token_id] + token_ids + padding + [self.tokenizer.sep_token_id]
            
            yield tree, token_ids
    
    def __len__(self):
        return len(self.trees)

    def __getitem__(self, index):
        """return batch
        
        token_ids, gold_labels
        """
        return self.token_ids[index], self.gold_labels[index]

## model

In [20]:
from itertools import combinations
from torch import nn


class Parser(nn.Module):
    def __init__(self, encoder, label_counts=LABEL_COUNTS):
        super().__init__()

        self.encoder = encoder
        # TODO: 学習済みモデルの読み込み
        # if True or fine_tuned_model:
        #     config = AutoConfig(base_model)
        #     self.bert = AutoModel.from_config(config)
        #     self.load_state_dict(torch.load(fine_tuned_model))           
        # else:
        #     self.bert = AutoModel.from_pretrained(base_model)

        # 768 -> 300 -> 60
        hidden_dim = 300
        self.label_counts = label_counts
        self.mlp = nn.Sequential(nn.Linear(self.encoder.config.hidden_size,
                                           hidden_dim, bias=False),
                                 nn.LayerNorm(hidden_dim),
                                 nn.ReLU(),
                                 nn.Linear(hidden_dim, label_counts, bias=False)
                                 )

    # 簡単な配列でテストする
    def forward(self, token_ids):
        """calculate scores for each label.
        
        index: [820, 2]
        labels: [25, 820]
        max_scores: [25, 820]
        all_scores: [25, 820, 60]
        """
        with torch.no_grad():
            # dataset ですべきでは？
            embed = self.encoder(token_ids, return_dict=True)["last_hidden_state"]

            batch_size, sent_len, embed_dim = embed.shape
            y_left = embed[:, 1:, embed_dim//2:]
            y_right = embed[:, :-1, :embed_dim//2]
            fence = torch.cat([y_right, y_left], dim=2)

        # 文の長さに合わせて学習したい
        indexes = list(combinations(range(sent_len-1), 2))
        all_scores = torch.empty(batch_size, len(indexes), self.label_counts).to(device)
        for i, (start, end) in enumerate(indexes):
            span = fence[:, end, :] - fence[:, start, :]
            all_scores[:, i, :] = self.mlp(span)

        max_score = torch.max(all_scores, dim=2)
        labels = max_score.indices
        max_scores = max_score.values

        labels.requires_grad = False

        return indexes, labels, max_scores, all_scores

    def parse(self):
        pass

## loss

In [26]:
from functools import cache


def convert_index(i, j, n):
    """tuple to combination's index"""
    assert 0 <= i < j <= n
    return int(i * (n + (n - i - 1)) / 2 + j - 1)


def score_gold_tree(label_scores, max_len=MAX_LEN):
    # 後で gold に最適化する
    @cache
    def score_subtree(i, j):
        label_max = label_scores[:, convert_index(i, j, max_len)]
        split_max = torch.zeros_like(label_max, device=device)
        for k in range(i+1, j):
            tmp_max = score_subtree(i, k) + score_subtree(k, j)
            split_max = torch.maximum(split_max, tmp_max)

        return label_max + split_max
    return score_subtree(0, max_len)


def score_pred_tree(label_scores, _pred_labels, _gold_labels, max_len=MAX_LEN):
    @cache
    def score_subtree(i, j):
        single_index = convert_index(i, j, max_len)
        label_max = label_scores[:, single_index]
        split_max = torch.zeros_like(label_max, device=device)
        for k in range(i+1, j):
            left_score, left_label = score_subtree(i, k)
            right_score, right_label = score_subtree(k, j)

            new_max = split_max < left_score + right_score
            
            
            split_max = torch.maximum(split_max, left_score + right_score)

        # これだとルートノードの margin しか返してない？
        gold_labels = _gold_labels[:, single_index]
        pred_labels = _pred_labels[:, single_index]
        # 教師データにスパンが存在しない範囲は無視
        margin = torch.where(gold_labels == -1, 0, (pred_labels != gold_labels).to(int))
        return label_max + split_max + margin, pred_labels
    return score_subtree(0, max_len)


def hamming(pred_labels, gold_labels):
    return torch.sum(pred_labels != gold_labels)


def loss_fn(max_scores, all_scores, pred_labels, gold_labels,
            zero=torch.tensor(0., requires_grad=False).to(device)):
    batch_size, span_counts, label_count = all_scores.shape
    gold_scores = torch.stack([all_scores[i, range(span_counts), gold_labels[i]]
                               for i in range(batch_size)]).detach()
    gold_scores = gold_scores.where(gold_labels == -1, zero)

    gold_tree_score = score_gold_tree(gold_scores)
    pred_tree_score = score_pred_tree(max_scores, pred_labels, gold_labels)
    loss = torch.maximum(zero, torch.sum(pred_tree_score - gold_tree_score))

    return loss

## train

In [22]:
from transformers import AdamW

parser = Parser(encoder).to(device)
optimizer = AdamW([{"params": parser.parameters()}], lr=1e-6)

In [15]:
from torch.utils.data import DataLoader

def make_loader(phase, batch_size=50):
    dataset = TreebankDataset(tokenizer, f"data/{phase}", f"data/{phase}_chart.pt")
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# train = make_loader("train", 500)
valid = make_loader("valid", 500)
# test = make_loader("test")

100%|██████████| 1525/1525 [00:01<00:00, 1108.60it/s]


In [30]:
with open("log.txt", "w") as f:
    print(end="", file=f)

torch.cuda.manual_seed_all(0)
num_epochs = 10
for epoch in range(num_epochs):
    loss_sum = 0
    for token_ids, gold_labels in tqdm(valid):
        indexes, pred_labels, max_scores, all_scores = parser(token_ids)
        loss = loss_fn(max_scores, all_scores, pred_labels, gold_labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        loss_sum += loss
        
        # tqdm.write(repr(margin))

        # del loss
        # torch.cuda.empty_cache()

    with open("log.txt", "a") as f:
        print(f"epoch: {epoch:2}, loss: {loss_sum:9.2f}, "
              f"ave. loss: {loss_sum/len(valid.dataset):5.3f}",
              file=f)

100%|██████████| 4/4 [00:10<00:00,  2.71s/it]
100%|██████████| 4/4 [00:11<00:00,  2.80s/it]
  0%|          | 0/4 [00:02<?, ?it/s]


KeyboardInterrupt: 