In [1]:
import os
# import pprint
import json
import copy
import time
import argparse

import numpy as np

import torch
import torch.nn as nn

from tqdm import tqdm

from functions_openllm import use_api_base, sure_infer,use_api_background,use_fusion_document,use_api_base_retrieval,use_api_base_retrievalAndfusion
from data_utils import get_em_f1

from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


parser = argparse.ArgumentParser(description='Query QA Data to GPT API.')
    parser.add_argument('--data_name', type=str, default=None, help='Name of QA Dataset')
    parser.add_argument('--qa_data', type=str, default=None, help='Path to QA Dataset')
    parser.add_argument('--start', type=int, default=None, help='Start index of QA Dataset')
    parser.add_argument('--end', type=int, default=None, help='End index of QA Dataset')
    parser.add_argument('--lm_type', type=str, default='llama2', help='Type of LLM (llama2, gemma, mistral)')
    parser.add_argument('--n_retrieval', type=int, default=10, help='Number of retrieval-augmented passages')
    parser.add_argument('--infer_type', type=str, default='sure', help='Inference Method (base or sure)', choices=['base', 'sure'])
    parser.add_argument('--output_folder', type=str, default=None, help='Path for save output files')
    
    args = parser.parse_args()

In [2]:
from types import SimpleNamespace

# 模拟命令行参数
args = SimpleNamespace(
    data_name="webqa",
    qa_data='../data/bm25/nq-test-bm25.json',
    start=None,
    end=None,
    lm_type='llama',
    n_retrieval=10,
    infer_type='base',
    output_folder='../data/dpr/'
)

# 现在你可以像使用 argparse 一样使用 args 对象
print(args.lm_type)
print(args.qa_data)

llama
../data/bm25/nq-test-bm25.json


In [3]:
#加载检索数据集
print("=====> Data Load...")
#dataset = json.load(open("../data/contriever/wq-test-contriever.json"))
dataset = json.load(open(args.qa_data))
start_idx, end_idx = args.start, args.end
if start_idx is None:
    start_idx = 0
elif end_idx is None:
    end_idx = len(dataset)
else:
    if start_idx >= end_idx:
     raise ValueError
dataset = dataset[start_idx:end_idx]
print(dataset[0])
print("Number of QA Samples: {}".format(len(dataset)))

=====> Data Load...
{'question': 'who got the first nobel prize in physics', 'answers': ['Wilhelm Conrad Röntgen'], 'contexts': [{'docid': '628725', 'score': '16.331089', 'has_answer': False, 'title': 'Nobel Prize in Physics', 'text': 'receive a diploma, a medal and a document confirming the prize amount. Nobel Prize in Physics The Nobel Prize in Physics () is a yearly award given by the Royal Swedish Academy of Sciences for those who have made the most outstanding contributions for mankind in the field of physics. It is one of the five Nobel Prizes established by the will of Alfred Nobel in 1895 and awarded since 1901; the others being the Nobel Prize in Chemistry, Nobel Prize in Literature, Nobel Peace Prize, and Nobel Prize in Physiology or Medicine. The first Nobel Prize in Physics was'}, {'docid': '12584253', 'score': '16.265154', 'has_answer': False, 'title': 'Frances Arnold', 'text': 'peptides and antibodies." She is the first female graduate of Princeton to be awarded a Nobel P

In [4]:
#加载大模型生成文档
Ldataset = json.load(open("../data/webqa_start0_endNone_base_ret10/backgroundLma70B.json"))
print(Ldataset[0])
print("Number of QA Samples: {}".format(len(Ldataset)))

FileNotFoundError: [Errno 2] No such file or directory: '../data/webqa_start0_endNone_base_ret10/backgroundLma70B.json'

In [5]:
#加载融合后的文档
Fdataset = json.load(open("../data/dpr/webqa_start0_endNone_base_ret10/FusionLma70B.json"))
print("Number of QA Samples: {}".format(len(dataset)))

Number of QA Samples: 2032


In [6]:
if args.lm_type == "gemma":
    model = AutoModelForCausalLM.from_pretrained("google/gemma-1.1-7b-it")
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")
elif args.lm_type == "mistral":
    model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
elif args.lm_type == "llama":
    tokenizer = AutoTokenizer.from_pretrained("../pre/Llama3-70b/LLM-Research/Llama-3___3-70B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("../pre/Llama3-70b/LLM-Research/Llama-3___3-70B-Instruct",device_map="auto",torch_dtype=torch.float16)
else:
    raise ValueError
#model = model.cuda()   上面已经指定了device_map="auto",torch_dtype=torch.float16，所以不需要再从内存中把模型移动到GPU上

Loading checkpoint shards: 100%|██████████| 30/30 [00:44<00:00,  1.49s/it]


In [7]:
if not os.path.exists(args.output_folder):
    os.makedirs(args.output_folder)
method = f'{args.data_name}_start{start_idx}_end{end_idx}_{args.infer_type}_ret{str(args.n_retrieval)}'
method_folder = args.output_folder + '/{}'.format(method)
if not os.path.exists(method_folder):
    os.makedirs(method_folder)
method_folder

'../data/dpr//webqa_start0_endNone_base_ret10'

In [None]:
print("=====> Begin Inference (type: {})".format(args.infer_type))
#让大模型生成相关背景文档
#results = use_api_background(model, args.lm_type, tokenizer, dataset,n=1)

#让大模型生成融合后的文档
#results = use_fusion_document(model, args.lm_type, tokenizer, dataset,Ldataset)

#让大模型根据融合文档生成答案
#results = use_api_base(model, args.lm_type, tokenizer, dataset,Fdataset)

#results = use_api_base_retrieval(model, args.lm_type, tokenizer, dataset,n_articles=30)
results = use_api_base_retrievalAndfusion(model,args.lm_type, tokenizer, dataset,Fdataset,n_articles=10)
print(results)

=====> Begin Inference (type: base)
This is the relevant documentation:


Jamaican people speak multiple languages, with the primary spoken language being Jamaican Patois, also known as Patwa or Jamaican Creole. This language is an English-based creole language with West African influences and is spoken by the majority of Jamaicans as a native language. It is used in informal situations, at home, and in local popular music. 

In addition to Jamaican Patois, many Jamaicans also speak Standard English, which is the official language of Jamaica and is used in formal and professional contexts, such as education, government, and media. Some Jamaicans may also speak Jamaican English, which is a variety of English that resembles parts of both British and American English dialects.

Jamaican Patois and Standard English coexist and influence each other, with many Jamaicans speaking both languages and using them in different contexts. Jamaican Patois has its own distinct grammar, vocabulary, and

In [4]:
results = json.load(open("../data/bm25/nq_start0_endNone_base_ret10/resultLma70B.json"))

In [5]:
results = json.load(open("../data/webqa_start0_endNone_base_ret10/result(NoRetrieval)Lma70B.json"))
print("=====> All Procedure is finished!")
with open(f'./{method_folder}/backgroundLma70B.json', "w", encoding='utf-8') as writer:
    writer.write(json.dumps(results, indent=4, ensure_ascii=False) + "\n")

=====> All Procedure is finished!


In [5]:
#print("=====> Results of {}".format(method))
em, f1 = get_em_f1(dataset, results)
print("EM: {} F1: {}".format(em.mean(), f1.mean()))

EM: 0.2922437673130194 F1: 0.45289803456008443
