### https://github.com/thunlp/KernelGAT

In [None]:
!git clone https://github.com/xiangwang1223/knowledge_graph_attention_network.git

Cloning into 'knowledge_graph_attention_network'...
remote: Enumerating objects: 254, done.[K
remote: Total 254 (delta 0), reused 0 (delta 0), pack-reused 254[K
Receiving objects: 100% (254/254), 99.56 MiB | 24.09 MiB/s, done.
Resolving deltas: 100% (92/92), done.


### Load dataset

In [None]:
%cd

/root


In [None]:
%cd /content/knowledge_graph_attention_network/Model

/content/knowledge_graph_attention_network/Model


In [None]:
%cd /content/knowledge_graph_attention_network

/content/knowledge_graph_attention_network


In [None]:
import tensorflow as tf
from utility.helper import *

from time import time

from BPRMF import BPRMF
from CKE import CKE
from CFKG import CFKG
from NFM import NFM
from KGAT import KGAT
import os
import sys
import numpy as np
import argparse


In [None]:
'''
Created on Dec 18, 2018
Tensorflow Implementation of Knowledge Graph Attention Network (KGAT) model in:
Wang Xiang et al. KGAT: Knowledge Graph Attention Network for Recommendation. In KDD 2019.
@author: Xiang Wang (xiangwang@u.nus.edu)
'''



os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

def load_pretrained_data(args):
    pre_model = 'mf'
    if args.pretrain == -2:
        pre_model = 'kgat'
    pretrain_path = '%spretrain/%s/%s.npz' % (args.proj_path, args.dataset, pre_model)
    try:
        pretrain_data = np.load(pretrain_path)
        print('load the pretrained bprmf model parameters.')
    except Exception:
        pretrain_data = None
    return pretrain_data


if __name__ == '__main__':
    # get argument settings.
    tf.set_random_seed(2019)
    np.random.seed(2019)
    args = parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    """
    *********************************************************
    Load Data from data_generator function.
    """
    config = dict()
    config['n_users'] = data_generator.n_users
    config['n_items'] = data_generator.n_items
    config['n_relations'] = data_generator.n_relations
    config['n_entities'] = data_generator.n_entities

    if args.model_type in ['kgat', 'cfkg']:
        "Load the laplacian matrix."
        config['A_in'] = sum(data_generator.lap_list)

        "Load the KG triplets."
        config['all_h_list'] = data_generator.all_h_list
        config['all_r_list'] = data_generator.all_r_list
        config['all_t_list'] = data_generator.all_t_list
        config['all_v_list'] = data_generator.all_v_list

    t0 = time()

    """
    *********************************************************
    Use the pretrained data to initialize the embeddings.
    """
    if args.pretrain in [-1, -2]:
        pretrain_data = load_pretrained_data(args)
    else:
        pretrain_data = None

    """
    *********************************************************
    Select one of the models.
    """
    if args.model_type == 'bprmf':
        model = BPRMF(data_config=config, pretrain_data=pretrain_data, args=args)

    elif args.model_type == 'cke':
        model = CKE(data_config=config, pretrain_data=pretrain_data, args=args)

    elif args.model_type in ['cfkg']:
        model = CFKG(data_config=config, pretrain_data=pretrain_data, args=args)

    elif args.model_type in ['nfm', 'fm']:
        model = NFM(data_config=config, pretrain_data=pretrain_data, args=args)

    elif args.model_type in ['kgat']:
        model = KGAT(data_config=config, pretrain_data=pretrain_data, args=args)

    saver = tf.train.Saver()

    """
    *********************************************************
    Save the model parameters.
    """
    if args.save_flag == 1:
        if args.model_type in ['bprmf', 'cke', 'fm', 'cfkg']:
            weights_save_path = '%sweights/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type,
                                                             str(args.lr), '-'.join([str(r) for r in eval(args.regs)]))

        elif args.model_type in ['ncf', 'nfm', 'kgat']:
            layer = '-'.join([str(l) for l in eval(args.layer_size)])
            weights_save_path = '%sweights/%s/%s/%s/l%s_r%s' % (
                args.weights_path, args.dataset, model.model_type, layer, str(args.lr), '-'.join([str(r) for r in eval(args.regs)]))

        ensureDir(weights_save_path)
        save_saver = tf.train.Saver(max_to_keep=1)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    """
    *********************************************************
    Reload the model parameters to fine tune.
    """
    if args.pretrain == 1:
        if args.model_type in ['bprmf', 'cke', 'fm', 'cfkg']:
            pretrain_path = '%sweights/%s/%s/l%s_r%s' % (args.weights_path, args.dataset, model.model_type, str(args.lr),
                                                             '-'.join([str(r) for r in eval(args.regs)]))

        elif args.model_type in ['ncf', 'nfm', 'kgat']:
            layer = '-'.join([str(l) for l in eval(args.layer_size)])
            pretrain_path = '%sweights/%s/%s/%s/l%s_r%s' % (
                args.weights_path, args.dataset, model.model_type, layer, str(args.lr), '-'.join([str(r) for r in eval(args.regs)]))

        ckpt = tf.train.get_checkpoint_state(os.path.dirname(pretrain_path + '/checkpoint'))
        if ckpt and ckpt.model_checkpoint_path:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('load the pretrained model parameters from: ', pretrain_path)

            # *********************************************************
            # get the performance from the model to fine tune.
            if args.report != 1:
                users_to_test = list(data_generator.test_user_dict.keys())

                ret = test(sess, model, users_to_test, drop_flag=False, batch_test_flag=batch_test_flag)
                cur_best_pre_0 = ret['recall'][0]

                pretrain_ret = 'pretrained model recall=[%.5f, %.5f], precision=[%.5f, %.5f], hit=[%.5f, %.5f],' \
                               'ndcg=[%.5f, %.5f], auc=[%.5f]' % \
                               (ret['recall'][0], ret['recall'][-1],
                                ret['precision'][0], ret['precision'][-1],
                                ret['hit_ratio'][0], ret['hit_ratio'][-1],
                                ret['ndcg'][0], ret['ndcg'][-1], ret['auc'])
                print(pretrain_ret)

                # *********************************************************
                # save the pretrained model parameters of mf (i.e., only user & item embeddings) for pretraining other models.
                if args.save_flag == -1:
                    user_embed, item_embed = sess.run(
                        [model.weights['user_embedding'], model.weights['item_embedding']],
                        feed_dict={})
                    # temp_save_path = '%spretrain/%s/%s/%s_%s.npz' % (args.proj_path, args.dataset, args.model_type, str(args.lr),
                    #                                                  '-'.join([str(r) for r in eval(args.regs)]))
                    temp_save_path = '%spretrain/%s/%s.npz' % (args.proj_path, args.dataset, model.model_type)
                    ensureDir(temp_save_path)
                    np.savez(temp_save_path, user_embed=user_embed, item_embed=item_embed)
                    print('save the weights of fm in path: ', temp_save_path)
                    exit()

                # *********************************************************
                # save the pretrained model parameters of kgat (i.e., user & item & kg embeddings) for pretraining other models.
                if args.save_flag == -2:
                    user_embed, entity_embed, relation_embed = sess.run(
                        [model.weights['user_embed'], model.weights['entity_embed'], model.weights['relation_embed']],
                        feed_dict={})

                    temp_save_path = '%spretrain/%s/%s.npz' % (args.proj_path, args.dataset, args.model_type)
                    ensureDir(temp_save_path)
                    np.savez(temp_save_path, user_embed=user_embed, entity_embed=entity_embed, relation_embed=relation_embed)
                    print('save the weights of kgat in path: ', temp_save_path)
                    exit()

        else:
            sess.run(tf.global_variables_initializer())
            cur_best_pre_0 = 0.
            print('without pretraining.')
    else:
        sess.run(tf.global_variables_initializer())
        cur_best_pre_0 = 0.
        print('without pretraining.')

    """
    *********************************************************
    Get the final performance w.r.t. different sparsity levels.
    """
    if args.report == 1:
        assert args.test_flag == 'full'
        users_to_test_list, split_state = data_generator.get_sparsity_split()

        users_to_test_list.append(list(data_generator.test_user_dict.keys()))
        split_state.append('all')

        save_path = '%sreport/%s/%s.result' % (args.proj_path, args.dataset, model.model_type)
        ensureDir(save_path)
        f = open(save_path, 'w')
        f.write('embed_size=%d, lr=%.4f, regs=%s, loss_type=%s, \n' % (args.embed_size, args.lr, args.regs,
                                                                       args.loss_type))

        for i, users_to_test in enumerate(users_to_test_list):
            ret = test(sess, model, users_to_test, drop_flag=False, batch_test_flag=batch_test_flag)

            final_perf = "recall=[%s], precision=[%s], hit=[%s], ndcg=[%s]" % \
                         ('\t'.join(['%.5f' % r for r in ret['recall']]),
                          '\t'.join(['%.5f' % r for r in ret['precision']]),
                          '\t'.join(['%.5f' % r for r in ret['hit_ratio']]),
                          '\t'.join(['%.5f' % r for r in ret['ndcg']]))
            print(final_perf)

            f.write('\t%s\n\t%s\n' % (split_state[i], final_perf))
        f.close()
        exit()

    """
    *********************************************************
    Train.
    """
    loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
    stopping_step = 0
    should_stop = False

    for epoch in range(args.epoch):
        t1 = time()
        loss, base_loss, kge_loss, reg_loss = 0., 0., 0., 0.
        n_batch = data_generator.n_train // args.batch_size + 1

        """
        *********************************************************
        Alternative Training for KGAT:
        ... phase 1: to train the recommender.
        """
        for idx in range(n_batch):
            btime= time()

            batch_data = data_generator.generate_train_batch()
            feed_dict = data_generator.generate_train_feed_dict(model, batch_data)

            _, batch_loss, batch_base_loss, batch_kge_loss, batch_reg_loss = model.train(sess, feed_dict=feed_dict)

            loss += batch_loss
            base_loss += batch_base_loss
            kge_loss += batch_kge_loss
            reg_loss += batch_reg_loss

        if np.isnan(loss) == True:
            print('ERROR: loss@phase1 is nan.')
            sys.exit()

        """
        *********************************************************
        Alternative Training for KGAT:
        ... phase 2: to train the KGE method & update the attentive Laplacian matrix.
        """
        if args.model_type in ['kgat']:

            n_A_batch = len(data_generator.all_h_list) // args.batch_size_kg + 1

            if args.use_kge is True:
                # using KGE method (knowledge graph embedding).
                for idx in range(n_A_batch):
                    btime = time()

                    A_batch_data = data_generator.generate_train_A_batch()
                    feed_dict = data_generator.generate_train_A_feed_dict(model, A_batch_data)

                    _, batch_loss, batch_kge_loss, batch_reg_loss = model.train_A(sess, feed_dict=feed_dict)

                    loss += batch_loss
                    kge_loss += batch_kge_loss
                    reg_loss += batch_reg_loss

            if args.use_att is True:
                # updating attentive laplacian matrix.
                model.update_attentive_A(sess)

        if np.isnan(loss) == True:
            print('ERROR: loss@phase2 is nan.')
            sys.exit()

        show_step = 10
        if (epoch + 1) % show_step != 0:
            if args.verbose > 0 and epoch % args.verbose == 0:
                perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f]' % (
                    epoch, time() - t1, loss, base_loss, kge_loss, reg_loss)
                print(perf_str)
            continue

        """
        *********************************************************
        Test.
        """
        t2 = time()
        users_to_test = list(data_generator.test_user_dict.keys())

        ret = test(sess, model, users_to_test, drop_flag=False, batch_test_flag=batch_test_flag)

        """
        *********************************************************
        Performance logging.
        """
        t3 = time()

        loss_loger.append(loss)
        rec_loger.append(ret['recall'])
        pre_loger.append(ret['precision'])
        ndcg_loger.append(ret['ndcg'])
        hit_loger.append(ret['hit_ratio'])

        if args.verbose > 0:
            perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \
                       'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \
                       (epoch, t2 - t1, t3 - t2, loss, base_loss, kge_loss, reg_loss, ret['recall'][0], ret['recall'][-1],
                        ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1],
                        ret['ndcg'][0], ret['ndcg'][-1])
            print(perf_str)

        cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0,
                                                                    stopping_step, expected_order='acc', flag_step=10)

        # *********************************************************
        # early stopping when cur_best_pre_0 is decreasing for ten successive steps.
        if should_stop == True:
            break

        # *********************************************************
        # save the user & item embeddings for pretraining.
        if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1:
            save_saver.save(sess, weights_save_path + '/weights', global_step=epoch)
            print('save the weights in path: ', weights_save_path)

    recs = np.array(rec_loger)
    pres = np.array(pre_loger)
    ndcgs = np.array(ndcg_loger)
    hit = np.array(hit_loger)

    best_rec_0 = max(recs[:, 0])
    idx = list(recs[:, 0]).index(best_rec_0)

    final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]" % \
                 (idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]),
                  '\t'.join(['%.5f' % r for r in pres[idx]]),
                  '\t'.join(['%.5f' % r for r in hit[idx]]),
                  '\t'.join(['%.5f' % r for r in ndcgs[idx]]))
    print(final_perf)

    save_path = '%soutput/%s/%s.result' % (args.proj_path, args.dataset, model.model_type)
    ensureDir(save_path)
    f = open(save_path, 'a')

    f.write('embed_size=%d, lr=%.4f, layer_size=%s, node_dropout=%s, mess_dropout=%s, regs=%s, adj_type=%s, use_att=%s, use_kge=%s, pretrain=%d\n\t%s\n'
            % (args.embed_size, args.lr, args.layer_size, args.node_dropout, args.mess_dropout, args.regs, args.adj_type, args.use_att, args.use_kge, args.pretrain, final_perf))
    f.close()


### https://github.com/thunlp/KernelGAT.git

In [None]:
!git clone https://github.com/thunlp/KernelGAT

Cloning into 'KernelGAT'...
remote: Enumerating objects: 395, done.[K
remote: Counting objects: 100% (71/71), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 395 (delta 57), reused 49 (delta 48), pack-reused 324[K
Receiving objects: 100% (395/395), 2.32 MiB | 17.10 MiB/s, done.
Resolving deltas: 100% (213/213), done.


In [None]:
!pip install fever_score
!pip install torch
!pip install pytorch_pretrained_bert
!pip install transformers

[31mERROR: Could not find a version that satisfies the requirement fever_score (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for fever_score[0m[31m


In [None]:
!git clone https://github.com/thunlp/KernelGAT.git

Cloning into 'KernelGAT'...
remote: Enumerating objects: 395, done.[K
remote: Counting objects: 100% (71/71), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 395 (delta 57), reused 49 (delta 48), pack-reused 324[K
Receiving objects: 100% (395/395), 2.32 MiB | 1.53 MiB/s, done.
Resolving deltas: 100% (213/213), done.


In [None]:
%cd /content/drive/MyDrive/DSC/Data_CheckPoints

/root


In [None]:
#Download check points
!wget https://thunlp.oss-cn-qingdao.aliyuncs.com/KernelGAT/FEVER/KernelGAT.zip
!wget https://thunlp.oss-cn-qingdao.aliyuncs.com/KernelGAT/FEVER/KernelGAT_roberta_large.zip

--2023-09-22 14:29:39--  https://thunlp.oss-cn-qingdao.aliyuncs.com/KernelGAT/FEVER/KernelGAT.zip
Resolving thunlp.oss-cn-qingdao.aliyuncs.com (thunlp.oss-cn-qingdao.aliyuncs.com)... 47.104.38.189
Connecting to thunlp.oss-cn-qingdao.aliyuncs.com (thunlp.oss-cn-qingdao.aliyuncs.com)|47.104.38.189|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2505476826 (2.3G) [application/zip]
Saving to: ‘KernelGAT.zip’


2023-09-22 14:33:36 (10.1 MB/s) - ‘KernelGAT.zip’ saved [2505476826/2505476826]

--2023-09-22 14:33:36--  https://thunlp.oss-cn-qingdao.aliyuncs.com/KernelGAT/FEVER/KernelGAT_roberta_large.zip
Resolving thunlp.oss-cn-qingdao.aliyuncs.com (thunlp.oss-cn-qingdao.aliyuncs.com)... 47.104.38.189
Connecting to thunlp.oss-cn-qingdao.aliyuncs.com (thunlp.oss-cn-qingdao.aliyuncs.com)|47.104.38.189|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2154388119 (2.0G) [application/zip]
Saving to: ‘KernelGAT_roberta_large.zip’


2023-09-22 14:39:03

In [None]:
!mkdir Kgat
!unzip KernelGAT.zip -d /Kgat
!unzip KernelGAT_roberta_large.zip -d /Kgat

Archive:  KernelGAT.zip
   creating: /Kgat/KernelGAT/
  inflating: /Kgat/KernelGAT/.DS_Store  
   creating: /Kgat/__MACOSX/
   creating: /Kgat/__MACOSX/KernelGAT/
  inflating: /Kgat/__MACOSX/KernelGAT/._.DS_Store  
   creating: /Kgat/KernelGAT/checkpoint/
   creating: /Kgat/KernelGAT/checkpoint/retrieval_model/
  inflating: /Kgat/KernelGAT/checkpoint/retrieval_model/model.best.pt  
   creating: /Kgat/__MACOSX/KernelGAT/checkpoint/
   creating: /Kgat/__MACOSX/KernelGAT/checkpoint/retrieval_model/
  inflating: /Kgat/__MACOSX/KernelGAT/checkpoint/retrieval_model/._model.best.pt  
  inflating: /Kgat/KernelGAT/checkpoint/.DS_Store  
  inflating: /Kgat/__MACOSX/KernelGAT/checkpoint/._.DS_Store  
   creating: /Kgat/KernelGAT/checkpoint/kgat/
  inflating: /Kgat/KernelGAT/checkpoint/kgat/model.best.pt  
   creating: /Kgat/__MACOSX/KernelGAT/checkpoint/kgat/
  inflating: /Kgat/__MACOSX/KernelGAT/checkpoint/kgat/._model.best.pt  
   creating: /Kgat/KernelGAT/checkpoint/pretrain/
  inflating: /Kga

In [None]:
#read data files
import json

with open('your_json_file.json', 'r') as file:
    for line in file:
        try:
            json_data = json.loads(line)
            # Process the JSON data here
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError: {e}")

### Kgat inference

In [None]:
%cd /content/KernelGAT

/content/KernelGAT


In [None]:
from kgat.data_loader import DataLoaderTest
from kgat.bert_model import BertForSequenceEncoder

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [None]:
import random, os
import argparse
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torch.autograd import Variable
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer

from kgat.data_loader import DataLoaderTest
from kgat.bert_model import BertForSequenceEncoder
from pytorch_pretrained_bert.optimization import BertAdam
import logging
import json
import random, os
import argparse
import numpy as np
from kgat.models import inference_model


In [None]:
def eval_model(model, label_list, validset_reader, outdir, name):
    outpath = outdir + name

    with open(outpath, "w") as f:
        for index, data in enumerate(validset_reader):
            inputs, ids = data
            logits = model(inputs)
            preds = logits.max(1)[1].tolist()
            assert len(preds) == len(ids)
            for step in range(len(preds)):
                instance = {"id": ids[step], "predicted_label": label_list[preds[step]]}
                f.write(json.dumps(instance) + "\n")


In [None]:
def kernal_mus(n_kernels):
    """
    get the mu for each guassian kernel. Mu is the middle of each bin
    :param n_kernels: number of kernels (including exact match). first one is exact match
    :return: l_mu, a list of mu.
    """
    l_mu = [1]
    if n_kernels == 1:
        return l_mu

    bin_size = 2.0 / (n_kernels - 1)  # score range from [-1, 1]
    l_mu.append(1 - bin_size / 2)  # mu: middle of the bin
    for i in range(1, n_kernels - 1):
        l_mu.append(l_mu[i] - bin_size)
    return l_mu


def kernel_sigmas(n_kernels):
    """
    get sigmas for each guassian kernel.
    :param n_kernels: number of kernels (including exactmath.)
    :param lamb:
    :param use_exact:
    :return: l_sigma, a list of simga
    """
    bin_size = 2.0 / (n_kernels - 1)
    l_sigma = [0.001]  # for exact match. small variance -> exact match
    if n_kernels == 1:
        return l_sigma

    l_sigma += [0.1] * (n_kernels - 1)
    return l_sigma

In [None]:
label_map = {'SUPPORTS': 0, 'REFUTES': 1, 'NOT ENOUGH INFO': 2}
label_list = ['SUPPORTS', 'REFUTES', 'NOT ENOUGH INFO']
args.num_labels = len(label_map)
tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False)

validset_reader = DataLoaderTest(args.test_path, label_map, tokenizer, args, batch_size=args.batch_size)

bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain)
bert_model = bert_model.cuda()
bert_model.eval()
model = inference_model(bert_model, args)
model.load_state_dict(torch.load(args.checkpoint)['model'])
model = model.cuda()
model.eval()
eval_model(model, label_list, validset_reader, args.outdir, args.name)
model.eval()

### Retrieval model

In [None]:
def save_to_file(all_predict, outpath):
    with open(outpath, "w") as out:
        for key, values in all_predict.items():
            sorted_values = sorted(values, key=lambda x:x[-1], reverse=True)
            data = json.dumps({"id": key, "evidence": sorted_values[:5]})
            out.write(data + "\n")

In [None]:
#retrieval eval
def eval_model(model, validset_reader):
    model.eval()
    all_predict = dict()
    for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in validset_reader:
        probs = model(inp_tensor, msk_tensor, seg_tensor)
        probs = probs.tolist()
        assert len(probs) == len(evi_list)
        for i in range(len(probs)):
            if ids[i] not in all_predict:
                all_predict[ids[i]] = []
            #if probs[i][1] >= probs[i][0]:
            all_predict[ids[i]].append(evi_list[i] + [probs[i]])
    return all_predict

In [None]:

out_dir =" "
test_path =" "
batch_size = 64
args.cuda = not args.no_cuda and torch.cuda.is_available()


In [None]:
#Load args
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', help='train path')
parser.add_argument('--name', help='train path')
parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument('--outdir', required=True, help='path to output directory')
parser.add_argument('--bert_pretrain', required=True)
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.')
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.")
parser.add_argument("--layer", type=int, default=1, help='Graph Layer.')
parser.add_argument("--num_labels", type=int, default=3)
parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.')
parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.')
parser.add_argument("--max_len", default=120, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
args = parser.parse_args()

In [None]:
#From retrieval model test.py

tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False)
#Change the arg or dataloader test
validset_reader = DataLoaderTest(test_path, tokenizer, args, batch_size= batch_size)


bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain)
bert_model = bert_model.cuda()
model = inference_model(bert_model, args)
model.load_state_dict(torch.load(args.checkpoint)['model'])
model = model.cuda()

save_path = args.outdir + "/" + args.name
predict_dict = eval_model(model, validset_reader)
#save predictions
save_to_file(predict_dict, save_path)

In [None]:
!pip install huggingface
!pip install datasets

Collecting huggingface
  Downloading huggingface-0.0.1-py3-none-any.whl (2.5 kB)
Installing collected packages: huggingface
Successfully installed huggingface-0.0.1
Collecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32

In [None]:
from datasets import load_dataset
dataset = load_dataset("fever","wiki_pages")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating wikipedia_pages split:   0%|          | 0/5416537 [00:00<?, ? examples/s]

In [None]:
wiki_pages = dataset
wiki_pages['wikipedia_pages'].to_pandas()

Unnamed: 0,id,text,lines
0,,,
1,1928_in_association_football,The following are the football -LRB- soccer -R...,0\tThe following are the football -LRB- soccer...
2,1986_NBA_Finals,The 1986 NBA Finals was the championship round...,0\tThe 1986 NBA Finals was the championship ro...
3,1901_Villanova_Wildcats_football_team,The 1901 Villanova Wildcats football team repr...,0\tThe 1901 Villanova Wildcats football team r...
4,1992_Northwestern_Wildcats_football_team,The 1992 Northwestern Wildcats team represente...,0\tThe 1992 Northwestern Wildcats team represe...
...,...,...,...
5416532,Yuto_Agarie,Yuto Agarie -LRB- born 6 July 1993 -RRB- is a ...,0\tYuto Agarie -LRB- born 6 July 1993 -RRB- is...
5416533,Yume_1_Go,is the eleventh single by the Japanese rock ba...,0\tis the eleventh single by the Japanese rock...
5416534,Yada_Yada_-LRB-album-RRB-,Yada Yada is the eighth studio album by Dutch ...,0\tYada Yada is the eighth studio album by Dut...
5416535,Xylorycta_bipunctella,Xylorycta bipunctella is a moth in the Xyloryc...,0\tXylorycta bipunctella is a moth in the Xylo...


In [None]:
train = dataset['train'].to_pandas()

In [None]:
train

Unnamed: 0,id,label,claim,evidence_annotation_id,evidence_id,evidence_wiki_url,evidence_sentence_id
0,75397,SUPPORTS,Nikolaj Coster-Waldau worked with the Fox Broa...,92206,104971,Nikolaj_Coster-Waldau,7
1,75397,SUPPORTS,Nikolaj Coster-Waldau worked with the Fox Broa...,92206,104971,Fox_Broadcasting_Company,-1
2,150448,SUPPORTS,Roman Atwood is a content creator.,174271,187498,Roman_Atwood,1
3,150448,SUPPORTS,Roman Atwood is a content creator.,174271,187499,Roman_Atwood,3
4,214861,SUPPORTS,"History of art includes architecture, dance, s...",255136,254645,History_of_art,2
...,...,...,...,...,...,...,...
311426,13114,SUPPORTS,J. R. R. Tolkien created Gimli.,28359,34669,Gimli_-LRB-Middle-earth-RRB-,-1
311427,13114,SUPPORTS,J. R. R. Tolkien created Gimli.,28359,34670,Gimli_-LRB-Middle-earth-RRB-,1
311428,152180,SUPPORTS,Susan Sarandon is an award winner.,176133,189101,Susan_Sarandon,1
311429,152180,SUPPORTS,Susan Sarandon is an award winner.,176133,189102,Susan_Sarandon,2


### Abtract retrieval

In [None]:
!pip install rank_bm25

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [None]:
from rank_bm25 import BM25Okapi

corpus = [
    "Hello there good man!",
    "It is quite windy in London",
    "How is the weather today?"
]

tokenized_corpus = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)

In [None]:
query = "windy London"
tokenized_query = query.split(" ")
doc_scores = bm25.get_scores(tokenized_query)

In [None]:
bm25.get_top_n(tokenized_query, corpus, n=1)

['It is quite windy in London']