<a href="https://colab.research.google.com/github/yg-li/QA-KG-RL/blob/master/QA_KG_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install --no-cache-dir torch-scatter torch-sparse torch-cluster
!pip install torch-geometric

!pip uninstall -y spacy
!pip install spacy==2.1.3
# !spacy download en_core_web_lg
!pip install neuralcoref

In [0]:
import os
import json
import itertools
import re

import spacy
nlp = spacy.load("en_core_web_sm")
# import en_core_web_lg
# nlp = en_core_web_lg.load()
import neuralcoref
neuralcoref.add_to_pipe(nlp)

import networkx as nx
import torch
from torch_geometric.data import Data

In [0]:
# mount Google Drive
from google.colab import drive
drive.mount('/gdrive')
data_path = '/gdrive/My Drive/Colab Notebooks/CSML/Project/data/qangaroo_v1.1/wikihop'

# read in QAngaroo WikiHop
with open(os.path.join(data_path, 'dev.json')) as f:
  src = json.loads(f.read())

## Build Graph

In [0]:
instance = 0

query = src[instance].get('query')
docs = [nlp(text.lower()) for text in src[instance].get('supports')]
cands = src[instance].get('candidates')
cands = dict(zip(cands, range(len(cands))))
answer = src[instance].get('answer')

In [0]:
def build_entity_graph(cands, docs):
  ''' build the entity graph used in Entity-GCN
  Args:
    cands: candidates \\ list[str]
    docs: spacy annotated documents \\ list[Doc]
  Returns:
    nodes: nodes of the graph \\ dict{id : candidate name}
    doc_based_edges: edges that connect mentions in the same document \\ set((node1, node2))
    match_edges: edges that connect exact match \\ set((node1, node2))
    coref_edges: edges that connect mentions in the same coreference chain 
                 \\ set((node1, node2))
    compl_edges: edges that connect all nodes that have not been connected by 
                 any other types of edges \\ set((node1, node2))
  ''' 
  nodes = {}
  with_edges = set()
  
  doc_based_edges = set()
  match_edges = set()
  coref_edges = set()
  compl_edges = set()
  
  # accumulate nodes, add the doc_based & coreference edges
  for doc in docs:
    coref_clusters = [[m.text for m in c.mentions] for c in doc._.coref_clusters]
    text = doc.text
    doc_tmp = set()
    coref_tmps = [set()] * len(coref_clusters)
    for cand in cands:
      matches = re.findall(r'\b%s\b' % cand, text)
      new_nodes = range(len(nodes), len(nodes)+len(matches))
      doc_tmp.update(new_nodes)
      nodes.update(dict(zip(new_nodes, [cand]*len(matches))))
      
      for cluster, i in zip(coref_clusters, range(len(coref_clusters))):
        if cand in cluster:
          coref_tmps[i].update(new_nodes)
          
    for pair in itertools.combinations(doc_tmp, 2):
      doc_based_edges.add(pair)
      with_edges.update(pair)
    for coref_tmp in coref_tmps:
      for pair in itertools.combinations(coref_tmp, 2):
        coref_edges.add(pair)
        with_edges.update(pair)
      
  # add exact match edges
  for i, j in itertools.combinations(nodes, 2):
    if nodes[i] == nodes[j]:
      match_edges.add((i,j))
      with_edges.update((i,j))
      
  # add complement edges
  isolated_nodes = set(nodes) - with_edges
  if isolated_nodes:
    for pair in itertools.combinations(isolated_nodes, 2):
      compl_edges.add(pair)
      
  return nodes, doc_based_edges, match_edges, coref_edges, compl_edges

In [0]:
nodes, doc_based_edges, match_edges, coref_edges, compl_edges = build_entity_graph(cands, docs)