## Running inference for the pre-trained model on OTTQA datset

In [1]:
config = {
    "per_device_train_batch_size_rr":8,
    "per_device_eval_batch_size_rr":8,
    "rr_model_name":"bert-base-uncased",
    "row_retriever_model_name_path":"data/ottqa/row_retriever/rr.bin",
    "pos_frac_per_epoch":[0.3, 0.3, 0.1, 0.0001, 0.0001],
    "group_frac_per_epoch":[0.0, 0.5, 1.0, 1.0, 1.0],
    "max_seq_length":512,
    "per_gpu_train_batch_size":8,
    "train_batch_size":8,
    "per_gpu_eval_batch_size":8,
    "eval_batch_size":8,
    "max_query_length":64,
    "threads":1,
    "null_score_diff_threshold":0.0,
    "n_best_size":20,
    "do_predict_ae":True,
    "n_gpu":1,
    "max_answer_length":30,
    "model_name_or_path_ae":"bert-base-uncased",
    "output_dir":"data/ottqa/models/answer_extractor/",
    "model_type":"bert",
    "doc_stride":128,
    "pred_ans_file":"data/ottqa/predictions/answer_extractor_output_test.json",
    "eval_file":"data/ottqa/ae_input_test.json",
    "model":"gpt2",
    "top_k":0,
    "top_p":0.9,
    "seed_lg":42,
    "batch_size_lg":2,
    "linker_model":"data/ottqa/models/link_generator/model-ep9.pt",
    "max_source_len":32,
    "max_target_len":16,
    "do_all_lg":True,
    "data_path_root":"data/ottqa/",
    "dataset_name":"ottqa",
    "test_data_path":"data/ottqa/released_data/toy.json",
    "collections_file":"linearized_tables.tsv",
    "test":True
}

In [2]:
from transformers import (
    HfArgumentParser,
    TrainingArguments,
)
from primeqa.mitqa.utils.arguments_utils import HybridQAArguments,LinkPredictorArguments, RRArguments,AEArguments


hqa_parser = HfArgumentParser((HybridQAArguments,LinkPredictorArguments, RRArguments,AEArguments))


  from .autonotebook import tqdm as notebook_tqdm


## parse configs from config dict

In [3]:
import json

from primeqa.mitqa.utils.model_utils.row_retriever_MITQA import RowRetriever
from primeqa.mitqa.utils.model_utils.reranker import re_rank_ae_output
from primeqa.mitqa.utils.link_predictor import predict_link_for_tables,train_link_generator
from primeqa.mitqa.utils.model_utils.table_retriever import train_table_retriever,predict_table_retriever
from primeqa.mitqa.utils.model_utils.process_row_retriever_output import preprocess_data_using_row_retrieval_scores,create_dataset_for_answer_extractor
from primeqa.mitqa.utils.model_utils.answer_extractor_multi_Answer import run_answer_extractor
from primeqa.mitqa.processors.preprocessors.preprocess_raw_data import preprocess_data
import logging
import torch
import os
import sys
from primeqa.mitqa.utils.arguments_utils import HybridQAArguments,LinkPredictorArguments, RRArguments,AEArguments

# Read arguments and load the dataset
test=True
hqa_args,lp_args,rr_args,ae_args,= hqa_parser.parse_dict(config)
raw_test_data = json.load(open(hqa_args.test_data_path))
retrieved_data = predict_table_retriever(hqa_args.data_path_root,hqa_args.collections_file,raw_test_data)
linked_data = predict_link_for_tables(lp_args,retrieved_data)
test_data_processed = preprocess_data(hqa_args.data_path_root,hqa_args.dataset_name,linked_data,split="test",test=test)
print("Initial preprocessing done")
rr = RowRetriever(hqa_args,rr_args)
qid_scores_dict = rr.predict(test_data_processed)
print("Row retrieval predictions Done")
test_processed_data = preprocess_data_using_row_retrieval_scores(raw_test_data,qid_scores_dict,test)
print("Row retrieval output processed")
answer_extraction_data = create_dataset_for_answer_extractor(test_processed_data,hqa_args.data_path_root,test)
print("Answer extraction data generated")
ae_output_path,ae_output_path_nbest = run_answer_extractor(ae_args,answer_extraction_data)
print(ae_output_path)
print(ae_output_path_nbest)
re_ranked_output = re_rank_ae_output(qid_scores_dict,ae_output_path_nbest,ae_args.pred_ans_file) 
print("re-ranked output saved at ",re_ranked_output)

{"time":"2022-12-23 05:42:41,777", "name": "sentence_transformers.SentenceTransformer", "level": "INFO", "message": "Load pretrained SentenceTransformer: msmarco-distilbert-base-tas-b"}
{"time":"2022-12-23 05:42:43,593", "name": "sentence_transformers.SentenceTransformer", "level": "INFO", "message": "Use pytorch device: cpu"}
{"time":"2022-12-23 05:42:44,287", "name": "sentence_transformers.SentenceTransformer", "level": "INFO", "message": "Load pretrained SentenceTransformer: msmarco-distilbert-base-tas-b"}
{"time":"2022-12-23 05:42:45,739", "name": "sentence_transformers.SentenceTransformer", "level": "INFO", "message": "Use pytorch device: cpu"}


Batches: 100%|██████████| 1/1 [00:00<00:00,  3.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.64it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.97it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.23it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.20it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  4.20it/s]
