# Commonsense QA: Graph Attention Based Reasoning

EECS 595 Final Project, Task 1: Commonsense QA

* Team ID: 2
* Credit: Ziqiao Ma
* Last update: 2020.12.16

# Setup

## Colab setups

Run this cell load the autoreload extension.

In [1]:
%load_ext autoreload
%autoreload 2

Run the following cell to mount your Google Drive.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Fill in the Google Drive path where you uploaded the file.

In [3]:
GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = 'Colab Notebooks/eecs595/commonsense_qa'

Test if files are located.

In [4]:
import os
import sys

GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
sys.path.append(GOOGLE_DRIVE_PATH)
print(os.listdir(GOOGLE_DRIVE_PATH))

['roberta_fairseq.ipynb', 'csqa-baseline.ipynb', 'csqa-new.ipynb', 'csqa-graph-reasoning.ipynb']


## Dependency installation

Import libraries.

In [5]:
from __future__ import absolute_import

import argparse
import os
import re
import random
import sys
from io import open
import csv
import json
from collections import defaultdict

In [6]:
import logging
from tqdm import tqdm, trange
from itertools import cycle
from pprint import pprint

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO)

logger = logging.getLogger(__name__)

In [7]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler

Install `datasets`.

In [8]:
!pip install datasets
from datasets import load_dataset

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/1a/38/0c24dce24767386123d528d27109024220db0e7a04467b658d587695241a/datasets-1.1.3-py3-none-any.whl (153kB)
[K     |████████████████████████████████| 163kB 8.1MB/s 
[?25hCollecting pyarrow>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e1/27958a70848f8f7089bff8d6ebe42519daf01f976d28b481e1bfd52c8097/pyarrow-2.0.0-cp36-cp36m-manylinux2014_x86_64.whl (17.7MB)
[K     |████████████████████████████████| 17.7MB 205kB/s 
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/f7/73/826b19f3594756cb1c6c23d2fbd8ca6a77a9cd3b650c9dec5acc85004c38/xxhash-2.0.0-cp36-cp36m-manylinux2010_x86_64.whl (242kB)
[K     |████████████████████████████████| 245kB 48.9MB/s 
Installing collected packages: pyarrow, xxhash, datasets
  Found existing installation: pyarrow 0.14.1
    Uninstalling pyarrow-0.14.1:
      Successfully uninstalled pyarrow-0.14.1
Successfully installed datasets-1.

Install `sentencepiece` for `XLNetTokenizer`

In [9]:
!pip install sentencepiece
import sentencepiece

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 9.2MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.94


Install `transformers`

In [10]:
!pip install transformers

from transformers import (AdamW, get_linear_schedule_with_warmup,
                          XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer)

from transformers.models.xlnet.modeling_xlnet import XLNetLayer, XLNetPreTrainedModel
from transformers.modeling_utils import SequenceSummary

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/db/98c3ea1a78190dac41c0127a063abf92bd01b4b0b6970a6db1c2f5b66fa0/transformers-4.0.1-py3-none-any.whl (1.4MB)
[K     |████████████████████████████████| 1.4MB 9.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 23.6MB/s 
Collecting tokenizers==0.9.4
[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)
[K     |████████████████████████████████| 2.9MB 34.5MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=d86f5a45ec300faa030

Define global helper functions.

In [11]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)


def select_field(features, field):
    return [
        [
            choice[field]
            for choice in feature.choices_features
        ]
        for feature in features
    ]

# Benchmark

## Dataset

As a question answering benchmark, [Commonsense QA](https://arxiv.org/abs/1811.00937) presents a natural language question $Q$ of $m$ tokens $\{q_1,q_2,\cdots,q_m\}$ and 5 choices $\{a_1,a_2,\cdots,a_5\}$ labeled with $\{A,B,\cdots,E\}$ regarding each question [1]. Notably, the questions do not entail a inference basis in themselves, so the lack of evidence requires the model to hold a comprehensive understanding on common sense knowledge and a strong reasoning ability to make the right choice.

In [12]:
def load_data(dataset='commonsense_qa', preview=-1):

    assert dataset in {'commonsense_qa', 'conv_entail', 'eat'}

    if dataset == 'commonsense_qa':
        ds = load_dataset('commonsense_qa')

        if preview > 0:
            print('\nLoading an example...')
            data_tr = ds.data['train']
            question = data_tr['question']
            choices = data_tr['choices']
            answerKey = data_tr['answerKey']
            print(question[preview])
            for label, text in zip(choices[preview]['label'], choices[preview]['text']):
                print(label, text)
            print('Ans:', answerKey[preview])

    elif dataset == 'conv_entail':
        dev_set = codecs.open('data/conv_entail/dev_set.json', 'r', encoding='utf-8').read()
        act_tag = codecs.open('data/conv_entail/act_tag.json', 'r', encoding='utf-8').read()
        ds = json.loads(dev_set), json.loads(act_tag)

        if preview > 0:
            print('Preview not yet implemented for this dataset.')

    else:
        eat = codecs.open('data/eat/eat_train.json', 'r', encoding='utf-8').read()
        ds = json.loads(eat)

        if preview > 0:
            print('\nLoading an example...')
            story = ds[preview]['story']
            label = ds[preview]['label']
            bp = ds[preview]['breakpoint']
            for line in story:
                print(line)
            print(label)
            print(bp)

    return ds

Run the following code to preview the dataset:

In [13]:
ds = load_data(dataset='commonsense_qa', preview=5)
print('\nDataset statistics:')
print(ds)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1586.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1055.0, style=ProgressStyle(description…

Using custom data configuration default



Downloading and preparing dataset commonsense_qa/default (download: 4.46 MiB, generated: 2.08 MiB, post-processed: Unknown size, total: 6.54 MiB) to /root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3785890.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=423148.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=471653.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset commonsense_qa downloaded and prepared to /root/.cache/huggingface/datasets/commonsense_qa/default/0.1.0/0e60f0ee8c8509e854ed897f65eb5b2e6ca22578d64cbc3812c79b527d7a7a29. Subsequent calls will reuse this data.

Loading an example...
What home entertainment equipment requires cable?
A radio shack
B substation
C cabinet
D television
E desk
Ans: D

Dataset statistics:
DatasetDict({
    train: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 9741
    })
    validation: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 1221
    })
    test: Dataset({
        features: ['answerKey', 'question', 'choices'],
        num_rows: 1140
    })
})


## Commonsense Knowledge

Many knowledge sources are available, including structured knowledge like [ConceptNet](https://conceptnet.io/) and unstructured knowledge like Wikipedia plain texts [3]. Particularly, graph-structured knowledge is proved to be powerful in many application, because of its ability to represent words as individual nodes and relationships between words as edges.

We will use the preprocessed knowledge data from [2] for this notebook.

In [14]:
!wget -cO - https://6ipv4q.dm.files.1drv.com/y4m6omS9iWd4efq5NhX-gZSX2MDniu9p0ZyPcxzGOKAMM_OGuKnUlONExIzWlUx_tkU2w_-oWzwnODBGk-StxlP_V4WPSlmnRqFLj4V88V9sYyZWUoH5ZoZC8Ul-rRMK7kprx4jkN87PVnQ_SUu3yxNT1S5GlCwsmzENE3zzNnFKJpTVavpyGJJqTCuH3TQu8L6hK4vVpk3jgOl3rEjgsQz-Q > data.zip
!unzip data.zip

!rm -f data.zip
!mv AAAI2020-data/data/ .
!rm -r AAAI2020-data/

--2020-12-16 05:49:01--  https://6ipv4q.dm.files.1drv.com/y4m6omS9iWd4efq5NhX-gZSX2MDniu9p0ZyPcxzGOKAMM_OGuKnUlONExIzWlUx_tkU2w_-oWzwnODBGk-StxlP_V4WPSlmnRqFLj4V88V9sYyZWUoH5ZoZC8Ul-rRMK7kprx4jkN87PVnQ_SUu3yxNT1S5GlCwsmzENE3zzNnFKJpTVavpyGJJqTCuH3TQu8L6hK4vVpk3jgOl3rEjgsQz-Q
Resolving 6ipv4q.dm.files.1drv.com (6ipv4q.dm.files.1drv.com)... 13.107.42.12
Connecting to 6ipv4q.dm.files.1drv.com (6ipv4q.dm.files.1drv.com)|13.107.42.12|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 312773826 (298M) [application/zip]
Saving to: ‘STDOUT’


2020-12-16 05:49:18 (19.1 MB/s) - written to stdout [312773826/312773826]

Archive:  data.zip
   creating: AAAI2020-data/
   creating: AAAI2020-data/data/
  inflating: AAAI2020-data/data/dev.jsonl.concept  
  inflating: AAAI2020-data/data/dev.jsonl.wikigraphnew  
  inflating: AAAI2020-data/data/dev.jsonl_NL  
  inflating: AAAI2020-data/data/dev_rand_split.jsonl  
  inflating: AAAI2020-data/data/test.jsonl.concept  
  inflating: AAAI

## Data Preprocessing

Define helper classes `Example`, `InputFeatures` and `KnowledgeGraph`.

In [15]:
class Example(object):
    """
    A single training/test example for the SWAG dataset.
    """

    def __init__(self,
                 idx,
                 context_sentence,
                 ending_0,
                 ending_1,
                 ending_2,
                 ending_3,
                 ending_4,
                 nodes,
                 adj_matrix,
                 label = None):
        self.idx = idx
        self.context_sentence = context_sentence
        self.endings = [
            ending_0,
            ending_1,
            ending_2,
            ending_3,
            ending_4,
        ]
        self.label = label
        self.nodes=nodes
        self.adj_matrixs=adj_matrix

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        l = [
            "id: {}".format(self.idx),
            "context_sentence: {}".format(self.context_sentence),
            "ending_0: {}".format(self.endings[0]),
            "ending_1: {}".format(self.endings[1]),
            "ending_2: {}".format(self.endings[2]),
            "ending_3: {}".format(self.endings[3]),
            "ending_4: {}".format(self.endings[4]),
        ]

        if self.label is not None:
            l.append("label: {}".format(self.label))

        return "\n".join(l)


class InputFeatures(object):
    """
    A single feature converted from an example.
    """

    def __init__(self,
                 example_id,
                 choices_features,
                 label):
        self.example_id = example_id
        self.choices_features = [
            {
                'input_ids': input_ids,
                'input_mask': input_mask,
                'segment_ids': segment_ids,
                'node_ids': nodes_ids,
                'adj_mask':adj_mask,
            }
            for _, input_ids, input_mask, segment_ids,nodes_ids,adj_mask in choices_features
        ]
        self.label = label


class KnowledgeGraph:
    """
    A knowledge graph.
    """

    def __init__(self, directed=False):
        self.graph = defaultdict(list)
        self.directed = directed

    def addEdge(self, head, tail):
        self.graph[head].append(tail)
        if self.directed:
            self.graph[tail] = self.graph[tail]
        elif head == tail:
            self.graph[tail] = []

    def topologySortHelper(self, s, visited, sortlist):
        visited[s] = True
        for i in self.graph[s]:
            if not visited[i]:
                self.topologySortHelper(i, visited, sortlist)
        sortlist.insert(0, s)

    def topologySort(self):
        visited = {i: False for i in self.graph}
        sortlist = []
        for key in self.graph:
            self.graph[key] = sorted(self.graph[key])
        keys=list(self.graph)
        for v in sorted(keys):
            if not visited[v]:
                self.topologySortHelper(v, visited, sortlist)
        return sortlist 

Read in the files.

In [16]:
def read_examples(input_file, is_training):
    """
    Preprocess the knowledge files into examples.
    """
    cont=0
    examples=[]

    with open(input_file+'.concept') as f1, open(input_file+'_NL') as f2, open(input_file+'.wikigraphnew') as f3:

        for line, line2, line3 in zip(f1, f2, f3):

            if cont % 1000==0:
                logger.info('read cont:{}'.format(cont))

            js=json.loads(line.strip())
            js2=json.loads(line2.strip())
            js3=json.loads(line3.strip())

            qa_list=[]
            context=js['question']['stem']
            node_lists=[]
            adj_matrix_lists=[]

            # temp: .concept | temp2: add one _NL | temp3: wikipedia
            for temp, temp2, temp3 in zip(js['question']['choices'],
                                          js2['question']['choices'],
                                          js3['question']['choices']):

                #############################################################
                #                ConceptNet graph construction              #
                #############################################################

                # max 20 nodes in Concept-Graph
                k=20
                g = KnowledgeGraph(directed=True)
                temp["node"] = temp["node"][:k]

                # add edges according to the edges provided by Jingjing
                for r in temp['relation']:
                    if r[0] < k and r[1] < k:
                        g.addEdge(r[0],r[1])
                        
                # add one edge pointing to itself
                for i in range(len(temp["node"])):
                    g.addEdge(i, i)

                # get the sequence according to topology sort algorithm
                topsort_seq = g.topologySort()
                sorted_node = []

                # topsort_seq contains the idx of nodes, only containing numbers
                for i in topsort_seq:
                    # temp['node'][i]: get the evidence corresponding to node i
                    sorted_node.append(temp['node'][i])

                # sorted_node contains the sorted ConceptNet evidence
                sorted_adj_matrix=[]    
                adj_matrix_list=[]
                for tmp in temp['evidence_edges']:
                    # adj_matrix_list: the list of edges. each 1 or 0 represents two nodes are connected or not.
                    adj_matrix_list.append(tmp)
                    # sorted_adj_matrix: each line represents the nodes linked to the node i
                    sorted_adj_matrix.append([0 for i in range(len(tmp))])

                for i in range(min(k,len(adj_matrix_list))):
                    for j in range(min(k,len(adj_matrix_list[i]))):
                        # sorted_adj_matrix[i][j]=1 means node i is connected to j, 0 means not.
                        sorted_adj_matrix[i][j]=adj_matrix_list[topsort_seq[i]][topsort_seq[j]]

                #############################################################
                #                 Wikepedia graph construction              #
                #############################################################

                # Wiki-Graph has at most 10 sentences.
                k=10
                wiki_nodes=[]

                # temp3['node'] means srl results.
                for n in temp3['node']:
                    wiki_nodes.append(''.join(n.split()))
                wiki_evidences=[]

                # wiki_evidences contains the origin wikipedia evidence
                for n in temp3['searched_evidence']['basic']:
                    wiki_evidences.append(''.join(n['text'].split()))

                # each node belongs to which evidence
                node2evidences={}

                for idx,n in enumerate(wiki_nodes):
                    for idx1,e in enumerate(wiki_evidences):
                        # the condition is not influential by ''.join()
                        if n in e:
                            if idx not in node2evidences: 
                                node2evidences[idx]=[]
                            node2evidences[idx].append(idx1)
    
                g = KnowledgeGraph(directed=True)                
                wiki_adj, sorted_wiki_adj = [], []

                # wiki_adj and sorted_wiki_adj: both k*k, 10*10 matrix
                for i in range(min(k,len(wiki_evidences))):
                    wiki_adj.append([0 for j in range(len(wiki_evidences))])
                    sorted_wiki_adj.append([0 for j in range(len(wiki_evidences))])

                # add edges between two evidences if the components of them are connected...
                for r in temp3['relation']:
                    for i in node2evidences[r[0]]:
                        for j in node2evidences[r[1]]:
                            if i < k and j < k:
                                # construct a directed graph.
                                g.addEdge(i,j)
                                wiki_adj[i][j]=1
                                wiki_adj[j][i]=1
                wiki_evidences=[]

                for tmp in temp3['searched_evidence']['basic'][:k]:
                    wiki_evidences.append(tmp['text']) 

                for i in range(len(wiki_evidences)):
                    g.addEdge(i,i)

                topsort_seq = g.topologySort()

                sorted_evidences=[]
                for i in topsort_seq:
                    sorted_evidences.append(wiki_evidences[i])
                sorted_adj_matrix = []

                # the re-ordered wikipedia evidence
                for i in range(min(k, len(wiki_adj))):
                    for j in range(min(k, len(wiki_adj[i]))):
                        sorted_wiki_adj[i][j] = wiki_adj[topsort_seq[i]][topsort_seq[j]]

                #############################################################
                #                    Example construction                   #
                #############################################################
                
                t = sorted_evidences 
                t1 = sorted_node

                # Only wiki evidence
                if len(t) == 0:
                    t = ["None"]

                # Only Concept node   
                if len(t1) == 0:
                    t1 = ["None"]

                # Add Wiki-Graph triple nodes
                # 10 sentences, one sentence contains which triples
                srl_triples_index = [[] for i in range(len(wiki_evidences))]
                
                index = 0
                evidence2arguments = {}
                for srl in temp3['searched_evidence']['basic']:
                    srl_triple_index = []
                    srl_triples = srl['srl_triple']
                    # print('srl_triples:{}'.format(srl_triples))

                    srl_verbs = []
                    if len(srl['srl']['verbs']) > 0:
                        srl_verbs = srl['srl']['verbs']
                    else:
                        srl_verbs = [srl['srl']['verbs']]

                    # for triple, verb, evidence in zip(srl_triples, srl_verbs, wiki_evidences):
                    for verb in srl_verbs:
                        if len(verb) == 0:
                            continue
                        text = verb['description']
                        res = re.findall(r"\[(.*?)\]", text)

                        for temp_res in res:
                            if len(temp_res.split(':'))>=2:
                                temp_res2 = temp_res.split(':')[1].strip()
                                text = text.replace('[' + temp_res + ']', temp_res2)
                        if len(text.split(' ')) != len(verb['tags']):
                            continue

                        # add nodes and edges
                        if text not in evidence2arguments:
                            evidence2arguments[text] = []

                        tags = verb['tags']
                        if not ('B-ARG0' in tags and 'B-V' in tags and 'B-ARG1' in tags):
                            continue

                        # for ARG0
                        start = tags.index('B-ARG0')
                        end = start
                        for temp_i in range(start+1, len(tags)):
                            if tags[temp_i] == 'I-ARG0':
                                end += 1
                            else:
                                break
                        text = text.split(' ')
                        # print('ARG0 {} {} {}'.format(text[start:end+1], start, end))
                        evidence2arguments[' '.join(text)].append([' '.join(text[start:end+1]), start, end])

                        # for verb
                        start = tags.index('B-V')
                        end = start
                        for temp_i in range(start+1, len(tags)):
                            if tags[temp_i] == 'I-V':
                                end += 1
                            else:
                                break
                        # print('VERB {} {}'.format(text[start:end+1],start, end))
                        evidence2arguments[' '.join(text)].append([' '.join(text[start:end + 1]), start, end])

                        # for ARG1
                        start = tags.index('B-ARG1')
                        end = start
                        for temp_i in range(start+1, len(tags)):
                            if tags[temp_i] == 'I-ARG1':
                                end += 1
                            else:
                                break
                        # print('ARG1 {} {}'.format(text[start:end+1],start, end))
                        evidence2arguments[' '.join(text)].append([' '.join(text[start:end + 1]), start, end])

                qa_list.append(([temp['text'],'##'.join(t),'##'.join([temp2['text']] + t1),sorted_wiki_adj,evidence2arguments]))

                # node_lists contains the nodes in ConceptNet
                # adj_matrix_lists contains the adjacent matrix of Wiki-Graph
                node_lists.append(sorted_node)
                adj_matrix_lists.append(sorted_wiki_adj)

            cont += 1
            examples.append(
                Example(
                        idx = cont,
                        context_sentence = context,
                        ending_0 = qa_list[0],
                        ending_1 = qa_list[1],
                        ending_2 = qa_list[2],
                        ending_3 = qa_list[3],
                        ending_4 = qa_list[4],
                        nodes=node_lists,
                        adj_matrix=adj_matrix_lists,
                        label = ord(js['answerKey'])-ord('A') if is_training else None
                        ) 
            )

    return examples

In [31]:
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """
    Truncates a sequence pair in place to the maximum length.
  
    This is a simple heuristic which will always truncate the longer sequence
    one token at a time. This makes more sense than truncating an equal percent
    of tokens from each, since if one sequence is very short then each token
    that's truncated likely contains more information than a longer sequence.
    """

    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

            
def convert_examples_to_features(examples, tokenizer, max_seq_length, is_training):
    """
    Loads a data file into a list of `InputBatch`s.
    
    Swag is a multiple choice task. To perform this task using Bert,
    we will use the formatting proposed in "Improving Language
    Understanding by Generative Pre-Training" and suggested by
    @jacobdevlin-google in this issue
    https://github.com/google-research/bert/issues/38.
    
    Each choice will correspond to a sample on which we run the
    inference. For a given Swag example, we will create the 4
    following inputs:
    - [CLS] context [SEP] choice_1 [SEP]
    - [CLS] context [SEP] choice_2 [SEP]
    - [CLS] context [SEP] choice_3 [SEP]
    - [CLS] context [SEP] choice_4 [SEP]
    The model will output a single value for each input. 
    To get the final decision of the model, 
    we will run a softmax over these 4 outputs.
    """

    features = []
    for example_index, example in enumerate(examples):

        choices_features = []
        if example_index % 1000 == 0 and example_index > 0:
            logger.info('convert example to feature:{}'.format(example_index))

        # change example:
        for ending_index, (ending,node,adj_matrix) in enumerate(zip(example.endings,example.nodes,example.adj_matrixs)):

            # We create a copy of the context tokens in order to be
            # able to shrink it according to ending_tokens

            # Question + Answer
            ending_tokens = tokenizer.tokenize(example.context_sentence) + tokenizer.tokenize(
                'The answer is') + tokenizer.tokenize(ending[0])
            evidence2arguments = ending[-1]

            # Conceptnet evidence
            concept_context_tokens_choice = tokenizer.tokenize(ending[2])
            _truncate_seq_pair(concept_context_tokens_choice, [], 128)
            concept_context_tokens_choice.append("<sep>")

            tokens_choice = concept_context_tokens_choice
            tokens_choice.extend(tokenizer.tokenize('# Wikipedia #'))
            wiki_nodes_choice = []

            for key in evidence2arguments:
                # one key, one evidence
                words2tokens = {}
                evidence_tokens = []
                key = key.split(' ')

                for temp_i in range(len(key)):
                    tokens = tokenizer.tokenize(key[temp_i])
                    words2tokens[temp_i] = (len(evidence_tokens), len(tokens) + len(evidence_tokens))
                    evidence_tokens.extend(tokens)

                if len(tokens_choice) + len(evidence_tokens) > 256 - len(ending_tokens) - 3 - 2: # 2 for ##
                    break

                origin_length = len(tokens_choice)
                tokens_choice.extend(evidence_tokens)
                tokens_choice.extend(tokenizer.tokenize('##'))

                # update the start and end index of each argument
                for item in evidence2arguments[' '.join(key)]:
                    start = words2tokens[item[1]][0] + origin_length
                    end = words2tokens[item[2]][1] + origin_length
                    wiki_nodes_choice.append([item[0], start, end])

            tokens_choice.append("<sep>")
            tokens_choice.extend(ending_tokens)
            tokens_choice.append('<cls>')

            tokens = tokens_choice
            segment_ids = [0] * (len(tokens) - len(ending_tokens) - 2) + [1] * (len(ending_tokens) + 1)+[2]
            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            #update the index because padding is at the beginning of the input
            padding_length = max_seq_length - len(input_ids)
            for node in wiki_nodes_choice:
                node[1] += padding_length
                node[2] += padding_length

            temp=' '.join(str(x) for x in input_ids)
            temp=temp.split('17 7967 20631 17 7967')[0]
            temp=temp.split('7967 7967')
            temp=[len(x.split()) for x in temp]
            input_mask = [1] * len(input_ids)
            padding_length = max_seq_length - len(input_ids)

            input_ids = ([0] * padding_length) + input_ids
            input_mask = ([0] * padding_length) + input_mask
            segment_ids = ([4] * padding_length) + segment_ids

            skip=padding_length+temp[0]+2
            node=[]
            for t in temp[1:]:
                vector=np.zeros(max_seq_length)
                vector[skip:t+skip]=1
                skip+=t+2
                node.append(vector[None,:])

            node=node[:50]
            for i in range(50-len(node)):
                vector=np.zeros(max_seq_length)
                node.append(vector[None,:])

            node_size=len(temp)-1
            matrix=np.zeros((150,150))
            for i,val in enumerate(adj_matrix):
                for j,v in enumerate(adj_matrix[i]):
                    if v==1 and i<50 and j<50 and i<node_size and j<node_size:
                        matrix[i,j]=1
       
            #############################################################

            temp=' '.join(str(x) for x in input_ids)
            temp=temp.split('17 7967 20631 17 7967')
            skip=len(temp[0].split())+5
            temp = temp[1] if len(temp) > 1 else temp[0]

            for item in wiki_nodes_choice:
                vector = np.zeros(max_seq_length)
                for temp in range(item[1], item[2]):
                    vector[temp] = 1
                node.append(vector[None,:])

            node = node[:150]
            for i in range(150 - len(node)):
                vector = np.zeros(max_seq_length)
                node.append(vector[None, :])
            node = np.concatenate(node, 0)

            node_size = len(wiki_nodes_choice) - 1  # 0.. len(wiki_nodes_choices)-1

            for idx1, node1 in enumerate(wiki_nodes_choice):
                for idx2, node2 in enumerate(wiki_nodes_choice):
                    if node1[0].lower() == node2[0].lower() and idx1<100 and idx2<100 and idx1<node_size and idx2<node_size :
                        matrix[idx1+50, idx2+50] = 1

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            choices_features.append((tokens, input_ids, input_mask, segment_ids,node,matrix))


        label = example.label
        if example_index < 1 and is_training:
            logger.info("*** Example ***")
            logger.info("idx: {}".format(example.idx))
            for choice_idx, (tokens, input_ids, input_mask, segment_ids,_,_) in enumerate(choices_features):
                logger.info("choice: {}".format(choice_idx))
                logger.info("tokens: {}".format(' '.join(tokens).replace('\u2581','_')))
                logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
                logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
                logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
                logger.info("label: {}".format(label))

        features.append(
            InputFeatures(
                example_id = example.idx,
                choices_features = choices_features,
                label = label
            )
        )

    return features


# Model

## Graph-based Model

The [Graph-Based Reasoning over Heterogeneous External Knowledge for Commonsense Question Answering](https://arxiv.org/abs/1909.05311) addresses the Question Answering problem and perform experiment on CommonsenseQA [2]. 

One major contribution of the work is that they are the first to propose a model that leverage evidence from multiple knowledge sources. In the experiment, ConceptNet and Wikipedia Plain Text are preprocessed into knowledge graphs.

The Graph-based Reasoning module consists of a graph-based contextual representation learning module and a graph-based inference module.

* The graph-based contextual representation learning module is built upon XLNet. The module assigns a closer distance of those related works in different evidence sentences by using graph information. Algorithmically, Topology Sort Algorithm is applied to re-order the input evidence according to the constructed knowledge graphs.

* The graph-based inference module tries to aggregate evidence at the graph-level for predictions. Specifically, a Graph Convolutional Network (GCN) is used to retrieve the node representation, and a graph attention layer is applied for prediction.

See official implementation [here](https://github.com/DecstionBack/AAAI_2020_CommonsenseQA).

```
GraphBasedXLNetModel(
  (word_embedding): Embedding(32000, 1024)
  (layer): ModuleList(
    (0): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (2): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (3): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (4): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (5): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (6): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (7): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (8): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (9): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (10): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (11): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (12): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (13): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (14): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (15): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (16): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (17): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (18): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (19): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (20): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (21): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (22): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (23): XLNetLayer(
      (rel_attn): XLNetRelativeAttention(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): XLNetFeedForward(
        (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (layer_1): Linear(in_features=1024, out_features=4096, bias=True)
        (layer_2): Linear(in_features=4096, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (map): Linear(in_features=64, out_features=1024, bias=True)
  (map_node_emb): Linear(in_features=1024, out_features=64, bias=True)
  (GCN_W): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
  )
  (GCN_W_self): ModuleList(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
  )
)
```

Define the graph based transformer module.

In [29]:
class GraphBasedXLNetModel(XLNetPreTrainedModel):
    """
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **mems**:
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
    """

    def __init__(self, config):
        super(GraphBasedXLNetModel, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

        self.mem_len = config.mem_len
        self.reuse_len = config.reuse_len
        self.d_model = config.d_model
        self.same_length = config.same_length
        self.attn_type = config.attn_type
        self.bi_data = config.bi_data
        self.clamp_len = config.clamp_len
        self.n_layer = config.n_layer

        self.word_embedding = nn.Embedding(32000, config.d_model)
        self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
        self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
        self.dropout = nn.Dropout(config.dropout)
        
        self.map = nn.Linear(64, config.d_model)
        self.map_node_emb = nn.Linear(config.d_model, 64)
        
        self.GCN_W = []
        self.GCN_W_self = []
        for i in range(2):
            self.GCN_W.append(nn.Linear(64, 64))
            self.GCN_W_self.append(
                nn.Linear(64, 64) 
                )
        self.GCN_W = nn.ModuleList(self.GCN_W)
        self.GCN_W_self = nn.ModuleList(self.GCN_W_self)

        self.init_weights()

    def _resize_token_embeddings(self, new_num_tokens):
        self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens)
        return self.word_embedding

    def _prune_heads(self, heads_to_prune):
        raise NotImplementedError

    def create_mask(self, qlen, mlen):
        """
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
        Args:
            qlen: TODO Lysandre didn't fill
            mlen: TODO Lysandre didn't fill
        ::
                  same_length=False:      same_length=True:
                  <mlen > <  qlen >       <mlen > <  qlen >
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]
        """
        attn_mask = torch.ones([qlen, qlen])
        mask_up = torch.triu(attn_mask, diagonal=1)
        attn_mask_pad = torch.zeros([qlen, mlen])
        ret = torch.cat([attn_mask_pad, mask_up], dim=1)
        if self.same_length:
            mask_lo = torch.tril(attn_mask, diagonal=-1)
            ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)

        ret = ret.to(next(self.parameters()))
        return ret

    def cache_mem(self, curr_out, prev_mem):
        """
        cache hidden states into memory.
        """
        if self.mem_len is None or self.mem_len == 0:
            return None
        else:
            if self.reuse_len is not None and self.reuse_len > 0:
                curr_out = curr_out[:self.reuse_len]

            if prev_mem is None:
                new_mem = curr_out[-self.mem_len:]
            else:
                new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]

        return new_mem.detach()

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        """create relative positional encoding."""
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == 'bi':
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == 'uni':
            # beg, end = klen - 1, -1
            beg, end = klen, -1
        else:
            raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))

        if self.bi_data:
            fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
            bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)

            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
                bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)

            if bsz is not None:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
            else:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)

            pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)

        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(next(self.parameters()))
        return pos_emb

    def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
                mems=None, perm_mask=None, target_mapping=None, head_mask=None,node_mask=None,adj_mask=None):
        """
        The original code for XLNet uses shapes [len, bsz] with the batch dimension at the end,
        but we want a unified interface in the library with the batch size on the first dimension,
        so we move here the first dimension (batch) to the end
        """

        input_ids = input_ids.transpose(0, 1).contiguous()
        token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
        input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
        attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
        perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
        target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None

        qlen, bsz = input_ids.shape[0], input_ids.shape[1]
        mlen = mems[0].shape[0] if mems is not None else 0
        klen = mlen + qlen

        dtype_float = next(self.parameters()).dtype
        device = next(self.parameters()).device

        ##### Attention mask
        # causal attention mask
        if self.attn_type == 'uni':
            attn_mask = self.create_mask(qlen, mlen)
            attn_mask = attn_mask[:, :, None, None]
        elif self.attn_type == 'bi':
            attn_mask = None
        else:
            raise ValueError('Unsupported attention type: {}'.format(self.attn_type))

        # data mask: input mask & perm mask
        assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
        "or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
        if input_mask is None and attention_mask is not None:
            input_mask = 1.0 - attention_mask
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
            data_mask = torch.cat([mems_mask, data_mask], dim=1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = (attn_mask > 0).to(dtype_float)

        if attn_mask is not None:
            non_tgt_mask = -torch.eye(qlen).to(attn_mask)
            non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
            non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
        else:
            non_tgt_mask = None

        ##### Word embeddings and prepare h & g hidden states
        word_emb_k = self.word_embedding(input_ids)
        output_h = self.dropout(word_emb_k)
        if target_mapping is not None:
            word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
        # else:  # We removed the inp_q input which was same as target mapping
        #     inp_q_ext = inp_q[:, :, None]
        #     word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
            output_g = self.dropout(word_emb_q)
        else:
            output_g = None

        ##### Segment embedding
        if token_type_ids is not None:
            # Convert `token_type_ids` to one-hot `seg_mat`
            mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
            cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
            seg_mat = torch.cat([1-seg_mat.unsqueeze(-1),seg_mat.unsqueeze(-1)],-1).to(dtype_float)
        else:
            seg_mat = None

        ##### Positional encoding
        pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
        pos_emb = self.dropout(pos_emb)

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

        new_mems = ()
        if mems is None:
            mems = [None] * len(self.layer)

        attentions = []
        hidden_states = []
        for i, layer_module in enumerate(self.layer):
            # cache new mems
            new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
            if self.output_hidden_states:
                hidden_states.append((output_h, output_g) if output_g is not None else output_h)

            outputs = layer_module(output_h, output_g, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
                                   r=pos_emb, seg_mat=seg_mat, mems=mems[i], target_mapping=target_mapping,
                                   head_mask=head_mask[i])
            output_h, output_g = outputs[:2]
            if self.output_attentions:
                attentions.append(outputs[2])

        # Add last hidden state
        if self.output_hidden_states:
            hidden_states.append((output_h, output_g) if output_g is not None else output_h)

        output = self.dropout(output_g if output_g is not None else output_h)

        output = output.permute(1, 0, 2).contiguous().view(node_mask.shape[0],node_mask.shape[1],output.shape[0],output.shape[-1])

        node_mask = node_mask.float()
        adj_mask = adj_mask.float()
        
        node_emb = torch.einsum('abcd,abde->abce', node_mask, output)
        node_emb = node_emb/(node_mask.sum(-1)+1e-30)[:,:,:,None]
        node_emb = nn.functional.tanh(self.map_node_emb(node_emb))
        
        for i in range(1):
            node_emb_avg  =torch.einsum('abcd,abde->abce', adj_mask,self.GCN_W[i](node_emb))
            node_emb_avg = node_emb_avg/(adj_mask.sum(-1)+1e-30)[:,:,:,None]
            node_emb = nn.functional.tanh(self.GCN_W_self[i](node_emb)+node_emb_avg)
            
        output = output[:,:,-1]
        mask = node_mask.max(-1)[0]
        attn_score = torch.einsum('abc,abdc->abd', output,nn.functional.tanh(self.map(node_emb)))
        attn_score = attn_score - 1e30 * (1-mask)
        PR = nn.functional.softmax(attn_score,-1)
        kb_emb = PR[:,:,:,None]*node_emb
        kb_emb = kb_emb.sum(2).view(-1,kb_emb.shape[-1])          
        
        # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
        output = self.dropout(output_g if output_g is not None else output_h)
        outputs = (output.permute(1, 0, 2).contiguous(), self.dropout(kb_emb),new_mems)
        if self.output_hidden_states:
            if output_g is not None:
                hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
            else:
                hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
            outputs = outputs + (hidden_states,)
        if self.output_attentions:
            attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
            outputs = outputs + (attentions,)

        return outputs  # outputs, new_mems, (hidden_states), (attentions)


Define the corresponding `MultipleChoice` module.

In [19]:
class GraphBasedXLNetForMultipleChoice(XLNetPreTrainedModel):

    def __init__(self, config):
        super(GraphBasedXLNetForMultipleChoice, self).__init__(config)
        self.num_labels = config.num_labels
        self.transformer = GraphBasedXLNetModel(config)
        self.sequence_summary = SequenceSummary(config)
        self.sequence_summary1 = SequenceSummary(config)
        self.ln = nn.Linear(64, 64)
        self.ln1 = nn.Linear(config.d_model, config.d_model)
        self.last_dropout = nn.Dropout(config.summary_last_dropout)
        self.classifier = nn.Linear(config.d_model, 1)
        self.classifier1 = nn.Linear(64, 1)
        self.apply(self.init_weights)


    def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
                mems=None, perm_mask=None, target_mapping=None,
                labels=None, head_mask=None, node_mask=None, adj_mask=None):
        num_choices = self.num_labels

        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_input_mask = input_mask.view(-1, token_type_ids.size(-1)) if input_mask is not None else None
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        flat_mems = mems.view(-1, mems.size(-1)) if mems is not None else None
        flat_perm_mask = perm_mask.view(-1, perm_mask.size(-1)) if perm_mask is not None else None
        flat_target_mapping = target_mapping.view(-1, target_mapping.size(-1)) if target_mapping is not None else None
        flat_head_mask = head_mask.view(-1, head_mask.size(-1)) if head_mask is not None else None
        
        transformer_outputs = self.transformer(flat_input_ids, token_type_ids=flat_token_type_ids,
                                               input_mask=flat_input_mask, attention_mask=flat_attention_mask,
                                               mems=flat_mems, perm_mask=flat_perm_mask, target_mapping=flat_target_mapping,
                                               head_mask=flat_head_mask,node_mask=node_mask,adj_mask=adj_mask)
        
        output,output_kb = transformer_outputs[0][:,-1],transformer_outputs[1]
        
        output = self.sequence_summary(torch.cat((output,nn.functional.tanh(self.ln(output_kb))),-1))
        logits = self.classifier(output)        

        reshaped_logits = logits.view(-1, num_choices)

        outputs = reshaped_logits  # add hidden states and attention if they are here

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
            outputs = loss

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)

Define a wrapper for model loading.

In [20]:
def load_model(model='all'):

    print('Loading model', model)
    
    if model == 'graph_xlnet':
        return XLNetConfig, GraphBasedXLNetModel, XLNetTokenizer
    elif model == 'xlnet':
        return XLNetConfig, XLNetForMultipleChoice, XLNetTokenizer
    raise NotImplemented

# Experiment

## Runtime

In [21]:
def main(args):

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
    
    # Set seed
    set_seed(args)

    try:
        os.makedirs(args.output_dir)
    except:
        pass

    config_class, model_class, tokenizer_class = load_model(args.model_type)

    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name,
        num_labels=5, finetuning_task=args.task_name)
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name,
        do_lower_case=True)
    
    model = model_class.from_pretrained(
        args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)

    if args.fp16:
        model.half()
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Prepare data loader
    train_examples = read_examples(os.path.join(args.data_dir, 'train.jsonl'), is_training = True)
    train_features = convert_examples_to_features(
        train_examples, tokenizer, args.max_seq_length, True)
    all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
    all_node_ids = torch.tensor(select_field(train_features, 'node_ids'), dtype=torch.long)
    all_adj_mask = torch.tensor(select_field(train_features, 'adj_mask'), dtype=torch.long) 
    all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label,all_node_ids,all_adj_mask)
    logger.info(all_input_ids.size())
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size//args.gradient_accumulation_steps)

    num_train_optimization_steps =  args.train_steps

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=10000)
        
    global_step = 0

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_optimization_steps)
        
    model.train()
    best_acc=0
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0        
    bar = tqdm(range(num_train_optimization_steps),total=num_train_optimization_steps, disable=False, leave=True, position=1)
    train_dataloader=cycle(train_dataloader)
        
    for step in bar:

        batch = next(train_dataloader)
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_ids,node_ids,adj_mask = batch
        loss = model(input_ids=input_ids, 
                     token_type_ids=segment_ids, 
                     attention_mask=input_mask, 
                     labels=label_ids,
                     node_mask=node_ids,
                     adj_mask=adj_mask)
        if args.n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        if args.fp16 and args.loss_scale != 1.0:
            loss = loss * args.loss_scale
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps
        tr_loss += loss.item()
        train_loss=round(tr_loss*args.gradient_accumulation_steps/(nb_tr_steps+1),4)
        bar.set_description("loss {}".format(train_loss))
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if args.fp16:
            optimizer.backward(loss)
        else:
            loss.backward()

        if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16:
                # modify learning rate with special warm up BERT uses
                # if args.fp16 is False, BertAdam is used that handles this automatically
                lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
            scheduler.step()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

        if (step + 1) %(args.eval_steps*args.gradient_accumulation_steps)==0:
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0 
            logger.info("***** Report result *****")
            logger.info("  %s = %s", 'global_step', str(global_step))
            logger.info("  %s = %s", 'train loss', str(train_loss))

        if (step + 1) %(args.eval_steps * args.gradient_accumulation_steps)==0:

             for file in ['dev.jsonl']:
                eval_examples = read_examples(os.path.join(args.data_dir, file), is_training = True)
                inference_labels=[]
                gold_labels=[]
                eval_examples = read_examples(os.path.join(args.data_dir, file), is_training = True)
                eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,False)
                all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
                all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
                all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)

                all_node_ids = torch.tensor(select_field(eval_features, 'node_ids'), dtype=torch.long)
                all_adj_mask = torch.tensor(select_field(eval_features, 'adj_mask'), dtype=torch.long)                     

                all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)

                eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label,all_node_ids,all_adj_mask)
                        
                logger.info("***** Running evaluation *****")
                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)  
                        
                # Run prediction for full data
                eval_sampler = SequentialSampler(eval_data)
                eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

                model.eval()

                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0
                for input_ids, input_mask, segment_ids, label_ids,node_ids,adj_mask in eval_dataloader:
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)
                    node_ids = node_ids.to(device)
                    adj_mask = adj_mask.to(device) 

                    with torch.no_grad():
                        tmp_eval_loss= model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids,node_mask=node_ids,adj_mask=adj_mask)
                        logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,node_mask=node_ids,adj_mask=adj_mask)

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits, label_ids)
                    inference_labels.append(np.argmax(logits, axis=1))
                    gold_labels.append(label_ids)
                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                eval_accuracy = eval_accuracy / nb_eval_examples

                result = {'eval_loss': eval_loss,
                          'eval_accuracy': eval_accuracy,
                          'global_step': global_step+1,
                          'loss': train_loss}

                output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
                with open(output_eval_file, "a") as writer:
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))
                    writer.write('*'*80)
                    writer.write('\n')
                if eval_accuracy>best_acc and 'dev' in file:
                    print("="*80)
                    print("Best Acc",eval_accuracy)
                    print("Saving Model......")
                    best_acc=eval_accuracy

                    # Save a trained model
                    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
                    torch.save(model_to_save.state_dict(), output_model_file)
                    print("="*80)
                    inference_labels=np.concatenate(inference_labels,0)
                    gold_labels=np.concatenate(gold_labels,0)

                    with open(os.path.join(args.output_dir, "error_output.txt"),'w') as f:
                        for i in range(len(eval_examples)):
                            if inference_labels[i]!=gold_labels[i]:
                                try:
                                    f.write(str(repr(eval_examples[i]))+'\n')
                                    f.write(str(inference_labels[i])+'\n')
                                    f.write("="*80+'\n')
                                except:
                                    print('Failed to write the file.')
                                    pass
                else:
                    print("="*80) 



In [32]:
def run(data_dir, output_dir,
        model_type='graph_xlnet',
        model_name='xlnet-large-cased',
        task_name=None):

    parser = argparse.ArgumentParser(description="Common sense question answering")

    # Required parameters
    parser.add_argument("--data_dir", default=data_dir, type=str,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", default=output_dir, type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    
    # Training parameters
    parser.add_argument("--model_type", type=str, default=model_type,
                        help="Model: <str> [ bert | xlnet | roberta | gpt2 ]")
    parser.add_argument("--task_name", default=task_name, type=str, required=False,
                        help="The name of the task to train: <str> [ commonqa ]")
    parser.add_argument("--model_name", type=str,
                        default=model_name,
                        help="Path to pre-trained model or shortcut name."
                              "See https://huggingface.co/models")
    parser.add_argument("--config_name", type=str,
                        default=model_name,
                        help="Pre-trained config name or path")
    parser.add_argument("--tokenizer_name", default=model_name, type=str,
                        help="Pre-trained tokenizer name or path if not the same as model_name")

    # Other parameters
    parser.add_argument("--max_seq_length", default=256, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size", default=2, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-6, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-6, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--eval_steps", default=200, type=int,
                        help="")
    parser.add_argument("--train_steps", default=10000, type=int,
                        help="")  # Original Paper: 40000
    parser.add_argument("--report_steps", default=1000, type=int,
                        help="")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--seed', type=int, default=0,
                        help="random seed for initialization")

    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
    args, unknown = parser.parse_known_args()
    
    main(args)


## Finetuning

In [23]:
!mkdir csqa_out

Run the graph based `graph_xlnet`.

In [None]:
data_dir = 'data/'
output_dir = 'csqa_out/'
model_type = 'graph_xlnet'
model_name = 'xlnet-large-cased'
task_name = 'commonqa'

run(data_dir, output_dir, model_type, model_name, task_name)



Loading model graph_xlnet


Some weights of GraphBasedXLNetModel were not initialized from the model checkpoint at xlnet-large-cased and are newly initialized: ['transformer.map.weight', 'transformer.map.bias', 'transformer.map_node_emb.weight', 'transformer.map_node_emb.bias', 'transformer.GCN_W.0.weight', 'transformer.GCN_W.0.bias', 'transformer.GCN_W.1.weight', 'transformer.GCN_W.1.bias', 'transformer.GCN_W_self.0.weight', 'transformer.GCN_W_self.0.bias', 'transformer.GCN_W_self.1.weight', 'transformer.GCN_W_self.1.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
12/16/2020 06:06:46 - INFO - __main__ -   read cont:0
12/16/2020 06:06:57 - INFO - __main__ -   read cont:1000
12/16/2020 06:07:08 - INFO - __main__ -   read cont:2000
12/16/2020 06:07:17 - INFO - __main__ -   read cont:3000
12/16/2020 06:07:27 - INFO - __main__ -   read cont:4000
12/16/2020 06:07:39 - INFO - __main__ -   read cont:5000
12/16/2020 06:07:49 - INFO - __main__ -   read 

# Results

The original paper ran experiments on 2 P100 GPUs with 50 GB RAM, and trained for 40000 epochs. We cannot afford the hardware on Colab. Therefore, the code is tested on a reduced scale.

Result reported by the paper:
* Validation Accuracy: 79.3%
* Test Accuracy: 75.3%

Our experiement:
* Validation Accuracy: 73.0%
* Test Accuracy: Not Available

# References

[1] Talmor, A., Herzig, J., Lourie, N., & Berant, J. (2019). CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge. ArXiv, abs/1811.00937.

[2] Lv, S., Guo, D., Xu, J., Tang, D., Duan, N., Gong, M., Shou, L., Jiang, D., Cao, G., & Hu, S. (2020). Graph-Based Reasoning over Heterogeneous External Knowledge for Commonsense Question Answering. AAAI.

[3] Speer, R., Chin, J., & Havasi, C. (2017). ConceptNet 5.5: An Open Multilingual Graph of General Knowledge. ArXiv, abs/1612.03975.

```
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```