In [95]:
import spacy
from spacy import displacy
import en_core_web_lg
import pprint
import collections

import dgl
from dgl import DGLGraph
from dgl.data import MiniGCDataset
import dgl.function as fn
from dgl.data.utils import save_graphs

import torch

import pandas as pd

import numpy as np
import statistics 

parser = en_core_web_lg.load()

In [2]:
pp = pprint.PrettyPrinter(indent=2)

In [3]:
def transfer_n_e(nodes, edges):

    num_nodes = len(nodes)
    new_edges = []
    for e1, e2 in edges:
        new_edges.append( [nodes[e1], nodes[e2]] ) 
    return num_nodes, new_edges

In [179]:
def bert_embedding_for_dp_token(token, bert_tokens, bert_embeddings, debug=False):
    try:
        idx = bert_tokens.index(token)
        return bert_embeddings[idx]
    except ValueError:
        temp = token
        start, end = 0, 0
        seq = False
        for i, bert_token in enumerate(bert_tokens):
            if debug:
                print("HA:", bert_token, token)
            if bert_token in temp and temp.index(bert_token) == 0:
                temp = temp[len(bert_token):]
                if not seq:
                    start = i
                    seq = True
            else:
                temp = token
                seq = False
                if bert_token in temp and temp.index(bert_token) == 0:
                    temp = temp[len(bert_token):]
                    if not seq:
                        start = i
                        seq = True
            if len(temp) == 0:
                end = i + 1
                break
        
        if (debug):
            print(start, end)

        bert_emb_tensor = bert_embeddings[start:end]
        return torch.mean(bert_emb_tensor, dim=0)

In [180]:
bt = ['<s>', 'Robert', 'checked', 'the', 'therm', 'ometer', 'before', 'reassuring', 'Randy', 'that', 'the', 'baby', 'would', 'be', 'fine', ',', 'because', 'Robert', 'worked', 'as', 'a', 'nurse', '.', '</s>']
be = torch.randn((len(bt), 3))
print(bert_embedding_for_dp_token("thermometer", bt, be))


tensor([-0.3612, -0.6157,  0.4424])


In [11]:
size = 'xl'

In [13]:
X_preprocessed = torch.load("../data/bert_preprocessed/X_{}.pt".format(size))
y_data = torch.load("../data/bert_preprocessed/y_{}.pt".format(size))

In [14]:
def is_target(dp_token, options, debug=False):
    
    if dp_token == options[0].split(' ')[0]:
        return True
    
    if dp_token == options[1].split(' ')[0]:
        return True
    
    return False

In [106]:
doc = parser("Sheila didn’t check the temperature of her flat iron before straightening hair. The iron was hot.")

In [115]:
db_tokens = []
db_head_tokens = []
for token in doc:
    db_tokens.append(str(token))
    db_head_tokens.append(str(token.head))
    
print(db_tokens)
print(db_head_tokens)

contractions = []
for i in range(len(db_tokens)):
    if db_tokens[i] == "n\'t" or db_tokens[i] == "n’t":
        db_tokens[i] = "\'t"
        contractions.append(db_tokens[i-1])
        db_tokens[i-1] = db_tokens[i-1] + "n"

for i in range(len(db_head_tokens)):
    if db_head_tokens[i] == "n\'t" or db_head_tokens[i] == "n’t":
        db_head_tokens[i] = "\'t"
    if db_head_tokens[i] in contractions:
        db_head_tokens[i] = db_head_tokens[i] + "n"
        
print(db_tokens)
print(db_head_tokens)

['Sheila', 'did', 'n’t', 'check', 'the', 'temperature', 'of', 'her', 'flat', 'iron', 'before', 'straightening', 'hair', '.', 'The', 'iron', 'was', 'hot', '.']
['check', 'check', 'check', 'check', 'temperature', 'check', 'temperature', 'iron', 'iron', 'of', 'check', 'before', 'straightening', 'check', 'iron', 'was', 'was', 'was', 'was']
Sheila n't
did n't
n’t n't
FUCK
check n't
the n't
temperature n't
of n't
her n't
flat n't
iron n't
before n't
straightening n't
hair n't
. n't
The n't
iron n't
was n't
hot n't
. n't
['Sheila', 'didn', "'t", 'check', 'the', 'temperature', 'of', 'her', 'flat', 'iron', 'before', 'straightening', 'hair', '.', 'The', 'iron', 'was', 'hot', '.']
['check', 'check', 'check', 'check', 'temperature', 'check', 'temperature', 'iron', 'iron', 'of', 'check', 'before', 'straightening', 'check', 'iron', 'was', 'was', 'was', 'was']


In [183]:
all_graphs = []
gcn_offsets = []
cls_tokens = []

iteration = 1
total = len(X_preprocessed) + 1
skipped = 0

for row in X_preprocessed:
    
    if iteration % 100 == 0:
        print(f"Progress: {iteration}/{total}...")
    iteration += 1
    
    sentence = row['sentence']
    bert_embeddings = row['encoding'][0]
    bert_tokens = row['tokens']
    options = row['options']
    
    doc = parser(sentence)
    nodes = collections.OrderedDict()
    edges = []
    edge_type = []
    
    offsets = []
    offset_words = []
    
    spacy_tokens = []
    
    for token in doc:
        
        spacy_tokens.append(token)
        
        # skip words that aren't targets or separated by one edge from target
        if not (is_target(token.text, options) or is_target(token.head.text, options)):
            continue
        
        if token.i not in nodes:
            nodes[token.i] = len(nodes)
            edges.append( [token.i, token.i])
            edge_type.append(0)
        
        if token.head.i not in nodes:
            nodes[token.head.i] = len(nodes)
            edges.append( [token.head.i, token.head.i] )
            edge_type.append(0)
            
        if token.dep_ != 'ROOT':
            edges.append( [ token.head.i, token.i ])
            edge_type.append(1)
            edges.append( [ token.i, token.head.i ])
            edge_type.append(2)
            
        if is_target(token.text, options):
            offsets.append(token.i)
            offset_words.append(token.text)
    
    num_nodes, tran_edges = transfer_n_e(nodes, edges)
    
    if (len(offsets) != 3):
        print("UNEXPECTED: at least 3 positions should be in offsets")
#         print(sentence, options, len(offsets))
#         print(offset_words)
        skipped += 1
        continue
    
    gcn_offset = [nodes[offset] for offset in offsets]
    gcn_offsets.append(gcn_offset)
#     print(f"gcn_offset size: {len(gcn_offset)}")
    
    G = dgl.DGLGraph()
    G.add_nodes(num_nodes)
    try:
        G.add_edges(list(zip(*tran_edges))[0], list(zip(*tran_edges))[1])
    except:
        print("Index of bound on zip G edges. Skipping...")
        skipped += 1
        continue
    
    # Transform tokens and head tokens into bert contractions
    db_tokens = []
    db_head_tokens = []
    for token in doc:
        db_tokens.append(str(token))
        db_head_tokens.append(str(token.head))

    contractions = []
    for i in range(len(db_tokens)):
        if db_tokens[i] == "n\'t" or db_tokens[i] == "n’t":
            db_tokens[i] = "\'t"
            contractions.append(db_tokens[i-1])
            db_tokens[i-1] = db_tokens[i-1] + "n"

    for i in range(len(db_head_tokens)):
        if db_head_tokens[i] == "n\'t" or db_head_tokens[i] == "n’t":
            db_head_tokens[i] = "\'t"
        if db_head_tokens[i] in contractions:
            db_head_tokens[i] = db_head_tokens[i] + "n"

    idx = 0
    prev_dp_token = None
    for token in doc:
        
        if not (is_target(token.text, options) or is_target(token.head.text, options)):
            continue
        
        embedding = bert_embedding_for_dp_token(db_tokens[idx], bert_tokens, bert_embeddings)
        if db_tokens[idx] == "\'t":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('t', bert_tokens, bert_embeddings)
        elif db_tokens[idx] == "can":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('cannot', bert_tokens, bert_embeddings)
        elif db_tokens[idx] == "not":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('cannot', bert_tokens, bert_embeddings)
        elif db_tokens[idx] == "wo":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('wont', bert_tokens, bert_embeddings)
        elif db_tokens[idx] == "nt":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('wont', bert_tokens, bert_embeddings)
                
        if(torch.isnan(embedding.unsqueeze(0)).any()):
            print("fuckBBBB: ", db_tokens[idx])
            print("UNEXPECTED: bert_embedding_for_dp_token returns NaN, skipping")
            print(sentence)
            print(spacy_tokens)
            print(bert_tokens)
            print(bert_embeddings)
            print('\n')
            G.nodes[ nodes[token.i] ].data['h'] = torch.randn(1024).unsqueeze(0)
        else:
            G.nodes[ nodes[token.i] ].data['h'] = embedding.unsqueeze(0)
        
        embedding = bert_embedding_for_dp_token(db_head_tokens[idx], bert_tokens, bert_embeddings)
        if db_head_tokens[idx] == "\'t":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('t', bert_tokens, bert_embeddings)
        elif db_head_tokens[idx] == "can":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('cannot', bert_tokens, bert_embeddings)
        elif db_head_tokens[idx] == "not":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('cannot', bert_tokens, bert_embeddings)
        elif db_head_tokens[idx] == "wo":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('wont', bert_tokens, bert_embeddings)
        elif db_head_tokens[idx] == "nt":
            if(torch.isnan(embedding.unsqueeze(0)).any()):
                embedding = bert_embedding_for_dp_token('wont', bert_tokens, bert_embeddings)
                
        if(torch.isnan(embedding.unsqueeze(0)).any()):
            print("fuckHEAD: ", db_head_tokens[idx])
            print("UNEXPECTED: bert_embedding_for_dp_token returns NaN, skipping")
            print(sentence)
            print(spacy_tokens)
            print(bert_tokens)
            print('\n')
            G.nodes[ nodes[token.head.i] ].data['h'] = torch.randn(1024).unsqueeze(0)
        else:
            G.nodes[ nodes[token.head.i] ].data['h'] = embedding.unsqueeze(0)
        
        prev_dp_token = token
        idx += 1
        
    edge_norm = []
    for e1, e2 in tran_edges:
        if e1 == e2:
            edge_norm.append(1)
        else:
            edge_norm.append( 1 / (G.in_degree(e2) - 1 ) )

    edge_type = torch.from_numpy(np.array(edge_type))
    edge_norm = torch.from_numpy(np.array(edge_norm)).unsqueeze(1).float()
    
    G.edata.update({'rel_type': edge_type,})
    G.edata.update({'norm': edge_norm})
    # todo: Add <s> token embedding to graph here.
    all_graphs.append(G)
    cls_tokens.append(bert_embeddings[0])
    
print(f"Done. Skipped {skipped} instances.")

Progress: 100/80797...
Progress: 200/80797...
Progress: 300/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 400/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 500/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 600/80797..

UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 4300/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 4400/80797...
Progress: 4500/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 4600/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should 

Progress: 9500/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 9600/80797...
Progress: 9700/80797...
Progress: 9800/80797...
Progress: 9900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 10000/80797...
Progress: 10100/80797...
Progress: 10200/80797...
Progress: 10300/80797...
Progress: 10400/80797...
Progress: 10500/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 10600/80797...
Progress: 10700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions

Progress: 14500/80797...
Progress: 14600/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 14700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 14800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 14900/80797...
Progress: 1

UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 18100/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 18200/80797...
Progress: 18300/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 18400/80797...
Progress: 18500/80797...
Progress: 1

Progress: 26500/80797...
Progress: 26600/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 26700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 26800/80797...
Progress: 26900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 2

Progress: 31700/80797...
Progress: 31800/80797...
Progress: 31900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 32000/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 32100/80797...
Progress: 32200/80797...
Progress: 32300/80797...
Progress: 32400/80797...
Progress: 32500/80797...
Progress: 32600/80797...
Progress: 32700/80797...
Progress: 32800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offset

Progress: 37200/80797...
Progress: 37300/80797...
Progress: 37400/80797...
Progress: 37500/80797...
Progress: 37600/80797...
Progress: 37700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 37800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 37900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be i

Progress: 43200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 43300/80797...
Progress: 43400/80797...
Progress: 43500/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 43600/80797...
Progress: 43700/80797...
Progress: 43800/80797...
Progress: 43900/80797...
Progress: 44000/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 44100/80797...
Progress: 44200/80797...
Progress: 44300/80797...
Progress: 44400/80797...
Progress: 44500/80797...
UN

Progress: 48900/80797...
Progress: 49000/80797...
Progress: 49100/80797...
Progress: 49200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 49300/80797...
Progress: 49400/80797...
Progress: 49500/80797...
Progress: 49600/80797...
Progress: 49700/80797...
Progress: 49800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 49900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 

Progress: 52100/80797...
Progress: 52200/80797...
Progress: 52300/80797...
Progress: 52400/80797...
Progress: 52500/80797...
Progress: 52600/80797...
Progress: 52700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 52800/80797...
Progress: 52900/80797...
Progress: 53000/80797...
Progress: 53100/80797...
Progress: 53200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
fuckBBBB:  ’s
UNEXPECTED: bert_embedding_for_dp_token returns NaN, skipping
Red was Monica ’s favourite colour, so it comes as no suprise that Monica’s house was full of rich warm scarlets and Kayla’s wardrobe was all cool metallic greys.
[Red, was, Monica, ’s, favour

Progress: 54800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 54900/80797...
Progress: 55000/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 55100/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 55200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 5

Progress: 60500/80797...
Progress: 60600/80797...
Progress: 60700/80797...
Progress: 60800/80797...
Progress: 60900/80797...
Progress: 61000/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 61100/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 61200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 61300/80797...
Progress: 61400/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in of

Progress: 66300/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 66400/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 66500/80797...
fuckHEAD:  gon
UNEXPECTED: bert_embedding_for_dp_token returns NaN, skipping
Chuck was gonna get an RV for the trip but decided to take his SUV instead because he was familiar with the SUV .
[Chuck, was, gon, na, get, an, RV, for, the, trip, but, decided, to, take, his, SUV, instead, because, he, was, familiar, with, the, SUV, .]
['<s>', 'Chuck', 'was', 'gonna', 'get', 'an', 'RV', 'for', 'the', 'trip', 'but', 'decided', 'to', 'take', 'his', 'SUV', 'instead', 'because', 'he', 'was', 'famil

UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 66600/80797...
Progress: 66700/80797...
Progress: 66800/80797...
Progress: 66900/80797...
Progress: 67000/80797...
Progress: 67100/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 67200/80797...
Progress: 67300/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be i

Progress: 69600/80797...
Progress: 69700/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 69800/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 69900/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 70000/80797...
Progress: 7

Progress: 75000/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 75100/80797...
UNEXPECTED: at least 3 positions should be in offsets
Progress: 75200/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
Progress: 75300/80797...
Progress: 75400/80797...
Progress: 75500/80797...
Progress: 75600/80797...
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at least 3 positions should be in offsets
UNEXPECTED: at 

KeyboardInterrupt: 

In [182]:
# TODO: Save all_graphs, gcn_offsets & cls_tokens
# cls_tokens = torch.tensor(cls_tokens)
cls_tokens = torch.stack(cls_tokens)
gcn_offsets = torch.tensor(gcn_offsets)

# https://docs.dgl.ai/en/0.4.x/generated/dgl.data.utils.load_graphs.html
save_graphs("data/Jack_X_train_graphs_{}.bin".format(size), all_graphs) 
torch.save(cls_tokens, "data/Jack_X_train_cls_tokens_{}.bin".format(size))
torch.save(gcn_offsets, "data/Jack_X_train_gcn_offsets_{}.bin".format(size))

ValueError: expected sequence of length 3 at dim 1 (got 4)

# NOTEBOOK OVER - ALL CODE BELOW IS SIMPLY BACKUP

> 

>

> 

> 

>

> 

> 

>

> 

> 

>

> 

> 

>

> 


In [None]:
print(bert_embedding_for_dp_token("bread", bt, be))
# print(torch.mean(torch.stack(be[2:5]), dim=0))

In [None]:
nodes = collections.OrderedDict()
edges = []
edge_type = []

#     for i_word, word in enumerate(parse_rst['tokens']):
#         # TODO: skip words that aren't targets or seperated by one edge from target
    
#         if i_word not in nodes:
#             nodes[i_word] = len(nodes)
#             edges.append( [i_word, i_word])
#             edge_type.append(0)
#         if word['head'] not in nodes:
#             nodes[word['head']] = len(nodes)
#             edges.append( [ word['head'], word['head'] ] )
#             edge_type.append(0)

#         if word['dep'] != 'ROOT':
#             edges.append([ word['head'], word['id'] ])
#             edge_type.append(1)
#             edges.append([ word['id'], word['head'] ])
#             edge_type.append(2)
    
num_nodes, tran_edges = transfer_n_e(nodes, edges)

In [None]:
G = dgl.DGLGraph()
G.add_nodes(num_nodes)
G.add_edges(list(zip(*tran_edges))[0], list(zip(*tran_edges))[1])

In [None]:
for token in doc:
    dp_token = token.text
    embedding = bert_embedding_for_dp_token(dp_token, bt, be)
    G.nodes[ nodes[token.i] ].data['h'] = embedding.unsqueeze(0)
    
    head_dp_token = token.head.text
    embedding = bert_embedding_for_dp_token(head_dp_token, bt, be)
    G.nodes[ nodes[token.head.i] ].data['h'] = embedding.unsqueeze(0)

In [None]:
edge_norm = []
for e1, e2 in tran_edges:
    if e1 == e2:
        edge_norm.append(1)
    else:
        edge_norm.append( 1 / (G.in_degree(e2) - 1 ) )

edge_type = torch.from_numpy(np.array(edge_type))
edge_norm = torch.from_numpy(np.array(edge_norm)).unsqueeze(1).float()

In [None]:
G.edata.update({'rel_type': edge_type,})
G.edata.update({'norm': edge_norm})

In [None]:
# TODO: repeat above steps in a loop for all input!