<a href="https://colab.research.google.com/github/yg-li/QA-KG-RL/blob/master/QA_KG_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install --no-cache-dir torch-scatter torch-sparse torch-cluster
!pip install torch-geometric
! (if [ "$(pip freeze | grep spacy | cut -d'=' -f 3)" != "2.1.3" ]; then \
     pip uninstall -y spacy; \
     pip install spacy==2.1.3; \
   fi)
!pip install neuralcoref
!pip install allennlp

In [0]:
import os
import json
import itertools
import re
from datetime import datetime

import spacy
from spacy.matcher import PhraseMatcher
nlp = spacy.load("en_core_web_sm")
import neuralcoref
neuralcoref.add_to_pipe(nlp)

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam

from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.nn import RGCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder(cuda_device=0 if torch.cuda.is_available() else -1)

In [0]:
# mount Google Drive
from google.colab import drive
drive.mount('/gdrive')

# read in QAngaroo WikiHop
wh_data_path = '/gdrive/My Drive/Colab Notebooks/CSML/Project/data/qangaroo_v1.1/wikihop'
with open(os.path.join(wh_data_path, 'train.json')) as f:
  src = json.loads(f.read())
# with open(os.path.join(wh_data_path, 'dev.json')) as f:
#   src = json.loads(f.read())

# # read in HotpotQA
# hpqa_data_path = '/gdrive/My Drive/Colab Notebooks/CSML/Project/data/hotpotqa'
# with open(os.path.join(hpqa_data_path, 'hotpot_train_v1.1.json')) as f:
#   src = json.loads(f.read())

# Encoders and Output Layer

In [0]:
class QueryEncoder(nn.Module):
  def __init__(self, dropout=0):
    super(QueryEncoder, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.lstm1 = nn.LSTM(3072, 256, batch_first=True, bidirectional=True)
    self.lstm2 = nn.LSTM(512, 128, batch_first=True, bidirectional=True)
    self.hidden_map = nn.Linear(256, 128)
    self.cell_map = nn.Linear(256, 128)
    
  def forward(self, x):
    # batch_size is always 1 as encoding happens per query
    x, (h_n, c_n) = self.lstm1(x)
    x = self.dropout(x)
    h_n = self.dropout(F.relu(self.hidden_map(h_n)))
    c_n = self.dropout(F.relu(self.cell_map(c_n)))
    x, (q, c_n) = self.lstm2(x, (h_n, c_n))
    q = self.dropout(q.reshape(1, -1, 256))
    return q
  
class CandidateEncoder(nn.Module):
  def __init__(self, dropout=0):
    super(CandidateEncoder, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.linear1 = nn.Linear(3072, 256)
    # the following FF layers applied to concat of query and cand.
    self.linear2 = nn.Linear(512, 1024)
    self.linear3 = nn.Linear(1024, 512)
    
  def forward(self, x, q):
    x = self.dropout(F.relu(self.linear1(x)))
    # TODO: expand q to be the same size of x
    x = torch.cat((q, x), dim=-1)
    x = self.dropout(F.relu(self.linear2(x)))
    x = self.dropout(F.relu(self.linear3(x)))
    return x
  
class OutputLayer(nn.Module):
  def __init__(self, dropout=0):
    super(OutputLayer, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.linear1 = nn.Linear(768, 256)
    self.linear2 = nn.Linear(256, 128)
    self.linear3 = nn.Linear(128, 1)
    
  def forward(self, x, q):
    # batch_size is 1 as instances have different number of candidates
    # TODO: expand q to be the same size of x
    x = torch.cat((q, x), dim=-1)
    x = self.dropout(self.linear1(x))
    x = self.dropout(self.linear2(x))
    x = self.linear3(x)
    a = F.log_softmax(x, dim=-1)
    return a

# Build Graph

## Extract nodes and edges & Encode mentions with ELMo

In [0]:
instance = 0
query = src[instance].get('query')
supports = [text.lower() for text in src[instance].get('supports')]
docs = [nlp(text) for text in supports]
whole_doc = nlp(' '.join(supports))
cands = dict([(v, i) for i, v in 
              enumerate([cand.lower() for cand in src[instance].get('candidates')])])
answer = src[instance].get('answer')

query_encoder = QueryEncoder().to(device)
cand_encoder = CandidateEncoder().to(device)

nodes, node_types, node_embeddings, query_embedding, edges = \
        build_entity_graph(query, docs, whole_doc, cands, answer, query_encoder, cand_encoder)

In [0]:
def build_entity_graph(query, docs, whole_doc, cands, answer,
                       query_encoder, cand_encoder):
  ''' build the entity graph used in Entity-GCN
  Args:
    query: the query
    docs: spacy annotated documents
    whole_doc: spacy annotated concatenated documents
    cands: candidates
    answer: the correct answer
    query_encoder: the encoder for query
    cand_encoder: the encoder for candidates given query embedding
  Returns:
    nodes: nodes of the graph \\ dict{id : candidate id (-1 if query entity)}
    node_types: 1 for answer, 0 for other candidates, -1 for query entity
    node_embeddings: the contextualized embedding of mentions of candidates
    query_embedding: the emebdding of query
    doc_based_edges: edges that connect mentions in the same document 
    match_edges: edges that connect exact match \\ set((node1, node2))
    coref_edges: edges that connect mentions in the same coreference chain 
    compl_edges: edges that connect all nodes that have not been connected by 
                 any other types of edges \\ set((node1, node2))
  ''' 
  # extract the query entity
  query_entity = ' '.join(query.split(' ')[1:])
  cands[query_entity] = -1
  
  # matcher for candidates and query entity
  matcher = PhraseMatcher(nlp.vocab)
  patterns = [nlp.make_doc(cand) for cand in cands]
  matcher.add("CandList", None, *patterns)

  # get embedding of query, q
  query_embedding = query_encoder(torch.tensor(elmo.embed_sentence(query.split(' ')).reshape(1, -1, 3072)).to(device))
  # get elmo for all documents
  docs_elmo = [torch.tensor(elmo.embed_sentence([w.text for w in doc]).reshape(1, -1, 3072)).to(device) for doc in docs]
  print([emb.shape for emb in docs_elmo])
  
  nodes = {}
  node_types = []
  node_embeddings = []
  with_edges = set()
  
  # sets to store edges
  doc_based_edges = set()
  match_edges = set()
  coref_edges = set()
  compl_edges = set()
  
  # auxiliary variables for cross-document coreference
  out_coref_clusters = [[m.text for m in c.mentions] 
                        for c in whole_doc._.coref_clusters]
  out_coref_tmps = [set()] * len(out_coref_clusters) # nodes in same coref chain
  
  # accumulate nodes, add the doc_based & coreference edges
  for doc_id, doc in enumerate(docs):
    matches = matcher(doc)
    # text = ' '.join([toc.text for toc in doc])
    in_coref_clusters = [[m.text for m in c.mentions] 
                         for c in doc._.coref_clusters]
    in_coref_tmps = [set()] * len(in_coref_clusters) # nodes in same coref chain
    doc_tmp = set() # nodes in the same doc
    for _, start, end in matches:
      match = doc[start:end].text
      new_node = len(nodes)
      doc_tmp.add(new_node)
      nodes[new_node] = cands[match]
      node_types.append([1 if match == answer 
                         else -1 if match == query_entity 
                         else 0])
      match_elmo = docs_elmo[doc_id][:, start:end, :].mean(dim=1, keepdim=True)
      node_embeddings.append(cand_encoder(match_elmo, query_embedding))
      for i, cluster in enumerate(in_coref_clusters):
        if match in cluster:
          in_coref_tmps[i].add(new_node)
      for i, cluster in enumerate(out_coref_clusters):
        if match in cluster:
          out_coref_tmps[i].add(new_node)
          
    for pair in itertools.combinations(doc_tmp, 2):
      doc_based_edges.add(pair) # doc_based edges
      with_edges.update(pair)
    for coref_tmp in in_coref_tmps:
      for pair in itertools.combinations(coref_tmp, 2):
        coref_edges.add(pair) # within-document coref_edges
        with_edges.update(pair)
        
  # cross-document coref_edges
  for coref_tmp in out_coref_tmps:
    for pair in itertools.combinations(coref_tmp, 2):
      coref_edges.add(pair) # cross-document coref_edges
      with_edges.update(pair)
      
  # add exact match edges
  for i, j in itertools.combinations(nodes, 2):
    if nodes[i] == nodes[j]:
      match_edges.add((i,j))
      with_edges.update((i,j))
      
  # add complement edges
  isolated_nodes = set(nodes) - with_edges
  if isolated_nodes:
    for pair in itertools.combinations(isolated_nodes, 2):
      compl_edges.add(pair)
      
  return nodes, node_types, node_embeddings, query_embedding, \
         [doc_based_edges, match_edges, coref_edges, compl_edges]

## Build PyG graph

In [0]:
# build graphs from train set of WikiHop
query_encoder = QueryEncoder().to(device)
cand_encoder = CandidateEncoder().to(device)

graphs = []
for instance in range(10):#range(len(src)):
  print('i:', instance, '|', datetime.now(), flush=True)
  query = src[instance].get('query')
  supports = [text.lower() for text in src[instance].get('supports')]
  docs = [nlp(text) for text in supports]
  whole_doc = nlp(' '.join(supports))
  cands = dict([(v, i) for i, v in 
                enumerate([cand.lower() for cand in src[instance].get('candidates')])])
  answer = src[instance].get('answer')

  

  nodes, node_types, node_embeddings, query_embedding, edges = \
    build_entity_graph(query, docs, whole_doc, cands, answer, query_encoder, cand_encoder)
  
  # build graph in PyG
  # nodes
  x = torch.tensor(node_embeddings).shape
  node_name = list(nodes.values())
  tmp = torch.tensor(node_types)
  node_mask = (tmp >= 0).to(torch.float) # whether the node is in candidate list
  y = (tmp > 0).to(torch.float) # target

  # edges
  edge_index = torch.zeros(0, 2)
  edge_type = torch.zeros(0)
  num_relations = 0
  for i, e in enumerate(edges):
    if len(e) > 0:
      num_relations += 1
      tmp = torch.tensor(list(e), dtype=torch.float)
      # add edges with swapped direction to make the graph undirected
      tmp = torch.cat((tmp, torch.index_select(tmp,1,torch.tensor([1,0]))), 0)
      edge_index = torch.cat((edge_index, tmp), 0)
      edge_type = torch.cat((edge_type, torch.ones(tmp.shape[0]) * i), 0)

  data = Data(x=x, node_name=node_name, node_mask=node_mask, 
              edge_index=edge_index.t().contiguous(), edge_type=edge_type, 
              num_relations=num_relations, y=y)
  graphs.append(data)

In [0]:
# torch.save(graphs, '/gdrive/My Drive/Colab Notebooks/CSML/Project/data/qangaroo_train.pt')
# test = torch.load('/gdrive/My Drive/Colab Notebooks/CSML/Project/data/test.pt')

#RGCN

In [0]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    # three layers are sharing weights
    self.conv = RGCNConv(data.num_node_features, 512, 
                         data.num_relations, num_bases=30)

In [0]:
def train():
  # for some epoches
    # for each batch
      # build graphs using encoders (and ELMO)
      # run RGCN on graphs
      # output result and backward loss