In [None]:
## CALL API ENPOINTS (LLM, EMBEDDING)
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "3"

In [None]:
import os
from tqdm import tqdm
import copy
import numpy as np

os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""

In [None]:
### CALL LLM
from transformers import AutoTokenizer
from langchain_community.llms import VLLMOpenAI
from langchain_openai import ChatOpenAI
from config import *

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

### For Chat OpenAI template
llm = ChatOpenAI(
	model=MODEL_NAME,
	openai_api_key="EMPTY",
	max_tokens=512,
	openai_api_base=INFERENCE_SERVER_URL,
	temperature=0,
	streaming= False
)

In [None]:
from beam import *
from transformers import AutoTokenizer
ebd_tok = AutoTokenizer.from_pretrained("embedding_model")

In [None]:
### Embeeding

### Call API Endpoint Embedding
import json
import requests
from typing import List
from langchain_core.embeddings import Embeddings
from tqdm.notebook import tqdm

def max_len(query):
	t = ebd_tok.encode(query)
	if len(t) > 512:
		t = t[:500]
		query = ebd_tok.decode(t)
	return query
	
class CustomAPIEmbeddings(Embeddings):
	def __init__(self, api_url: str, show_progress: bool):
		self.api_url = api_url
		self.show_progress = show_progress

	def embed_documents(self, texts: List[str]) -> List[List[float]]:
		lst_embedding = []
		if self.show_progress:  # for tqdm embedding
			for query in tqdm(texts):
				# query = max_len(query)
				payload = json.dumps({
					"inputs": query
				})
				headers = {
					'Content-Type': 'application/json'
				}
				try:
					response = json.loads(
						requests.request("POST", self.api_url, headers=headers, data=payload).text
					)
					lst_embedding.append(response[0])
				except Exception as e:
					print(f"Error: {e}")
					print(requests.request("POST", self.api_url, headers=headers, data=payload).text)
		else:
			for query in texts:
				query = max_len(query)
				payload = json.dumps({
					"inputs": query
				})
				headers = {
					'Content-Type': 'application/json'
				}
				try:
					response = json.loads(
						requests.request("POST", self.api_url, headers=headers, data=payload).text
					)
					lst_embedding.append(response[0])
				except Exception as e:
					print(f"Error: {e}")
					# print(requests.request("POST", self.api_url, headers=headers, data=payload).text)

		return lst_embedding

	def embed_query(self, text: str) -> List[float]:
		return self.embed_documents([text])[0]

# Instantiate
embeddings = CustomAPIEmbeddings(api_url=INFERENCE_SERVER_URL, show_progress=False)


### test here:

### 1. Load Graph Data

#### Load Data (Triplets, Triplets Relation Embeddings)

In [None]:

import pickle
with open(TRIPLET_MAP_PATH,'rb') as f:
	dct_mapping_triplet = pickle.load(f)

with open(TRIPLET_EMB_PATH,'rb') as f:
	lst_embedding = pickle.load(f)

In [None]:
#### convert pickle file to numpy

import numpy as np
lst_embedding = np.array(lst_embedding)

In [None]:
import pandas as pd
df_test = pd.read_csv("final_data.csv")
test_data = df_test['question'].tolist()

In [None]:
import faiss
faiss_embeddings = lst_embedding.astype('float32')
d = faiss_embeddings.shape[1] 
index = faiss.IndexFlatIP(d)
index.add(faiss_embeddings) 

### 2. Contextxual Question Retrieval (CQR)

In [None]:

from collections import defaultdict, namedtuple

model = embeddings
Triplet = namedtuple("Triplet", ["head", "relation", "tail", "ttr"])
KG_list = [
		Triplet(rec["r"][0]["id"], rec["r"][1], rec["r"][2]["id"], rec["r.summary"])
		for rec in dct_mapping_triplet
	]
KG = build_undirected_graph(KG_list)
all_summaries = {edge["r"]["summary"] for edges in KG.values() for edge in edges}
summary_embeddings = dict(
		zip(
			list(all_summaries),
			model.embed_query(list(all_summaries))
		)
	)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from langchain.prompts import PromptTemplate
from typing import Literal
import multiprocessing

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

def faiss_cosine(index, query_vector, k=10):
	query_vector = query_vector.astype('float32')
	distances, indices = index.search(query_vector, k)
	return indices.flatten()

	
def max_len(query):
	t = ebd_tok.encode(query)
	if len(t) > 512:
		t = t[:500]
		query = ebd_tok.decode(t)
	return query

def query_triplet_topk(query, k=10):
	query = max_len(query)
	query_emb = np.array(embeddings.embed_query(query)).reshape(1,-1)
	# similarities = cosine_similarity(query_emb, lst_embedding).flatten()
	topk_indices_sorted = faiss_cosine(query_emb).tolist()
	return [dct_mapping_triplet[x] for x in topk_indices_sorted]


class GradeRelationList(BaseModel):
	"""List passage index check on retrieved text."""
	passage_idx: str = Field(
		description="The passage index of relevant chunks, seperated by a comma"
	)

def check_grade_lst(question, text):
	prompt_text_grader = PromptTemplate(template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance 
		of a list of retrieved passages to a user question. The goal is to filter out erroneous retrievals. \n
		Return only the passage index whether the passage is relevant to the question. \n
		Provide the output as a JSON with passage index seperated by a comma and no premable or explaination.
		 <|eot_id|><|start_header_id|>user<|end_header_id|>
		Here is the list of retrieved text: \n\n {text} \n\n
		Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
		""",
		input_variables=["question", "text"]
	)
	# retrieval_grader = LLMChain(prompt=prompt, llm=llm)
	structured_llm_grader = llm.with_structured_output(GradeRelationList)
	relation_grader = prompt_text_grader | structured_llm_grader 
	result = relation_grader.invoke({"question": question, "text": text})
	# print(result)
	return result


def format_relations(relations):
	result = []
	for rel in relations:
		formatted_relation = f"{rel['n']['id']} - {rel['r'][1]} -> {rel['m']['id']}"
		result.append(formatted_relation)
	return result

In [None]:
import traceback, time

cnt_err = 0
def format_claim(relations):
	return "\n\n".join(f"{idx+1}. {rel['r']['summary']}" for idx, rel in enumerate(relations))

def format_triplet(relations):
	return "\n\n".join(f"{idx+1}. ({rel['r'][0]['id']}, {rel['r'][1]}, {rel['r'][2]['id']})" for idx, rel in enumerate(relations))


class contextual_output(BaseModel):
	"""contextual summarization for the input question."""
	summary: str = Field(
		description="Concise summary ocontextual information of the input question"
	)

class contextual_triplets(BaseModel):
	"""contextual generation of knowledge subgraph."""
	context: str = Field(
		description="generate concise contextual information based on list of triplets. Output MUST be VALID JSON"
	)
	

def contextual_question_retrieval(claims):
	system_promt=("You are a helpful assistant responsible for generating a comprehensive summary of the data provided below."
	" Given the list of claims that may relation with each other. Please write a Concise summary of claims that aim to provide a contextual information."
	" The output just generate a concise summary without any explaination."
	" Please note that if the provided claims are contradictory, please resolve the contradictions and provide a single, coherent summary (no need Here is part)")
	chat_template_contextual = tokenizer.apply_chat_template(
		[
			{"role":"system", "content":"{system}"},
			{"role":"user", "content":"\nHere is the list of claims {claims}\n"}
		], tokenize=False, add_generation_prompt=True)
	
	prompt_summary_contextual = PromptTemplate(template=chat_template_contextual, input_variables=["system", "claims"])
	structured_summary_contextual = llm.with_structured_output(contextual_output)
	contextual_chain = prompt_summary_contextual | structured_summary_contextual 
	results = contextual_chain.invoke({"system": system_promt, "claims": claims})
	return results


def add_triplet_context_to_question(KG, model, summary_embeddings, question):
	# global KG
	# global model
	# global summary_embeddings
	relations = query_triplet_topk(question)
	T = format_T(relations)### check all relations in one LLM call
	try:
		context = check_grade_lst(question, format_claim(T)).passage_idx
		context = [int(x) for x in context.split(",")]
		T = [T[x-1] for x in context]
		if T == []:
			contextual_summary = ""
		else:
			H = relevance_guided_path_addition_beam(KG, T, model, summary_embeddings, 20, 20, 2)
			contextual_summary = contextual_question_retrieval(format_claim(H)).summary
	except Exception as e:
		contextual_summary = ""
	if contextual_summary != "":
		question = question + " with some extra data: " + contextual_summary
	return question

In [None]:
lst_triplet_top_k_cos = []
for i in tqdm(test_data):
	lst_triplet_top_k_cos.append(query_triplet_topk(i))

map_triplet = {}
for i,j in zip(lst_triplet_top_k_cos, test_data):
	map_triplet[j] = i

### 3. CQR for Multi-Step Questions

#### 3.1 Loading Data

In [None]:
df_test.head()
docs_corpus = df_test["documents"].tolist()
eval(random.choice(docs_corpus))

In [None]:
# BM25
with open("passages.txt","r") as f:
	lst_chunks = f.read().split("<endofpassage>")[:-1]
mapping_chunks = {j:i for i,j in enumerate(list(set(lst_chunks)))}
lst_chunks = list(set(lst_chunks))

In [None]:
### Visual length of context
import matplotlib.pyplot as plt
import matplotlib

# %%matplotlib.inline()

length = [len(x.split(" ")) for x in lst_chunks]
plt.hist(length, bins=20)
plt.show()


#### 3.2 Excuting Baseline - IRCOT
ref: https://github.com/stonybrooknlp/ircot

##### 3.2.1 Retrieve Modulus

In [None]:
### Retrieval 
### Using BM25
from rank_bm25 import BM25Okapi
from tqdm.notebook import tqdm

tokenized_corpus = [doc.split(" ") for doc in lst_chunks]
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
def retrieval_bm25(question, k):
	tokenized_query = question.split(" ")
	lst_retrieval = bm25.get_top_n(tokenized_query, lst_chunks, n=k)
	return lst_retrieval

#### BGE

In [None]:
from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel('/workspace/home/NLP_CORE/HUB_Embedding/bge-large-en-v1.5/',
					   use_fp16=True)

In [None]:
passages = lst_chunks
sentences_1 = passages
embeddings_1 = embeddings.embed_documents(lst_chunks)

In [None]:
p_embd = []
for i in tqdm(range(len(embeddings_1))):
	p_embd.append(embeddings_1[i])

In [None]:
p_embd = np.array(p_embd)
p_embd.shape

In [None]:
index_p = faiss.IndexFlatIP(p_embd.shape[1])  # IP = Inner Product for cosine similarity
index_p.add(p_embd)  # Add encoded passages to the index

In [None]:
def retrieval_bge(query, k, alpha=0.7):
	query = query.split(" with some extra data: ")
	if len(query) > 1:
		question = query[0]
		context = query[1]
		q_embd = np.array(embeddings.embed_query(max_len(question)))
		if len(context) > 0:
			c_embd = np.array(embeddings.embed_query(max_len(context)))
			v_fuse = alpha*q_embd + (1-alpha)*c_embd
			v_fuse = v_fuse.reshape(1, -1)
		else:
			v_fuse = q_embd.reshape(1,-1)
	else:
		question = query[0]
		q_embd = np.array(embeddings.embed_query(max_len(question)))
		v_fuse = q_embd.reshape(1,-1)
	distances, indices = index_p.search(v_fuse, k)
	indices = indices[0]
	retrieved_docs = []
	for idx in indices:
		retrieved_docs.append(lst_chunks[idx])
	return retrieved_docs
	

##### 3.2.12 Interleaving Retrieval with Chain-of-Thought Reasoning

In [None]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser,JsonOutputParser

	
def format_docs(docs):
	# return "\n\n".join(f"[{i+1}] {doc.page_content}" for i, doc in enumerate(docs))
	return "\n\n".join(f"{doc}" for doc in docs)

class GradeRespose(BaseModel):
	"""Binary score to determine if the passages provide sufficient information to answer the question directly."""
	binary_score: bool = Field(
		description="The relevant passages provide sufficient information to answer the question directly, 'yes' or 'no'"
	)

class gen_query(BaseModel):
	"""Generate  chain-of-thought query for futher research and exploration."""
	new_query: str = Field(
		description="Generate new chain-of-thought query for futher research and exploration"
	)

def check_response(question, context):
	system_promt=("You are an advanced AI assistant skilled in analyzing textual data."
		"\nBelow is a question and relevant passages that may contain information to answer it."
		"\nYour task is to determine if the provided passages contain enough relevant information to answer the question, even if not directly stated."
		"\nConsider both direct answers and implied or partially inferred information."
		"\nReturn a binary score: 'True' if the context provides sufficient information to answer the question; 'False' if it does not."
		"\nProvide only the binary score in JSON format with a single key 'score'. Do not include explanations.")
	
	chat_template_check = tokenizer.apply_chat_template(
		[
			{"role":"system", "content":"{system_promt}"},
			{"role":"user", "content":"\nQuestion: {question}\nRelevan Passages: {context}"}
		], tokenize=False, add_generation_prompt=True)
	
	prompt_check_response = PromptTemplate(template=chat_template_check, input_variables=["system_promt", "question","context"])
	structured_check_content= llm.with_structured_output(GradeRespose)
	check_response_chain = prompt_check_response | structured_check_content 
	results = check_response_chain.invoke({"system_promt": system_promt, "question": question ,"context": context})
	return results

def gen_question(question, context, previous_though):
	system_promt_gen_answer = (
		"You are an advanced AI skilled in generating a concise insightful chain-of-thought query to guide further research and exploration."
		" Below is an input question and relevant context information and previous failed queries."
		"\nYour task is to :"
		"\n1. Analyze the input question to understand its intent and identify gaps in the provided context that prevent a complete answer."
		"\n2. Generate a new chain-of-thought query that is based on the input question, incorporating logical steps or deeper aspects of the topic."
		" This new query should be designed to guide further search or inquiry, aiming to bridge the identified gaps and refine the search for an answer."
		"\n3. Avoid repeating or rephrasing any of the previous failed queries. Instead, aim to expand the scope or explore different facets of the topic that have not been addressed yet."
		"All JSON MUST in correct format"
		"**DO NOT get information from 'Relevant context information' to create new input variables.**"
	)

	chat_gen_answer = tokenizer.apply_chat_template(
		[
			{"role": "system", "content": system_promt_gen_answer},
			{"role": "user", "content": f"\nQuestion: {question}\nRelevant context information: {context}\nPrevious failed queries: {previous_though}"}
		],
		tokenize=False, 
		add_generation_prompt=True
	)
	# print(chat_gen_answer)
	# print("\n\n\n")
	
	prompt_gen_answer = PromptTemplate(template=chat_gen_answer, input_variables=["system_promt_gen_answer", "question", "context", "previous_though"])
	# print(prompt_gen_answer)
	structured_check_content = llm.with_structured_output(gen_query)
	chain_gen_answer = prompt_gen_answer | structured_check_content
	answer = chain_gen_answer.invoke({"system_promt_gen_answer": system_promt_gen_answer, "question": question, "context": context, "previous_though": previous_though})

	return answer


def final_answer(question, context):
	system_promt_gen_answer=("You are an expert AI designed to analyze information from retrieval-augumented generation system."
	"\nYour task is to answer questions based on the input context. Below is a question along with the input context."
	"\nMake sure your repsonse is consice clear, and directly answer the question in maximum 5 sentences WITHOUT any explaination."
	"\nDO NOT use any external knowledge. "
	"\nIf the answer is not directly found, try to infer the best possible answer from the context.")
	
	chat_gen_answer= tokenizer.apply_chat_template(
		[
			{"role":"system", "content":"{system_promt_gen_answer}"},
			{"role":"user", "content":"\nQuestion: {question}\nInput context: {context}"}
		], tokenize=False, add_generation_prompt=True)
	prompt_gen_answer = PromptTemplate(template=chat_gen_answer, input_variables=["system_promt_gen_answer", "question","context"])
	# print(prompt_gen_answer.invoke({"system_promt_gen_answer": system_promt_gen_answer,"question":question, "context": context}))
	chain_gen_answer = prompt_gen_answer | llm | StrOutputParser()
	answer = chain_gen_answer.invoke({"system_promt_gen_answer": system_promt_gen_answer,"question":question, "context": context}).strip()
	return answer

def max_length_context(context,threshold=512):
	### context : list
	res = []
	for i in context:
		if len(i.split(" ")) > threshold:
			tmp = " ".join(x for x in i.split(" ")[:threshold])
			res.append(tmp)
		else:
			res.append(i)
	return res


# IRCoT Baseline

In [None]:
import uuid
import pickle
import traceback
from multiprocessing import Pool, Manager

def process_question(tasks):
	"""Process a single question."""
	question, label, k, n_loop, qid = tasks  # Unpack the arguments
	try:
		i = 0
		thought_q = ""
		pt = []
		gen_answer = None  # Ensure it's always defined

		context = max_length_context(retrieval_bge(question, k))
		while i < n_loop:
			check = check_response(question, format_docs(context)).binary_score
			if check or (not check and i == n_loop - 1):
				gen_answer = final_answer(question, format_docs(context))
				break
			else:
				new_CoT_query = gen_question(question, format_docs(context), "\n".join(pt)).new_query
				pt.append(new_CoT_query)
				thought_q += f"\n{i}-{new_CoT_query}"
				new_context = max_length_context(retrieval_bge(new_CoT_query, k))
				context = list(set(context + new_context))  # Deduplicate
			i += 1

		res = {
			"Question": question,
			"id": qid,
			"Answer": gen_answer,
			"Label": label,
			"Context": context,
			"CoT": thought_q,
			"n_CoT": int(i),
		}
	except Exception as e:
		print(f"Error occurred during processing question '{question}': {e}")
		traceback.print_exc()
		res = None

	fn = uuid.uuid4()
	with open(f"llama3.2-3brb_bge/{fn}.pkl", "wb") as f:
		pickle.dump(res, f)  # Corrected from dumping `fn` to dumping `res`
	return res

	


def main():
	# Parameters
	k = 8 
	n_loop = 5 
	num_procs = 8
	questions = df_test['question'].tolist()[:600]
	labels = df_test["response"].tolist()[:600]
	ids = df_test["id"].tolist()[:600]
	tasks = [(questions[i], labels[i], k, n_loop, ids[i]) for i in range(len(questions))]

	# Use a Manager list to store results
	with Manager() as manager:
		with Pool(20) as pool:
			results = list(tqdm(pool.imap(process_question, tasks), total=len(tasks)))

		results = [res for res in results if res is not None]
		final_test = pd.DataFrame(results)

		# Save to an Excel file
		final_test.to_excel("IRCoT_CQR_Inference_1605_bge.xlsx", index=False)
		print("Processing complete. Results saved to 'IRCoT_baseline_inference_llama3-70b.xlsx'.")


# Run the main function
if __name__ == "__main__":
	main()

# IRCoT + KG

In [None]:
import uuid
import pickle
import traceback
from multiprocessing import Pool, Manager

def process_question(tasks):
	"""Process a single question."""
	question, label, k, n_loop, qid = tasks  # Unpack the arguments
	try:
		i = 0
		thought_q = ""
		pt = []
		gen_answer = None  # Ensure it's always defined

		context = max_length_context(retrieval_bge(add_triplet_context_to_question(question), k))
		while i < n_loop:
			check = check_response(question, format_docs(context)).binary_score
			if check or (not check and i == n_loop - 1):
				gen_answer = final_answer(question, format_docs(context))
				break
			else:
				new_CoT_query = gen_question(question, format_docs(context), "\n".join(pt)).new_query
				pt.append(new_CoT_query)
				thought_q += f"\n{i}-{new_CoT_query}"
				new_context = max_length_context(retrieval_bge(add_triplet_context_to_question(new_CoT_query), k))
				context = list(set(context + new_context))  # Deduplicate
			i += 1

		res = {
			"Question": question,
			"id": qid,
			"Answer": gen_answer,
			"Label": label,
			"Context": context,
			"CoT": thought_q,
			"n_CoT": int(i),
		}
	except Exception as e:
		print(f"Error occurred during processing question '{question}': {e}")
		# traceback.print_exc()
		res = None

	fn = uuid.uuid4()
	with open(f"llama3.2-3brb_bge_kcqr/{fn}.pkl", "wb") as f:
		pickle.dump(res, f)  # Corrected from dumping `fn` to dumping `res`
	return res

	


def main():
	# Parameters
	k = 8  # Set top-k retrieval
	n_loop = 5  # Number of loops
	num_procs = 8  # Number of processes to use
	# Convert test data into a list of arguments for the worker function
	questions = df_test['question'].tolist()[:600]
	labels = df_test["response"].tolist()[:600]
	ids = df_test["id"].tolist()[:600]
	tasks = [(questions[i], labels[i], k, n_loop, ids[i]) for i in range(len(questions))]

	# Use a Manager list to store results
	with Manager() as manager:
		# Create a multiprocessing pool
		with Pool(20) as pool:
			# Process questions in parallel with progress tracking
			results = list(tqdm(pool.imap(process_question, tasks), total=len(tasks)))

		# Filter out None results (in case of errors)
		results = [res for res in results if res is not None]

		# Convert the results into a DataFrame
		final_test = pd.DataFrame(results)

		# Save to an Excel file
		final_test.to_excel("IRCoT_CQR_Inference_1605_bge_kcqr.xlsx", index=False)
		print("Processing complete. Results saved to 'IRCoT_baseline_inference_llama3-70b.xlsx'.")


# Run the main function
if __name__ == "__main__":
	main()