Neural Network Dependency Parser Demo

This notebook demonstrates how to use your trained dependency parser to analyze sentences and visualize dependency trees.

1. Setup and Import

In [None]:
import sys
import os
import torch
import pickle
import spacy
import matplotlib.pyplot as plt
import networkx as nx

sys.path.append(os.path.abspath('.'))

2. Load Model and Vocabularies

In [None]:
def load_model_and_vocabs():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    proc_dir = os.path.join('data', 'processed')
    with open(os.path.join(proc_dir, 'word_vocab.pkl'), 'rb') as f:
        word_vocab = pickle.load(f)
    with open(os.path.join(proc_dir, 'pos_vocab.pkl'), 'rb') as f:
        pos_vocab = pickle.load(f)
    with open(os.path.join(proc_dir, 'label_vocab.pkl'), 'rb') as f:
        label_vocab = pickle.load(f)
    from models.parser import DependencyParser
    model = DependencyParser(
        vocab_sizes={'word': len(word_vocab), 'pos': len(pos_vocab)},
        emb_dims={'word': 100, 'pos': 32},
        lstm_dim=256,
        num_labels=len(label_vocab)
    ).to(device)
    model.load_state_dict(torch.load('best_model.pt', map_location=device))
    model.eval()
    return model, word_vocab, pos_vocab, label_vocab, device

model, word_vocab, pos_vocab, label_vocab, device = load_model_and_vocabs()
print("Model loaded successfully!")

3. Helper Function

In [None]:
def tokenize_sentence(sentence):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(sentence)
    words = [token.text for token in doc]
    pos_tags = [token.pos_ for token in doc]
    return words, pos_tags

def predict_dependencies(model, words, pos_tags, word_vocab, pos_vocab, label_vocab, device):
    word_idx = [word_vocab.get(w, word_vocab['<unk>']) for w in words]
    pos_idx = [pos_vocab.get(p, pos_vocab['<unk>']) for p in pos_tags]
    word_tensor = torch.tensor([word_idx], dtype=torch.long).to(device)
    pos_tensor = torch.tensor([pos_idx], dtype=torch.long).to(device)
    with torch.no_grad():
        head_scores, label_scores = model(word_tensor, pos_tensor)
        pred_heads = head_scores.argmax(-1).squeeze(0)
        pred_labels = label_scores.permute(0,2,3,1).gather(
            2, pred_heads.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,1,label_scores.size(1))
        ).squeeze(2).argmax(-1).squeeze(0)
    return pred_heads.cpu().numpy(), pred_labels.cpu().numpy()

def format_dependency_tree(words, pos_tags, heads, labels, label_vocab):
    print("\nDependency Tree:")
    print("-" * 50)
    print(f"{'ID':<3} {'Word':<15} {'POS':<8} {'Head':<8} {'Label':<15}")
    print("-" * 50)
    for i, (word, pos, head, label) in enumerate(zip(words, pos_tags, heads, labels)):
        head_word = words[head] if head < len(words) else "ROOT"
        label_name = label_vocab.itos[label] if label < len(label_vocab.itos) else "UNK"
        print(f"{i:<3} {word:<15} {pos:<8} {head_word:<8} {label_name:<15}")

def visualize_dependency_tree(words, pos_tags, heads, labels, label_vocab, title="Dependency Tree"):
    G = nx.DiGraph()
    for i, (word, pos) in enumerate(zip(words, pos_tags)):
        G.add_node(i, word=word, pos=pos)
    for i, (head, label) in enumerate(zip(heads, labels)):
        if head < len(words):
            label_name = label_vocab.itos[label] if label < len(label_vocab.itos) else "UNK"
            G.add_edge(head, i, label=label_name)
    pos = nx.spring_layout(G, k=3, iterations=50)
    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, node_color='lightblue',
            node_size=2000, font_size=10, font_weight='bold',
            arrows=True, arrowstyle='->', arrowsize=20)
    edge_labels = nx.get_edge_attributes(G, 'label')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
    node_labels = {i: f"{G.nodes[i]['word']}\n({G.nodes[i]['pos']})" for i in G.nodes()}
    nx.draw_networkx_labels(G, pos, node_labels, font_size=8)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

4. Example Sentence Test

In [None]:
test_sentences = [
    "The cat sat on the mat.",
    "I love neural networks.",
    "She quickly ran to the store.",
    "The beautiful red car drove fast."
]

for i, sentence in enumerate(test_sentences, 1):
    print(f"\n{'='*60}")
    print(f"Sentence {i}: {sentence}")
    print(f"{'='*60}")
    words, pos_tags = tokenize_sentence(sentence)
    print(f"Tokens: {words}")
    print(f"POS tags: {pos_tags}")
    heads, labels = predict_dependencies(model, words, pos_tags, word_vocab, pos_vocab, label_vocab, device)
    format_dependency_tree(words, pos_tags, heads, labels, label_vocab)
    visualize_dependency_tree(words, pos_tags, heads, labels, label_vocab, f"Sentence {i}: {sentence}")

5. Interactive Parsing

In [None]:
def parse_custom_sentence():
    sentence = input("Enter a sentence to parse: ")
    if sentence.strip():
        words, pos_tags = tokenize_sentence(sentence)
        heads, labels = predict_dependencies(model, words, pos_tags, word_vocab, pos_vocab, label_vocab, device)
        format_dependency_tree(words, pos_tags, heads, labels, label_vocab)
        visualize_dependency_tree(words, pos_tags, heads, labels, label_vocab, f"Custom: {sentence}")

# Uncomment to use:
# parse_custom_sentence()