In [86]:
from __future__ import annotations
import pandas as pd
from transformers import set_seed
import json
from Bio import Entrez
# import vllm
# from lmformatenforcer import RegexParser
# from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
import argparse
from abstract_comprehension import read_tsv_to_dataframe
from tqdm import tqdm
import numpy as np
import os
import jinja2
from classifier import process_single_row, write_to_json, test_openai_connection
from itertools import chain

class RaggedTensor:
	def __init__(self, data, break_point = []):
		self.data = data
		self.break_point = break_point
		self.index = 0
		self.getShape()
		
	def getShape(self) -> None:
		if self.is2D():
			self.shape = [len(i) for i in self.data]
		else:
			self.shape = len(self.data)
			
	def is2D(self) -> bool:
		if not(len(self.data) == 0):
			return isinstance(self.data[0], list)
		else:
			return False;
	
	# Duplicates each element in data according to the shape_list
	def expand(self, shape_list: list) -> None:
		assert not self.is2D(), "Data must be 1D before calling expand. Call flatten first?"
		assert self.shape == len(shape_list), "The length of shape list must equal the length of data"
		
		expanded = []
		for idx, inp in enumerate(self.data):
			expanded.extend([inp] * shape_list[idx])
			
		return RaggedTensor(expanded)
		
	def flatten(self) -> RaggedTensor:
		if self.is2D():
			output = []
			for lst in self.data:
				output.extend(lst)
			return RaggedTensor(output)
		else:
			return self
		
	# Inverts the expand method
	def compress(self, shape_list: list):
		assert self.shape == sum(shape_list)
		self.data = list(set(self.data))
		self.getShape()
	
	# Splits the data depending on the index
	def split(self) -> list[RaggedTensor]:
		if len(self.break_point) == 0:
			print("Warning: No breakpoint was specified.")
			return self, RaggedTensor([])
		past_break_point = 0
		output = []
		for break_point in self.break_point:
			output.append(RaggedTensor(self.data[past_break_point:break_point]))
			past_break_point = break_point
		output.append(RaggedTensor(self.data[past_break_point:]))
		return output
	
	# Reshapes the data depending on the input
	def reshape(self, shape: list) -> list:
		assert not self.is2D(), "Reshape only works with 1D tensors."
		assert self.shape == sum(shape), "The shape of the tensor should be equal to the sum of the wanted shape."
		output = []
		running_length = 0;
		for length in shape:
			output.append(self.data[running_length: running_length + length])
			running_length += length
			
		self.data = output
		self.getShape()
	
	# Applies a mask to the tensor
	def applyFilter(self, mask: RaggedTensor) -> None:
		assert self.shape == mask.shape, "Filtering only works when the shapes are the same"
		if self.is2D():
			for i in range(len(self.data)):
				boolean_mask = np.array(mask[i]) == 1
				self.data[i] = list(np.array(self.data[i])[boolean_mask])
		else:
			boolean_mask = np.array(mask) == 1
			self.data = list(np.array(self.data)[boolean_mask])
		
	# Applies a function to the tensor
	def map(self, func: callable, *args) -> RaggedTensor:
		assert not self.is2D(), "Map only works with 1D tensors"
		return RaggedTensor([func(i, *args) for i in self.data], self.break_point)
	
	# Helper method to mask the top values in data
	def getTopKMask(self, array, k: int) -> list:
		top_k_indices = np.argsort(array, axis=None)[-k:]

		# Create a binary mask with the top 5 values set to True.
		mask = np.zeros_like(array, dtype=bool)
		mask[top_k_indices] = True
		return mask.tolist()
	
	# Creates a mask of the data with the top k values in each row of the data
	def getFullKArgMax(self, k: int) -> RaggedTensor:
		output = []
		if self.is2D():
			for array in self.data:
				mask = self.getTopKMask(array, k)
				output.append(mask)
		else:
			output = self.getTopKMask(self.data, k)
			
		output = RaggedTensor(output, self.break_point)
		assert output.shape == self.shape
		
		return output
			
		
	def __add__(self, other: RaggedTensor) -> RaggedTensor:
		assert not self.is2D(), "Adding only works with flattened tensors"
		break_point = self.shape
		return RaggedTensor(self.data + other.data, self.break_point + [break_point])
	
	def __str__(self):
		return str(self.data)
	
	def __iter__(self):
		return self.flatten().data.__iter__()
	
	def __getitem__(self, index: int) -> any:
		return self.data[index]
	
class Config:
	def __init__(self, args: dict):
		self.data = read_tsv_to_dataframe(args.km_output)
		with open(args.config, 'r') as config_file:
			self.job_config = json.load(config_file)
		
		self.global_settings = self.job_config["GLOBAL_SETTINGS"]
		self.k = self.global_settings["MAX_ABSTRACTS"]
		assert self.k > 0
		self.km_output_dir = os.path.dirname(args.km_output)
		self.km_output_base_name = os.path.splitext(os.path.basename(args.km_output))[0]

		# Ensure the directory exists
		if not os.path.exists(self.km_output_dir) and self.km_output_dir != '':
			os.makedirs(self.km_output_dir)
		
		self.filtered_tsv_name = os.path.join(self.km_output_dir, f"filtered_{self.km_output_base_name}.tsv")
		self.cot_tsv_name = os.path.join(self.km_output_dir, f"cot_{self.km_output_base_name}.tsv")
		self.job_type = self.job_config.get('JOB_TYPE')
		self.filter_config = self.job_config["abstract_filter"]
		
		self.sys_prompt = self.filter_config['SYS_PROMPT']
		self.is_skim_gpt = self.job_type == "skim_with_gpt"
		self.has_ac = "ac_pmid_intersection" in self.data.columns
		self.continuous = self.filter_config["CONTINUOUS_SCORE"]
		
		self.regex = r'[0][.]\d{5}' if self.continuous else r'0|1'
		self.max_score_tokens = 7 if self.continuous else 1
		self.max_cot_tokens = self.filter_config["MAX_COT_TOKENS"]
		
		print(f"Job type detected. Running {self.job_type}.")
		if self.is_skim_gpt:
			assert "c_term" in self.data.columns, "Input TSV must have c_term if running skim_with_gpt"
			assert "bc_pmid_intersection" in self.data.columns, "Input TSV must have an bc_pmid_intersection."
		
		assert "ab_pmid_intersection" in self.data.columns, "Input TSV must have an ab_pmid_intersection."
		assert "a_term" in self.data.columns, "Input TSV must have an a_term."
		assert "b_term" in self.data.columns, "Input TSV must have a b_term"
		
# Returns either AB or BC hypotheses depending on the input. If A, B is passed in, getHypothesis will retrieve the AB hypothesis. 
# Only two arguements should be specified at once.
def getHypothesis(config, a_term: str = None, b_term: str = None, c_term: str = None) -> str:
	job_type = config.get("JOB_TYPE", "").lower()
	
	if job_type == "km_with_gpt":
		assert a_term and b_term and not c_term
		hypothesis_template = config.get("KM_hypothesis", "")
		
		return hypothesis_template.format(a_term=a_term, b_term=b_term)
	
	elif job_type == "position_km_with_gpt":
		assert a_term and b_term and not c_term
		
		hypothesis_template = config.get("POSITION_KM_hypothesis", "")
		return hypothesis_template.format(a_term=a_term, b_term=b_term), None
	
	elif job_type == "skim_with_gpt":
		assert (a_term and b_term and not c_term) or (b_term and c_term and not a_term) or (a_term and c_term and not b_term)
		
		if a_term and b_term and not c_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("AB")
			return hypothesis_template.format(a_term=a_term, b_term=b_term)
		
		elif b_term and c_term and not a_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("BC")
			return hypothesis_template.format(b_term=b_term, c_term = c_term)

		elif a_term and c_term and not b_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("AC")
			return hypothesis_template.format(a_term=a_term, c_term = c_term)

	else:
		return "No valid hypothesis for the provided JOB_TYPE."
	
	

def cot_prompt(sys_prompt: str, hyp: str, abstract: str) -> str:
	context = {
		"sys_prompt": sys_prompt,
		"hyp": hyp,
		"abstract": abstract,
	}
	
	template = jinja2.Template("""
		<|im_start|>system
		{{sys_prompt}}
		<|im_end|>
		<|im_start|>user
		Hypothesis: {{hyp}}
		Abstract: {{abstract}}

		Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

		Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Make sure to stay focused on what the hypothesis is specifically saying. Ignore redacted terms and make sure to look at the terms provided. Let's work this out in a step by step way to be sure we have the right answer. As a first step, use context clues to figure out the meaning of the terms given.
		<|im_end|>
		<|im_start|>assistant                           
	""")
	
	return template.render(context)

def answer_prompt(sys_prompt: str, hyp: str, abstract: str, chain_of_thought: str, continuous: bool) -> str:
	context = {
		"sys_prompt": sys_prompt,
		"hyp": hyp,
		"abstract": abstract,
		"chain_of_thought": chain_of_thought,
		"continuous": continuous
	}
	
	template = jinja2.Template("""
		<|im_start|>system
		{{sys_prompt}}
		<|im_end|>
		<|im_start|>user
		Hypothesis: {{hyp}}
		Abstract: {{abstract}}

		Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

		Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Make sure to stay focused on what the hypothesis is specifically saying. Ignore redacted terms and make sure to look at the terms provided. Let's work this out in a step by step way to be sure we have the right answer. As a first step, use context clues to figure out the meaning of the terms given.
		{{chain_of_thought}}
		
		{% if continuous %}
		Classify the given abstract with a score between 0 (Not relevant for scientifically assessing the hypothesis) and 1 (Relevant for scientifically assessing the hypothesis) based on the reasoning above and other useful pieces of information in the abstract and hypothesis.
		{% else %}
		Classify the given abstract as either 0 (Not relevant for scientifically assessing the hypothesis) or 1 (Relevant for scientifically assessing the hypothesis) based on the reasoning above and other useful pieces of information in the abstract and hypothesis.
		{% endif %}
		Answer: 
		<|im_end|>
		<|im_start|>assistant
	""")
	
	return template.render(context)
	
# def gen(prompts: RaggedTensor, model: any, sampling_config: vllm.SamplingParams) -> RaggedTensor:
# 	generated = model.generate(prompts.data, sampling_params = sampling_config)
# 	outputs = RaggedTensor([output.outputs[0].text for output in generated], prompts.break_point)
# 	return outputs

def getCoTPrompts(abstracts: RaggedTensor, sys_prompt: str, hypotheses: RaggedTensor) -> RaggedTensor:
	assert not abstracts.is2D(), "abstracts should be flattened."
	assert not hypotheses.is2D(), "hypotheses should be flattened."
	return RaggedTensor([cot_prompt(sys_prompt, hypotheses[i], abstracts[i]) for i in range(abstracts.shape)], hypotheses.break_point)

def getAnswerPrompts(abstracts: RaggedTensor, sys_prompt: str, hypotheses: RaggedTensor, cot_outputs: RaggedTensor, continuous: bool) -> RaggedTensor:
	assert not abstracts.is2D(), "abstracts should be flattened."
	assert not hypotheses.is2D(), "hypotheses should be flattened."
	assert not cot_outputs.is2D(), "cot outputs should be flattened"
	return RaggedTensor([answer_prompt(sys_prompt, hypotheses[i], abstracts[i], cot_outputs[i], continuous) for i in range(abstracts.shape)], hypotheses.break_point)
	
# Returns a dictionary for each PMID & Abstract Pair
# This method is needed since Entrez automatically removes duplicates in the pmid list
def getAbstractMap(config: json, pmids: list[str]) -> dict:
	returned_pmids = []
	returned_abstracts = []
	global_config = config["GLOBAL_SETTINGS"]
	pmid_config = global_config["PUBMED_PARAMS"]
	
	Entrez.email = 'leoxu27@gmail.com'
	Entrez.api_key = config["PUBMED_API_KEY"]
	Entrez.max_tries = global_config["MAX_RETRIES"]
	Entrez.sleep_between_tries = global_config["RETRY_DELAY"]
	efetch = Entrez.efetch(db=pmid_config["db"], id=pmids, rettype=pmid_config["rettype"])
	
	output = Entrez.read(efetch)
	efetch.close()
	
	for paper in output["PubmedArticle"]:
		pmid = paper["MedlineCitation"]["PMID"]
		returned_pmids.append(str(pmid))
		abstract_text = f'PMID {pmid}: {" ".join(paper["MedlineCitation"]["Article"]["Abstract"]["AbstractText"])}'
		returned_abstracts.append(abstract_text)
	return dict(zip(returned_pmids, returned_abstracts))

def postProcess(abstracts: RaggedTensor, cot: RaggedTensor, hypotheses: RaggedTensor, cot_df: pd.DataFrame, filtered_df: pd.DataFrame, terms: str, shape: list):
	abstracts.reshape(shape)
	cot.reshape(shape)
	
	# This is needed because there will only be one AC abstract list per TSV
	if terms == "ac":
		filtered_df[f"{terms}_pmid_intersection"] = abstracts.data * len(filtered_df)
		cot_df[f"{terms}_cot"] = cot.data * len(filtered_df)
		cot_df[f"{terms}_hypothesis"] = hypotheses.data * len(filtered_df)
  
	else:
		filtered_df[f"{terms}_pmid_intersection"] = abstracts.data
		cot_df[f"{terms}_cot"] = cot.data
		cot_df[f"{terms}_hypothesis"] = hypotheses.data


In [87]:
class args:
    km_output = "../test_tsvs/skim_no_ac/skim_no_ac.tsv"
    config = "../config.json"

In [98]:
config = Config(args)
cot_df = config.data.copy(deep = True)
filtered_df = config.data.copy(deep = True)

a_term = config.data.a_term.unique().tolist()[0].split("&")[0]
b_terms = config.data.b_term.unique().tolist()

Job type detected. Running skim_with_gpt.


In [99]:
ab_pmids = RaggedTensor([eval(lst) for lst in config.data.ab_pmid_intersection])
ab_hypotheses = RaggedTensor([getHypothesis(config.job_config, a_term = a_term, b_term = b_term) for b_term in b_terms])

all_pmids = ab_pmids.flatten()
all_hypotheses = ab_hypotheses.expand(ab_pmids.shape)

In [100]:
if config.is_skim_gpt:
	c_term = config.data.c_term.unique().tolist()[0]
	bc_pmids = RaggedTensor([eval(lst) for lst in config.data.bc_pmid_intersection])
	bc_hypotheses = RaggedTensor([getHypothesis(config.job_config, c_term = c_term, b_term = b_term) for b_term in b_terms])

	all_pmids += bc_pmids.flatten()
	all_hypotheses += bc_hypotheses.expand(bc_pmids.shape)
	
	if config.has_ac:
		# For each atomic run there should only be one unique ac_pmid intersection
		ac_pmids = RaggedTensor(eval(config.data.ac_pmid_intersection[0]))
		ac_hypothesis = RaggedTensor([getHypothesis(config.job_config, a_term = a_term, c_term = c_term)])
		all_hypotheses += ac_hypothesis.expand([ac_pmids.shape])

In [93]:
all_hypotheses.shape

50

In [94]:
abstract_map = getAbstractMap(config.job_config, all_pmids)
abstracts = all_pmids.map(lambda pmid: abstract_map.get(str(pmid), ""))

In [95]:
abstracts.break_point

[20, 40]

In [63]:
cot_prompts = getCoTPrompts(abstracts, config.sys_prompt, all_hypotheses)

In [64]:
cot_prompts.break_point

[20, 40]

In [65]:
abstracts.break_point

[20, 40]

In [66]:
defaults = 3 * [RaggedTensor([])]

In [67]:
ab_abstracts, bc_abstracts, ac_abstracts, *_ = chain(abstracts.split(), defaults)
ab_cot, bc_cot, ac_cot, *_ = chain(cot_prompts.split(), defaults)

In [68]:
postProcess(ab_abstracts, ab_cot, ab_hypotheses, cot_df, filtered_df, terms = "ab", shape = ab_pmids.shape)

In [None]:
postProcess(config, ab_abstracts, ab_cot, ab_hypotheses, cot_df, filtered_df, terms = "ab", shape = ab_pmids.shape)
if config.is_skim_gpt:
	postProcess(config, bc_abstracts, bc_cot, ab_hypotheses, cot_df, filtered_df, terms = "bc", shape = bc_pmids.shape)
	if config.has_ac:
		postProcess(config, ac_abstracts, ac_cot, ac_hypothesis, cot_df, filtered_df, terms = "ac", shape = [ac_pmids.shape])