In [1]:
from __future__ import annotations
import pandas as pd
from transformers import set_seed
import json
from Bio import Entrez
import vllm
from lmformatenforcer import RegexParser
from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
import argparse
from abstract_comprehension import read_tsv_to_dataframe
from utils import Config, RaggedTensor
from tqdm import tqdm
import numpy as np
import os
import jinja2
from classifier import process_single_row, write_to_json, test_openai_connection
from itertools import chain

# Returns either AB or BC hypotheses depending on the input. If A, B is passed in, getHypothesis will retrieve the AB hypothesis. 
# Only two arguements should be specified at once.
def getHypothesis(config, a_term: str = None, b_term: str = None, c_term: str = None) -> str:
	job_type = config.get("JOB_TYPE", "").lower()
	
	if job_type == "km_with_gpt":
		assert a_term and b_term and not c_term
		hypothesis_template = config.get("KM_hypothesis", "")
		
		return hypothesis_template.format(a_term=a_term, b_term=b_term)
	
	elif job_type == "position_km_with_gpt":
		assert a_term and b_term and not c_term
		
		hypothesis_template = config.get("POSITION_KM_hypothesis", "")
		return hypothesis_template.format(a_term=a_term, b_term=b_term), None
	
	elif job_type == "skim_with_gpt":
		assert (a_term and b_term and not c_term) or (b_term and c_term and not a_term) or (a_term and c_term and not b_term)
		
		if a_term and b_term and not c_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("AB")
			return hypothesis_template.format(a_term=a_term, b_term=b_term)
		
		elif b_term and c_term and not a_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("BC")
			return hypothesis_template.format(b_term=b_term, c_term = c_term)

		elif a_term and c_term and not b_term:
			hypothesis_template = config.get("SKIM_hypotheses", "").get("AC")
			return hypothesis_template.format(a_term=a_term, c_term = c_term)

	else:
		return "No valid hypothesis for the provided JOB_TYPE."
	
	

def prompt(abstract, hyp) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: Classify this abstract as either 0 (Not Relevant) or 1 (Relevant) for evaluating the provided hypothesis.\nScore: "

In [2]:
def gen(prompts: RaggedTensor, model: any, sampling_config: vllm.SamplingParams) -> RaggedTensor:
	generated = model.generate(prompts.data, sampling_params = sampling_config)
	outputs = RaggedTensor([output.outputs[0].text for output in generated], prompts.break_point)
	return outputs

def getPrompts(abstracts: RaggedTensor, hypotheses: RaggedTensor) -> RaggedTensor:
	assert not abstracts.is2D(), "abstracts should be flattened."
	assert not hypotheses.is2D(), "hypotheses should be flattened."
	return RaggedTensor([prompt(abstracts[i], hypotheses[i]) for i in range(abstracts.shape)], hypotheses.break_point)
	
# Returns a dictionary for each PMID & Abstract Pair
# This method is needed since Entrez automatically removes duplicates in the pmid list
def getAbstractMap(config: json, pmids: list[str]) -> dict:
	returned_pmids = []
	returned_abstracts = []
	global_config = config["GLOBAL_SETTINGS"]
	pmid_config = global_config["PUBMED_PARAMS"]
	
	Entrez.email = 'leoxu27@gmail.com'
	Entrez.api_key = "8bfe67116f93cedbee9e4f31a1e65b7e1d09"
	Entrez.max_tries = global_config["MAX_RETRIES"]
	Entrez.sleep_between_tries = global_config["RETRY_DELAY"]
	efetch = Entrez.efetch(db=pmid_config["db"], id=pmids, rettype=pmid_config["rettype"])
	
	output = Entrez.read(efetch)
	efetch.close()
	
	for paper in output["PubmedArticle"]:
		pmid = paper["MedlineCitation"]["PMID"]
		returned_pmids.append(str(pmid))
		abstract_text = " ".join(paper["MedlineCitation"]["Article"]["Abstract"]["AbstractText"])
		returned_abstracts.append(abstract_text)
	return dict(zip(returned_pmids, returned_abstracts))


# Packages all the inputted data into the provided dataframes
def postProcess(config: Config, outputs: RaggedTensor, abstracts: RaggedTensor, hypotheses: RaggedTensor, out_df: pd.DataFrame, terms: str, shape: list):
	abstracts.reshape(shape)

	if not config.debug:
        # If we're not debugging, the only output from the model will be a number from 0 to 1, so we can create answer masks
		answer_masks = outputs.map(eval)
		answer_masks.reshape(shape)
		abstracts.applyFilter(answer_masks)
        
	else:
		answer_masks = RaggedTensor([eval(answer[0]) for answer in outputs])
		answer_masks.reshape(shape)
		cot = RaggedTensor([answer[1:] for answer in outputs])
		cot.reshape(shape)

		if terms == "ac":
			out_df[f"{terms}_mask"] = answer_masks.data * len(out_df)
			out_df[f"{terms}_cot"] = cot.data * len(out_df)
			out_df[f"{terms}_hypothesis"] = hypotheses.data * len(out_df)
      
		else:
			out_df[f"{terms}_mask"] = answer_masks.data
			out_df[f"{terms}_cot"] = cot.data
			out_df[f"{terms}_hypothesis"] = hypotheses.data
        

    # This is because we'll only ever have one AC relation in a tsv
	if terms == "ac":
		out_df[f"{terms}_pmid_intersection"] = abstracts.data * len(out_df)
		out_df[f"{terms}_mask"] = answer_masks.data * len(out_df)
	else:
        # Debug file doesn't have the filter applied.
		out_df[f"{terms}_mask"] = answer_masks.data
		out_df[f"{terms}_pmid_intersection"] = abstracts.data

In [3]:
class args:
    km_output = "../test_tsvs/skim_with_ac/skim_with_ac.tsv"
    config = "../config.json"

In [4]:
config = Config(args)

Job type detected. Running skim_with_gpt.


In [5]:
out_df = config.data.copy(deep = True)

In [6]:
a_term = config.data.a_term.unique().tolist()[0].split("&")[0]
b_terms = config.data.b_term.unique().tolist()

ab_pmids = RaggedTensor([eval(lst) for lst in config.data.ab_pmid_intersection])
ab_hypotheses = RaggedTensor([getHypothesis(config.job_config, a_term = a_term, b_term = b_term) for b_term in b_terms])

all_pmids = ab_pmids.flatten()
all_hypotheses = ab_hypotheses.expand(ab_pmids.shape)

In [7]:
if config.is_skim_gpt:
    c_term = config.data.c_term.unique().tolist()[0]
    bc_pmids = RaggedTensor([eval(lst) for lst in config.data.bc_pmid_intersection])
    bc_hypotheses = RaggedTensor([getHypothesis(config.job_config, c_term = c_term, b_term = b_term) for b_term in b_terms])

    all_pmids += bc_pmids.flatten()
    all_hypotheses += bc_hypotheses.expand(bc_pmids.shape)

    if config.has_ac:
        # For each atomic run there should only be one unique ac_pmid intersection
        ac_pmids = RaggedTensor(eval(config.data.ac_pmid_intersection[0]))
        ac_hypothesis = RaggedTensor([getHypothesis(config.job_config, a_term = a_term, c_term = c_term)])

        all_pmids += ac_pmids
        all_hypotheses += ac_hypothesis.expand([ac_pmids.shape])

abstract_map = getAbstractMap(config.job_config, all_pmids)
abstracts = all_pmids.map(lambda pmid: abstract_map.get(str(pmid), ""))

In [8]:
##################### Model Loading & Generation ############################ 
model = vllm.LLM(model=config.filter_config["MODEL"], max_model_len=4000)

INFO 06-06 14:21:31 llm_engine.py:87] Initializing an LLM engine with config: model='lexu14/porpoise1', tokenizer='lexu14/porpoise1', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4000, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 06-06 14:21:34 weight_utils.py:163] Using model weights format ['*.safetensors']
INFO 06-06 14:21:37 llm_engine.py:357] # GPU blocks: 10835, # CPU blocks: 682
INFO 06-06 14:21:38 model_runner.py:684] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 06-06 14:21:38 model_runner.py:688] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 06-06 14:21:44 model_runner.py:756] Graph capturing finished in 6 secs.


In [9]:
sampling_config = vllm.SamplingParams(
        temperature=config.filter_config["TEMPERATURE"], 
        top_k = config.filter_config["TOP_K"], top_p = config.filter_config["TOP_P"], 
        max_tokens = config.filter_config["MAX_COT_TOKENS"] if config.debug else 1)

prompts = getPrompts(abstracts, all_hypotheses)

In [10]:
answers = gen(prompts, model, sampling_config)

Processed prompts: 100%|██████████| 50/50 [00:12<00:00,  3.97it/s]


In [11]:
print(answers)

["1\nExplanation: The abstract provided is highly relevant for evaluating the hypothesis that there exists an interaction between the disease diabetes and the gene COX-2. Here's a detailed explanation supporting the score:  1. **Mention of Diabetes and COX-2**: The abstract explicitly mentions both diabetes and the gene COX-2. It states that chronic inflammation can lead to various diseases, including diabetes. This establishes a direct link between the disease and the gene in question.  2. **Role of COX-2 in Inflammation**: COX-2 (Cyclooxygenase-2) is a member of the cyclooxygenase family of enzymes that are involved in the production of prostaglandins, which are key mediators of inflammation. The abstract discusses the role of pro-inflammatory gene products, including COX-2, in the context of chronic inflammation and its potential implications in diseases like diabetes.  3. **Connection to Disease Mechanisms**: The abstract highlights the role of chronic inflammation in the pathogene

In [12]:
# Adding defaults for unraveling. In the case where there's no AC or BC, they will be filled with empty RaggedTensors
defaults = 3 * [RaggedTensor([])]

ab_outputs, bc_outputs, ac_outputs, *_ = chain(answers.split(), defaults)
ab_abstracts, bc_abstracts, ac_abstracts, *_ = chain(abstracts.split(), defaults)

postProcess(config, ab_outputs, ab_abstracts, ab_hypotheses, out_df, terms = "ab", shape = ab_pmids.shape)

##################### Post process BC answers ############################ 
if config.is_skim_gpt:
    postProcess(config, bc_outputs, bc_abstracts, bc_hypotheses, out_df, terms = "bc", shape = bc_pmids.shape)
    if config.has_ac:
        postProcess(config, ac_outputs, ac_abstracts, ac_hypothesis, out_df, terms = "ac", shape = [ac_pmids.shape])

out_df.to_csv(f"{config.debug_tsv_name if config.debug else config.filtered_tsv_name}", sep="\t")

In [13]:
test = RaggedTensor([eval(i[0]) for i in bc_outputs.data])
test.reshape(bc_pmids.shape)

In [14]:
print(test)

[[1, 1, 1, 1, 1, 1, 1, 1, 0, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]


In [15]:
out_df

Unnamed: 0,a_count,a_term,ab_count,ab_pmid_intersection,ab_pred_score,ab_pvalue,ab_sort_ratio,b_count,b_term,bc_count,...,ac_pmid_intersection,ab_mask,ab_cot,ab_hypothesis,bc_mask,bc_cot,bc_hypothesis,ac_mask,ac_cot,ac_hypothesis
0,657735,Diabetes,646,[Although inflammation has long been known as ...,0.061449,0.537207,0.01787,36149,COX-2,96,...,[The National High Blood Pressure Education Pr...,"[1, 1, 0, 1, 1, 1, 1, 1, 1, 0]",[\nExplanation: The abstract provided is highl...,There exists an interaction between the diseas...,"[1, 1, 1, 1, 1, 1, 1, 1, 0, 1]",[\nExplanation: The abstract provided is highl...,There exists an interaction between the drug F...,"[0, 1, 1, 1, 0, 1, 1, 1, 0, 1]",[\nExplanation: The abstract provided is not d...,The drug Fish oil has an interaction with the ...
1,657735,Diabetes,2544,[Thiazolidinedione derivatives are antidiabeti...,1.569446,0.0,0.155835,16325,PPAR,95,...,[The National High Blood Pressure Education Pr...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",[\nExplanation: The abstract provided is highl...,There exists an interaction between the diseas...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]",[\nExplanation: The abstract provided is highl...,There exists an interaction between the drug F...,"[0, 1, 1, 1, 0, 1, 1, 1, 0, 1]",[\nExplanation: The abstract provided is not d...,The drug Fish oil has an interaction with the ...


In [16]:
len(out_df["ab_pmid_intersection"][0])

10

In [18]:
len(out_df["ab_cot"][0])

10

In [19]:
sum(out_df["ab_mask"][0])

8