In [328]:
import pandas as pd
import pandas as pd
from transformers import set_seed
import json
from Bio import Entrez
# import vllm
# from lmformatenforcer import RegexParser
# from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
import argparse
from abstract_comprehension import read_tsv_to_dataframe
from tqdm import tqdm
import numpy as np
from __future__ import annotations

class RaggedTensor:
    def __init__(self, data, break_point = None):
        self.data = data
        self.break_point = break_point
        self.index = 0
        self.getShape()
        
    def getShape(self) -> None:
        if self.is2D():
            self.shape = [len(i) for i in self.data]
        else:
            self.shape = len(self.data)
            
    def is2D(self) -> bool:
        return isinstance(self.data[0], list)
    
    # Duplicates each element in data according to the shape_list
    def expand(self, shape_list: list) -> None:
        assert not self.is2D(), "Data must be 1D before calling expand. Call flatten first?"
        assert self.shape == len(shape_list), "The length of shape list must equal the length of data"
        
        expanded = []
        for idx, inp in enumerate(self.data):
            expanded.extend([inp] * shape_list[idx])
            
        return RaggedTensor(expanded, self.break_point)
        
    def flatten(self) -> RaggedTensor:
        if self.is2D():
            output = []
            for lst in self.data:
                output.extend(lst)
            return RaggedTensor(output)
        else:
            return self
        
    # Inverts the expand method
    def compress(self, shape_list: list):
        assert self.shape == sum(shape_list)
        self.data = list(set(self.data))
        self.getShape()
    
    # Splits the data depending on the index
    def split(self) -> tuple[list, list]:
        assert self.break_point, "You need to set the breakpoint."
        return self.data[:self.break_point], self.data[self.break_point:]
    
    # Reshapes the data depending on the input
    def reshape(self, shape: list) -> list:
        assert not self.is2D(), "Reshape only works with 1D tensors."
        assert self.shape == sum(shape), "The shape of the tensor should be equal to the sum of the wanted shape."
        output = []
        running_length = 0;
        for length in shape:
            output.append(self.data[running_length: running_length + length])
            running_length += length
            
        self.data = output
        self.getShape()
    
    # Applies a mask to the tensor
    def applyFilter(self, mask: RaggedTensor) -> None:
        assert self.shape == mask.shape, "Filtering only works when the shapes are the same"
        if self.is2D():
            for i in range(len(self.data)):
                boolean_mask = np.array(mask[i]) == 1
                self.data = list(np.array(self.data[i])[boolean_mask], self.break_point)
        else:
            boolean_mask = np.array(mask) == 1
            self.data = list(np.array(self.data)[boolean_mask], self.break_point)
        
    # Applies a function to the tensor
    def map(self, func: callable, *args) -> RaggedTensor:
        assert not self.is2D(), "Map only works with 1D tensors"
        return RaggedTensor([func(i, *args) for i in self.data], self.break_point)
        
    def __add__(self, other: RaggedTensor) -> RaggedTensor:
        assert not self.is2D(), "Adding only works with flattened tensors"
        break_point = self.shape
        return RaggedTensor(self.data + other.data, break_point)
    
    def __str__(self):
        return str(self.data)
    
    def __iter__(self):
        return self.flatten().data.__iter__()
        

# Returns either AB or BC hypotheses depending on the input. If A, B is passed in, getHypothesis will retrieve the AB hypothesis. 
# Only two arguements should be specified at once.
def getHypothesis(config, a_term: str = None, b_term: str = None, c_term: str = None) -> str:
    job_type = config.get("JOB_TYPE", "").lower()
    
    if job_type == "km_with_gpt":
        assert a_term and b_term and not c_term
        hypothesis_template = config.get("KM_hypothesis", "")
        
        return hypothesis_template.format(a_term=a_term, b_term=b_term)
    
    elif job_type == "position_km_with_gpt":
        assert a_term and b_term and not c_term
        
        hypothesis_template = config.get("POSITION_KM_hypothesis", "")
        return hypothesis_template.format(a_term=a_term, b_term=b_term), None
    
    elif job_type == "skim_with_gpt":
        assert (a_term and b_term and not c_term) or (b_term and c_term and not a_term)
        
        if a_term and b_term and not c_term:
            hypothesis_template = config.get("SKIM_hypotheses", "").get("AB")
            return hypothesis_template.format(a_term=a_term, b_term=b_term)
        
        elif b_term and c_term and not a_term:
            hypothesis_template = config.get("SKIM_hypotheses", "").get("BC")
            return hypothesis_template.format(b_term=b_term, c_term = c_term)

    else:
        return "No valid hypothesis for the provided JOB_TYPE."
    
    

def cot_prompt(sys_prompt: str, hyp: str, abstract: str) -> str:
  return f"""
    <|im_start|>system
    {sys_prompt}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {hyp}
    Abstract: {abstract}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
    <|im_end|>
    <|im_start|>assistant
    """

def answer_prompt(sys_prompt: str, hypothesis: str, abstract: str, chain_of_thought: str) -> str:
    return f"""
    <|im_start|>system
    {sys_prompt}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {hypothesis}
    Abstract: {abstract}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
    {chain_of_thought}
    
    Classify the given abstract as either 0 (Not relevant) or 1 (Relevant) based on your reasoning above and any information in the abstract and hypothesis.
    Answer: 
    <|im_end|>
    <|im_start|>assistant
    """

# def gen(prompts: list[str], model: any, sampling_config: vllm.SamplingParams) -> list[str]:
# 	generated = model.generate(prompts, sampling_params = sampling_config)
# 	outputs = [output.outputs[0].text for output in generated]
# 	return outputs

def getCoTPrompts(abstracts: RaggedTensor, sys_prompt: str, hypotheses: RaggedTensor) -> RaggedTensor:
	return [cot_prompt(sys_prompt, hypotheses[i], abstracts[i]) for i in range(len(abstracts))]

def getAnswerPrompts(abstracts: RaggedTensor, sys_prompt: str, hypotheses: RaggedTensor, cot_outputs: RaggedTensor) -> RaggedTensor:
	return [answer_prompt(sys_prompt, hypotheses[i], abstracts[i], cot_outputs[i]) for i in range(len(abstracts))]

    
# Returns a dictionary for each PMID & Abstract Pair
# This method is needed since Entrez automatically removes duplicates in the pmid list
def getAbstractMap(config: json, pmids: list[str]) -> dict:
    returned_pmids = []
    returned_abstracts = []
    global_config = config["GLOBAL_SETTINGS"]
    pmid_config = global_config["PUBMED_PARAMS"]
    
    Entrez.email = 'leoxu27@gmail.com'
    Entrez.api_key = pmid_config["api_key"]
    Entrez.max_tries = global_config["MAX_RETRIES"]
    Entrez.sleep_between_tries = global_config["RETRY_DELAY"]
    efetch = Entrez.efetch(db=pmid_config["db"], id=pmids, rettype=pmid_config["rettype"])
    
    output = Entrez.read(efetch)
    efetch.close()
    
    for paper in output["PubmedArticle"]:
        returned_pmids.append(str(paper["MedlineCitation"]["PMID"]))
        abstract_text = " ".join(paper["MedlineCitation"]["Article"]["Abstract"]["AbstractText"])
        returned_abstracts.append(abstract_text)
    return dict(zip(returned_pmids, returned_abstracts))

In [293]:
class args:
	km_output = "./data.tsv"
	config = "../config.json"
	filtered_tsv_name = "filtered.tsv"
	cot_tsv_name = "cot.tsv"

In [294]:
###################### Data Loading & Processsing ############################ 
km_output = read_tsv_to_dataframe(args.km_output)
with open(args.config) as f:
	config = json.load(f)
filtered_tsv_name = args.filtered_tsv_name
cot_tsv_name = args.cot_tsv_name

In [325]:
filter_config = config["abstract_filter"]
sys_prompt = filter_config['SYS_PROMPT']

In [317]:
a_term = km_output.a_term.unique().tolist()[0].split("&")[0]
b_terms = km_output.b_term.unique().tolist()
c_term = km_output.a_term.unique().tolist()[0]

In [322]:
ab_pmids = RaggedTensor([eval(lst) for lst in km_output.ab_pmid_intersection])
bc_pmids = RaggedTensor([eval(lst) for lst in km_output.bc_pmid_intersection])

all_pmids = ab_pmids.flatten() + bc_pmids.flatten()

In [323]:
abstract_map = getAbstractMap(config, all_pmids)
abstracts = all_pmids.map(lambda pmid: abstract_map.get(str(pmid), ""))

In [324]:
ab_hypotheses = RaggedTensor([getHypothesis(config, a_term = a_term, b_term = b_term) for b_term in b_terms])
bc_hypotheses = RaggedTensor([getHypothesis(config, c_term = c_term, b_term = b_term) for b_term in b_terms])

all_hypotheses = ab_hypotheses.expand(ab_pmids.shape) + bc_hypotheses.expand(bc_pmids.shape)

In [None]:
cot_prompts = getCoTPrompts(abstracts, sys_prompt, all_hypotheses)

In [None]:
answer_prompts = getAnswerPrompts(abstracts, sys_prompt, expanded_hypotheses, cot_outputs)

In [87]:
ab, bc = split(expanded_hypotheses, break_point)

In [57]:
answers = np.ones(65)
answers[:32] = 0
np.random.shuffle(answers)

In [58]:
answers

array([1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1.,
       0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0.,
       0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0.,
       0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0.])

In [61]:
answers = reshape(answers, shape)

In [62]:
abstracts = reshape(abstracts, shape)

In [63]:
abstracts

[["Warfarin is a highly efficacious oral anticoagulant, but its use is limited by a well-founded fear of bleeding. Drug and food interactions are frequently cited as causes of adverse events with warfarin. We provide an updated systematic overview of the quality, clinical effect, and importance of these reported interactions. MEDLINE, TOXLINE, IPA, and EMBASE databases from October 1993 to March 2004. Database searches combined the keyword warfarin with drug interactions, herbal medicines, Chinese herbal drugs, and food-drug interactions. Eligible articles contained original reports of warfarin drug or food interactions in human subjects. Non-English articles were included if sufficient information could be abstracted. Reports were rated independently by 2 investigators for interaction direction, clinical severity, and quality of evidence. Quality of evidence was based on previously validated causation criteria and study design. Of 642 citations retrieved, 181 eligible articles contain

In [64]:
answers

[array([1., 0., 1., 1., 0., 1., 0., 1., 1., 0.]),
 array([0., 0., 1., 1., 1., 1., 1., 0., 1., 0.]),
 array([1., 0., 1., 0., 1., 0., 1., 0., 1., 1.]),
 array([0., 1., 1., 0., 0., 0., 1., 0., 0., 1.]),
 array([0., 1., 0., 1., 1., 1., 0., 1.]),
 array([0., 0., 0., 0.]),
 array([1., 1., 0.]),
 array([0., 1., 0., 1., 0., 0., 1., 0., 1., 0.])]

In [66]:
# answers = reshape([eval(answer) for answer in answers], shape)
# # cot_outputs = reshape(cot_outputs, shape)
# abstracts = reshape(abstracts, shape)

# cot_tsv["scores"] = answers
# cot_tsv["chain_of_thought"] = cot_outputs
# cot_tsv["hypothesis"] = hypotheses
# cot_tsv.to_csv(f"{cot_tsv_name}", sep='\t')

# Filter out the abstracts according to the scores
filtered_abstracts = []
for i, abstract_list in tqdm(enumerate(abstracts), desc = "Post-processing abstracts..."):
	mask = np.array(answers[i]) == 1
	filtered = list(np.array(abstract_list)[mask])
	filtered_abstracts.append(filtered)

# filtered_tsv["ab_pmid_intersection"] = filtered_abstracts
# filtered_tsv.to_csv(f"{filtered_tsv_name}", sep="\t")

Post-processing abstracts...: 8it [00:00, 3773.98it/s]


In [67]:
filtered_abstracts

[["Warfarin is a highly efficacious oral anticoagulant, but its use is limited by a well-founded fear of bleeding. Drug and food interactions are frequently cited as causes of adverse events with warfarin. We provide an updated systematic overview of the quality, clinical effect, and importance of these reported interactions. MEDLINE, TOXLINE, IPA, and EMBASE databases from October 1993 to March 2004. Database searches combined the keyword warfarin with drug interactions, herbal medicines, Chinese herbal drugs, and food-drug interactions. Eligible articles contained original reports of warfarin drug or food interactions in human subjects. Non-English articles were included if sufficient information could be abstracted. Reports were rated independently by 2 investigators for interaction direction, clinical severity, and quality of evidence. Quality of evidence was based on previously validated causation criteria and study design. Of 642 citations retrieved, 181 eligible articles contain