In [24]:
import spacy
from spacy import displacy
import en_core_web_lg
import pprint
import collections

import dgl
from dgl import DGLGraph
from dgl.data import MiniGCDataset
import dgl.function as fn
from dgl.data.utils import save_graphs

import torch

import pandas as pd

import numpy as np

parser = en_core_web_lg.load()

In [25]:
pp = pprint.PrettyPrinter(indent=2)

In [26]:
parser = en_core_web_lg.load()

In [27]:
def transfer_n_e(nodes, edges):

    num_nodes = len(nodes)
    new_edges = []
    for e1, e2 in edges:
        new_edges.append( [nodes[e1], nodes[e2]] ) 
    return num_nodes, new_edges

In [28]:
def bert_embedding_for_dp_token(token, bert_tokens, bert_embeddings, debug=False):
    try:
        idx = bert_tokens.index(token)
        return bert_embeddings[idx]
    except ValueError:
        temp = token
        start, end = 0, 0
        seq = False
        for i, bert_token in enumerate(bert_tokens):
            if debug:
                print("HA:", bert_token, token)
            if bert_token in temp:
                temp = temp[len(bert_token):]
                if not seq:
                    start = i
                    seq = True
            else:
                temp = token
                seq = False
            if len(temp) == 0:
                end = i + 1
                break
        
        if (debug):
            print(start, end)

        bert_emb_tensor = bert_embeddings[start:end]
        return torch.mean(bert_emb_tensor, dim=0)

In [29]:
size = 'xs'

In [30]:
X_preprocessed = torch.load("X_{}.pt".format(size))
y_data = torch.load("y_{}.pt".format(size))

In [31]:
def is_target(dp_token, options, debug=False):
    
    if dp_token == options[0].split(' ')[0]:
        return True
    
    if dp_token == options[1].split(' ')[0]:
        return True
    
    return False

In [33]:
all_graphs = []
gcn_offsets = []
cls_tokens = []
for row in X_preprocessed:
    sentence = row['sentence']
    bert_embeddings = row['encoding'][0]
    bert_tokens = row['tokens']
    options = row['options']
    
    doc = parser(sentence)
    nodes = collections.OrderedDict()
    edges = []
    edge_type = []
    
    offsets = []
    offset_words = []
    
    spacy_tokens = []
    
    for token in doc:
        
        spacy_tokens.append(token)
        
        # skip words that aren't targets or separated by one edge from target
        if not (is_target(token.text, options) or is_target(token.head.text, options)):
            continue
        
        if token.i not in nodes:
            nodes[token.i] = len(nodes)
            edges.append( [token.i, token.i])
            edge_type.append(0)
        
        if token.head.i not in nodes:
            nodes[token.head.i] = len(nodes)
            edges.append( [token.head.i, token.head.i] )
            edge_type.append(0)
            
        if token.dep_ != 'ROOT':
            edges.append( [ token.head.i, token.i ])
            edge_type.append(1)
            edges.append( [ token.i, token.head.i ])
            edge_type.append(2)
            
        if is_target(token.text, options):
            offsets.append(token.i)
            offset_words.append(token.text)
    
    num_nodes, tran_edges = transfer_n_e(nodes, edges)
    
    if (len(offsets) != 3):
        print("UNEXPECTED: at least 3 positions should be in offsets")
        print(sentence, options, len(offsets))
        print(offset_words)
    
    gcn_offset = [nodes[offset] for offset in offsets]
    gcn_offsets.append(gcn_offset)
    
    G = dgl.DGLGraph()
    G.add_nodes(num_nodes)
    G.add_edges(list(zip(*tran_edges))[0], list(zip(*tran_edges))[1])
    
    for i in range(len(doc))
        token = doc[i]
        if not (is_target(token.text, options) or is_target(token.head.text, options)):
            continue
            
        if token[i + 1] == "n\'t":
            token = token[:token.index("n\'t")+1]
        elif token[i] == "n\'t"
            token = token[token.index("n\'t")+1:]
            
        dp_token = token.text
        embedding = bert_embedding_for_dp_token(dp_token, bert_tokens, bert_embeddings)
        if(torch.isnan(embedding.unsqueeze(0)).any()):
            print("UNEXPECTED: bert_embedding_for_dp_token returns NaN")
            print(embedding.unsqueeze(0))
            print(token.i, token.text, sentence, bert_tokens)
            print(spacy_tokens)
            G.nodes[ nodes[token.i] ].data['h'] = torch.randn(1024).unsqueeze(0)
        else:
            G.nodes[ nodes[token.i] ].data['h'] = embedding.unsqueeze(0)
        
        head_dp_token = token.head.text
        embedding = bert_embedding_for_dp_token(head_dp_token, bert_tokens, bert_embeddings)
        if(torch.isnan(embedding.unsqueeze(0)).any()):
            print("UNEXPECTED: bert_embedding_for_dp_token returns NaN")
            print(embedding.unsqueeze(0))
            print(token.i, token.head.i, token.head.text, sentence, bert_tokens)
            print(spacy_tokens)
            G.nodes[ nodes[token.head.i] ].data['h'] = torch.randn(1024).unsqueeze(0)
        else:
            G.nodes[ nodes[token.head.i] ].data['h'] = embedding.unsqueeze(0)
        
    edge_norm = []
    for e1, e2 in tran_edges:
        if e1 == e2:
            edge_norm.append(1)
        else:
            edge_norm.append( 1 / (G.in_degree(e2) - 1 ) )

    edge_type = torch.from_numpy(np.array(edge_type))
    edge_norm = torch.from_numpy(np.array(edge_norm)).unsqueeze(1).float()
    
    G.edata.update({'rel_type': edge_type,})
    G.edata.update({'norm': edge_norm})
    # todo: Add <s> token embedding to graph here.
    all_graphs.append(G)
    cls_tokens.append(bert_embeddings[0])

UNEXPECTED: bert_embedding_for_dp_token returns NaN
tensor([[nan, nan, nan,  ..., nan, nan, nan]])
18 19 did Emily asked her sister Sarah if she needed any tampons or pads from the store, even though Emily didn't because she had switched to using menstrual cups. ['<s>', 'Emily', 'asked', 'her', 'sister', 'Sarah', 'if', 'she', 'needed', 'any', 'tamp', 'ons', 'or', 'pads', 'from', 'the', 'store', ',', 'even', 'though', 'Emily', 'didn', "'t", 'because', 'she', 'had', 'switched', 'to', 'using', 'menstrual', 'cups', '.', '</s>']
[Emily, asked, her, sister, Sarah, if, she, needed, any, tampons, or, pads, from, the, store, ,, even, though, Emily, did, n't, because, she, had, switched, to, using, menstrual, cups, .]
UNEXPECTED: bert_embedding_for_dp_token returns NaN
tensor([[nan, nan, nan,  ..., nan, nan, nan]])
18 19 did Emily asked her sister Sarah if she needed any tampons or pads from the store, even though Sarah didn't because she had switched to using menstrual cups. ['<s>', 'Emily', 'a

In [23]:
# TODO: Save all_graphs, gcn_offsets & cls_tokens
# cls_tokens = torch.tensor(cls_tokens)
cls_tokens = torch.stack(cls_tokens)
gcn_offsets = torch.tensor(gcn_offsets)

# https://docs.dgl.ai/en/0.4.x/generated/dgl.data.utils.load_graphs.html
save_graphs("data/X_train_graphs_{}.bin".format(size), all_graphs) 
torch.save(cls_tokens, "data/X_train_cls_tokens_{}.bin".format(size))
torch.save(gcn_offsets, "data/X_train_gcn_offsets_{}.bin".format(size))

# NOTEBOOK OVER - ALL CODE BELOW IS SIMPLY BACKUP

> 

>

> 

> 

>

> 

> 

>

> 

> 

>

> 

> 

>

> 


In [None]:
print(bert_embedding_for_dp_token("bread", bt, be))
# print(torch.mean(torch.stack(be[2:5]), dim=0))

In [None]:
nodes = collections.OrderedDict()
edges = []
edge_type = []

#     for i_word, word in enumerate(parse_rst['tokens']):
#         # TODO: skip words that aren't targets or seperated by one edge from target
    
#         if i_word not in nodes:
#             nodes[i_word] = len(nodes)
#             edges.append( [i_word, i_word])
#             edge_type.append(0)
#         if word['head'] not in nodes:
#             nodes[word['head']] = len(nodes)
#             edges.append( [ word['head'], word['head'] ] )
#             edge_type.append(0)

#         if word['dep'] != 'ROOT':
#             edges.append([ word['head'], word['id'] ])
#             edge_type.append(1)
#             edges.append([ word['id'], word['head'] ])
#             edge_type.append(2)
    
num_nodes, tran_edges = transfer_n_e(nodes, edges)

In [None]:
G = dgl.DGLGraph()
G.add_nodes(num_nodes)
G.add_edges(list(zip(*tran_edges))[0], list(zip(*tran_edges))[1])

In [None]:
for token in doc:
    dp_token = token.text
    embedding = bert_embedding_for_dp_token(dp_token, bt, be)
    G.nodes[ nodes[token.i] ].data['h'] = embedding.unsqueeze(0)
    
    head_dp_token = token.head.text
    embedding = bert_embedding_for_dp_token(head_dp_token, bt, be)
    G.nodes[ nodes[token.head.i] ].data['h'] = embedding.unsqueeze(0)

In [None]:
edge_norm = []
for e1, e2 in tran_edges:
    if e1 == e2:
        edge_norm.append(1)
    else:
        edge_norm.append( 1 / (G.in_degree(e2) - 1 ) )

edge_type = torch.from_numpy(np.array(edge_type))
edge_norm = torch.from_numpy(np.array(edge_norm)).unsqueeze(1).float()

In [None]:
G.edata.update({'rel_type': edge_type,})
G.edata.update({'norm': edge_norm})

In [None]:
# TODO: repeat above steps in a loop for all input!