In [None]:
import spacy
from spacy import displacy
from datasets import load_dataset
from random import shuffle, seed
from tqdm import tqdm

nlp = spacy.load("en_core_web_sm")

In [None]:
seed('482')
dataset_1 = dict(load_dataset("embedding-data/sentence-compression"))
dataset_1 = [v['set'] for v in dataset_1['train']]
shuffle(dataset_1)
train_count = int(len(dataset_1) * 0.7)
dev_count = int(len(dataset_1) * 0.8)
dataset_1 = {'train': dataset_1[:train_count], 'dev': dataset_1[train_count:dev_count], 'test': dataset_1[dev_count:]}

print('Dataset 1 - embedding-data/sentence-compression')
print('\tTrain:', len(dataset_1['train']))
print('\tDevelopment:', len(dataset_1['dev']))
print('\tTest:', len(dataset_1['test']))



  0%|          | 0/1 [00:00<?, ?it/s]

Dataset 1 - embedding-data/sentence-compression
	Train: 125999
	Development: 18001
	Test: 36000


In [None]:
# Even with padding disabled, the labels below need to be more than used!!
ENABLE_PADDING = True
ABOVE_PADDING = 250
BELOW_PADDING = ABOVE_PADDING + 100

LABELS = ['tag']
for ix in range(len(LABELS), ABOVE_PADDING):
    LABELS.append(f'depa_{ix}')
for ix in range(len(LABELS), BELOW_PADDING):
    LABELS.append(f'depb_{ix}')

In [None]:
def save_svg(svg):
    with open('image.svg', 'w') as f:
        f.write(svg)

In [None]:
"""
Label all data and put it in one labeled array
"""
def label_data(data_input):
    out = {}
    print('\nLabeling data')
    for sentence, compressed in tqdm(data_input):
        root = Root.from_sentence(sentence, compressed)
        root.collapse()
        root.add_fake()
        out[root] = root.get_features(LABELS)
    return out


# Flattens the input and returns array for nltk classifier
def get_train_data(data_input):
    # {root: array} -> [(node-1, r), ..., (node-x, r), .., (node-n, r)]
    out = []
    for root, node_features in data_input.items():
        for node in node_features:
            # Check for all stems, but skip ROOT
            if node['idx'] == -1:
                continue
            keep = sum([root.comp_node.contains(stem) for stem in node['stem']]) == len(node['stem'])
            node.pop('idx')
            node.pop('stem')
            out.append((node, keep))
    return out

In [None]:
class Node:
    @staticmethod
    def parse_spacy(spacy_node):
        node = Node()
        node.text = spacy_node.text
        node.pos = spacy_node.tag_
        node.tag = spacy_node.pos_
        node.idx = spacy_node.idx
        node.dep = spacy_node.dep_
        node.stem = [spacy_node.lemma_]
        node.children = [Node.parse_spacy(child) for child in spacy_node.children]
        for child in node.children:
            child.parent = node
        return node

    @staticmethod
    def make_fake_root(existing_root):
        node = Node()
        node.text = 'root'
        node.pos = 'ROOT'
        node.tag = 'ROOT'
        node.idx = -1  # Always the first node
        node.dep = 'root'  # Should never be used
        node.children = [existing_root]
        existing_root.fake_parents.append(node)
        for verb in existing_root.get_tagged_in_tree('VERB'):
            if verb == existing_root:
                continue  # This connection already exists
            verb.fake_parents.append(node)
            node.children.append(verb)
        return node

    def __init__(self):
        # The current text of the node
        self.text = ''
        self._text = None
        self.pos = ''
        self.tag = ''
        # The dependency tag leading to this node
        self.dep = ''
        self._dep = None
        # This nodes character offset in the sentence
        self.idx = 0
        self.stem = []
        self._stem = None
        # The primary parent (For traversing)
        self.parent = None
        # Fake parents (For Spacy)
        self.fake_parents = []
        # Children (For traversing)
        self.children = []
        # Collapsed Children
        self.collapsed = []
        # Was this node's "parent" changed b/c of a collapse?
        self.grafted = False

    def __str__(self, level=1):
        base = f'<Node: {self.text} {self.dep} {self.idx}>'
        for child in self.children:
            base += '\n' + '\t' * level + child.__str__(level + 1)
        return base

    def count_nodes(self):
        count = 1
        for child in self.children:
            count += child.count_nodes()
        return count

    def get_index_dict(self):
        out = {self.idx: self}
        for child in self.children:
            out.update(child.get_index_dict())
        return out

    @staticmethod
    def _make_arc(idx, parent_idx, dep):
        start = min(idx, parent_idx)
        end = max(idx, parent_idx)
        direction = 'right' if idx > parent_idx else 'left'
        return {'start': start, 'end': end, 'label': dep, 'dir': direction}

    def get_arcs(self):
        arcs = []
        if self.parent is not None:
            arcs.append(self._make_arc(self.idx, self.parent.idx, self.dep))
        # Fake parents are always VERB -> ROOT
        for fake_parent in self.fake_parents:
            arcs.append(self._make_arc(self.idx, fake_parent.idx, 'root'))
        for child in self.children:
            arcs += child.get_arcs()
        return arcs

    def get_tagged_in_tree(self, tag):
        out = []
        if self.tag == tag:
            out.append(self)
        for child in self.children:
            out += child.get_tagged_in_tree(tag)
        return out

    def contains(self, stem):
        contains = stem == self.stem[0]
        for child in self.children:
            contains |= child.contains(stem)
        return contains

    def _collapse(self, node):
      # Change the dep to text if 'prep'
        if self.dep == 'prep':
            self._dep = self.dep
            self.dep = self.text
        self.children.remove(node)
        self.collapsed.append(node)
        # Add child's children to this node
        for child in node.children:
            self.children.append(child)
            child.grafted = True
        # Add child node to this one
        if node.idx < self.idx:
            first, second = node.text, self.text
        else:
            first, second = self.text, node.text
        self._text = self.text  # Save our current text
        space = ' ' if node.dep != 'case' else ''
        # Combine the texts
        self.text = f'{first}{space}{second}'
        self._stem = self.stem
        self.stem = self.stem + node.stem
        # Add the stems together
        

    def collapse(self):
        to_collapse = []
        for child in self.children:
            if child.collapse():
                to_collapse.append(child)
        for child in to_collapse:
            self._collapse(child)
        if self.dep in ['pobj', 'det', 'case', 'neg', 'auxpass', 'aux', 'poss']:
            return True
        return False

    def unfurl(self):
        # Remove grafted children
        to_remove = []
        for child in self.children:
            if child.grafted:
                to_remove.append(child)
        for child in to_remove:
            # The child's original parent was never removed, and will take precedence (If not deleted)
            self.children.remove(child)
            child.grafted = False
        if len(self.collapsed) > 0:
            to_unfurl = self.collapsed
            for child in to_unfurl:
                self.children.append(child)
                self.collapsed.remove(child)
            self.text = self._text
            self._text = ''

        for child in self.children:
            child.unfurl()

    def get_features(self, labels, deps_above):
        deps_above.append(self.dep)
        features = [self.tag] + list(reversed(deps_above))
        if ENABLE_PADDING:
            if len(features) > ABOVE_PADDING:
                print('ERROR: Too many Above Dep features! Add more padding', len(features), '>', ABOVE_PADDING)
            # Pad to 100
            for _ in range(len(features), ABOVE_PADDING):
                features.append(None)
        out = []
        for child in self.children:
            child_features, child_dep = child.get_features(labels, deps_above)
            out += child_features
            features.append(child_dep)
        if ENABLE_PADDING:
            # Pad to 125
            if len(features) > BELOW_PADDING:
                print('ERROR: Too many Below Dep features! Add more padding', len(features), '>', BELOW_PADDING)
            for _ in range(len(features), BELOW_PADDING):
                features.append(None)
        node_features = {'idx': self.idx, 'stem': self.stem}
        for ix, label in enumerate(labels):
            node_features[label] = features[ix]
        out.append(node_features)  # Add node to list
        return out, self.dep

    def prune(self, results):
        keep = self.idx == -1 or results[self.idx]
        if keep:
            to_remove = []
            for child in self.children:
                if not child.prune(results):
                    to_remove.append(child)
            for child in to_remove:
                self.children.remove(child)
        return keep

In [None]:
class Root:
    @staticmethod
    def from_sentence(sentence, compressed):
        root = Root()
        root.sentence = sentence
        root.compressed = compressed
        root.root_node = Node.parse_spacy(list(nlp(sentence).sents)[0].root)
        root.comp_node = Node.parse_spacy(list(nlp(compressed).sents)[0].root)
        root.node_count = root.root_node.count_nodes()
        return root

    def __init__(self):
        self.sentence = ''
        self.compressed = ''
        self.root_node = None
        self.comp_node = None
        self._root_node = None
        self.node_count = 0

    def __str__(self):
        return f'Sentence: {self.sentence}\nCompressed: {self.compressed}\nTree:\n{self.root_node}\n'

    def to_spacy_dict(self):
        index_dict = sorted(self.root_node.get_index_dict().items())
        index_transform = [x[0] for x in index_dict]
        out = {'words': []}
        for index, node in index_dict:
            out['words'].append({'text': node.text, 'tag': node.tag})
        out['arcs'] = self.root_node.get_arcs()
        # Transform the indexes in the arcs
        for ix, arc in enumerate(out['arcs']):
            arc['start'] = index_transform.index(arc['start'])
            arc['end'] = index_transform.index(arc['end'])
            out['arcs'][ix] = arc
        return out

    def to_sentence(self):
        index_dict = sorted(self.root_node.get_index_dict().items())
        out = ''
        for index, node in index_dict:
            if node.dep == 'case':
                out += node.text
            else:
                out += ' ' + node.text
        return out.strip()

    def collapse(self):
        self.root_node.collapse()

    def add_fake(self):
        self._root_node = self.root_node
        self.root_node = Node.make_fake_root(self.root_node)

    def pop_fake(self):
        self.root_node = self._root_node

    def unfurl(self):
        self.root_node.unfurl()

    def prune(self, results):
        self.root_node.prune(results)

    def get_features(self, labels):
        out, _ = self.root_node.get_features(labels, [])
        return out

In [None]:
"""
Get the train and test data, and train the model
"""
train_data = get_train_data(label_data(dataset_1['train'][:10000]))
test_data = label_data(dataset_1['test'][:2000])

from nltk.classify import NaiveBayesClassifier
classifier = NaiveBayesClassifier.train(train_data)
classifier.show_most_informative_features()


Labeling data


 51%|█████     | 5106/10000 [02:27<19:47,  4.12it/s]

In [None]:
test_output = []
for root, node_features in list(test_data.items())[:2]:
    results = {}
    for node in node_features:
        # Check for all stems, but skip ROOT
        if node['idx'] == -1:
            continue
        idx = node.pop('idx')
        node.pop('stem')
        results[idx] = classifier.classify(node)
    print(root)
    root.prune(results)
    print(root)
    root.unfurl()
    print(root)
    root.pop_fake()
    print(root)
    test_output.append([
        root,                # The final output (For metrics or whatever)
        root.sentence,       # Our raw sentence
        root.compressed,     # Our target sentence
        root.to_sentence(),  # Our classifiers output
    ])



Labeling data


100%|██████████| 10000/10000 [04:49<00:00, 34.52it/s]



Labeling data


100%|██████████| 2000/2000 [00:57<00:00, 34.70it/s]


Most Informative Features
                  depa_1 = 'on Monday'     False : True   =     48.4 : 1.0
                  depa_1 = 'for bankruptcy'   True : False  =     33.0 : 1.0
                depb_251 = 'in Iraq'        True : False  =     27.0 : 1.0
                  depa_1 = 'as coach'       True : False  =     25.7 : 1.0
                depb_255 = 'on Wednesday'   True : False  =     25.3 : 1.0
                  depa_1 = 'on Wednesday'  False : True   =     25.2 : 1.0
                  depa_1 = 'at gunpoint'    True : False  =     23.2 : 1.0
                depb_251 = 'in 2013'        True : False  =     21.9 : 1.0
                  depa_5 = 'at a station'   True : False  =     20.9 : 1.0
                  depa_2 = 'as part'       False : True   =     20.8 : 1.0
