<a href="https://colab.research.google.com/github/ulisesfm-py/XCS224N-Handouts/blob/main/parser_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Dependency Parser visualization

In [None]:
import argparse
import math
import os
import sys
import time
from datetime import datetime

from torch import nn, optim
from tqdm import tqdm

from submission import (AverageMeter, ParserModel, load_and_preprocess_data,
                        minibatches, train)

In [None]:
!conda install pygraphviz

In [None]:
args_debug = False
args_device = "cpu"

In [None]:
parser, embeddings, train_data, dev_data, test_data = load_and_preprocess_data(
    args_debug
)

In [None]:
model = ParserModel(embeddings)
model.to(args_device)
parser.model = model

model

In [None]:
%%time

output_dir = "run_results_jupyter/{:%Y%m%d_%H%M%S}/".format(datetime.now())
output_path = output_dir + "model.weights"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

train(
    parser,
    train_data,
    dev_data,
    output_path,
    batch_size=1024,
    n_epochs=15,
    lr=0.0005,
    device=args_device,
)

In [None]:
parser.parse(dataset=test_data, device=args_device)

In [None]:
sample = 150

In [None]:
test_data[sample]

In [None]:
listing = list(
    enumerate([parser.id2tok[w_] for w_ in test_data[sample]["word"]])
)
nodedict = {i_: w_ for i_, w_ in listing}
legend = " ".join([str(i_) + ":" + w_ for i_, w_ in listing])

legend, nodedict

In [None]:
edges = parser.parse(
    dataset=test_data[sample : sample + 1], device=args_device
)[1][0]
edges

In [None]:
import networkx as nx

G = nx.DiGraph()

[G.add_node(nodedict[i_]) for i_, _ in listing]
[G.add_edge(nodedict[n1_], nodedict[n2_]) for n1_, n2_ in edges]
list(G.edges())

In [None]:
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout


pos = graphviz_layout(G, prog="dot", root=0)
nx.draw(
    G, with_labels=True, font_weight="bold", pos=pos, node_color="lightblue"
)
plt.title(legend)
plt.show()

In [None]:
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout
import networkx as nx
import textwrap


def render_sample(parser, dataset, sample, device):
    words = [parser.id2tok[w_] for w_ in dataset[sample]["word"]]

    edges = parser.parse(dataset=dataset[sample : sample + 1], device=device)[
        1
    ][0]

    words_unique = []
    word_counter = {}
    for w_ in words:
        word_counter.update({w_: word_counter.get(w_, 0) + 1})
        if word_counter[w_] > 1:
            words_unique.append(w_ + "[" + str(word_counter[w_]) + "]")
        else:
            words_unique.append(w_)

    listing = list(enumerate(words_unique))
    nodedict = {i_: w_ for i_, w_ in listing}
    legend = " ".join([w_ for _, w_ in listing])

    G = nx.DiGraph()
    [G.add_node(nodedict[i_]) for i_, _ in listing]
    [G.add_edge(nodedict[n1_], nodedict[n2_]) for n1_, n2_ in edges]

    try:
        pos = graphviz_layout(G, prog="dot", root=0)
        nx.draw(
            G,
            with_labels=True,
            font_weight="bold",
            pos=pos,
            node_color="lightblue",
            font_size=8,
        )
    except TypeError:
        nx.draw(
            G,
            with_labels=True,
            font_weight="bold",
            node_color="lightblue",
            font_size=8,
        )

    plt.title(
        "\n".join(textwrap.wrap(legend, width=60)).replace("$", "\$"),
        fontweight="bold",
    )
    plt.show()

In [None]:
render_sample(parser=parser, dataset=test_data, sample=5, device=args_device)

In [None]:
render_sample(parser=parser, dataset=test_data, sample=210, device=args_device)

In [None]:
render_sample(parser=parser, dataset=test_data, sample=291, device=args_device)

In [None]:
render_sample(parser=parser, dataset=test_data, sample=667, device=args_device)

In [None]:
render_sample(parser=parser, dataset=test_data, sample=999, device=args_device)