<a href="https://colab.research.google.com/github/yg-li/MHQA-with-LP/blob/master/Entity_GCN_GCP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is about using Entity-GCN, an algorithm using R-GCN on entity-relations graph to solve the multi-hop QA problem.

In [0]:
import os
import json

# read in QAngaroo WikiHop
wh_data_path='./wikihop'
with open(os.path.join(wh_data_path, 'train.json')) as f:
  train_src = json.loads(f.read())
with open(os.path.join(wh_data_path, 'dev.json')) as f:
  dev_src = json.loads(f.read())

In [0]:
# Paths needed by torch-geometric
!export PATH=/usr/local/cuda/bin:$PATH
!export CPATH=/usr/local/cuda/include:$CPATH
!export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH

In [0]:
# # apex for mixed precision training
# ! (if ! [ "$(pip freeze | grep apex)" ]; then \
#      git clone https://github.com/NVIDIA/apex; \
#      pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex; \
#    fi)

In [0]:
# # neuralcoref works only with spacy<=2.1.3
# ! (if [ "$(pip freeze | grep spacy | cut -d'=' -f 3)" != "2.1.3" ]; then \
#      pip uninstall -y spacy; \
#      pip install spacy==2.1.3; \
#    fi)
# !pip install neuralcoref
# !pip install allennlp

# !pip install --no-cache-dir torch-scatter torch-sparse torch-cluster
# !pip install torch-geometric

In [0]:
import itertools
import random
from datetime import datetime

import spacy
from spacy.matcher import PhraseMatcher
import neuralcoref
from allennlp.commands.elmo import ElmoEmbedder

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam

from torch_geometric.data import Data
from torch_geometric.nn import RGCNConv

from apex import amp

In [0]:
nlp = spacy.load("en_core_web_sm")
neuralcoref.add_to_pipe(nlp)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

elmo = ElmoEmbedder(cuda_device=0 if torch.cuda.is_available() else -1)

# Models

In [0]:
class QueryEncoder(nn.Module):
  def __init__(self, dropout=0):
    super(QueryEncoder, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.lstm1 = nn.LSTM(3072, 256, batch_first=True, bidirectional=True)
    self.lstm2 = nn.LSTM(512, 128, batch_first=True, bidirectional=True)
    self.h_0 = nn.Parameter(torch.rand((2, 1, 256)))
    self.c_0 = nn.Parameter(torch.rand((2, 1, 256)))
    self.hidden_map = nn.Linear(256, 128)
    self.cell_map = nn.Linear(256, 128)
    
  def forward(self, x):
    # batch_size is always 1 as encoding happens per query
    x, (h_n, c_n) = self.lstm1(x, (self.h_0, self.c_0))
    x = self.dropout(x)
    h_n = F.relu(self.hidden_map(h_n))
    c_n = F.relu(self.cell_map(c_n))
    x, (q, c_n) = self.lstm2(x, (h_n, c_n))
    q = self.dropout(q.reshape(1, -1))
    return q
  
  
class CandidateEncoder(nn.Module):
  def __init__(self, dropout=0):
    super(CandidateEncoder, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.linear1 = nn.Linear(3072, 256)
    # the following FF layers applied to concat of query and candidates
    self.linear2 = nn.Linear(512, 1024)
    self.linear3 = nn.Linear(1024, 512)
    
  def forward(self, x, q):
    x = self.dropout(F.relu(self.linear1(x)))
    x = torch.cat((q, x), dim=-1)
    x = self.dropout(F.relu(self.linear2(x)))
    x = self.dropout(F.relu(self.linear3(x)))
    return x
  
  
class PyG_RGCN(nn.Module):
  def __init__(self, dropout=0, L=3):
    super(PyG_RGCN, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.L = L
    # all R-GCN layers are sharing weights
    self.conv = RGCNConv(512, 512, num_relations=4, num_bases=4)
    self.gating = nn.Linear(1024, 1)
    
  def forward(self, x, edge_index, edge_type):
    # L is the number of R-GCN layers
    for _ in range(self.L):
      u = self.conv(x, edge_index, edge_type)
      a = torch.sigmoid(self.gating(torch.cat((u, x), dim=-1)))
      x = self.dropout(torch.tanh(u) * a + x * (1-a))
    return x 

class OutputLayer(nn.Module):
  def __init__(self, dropout=0):
    super(OutputLayer, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    self.linear1 = nn.Linear(768, 256)
    self.linear2 = nn.Linear(256, 128)
    self.linear3 = nn.Linear(128, 1)
    
  def forward(self, x, q):
    # batch_size is 1 as instances have different number of candidates
    x = torch.cat((q.expand(x.shape[0], -1), x), dim=-1)
    x = self.dropout(F.relu(self.linear1(x)))
    x = self.dropout(F.relu(self.linear2(x)))
    x = self.linear3(x)
    return x

# Build Graph

## Extract nodes and edges & Encode mentions with ELMo

In [0]:
def extract_info(query, docs, whole_doc, cands, answer,
                 query_encoder, cand_encoder):
  ''' extract the information needed to build Entity-GCN's graph
  Args:
    query: the query
    docs: spacy annotated documents
    whole_doc: spacy annotated concatenated documents
    cands: candidates
    answer: the correct answer
    query_encoder: the encoder for query
    cand_encoder: the encoder for candidates given query embedding
  Returns:
    nodes: nodes of the graph \\ dict{id : candidate id (-1 if query entity)}
    node_types: 1 for answer, 0 for other candidates, -1 for query entity
    node_embeddings: the contextualized embedding of mentions of candidates
    query_embedding: the emebdding of query
    doc_based_edges: edges that connect mentions in the same document 
    match_edges: edges that connect exact match \\ set((node1, node2))
    coref_edges: edges that connect mentions in the same coreference chain 
    compl_edges: edges that connect all nodes that have not been connected by 
                 any other types of edges \\ set((node1, node2))
  ''' 
  # extract the query entity
  query_entity = ' '.join(query.split(' ')[1:])
  cands[query_entity] = -1
  
  # matcher for candidates and query entity
  matcher = PhraseMatcher(nlp.vocab)
  patterns = [nlp.make_doc(cand) for cand in cands]
  matcher.add("CandList", None, *patterns)

  # get embedding of query, q
  query_embedding = query_encoder(
      torch.as_tensor(
          elmo.embed_sentence(
              query.split(' ')[0].split('_') + query.split(' ')[1:]
          ).reshape(1, -1, 3072), 
      device=device))
  # get elmo for all documents
  docs_elmo = [torch.as_tensor(d.reshape(1, -1, 3072), device=device) 
        for d in elmo.embed_sentences([[w.text for w in doc] for doc in docs])]
  
  nodes = {}
  node_types = []
  node_embeddings = []
  with_edges = set()
  
  # sets to store edges
  doc_based_edges = set()
  match_edges = set()
  coref_edges = set()
  compl_edges = set()
  
  # auxiliary variables for cross-document coreference
  out_coref_clusters = [[m.text for m in c.mentions] 
                        for c in whole_doc._.coref_clusters]
  out_coref_tmps = [set()] * len(out_coref_clusters) # nodes in same coref chain
  
  # accumulate nodes, add the doc_based & coreference edges
  for doc_id, doc in enumerate(docs):
    matches = matcher(doc)
    in_coref_clusters = [[m.text for m in c.mentions] 
                         for c in doc._.coref_clusters]
    in_coref_tmps = [set()] * len(in_coref_clusters) # nodes in same coref chain
    doc_tmp = set() # nodes in the same doc
    for _, start, end in matches:
      match = doc[start:end].text
      new_node = len(nodes)
      doc_tmp.add(new_node)
      nodes[new_node] = cands.get(match, -1)
      node_types.append([1 if match == answer 
                         else -1 if match == query_entity 
                         else 0])
      match_elmo = docs_elmo[doc_id][:, start:end, :].mean(dim=1) # mean pooling
      node_embeddings.append(cand_encoder(match_elmo, query_embedding))
      for i, cluster in enumerate(in_coref_clusters):
        if match in cluster:
          in_coref_tmps[i].add(new_node)
      for i, cluster in enumerate(out_coref_clusters):
        if match in cluster:
          out_coref_tmps[i].add(new_node)
          
    for pair in itertools.combinations(doc_tmp, 2):
      doc_based_edges.add(pair) # doc_based edges
      with_edges.update(pair)
    for coref_tmp in in_coref_tmps:
      for pair in itertools.combinations(coref_tmp, 2):
        coref_edges.add(pair) # within-document coref_edges
        with_edges.update(pair)
        
  # cross-document coref_edges
  for coref_tmp in out_coref_tmps:
    for pair in itertools.combinations(coref_tmp, 2):
      coref_edges.add(pair) # cross-document coref_edges
      with_edges.update(pair)
      
  # add exact match edges
  for i, j in itertools.combinations(nodes, 2):
    if nodes[i] == nodes[j]:
      match_edges.add((i,j))
      with_edges.update((i,j))
      
  # add complement edges
  isolated_nodes = set(nodes) - with_edges
  if isolated_nodes:
    for pair in itertools.combinations(isolated_nodes, 2):
      compl_edges.add(pair)
      
  return nodes, node_types, node_embeddings, query_embedding, \
         [doc_based_edges, match_edges, coref_edges, compl_edges]

## PyG graph

In [0]:
def build_pyg_graph(nodes, node_types, node_embeddings, query_embedding, edges):
  # nodes
  x = torch.cat(node_embeddings).squeeze()
  # node_name = torch.tensor(list(nodes.values()), device=device)
  tmp = torch.tensor(node_types)
  # node_mask = (tmp >= 0).to(device) # whether the node is in candidate list
  y = (tmp > 0).to(device, torch.float) # target 

  # edges
  edge_index = torch.zeros(0, 2).to(device, torch.long)
  edge_type = torch.zeros(0).to(device, torch.long)
  for i, e in enumerate(edges):
    if len(e) > 0:
      tmp = torch.tensor(list(e), device=device)
      # add edges with swapped direction to make the graph undirected
      tmp = torch.cat((tmp, 
                       torch.index_select(tmp, 1, torch.tensor([1,0], device=device))), 
                      0)
      edge_index = torch.cat((edge_index, tmp.to(torch.long)), 0)
      edge_type = torch.cat((edge_type, torch.ones(tmp.shape[0]).to(device, torch.long) * i), 0)

  data = Data(x=x, query=query_embedding, y=y,
              edge_index=edge_index.t().contiguous(), edge_type=edge_type)
  return data

In [0]:
def build_graph(instance, query_encoder, cand_encoder, extract=False):
  query = instance.get('query').strip()
  supports = [text.lower() for text in instance.get('supports')]
  docs = [nlp(text) for text in supports]
  whole_doc = nlp(' '.join(supports))
  cands = dict([(v, i) for i, v in 
                enumerate([cand.lower().strip() 
                           for cand in instance.get('candidates')])])
  answer = instance.get('answer')
  
  # extract nodes, edges, and embeddings
  nodes, node_types, node_embeddings, query_embedding, edges = \
    extract_info(query, docs, whole_doc, cands, answer, query_encoder, cand_encoder)
  if extract:
    return nodes, node_types, node_embeddings, query_embedding, edges
  
  # build PyG graph
  g = build_pyg_graph(nodes, node_types, node_embeddings, query_embedding, edges)
  
  return g

# Training & Testing

In [0]:
def save_models(epoch, step, loss_history, optimizer,
                query_encoder, cand_encoder, rgcn, output_layer, PATH):
  torch.save({
        'epoch': epoch,
        'step': step,
        'loss_history': loss_history,
        'query_encoder': query_encoder.state_dict(),
        'cand_encoder': cand_encoder.state_dict(),
        'rgcn': rgcn.state_dict(),
        'output_layer': output_layer.state_dict(),
        'optimizer': optimizer.state_dict()
  }, PATH)

In [0]:
def train(epochs, step, batch_size, optimizer, loss_fn, src, query_encoder, 
          cand_encoder, rgcn, output_layer, PATH, loss_history=[], tol=1e-3):
  query_encoder.train()
  cand_encoder.train()
  rgcn.train()
  output_layer.train()
  for e in epochs:
    random.shuffle(src)
    for i in range(step, len(src)):
      try:
        optimizer.zero_grad()
        g = build_graph(src[i], query_encoder, cand_encoder)
        # TODO: link predition
        out = rgcn(g.x, g.edge_index, g.edge_type)
        pred = output_layer(out, g.query)
        loss = loss_fn(pred, g.y) 
        with amp.scale_loss(loss, optimizer) as scaled_loss:
          scaled_loss.backward()
        optimizer.step()
        print('Epoch: {:2d}  [{:d}/{:d}]\tloss: {:.4f}\t{}'.format(
            e, i+1, len(src), loss.item(), datetime.now()), flush=True)
        loss_history.append(loss.item())
        del loss
      except:
        print('Fail graph with id: {}'.format(src[i].get('id')), flush=True)
        continue
      if i != 0 and i % batch_size == 0 or i == len(src)-1:
        # end of a batch
        save_models(e, i+1, loss_history, optimizer, 
            query_encoder, cand_encoder, rgcn, output_layer, 
            PATH)
        print('Model saved', flush=True)
        
    # end of epoch
    save_models(e+1, 0, loss_history, optimizer, 
        query_encoder, cand_encoder, rgcn, output_layer, 
        PATH)
    if loss_history[-10] - loss_history[-1] < tol:
      return query_encoder, cand_encoder, rgcn, output_layer
    step = 0
  return query_encoder, cand_encoder, rgcn, output_layer

In [0]:
def test(loss_fn, src, query_encoder, cand_encoder, rgcn, output_layer):
  query_encoder.eval()
  cand_encoder.eval()
  rgcn.eval()
  output_layer.eval()
  
  num_processed_graphs = 0
  loss_history = []
  acc = 0.
  with torch.no_grad():
    for i in range(len(src)):
      try:
        g = build_graph(src[i], query_encoder, cand_encoder)
        # PyG
        out = rgcn(g.x, g.edge_index, g.edge_type)
        pred = output_layer(out, g.query, g.node_mask)
        
        loss_history.append(loss_fn(pred, g.y).item())
        acc += (g.y[pred.argmax(), 0] == 1).item()
        num_processed_graphs += 1
      except:
        print('Fail graph with id: {}'.format(src[i].get('id')), flush=True)
        continue
      if i % 32 == 0:
        print('[{:d}/{:d}]\tloss: {:.4f}\tacc: {:.1f}\t{}'.format(
            num_processed_graphs, len(src), loss_history[-1], acc, datetime.now()), flush=True)
        
  return acc/num_processed_graphs, loss_history     

In [0]:
## Training
# parameters
epochs = range(4)
step = 0
batch_size = 32
L = 3 # number of R-GCN layers
lr = 1e-5
dropout = 0
save_path='./entity_gcn.tar'

# models
query_encoder = QueryEncoder(dropout=dropout).to(device)
cand_encoder = CandidateEncoder(dropout=dropout).to(device)
rgcn = PyG_RGCN(dropout=dropout, L=L).to(device)
output_layer = OutputLayer(dropout=dropout).to(device)

optimizer = Adam(
    itertools.chain(
        query_encoder.parameters(), 
        cand_encoder.parameters(), 
        rgcn.parameters(), 
        output_layer.parameters()), 
    lr=lr,
    eps=1e-4)
loss_fn = nn.BCEWithLogitsLoss()
loss_history = []

# load checkpoint
if os.path.isfile(save_path):
  checkpoint = torch.load(save_path)
  epochs = range(checkpoint['epoch'], 20)
  step = checkpoint['step']
  loss_history = checkpoint['loss_history']
  query_encoder.load_state_dict(checkpoint['query_encoder'])
  cand_encoder.load_state_dict(checkpoint['cand_encoder'])
  rgcn.load_state_dict(checkpoint['rgcn'])
  output_layer.load_state_dict(checkpoint['output_layer'])
  optimizer.load_state_dict(checkpoint['optimizer'])

[query_encoder, cand_encoder, rgcn, output_layer], optimizer = \
  amp.initialize([query_encoder, cand_encoder, rgcn, output_layer], optimizer, opt_level='O1')

query_encoder, cand_encoder, rgcn, output_layer = train(
    epochs, step, batch_size, optimizer, loss_fn, train_src, query_encoder, 
    cand_encoder, rgcn, output_layer, save_path, loss_history=loss_history)

In [0]:
## Testing

L = 3 # number of R-GCN layers
dropout = 0
save_path='./entity_gcn.tar'

query_encoder = QueryEncoder(dropout=dropout).to(device)
cand_encoder = CandidateEncoder(dropout=dropout).to(device)
# rgcn = DGL_RGCN(dropout=dropout, num_layers=L).to(device)
rgcn = PyG_RGCN(dropout=dropout, L=L).to(device)
output_layer = OutputLayer(dropout=dropout).to(device)

loss_fn = nn.BCEWithLogitsLoss()

checkpoint = torch.load(save_path)
query_encoder.load_state_dict(checkpoint['query_encoder'])
cand_encoder.load_state_dict(checkpoint['cand_encoder'])
rgcn.load_state_dict(checkpoint['rgcn'])
output_layer.load_state_dict(checkpoint['output_layer'])

[query_encoder, cand_encoder, rgcn, output_layer] = \
  amp.initialize([query_encoder, cand_encoder, rgcn, output_layer], opt_level='O1')

acc, loss_history = test(loss_fn, dev_src, 
                         query_encoder, cand_encoder, rgcn, output_layer)