## CS310 Natural Language Processing
## Assignment 4. Dependency Parsing

**Total points**: 50

In this assignment, you will train feed-forward neural network-based dependency parser and evaluate its performance on the provided treebank dataset.

### 0. Import Necessary Libraries

In [10]:
import torch
import torch.nn as nn
from dep_utils import conll_reader, DependencyTree
import copy
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import random
from tqdm import tqdm

### 1. Read Data and Generate Training Instances

In [11]:
# Re-use the code from Lab 7
class RootDummy(object):
    def __init__(self):
        self.head = None
        self.id = 0
        self.deprel = None

    def __repr__(self):
        return "<ROOT>"


class State(object):
    def __init__(self, sentence=[]):
        self.stack = []
        self.buffer = []
        if sentence:
            self.buffer = list(reversed(sentence))
        self.deps = set()

    def shift(self):
        ### START YOUR CODE ###
        self.stack.append(self.buffer.pop())
        ### END YOUR CODE ###

    def left_arc(self, label: str):
        assert len(self.stack) >= 2
        ### START YOUR CODE ###
        self.deps.add((self.stack[-1], self.stack[-2], label))
        self.stack.pop(-2)
        ### END YOUR CODE ###

    def right_arc(self, label: str):
        assert len(self.stack) >= 2
        ### START YOUR CODE ###
        self.deps.add((self.stack[-2], self.stack[-1], label))
        self.stack.pop(-1)
        ### END YOUR CODE ###

    def __repr__(self):
        return "({},{},{})".format(self.stack, self.buffer, self.deps)


def get_training_instances(dep_tree) -> List[Tuple[State, Tuple[str, str]]]:
    deprels = dep_tree.deprels

    word_ids = list(deprels.keys())
    state = State(word_ids)
    state.stack.append(0)  # ROOT

    childcount = defaultdict(int)
    for _, rel in deprels.items():
        childcount[rel.head] += 1

    seq = []
    while len(state.buffer) > 0 or len(state.stack) > 1:
        if state.stack[-1] == 0:
            seq.append((copy.deepcopy(state), ("shift", None)))
            state.shift()
            continue

        stack_top1 = deprels[state.stack[-1]]
        if state.stack[-2] == 0:
            stack_top2 = RootDummy()
        else:
            stack_top2 = deprels[state.stack[-2]]

        # Decide transition action
        if stack_top2.head == stack_top1.id:  # Left-Arc
            seq.append((copy.deepcopy(state), ("left_arc", stack_top2.deprel)))
            state.left_arc(stack_top2.deprel)
            childcount[stack_top1.id] -= 1
        elif stack_top1.head == stack_top2.id:  # Right-Arc
            if childcount[stack_top1.id] != 0:
                seq.append((copy.deepcopy(state), ("shift", None)))
                state.shift()
            else:
                seq.append((copy.deepcopy(state), ("right_arc", stack_top1.deprel)))
                state.right_arc(stack_top1.deprel)
                childcount[stack_top2.id] -= 1
        else:  # Shift
            seq.append((copy.deepcopy(state), ("shift", None)))
            state.shift()

        # print(seq[-1])

    seq.append((copy.deepcopy(state), ("done", None)))

    return seq


def process(
    dep_trees: List[DependencyTree],
    word_vocab: dict, word_vectors: dict,
    pos_vocab: dict, pos_vectors: dict,
    action_vocab: dict) -> torch.Tensor:
    tensor_data = []
    tensor_truth = []
    for tree in dep_trees:
        instances = get_training_instances(tree)
        for state, action in instances:
            if action[0] == "done":
                continue
            # convert to torch tensor and append to tensor_data
            # use stack[-3:] and buffer[-3:] to get the top 3 elements of the stack and buffer
            # if the stack or buffer has less than 3 elements, use the <NULL> to pad
            stack = state.stack[-3:] if len(state.stack) >= 3 else state.stack + [-1] * (3 - len(state.stack))
            buffer = state.buffer[-3:] if len(state.buffer) >= 3 else state.buffer + [-1] * (3 - len(state.buffer))
            # use word_vocab, pos_vocab to convert the word, pos, action to index
            stack_idxes = []
            stack_pos_idxes = []
            buffer_idxes = []
            buffer_pos_idxes = []
            for s in stack:
                if s == -1:
                    stack_idxes.extend(word_vectors[word_vocab["<NULL>"]])
                    stack_pos_idxes.extend(pos_vectors[pos_vocab["<NULL>"]])
                elif s == 0:
                    stack_idxes.extend(word_vectors[word_vocab["<ROOT>"]])
                    stack_pos_idxes.extend(pos_vectors[pos_vocab["<NONE>"]])
                else:
                    stack_idxes.extend(
                        word_vectors[
                            word_vocab.get(tree.deprels[s].word, word_vocab["<NULL>"])
                        ]
                    )
                    stack_pos_idxes.extend(
                        pos_vectors[
                            pos_vocab.get(tree.deprels[s].pos, pos_vocab["<NULL>"])
                        ]
                    )
            for b in buffer:
                if b == -1:
                    buffer_idxes.extend(word_vectors[word_vocab["<NULL>"]])
                    buffer_pos_idxes.extend(pos_vectors[pos_vocab["<NULL>"]])
                elif b == 0:
                    buffer_idxes.extend(word_vectors[word_vocab["<ROOT>"]])
                    buffer_pos_idxes.extend(pos_vectors[pos_vocab["<NONE>"]])
                else:
                    buffer_idxes.extend(
                        word_vectors[
                            word_vocab.get(tree.deprels[b].word, word_vocab["<NULL>"])
                        ]
                    )
                    buffer_pos_idxes.extend(
                        pos_vectors[
                            pos_vocab.get(tree.deprels[b].pos, pos_vocab["<NULL>"])
                        ]
                    )
            # concatenate all index to get word vectors
            data_vector = torch.tensor(stack_idxes + buffer_idxes + stack_pos_idxes + buffer_pos_idxes)
            tensor_data.append(data_vector)
            tensor_truth.append(action_vocab[action])

    return (torch.stack(tensor_data), torch.tensor(tensor_truth))

### 2. Build the Model

In [12]:
class Parser(nn.Module):
    def __init__(self, vec_dim, hidden_dim, action_size, dropout=0.2):
        super(Parser, self).__init__()
        self.input_dim = vec_dim
        self.hidden_dim = hidden_dim
        self.output_dim = action_size
        self.dropout = dropout
        self.W_1w = nn.Linear(self.input_dim, self.hidden_dim)
        self.W_1t = nn.Linear(self.input_dim, self.hidden_dim)
        self.W_2 = nn.Linear(self.hidden_dim, self.output_dim)
        self.ReLU = nn.ReLU()
        # self.softmax = nn.Softmax()

    def forward(self, x):
        return self.W_2(self.ReLU(self.W_1w(x[:x.shape[0]//2]) + self.W_1t(x[x.shape[0]//2:])))

    def parse_sentence(
        self, sentence,
        word_vocab: dict, word_vectors: dict,
        pos_vocab: dict, pos_vectors: dict,
        action_vocab: dict,
        device: torch.device
    ):
        state = State(sentence)
        state.stack.append(0)
        while len(state.buffer) > 0 or len(state.stack) > 1:
            stack = state.stack[-3:] if len(state.stack) >= 3 else state.stack + [-1] * (3 - len(state.stack))
            buffer = state.buffer[-3:] if len(state.buffer) >= 3 else state.buffer + [-1] * (3 - len(state.buffer))
            stack_idxes = []
            stack_pos_idxes = []
            buffer_idxes = []
            buffer_pos_idxes = []
            for s in stack:
                if s == -1:
                    stack_idxes.extend(word_vectors[word_vocab["<NULL>"]])
                    stack_pos_idxes.extend(pos_vectors[pos_vocab["<NULL>"]])
                elif s == 0:
                    stack_idxes.extend(word_vectors[word_vocab["<ROOT>"]])
                    stack_pos_idxes.extend(pos_vectors[pos_vocab["<NONE>"]])
                else:
                    stack_idxes.extend(
                        word_vectors[
                            word_vocab.get(sentence[s].word, word_vocab["<NULL>"])
                        ]
                    )
                    stack_pos_idxes.extend(
                        pos_vectors[pos_vocab.get(sentence[s].pos, pos_vocab["<NULL>"])]
                    )
            for b in buffer:
                if b == -1:
                    buffer_idxes.extend(word_vectors[word_vocab["<NULL>"]])
                    buffer_pos_idxes.extend(pos_vectors[pos_vocab["<NULL>"]])
                elif b == 0:
                    buffer_idxes.extend(word_vectors[word_vocab["<ROOT>"]])
                    buffer_pos_idxes.extend(pos_vectors[pos_vocab["<NONE>"]])
                else:
                    buffer_idxes.extend(
                        word_vectors[
                            word_vocab.get(sentence[b].word, word_vocab["<NULL>"])
                        ]
                    )
                    buffer_pos_idxes.extend(
                        pos_vectors[pos_vocab.get(sentence[b].pos, pos_vocab["<NULL>"])]
                    )
            data_vector = torch.tensor(
                stack_idxes + buffer_idxes + stack_pos_idxes + buffer_pos_idxes
            ).to(device)
            actions = torch.argsort(self.forward(data_vector))
            action_idx = 0
            while action_idx < len(actions):
                action = action_vocab[actions[action_idx].item()]
                if action[0] == "shift":
                    if len(state.buffer) == 0:
                        action_idx += 1
                        continue
                    state.shift()
                    break
                elif action[0] == "left_arc":
                    if len(state.stack) < 2 or state.stack[-2] == 0:
                        action_idx += 1
                        continue
                    state.left_arc(action[1])
                    break
                elif action[0] == "right_arc":
                    if len(state.stack) < 2 or (state.stack[-2] == 0 and len(state.buffer) > 0):
                        action_idx += 1
                        continue
                    state.right_arc(action[1])
                    break
                action_idx += 1
        return state.deps

In [13]:
print("In train.conll:")
with open("data/train.conll") as f:
    train_trees = list(conll_reader(f))
print(f"{len(train_trees)} trees read.")

print("In dev.conll:")
with open("data/dev.conll") as f:
    dev_trees = list(conll_reader(f))
print(f"{len(dev_trees)} trees read.")

print("In test.conll:")
with open("data/test.conll") as f:
    test_trees = list(conll_reader(f))
print(f"{len(test_trees)} trees read.")

rel_counter = Counter()
for tree in train_trees:
    for item in tree.deprels.values():
        rel_counter[item.deprel] += 1

print(f"Found {len(rel_counter)} unique dependency relations in the training set.")

In train.conll:
39832 trees read.
In dev.conll:
1700 trees read.
In test.conll:
2416 trees read.
Found 39 unique dependency relations in the training set.


In [14]:
# load word embeddings
emb_path = "./data/glove.6B.50d.txt"
word_vocab = {"<NULL>": -1, "<ROOT>": 0}
word_vectors = [[-0.01] * 50, [0] * 50]
with open(emb_path) as f:
    for line in f:
        parts = line.strip().split()
        word = parts[0]
        vector = list(map(float, parts[1:]))
        word_vocab[word] = len(word_vocab)
        word_vectors.append(vector)

# word_vocab = {"<NULL>": -1, "<ROOT>": 0}
# random pos embeddings within (−0.01, 0.01)
pos_vocab = {"<NULL>": -1, "<NONE>": 0}
pos_vectors = [[-0.01] * 50, [0] * 50]

for tree in train_trees:
    for pos in tree.pos():
        if pos not in pos_vocab:
            pos_vocab[pos] = len(pos_vocab)
            rand_vec = [random.uniform(-0.01, 0.01) for _ in range(50)]
            while rand_vec in pos_vectors:
                rand_vec = [random.uniform(-0.01, 0.01) for _ in range(50)]
            pos_vectors.append(rand_vec)

# build action vocab
action_vocab = {}
action_rev_vocab = {}
action_vocab[("right_arc", "root")] = 0
action_vocab[("shift", None)] = 1
for rel in rel_counter.keys():
    if rel == "root":
        continue
    action_vocab[("left_arc", rel)] = len(action_vocab)
    action_vocab[("right_arc", rel)] = len(action_vocab)

for k, v in action_vocab.items():
    action_rev_vocab[v] = k

print(f"Word vocab size: {len(word_vocab)}")
print(f"POS vocab size: {len(pos_vocab)}")
print(f"Action vocab size: {len(action_vocab)}")

Word vocab size: 400002
POS vocab size: 48
Action vocab size: 78
[[-0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01, -0.01], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0.418, 0.24968, -0.41242, 0.1217, 0.34527, -0.044457, -0.49688, -0.17862, -0.00066023, -0.6566, 0.27843, -0.14767, -0.55677, 0.14658, -0.0095095, 0.011658, 0.10204, -0.12792, -0.8443, -0.12181, -0.016801, -0.33279, -0.1552, -0.23131, -0.19181, -1.8823, -0.76746, 0.099051, -0.42125, -0.19526, 4.0071, -0.18594, -0.52287, -0.31681, 0.00059213, 0.0074449, 0.17778, -0.15897, 0.012041, -0.054223, -0.29871, -0.15749, -0.34758, -0.04

In [15]:
train_data = []
train_truth = []
train_data, train_truth = process(
    train_trees,
    word_vocab,
    word_vectors,
    pos_vocab,
    pos_vectors,
    action_vocab,
)

IndexError: pop from empty list

In [None]:
test_data = []
test_truth = []
test_data, test_truth = process(
    test_trees,
    word_vocab,
    word_vectors,
    pos_vocab,
    pos_vectors,
    action_vocab,
)

### 3. Train and Evaluate

In [None]:
# train
model = Parser(300, 128, 78)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
epochs = 1
display_step = 1000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model.to(device)
train_data = train_data.to(device)
train_truth = train_truth.to(device)

for epoch in range(epochs):
    total_loss = 0
    pbar = tqdm(range(0, len(train_data)))
    for i in pbar:
        optimizer.zero_grad()
        output = model(torch.Tensor(train_data[i]))
        loss = loss_fn(output, train_truth[i])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if i % display_step == 0:
            pbar.set_description(f"loss: {total_loss / display_step}")
            total_loss = 0

In [None]:
# evaluate
# example = model.parse_sentence(test_trees[0].deprels, word_vocab, word_vectors, pos_vocab, pos_vectors, action_rev_vocab, device)
# print(example)
# example_truth = set([(x.head, x.id, x.deprel) for x in test_trees[0].deprels.values()])
# print(example_truth)

las_correct = 0
uas_correct = 0
total = 0

for tree in test_trees[:10]:
    example = model.parse_sentence(tree.deprels, word_vocab, word_vectors, pos_vocab, pos_vectors, action_rev_vocab, device)
    example_truth = set([(x.head, x.id, x.deprel) for x in tree.deprels.values()])
    for edge in example:
        if edge in example_truth:
            las_correct += 1
        if (edge[0], edge[1]) in [(x.head, x.id) for x in tree.deprels.values()]:
            uas_correct += 1
        total += 1

print(f"LAS: {las_correct / total}")
print(f"UAS: {uas_correct / total}")