# 1. Getting Setup

In [5]:
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import set_seed
import torch
from typing import Callable
import random
import os
from src.get_pubmed_text import process_abstracts_data
import json
import vllm
from lmformatenforcer import RegexParser
from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
from typing import Optional

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [1]:
import numpy as np

### 1.1 Configuration Settings

In [142]:
# Configuration variables for this whole notebook
class config:
    model = "Mistral-7B-OpenOrca"
    top_k = 20
    top_p = 0.95
    max_new_tokens = 500
    temperature = 0.8
    repetition_penalty = 1.2
    frequency_penalty = 1.2
    max_tokens = 2048
    batch_size = 32


In [6]:
# !sudo apt-get install git-lfs
# !git lfs install
# !git clone https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca

### 1.2 Data Processing

In [76]:
km_output = pd.read_csv("./data.tsv", sep = "\t")
km_output

Unnamed: 0,a_count,a_term,ab_count,ab_pmid_intersection,ab_pred_score,ab_pvalue,ab_sort_ratio,b_count,b_term,total_count
0,989,warfarin&drug&interaction,52,"[15911722, 7944078, 24550106, 20002088, 256468...",0.330813,2.689688e-101,0.005769,9013,omeprazole,36618932
1,989,warfarin&drug&interaction,42,"[10709776, 12036392, 15871634, 15568889, 15260...",0.243054,5.577908000000001e-75,0.00391,10743,simvastatin,36618932
2,989,warfarin&drug&interaction,26,"[7944078, 10709776, 9512916, 8801057, 18685566...",0.121822,3.8424569999999995e-38,0.001779,14616,fluconazole,36618932
3,989,warfarin&drug&interaction,16,"[25646891, 7429002, 3395358, 9667024, 20489028...",0.068799,1.350138e-21,0.001245,12852,furosemide,36618932
4,989,warfarin&drug&interaction,8,"[10709776, 20002088, 6096071, 8793611, 8793602...",0.034793,7.571592e-11,0.001029,7778,metoprolol,36618932
5,989,warfarin&drug&interaction,4,"[22250655, 22406649, 32862668, 34691471]",0.017386,1.583282e-05,0.000751,5329,enoxaparin,36618932
6,989,warfarin&drug&interaction,3,"[22794158, 32982467, 34691471]",0.007539,0.006358695,0.000219,13669,ceftriaxone,36618932
7,989,warfarin&drug&interaction,16,"[8792056, 16697485, 21053990, 21253716, 112482...",0.027246,2.550025e-09,0.000191,83972,heparin,36618932


In [78]:
def getHypothesis(a_term: str, b_term: str) -> str:
    return f"{b_term} will have a drug-drug interaction with {a_term}"

In [79]:
def cot_prompt(sys_prompt: str, hyp: str, abstract: str) -> str:
  return f"""
    <|im_start|>system
    {sys_prompt}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {{ hypothesis }}
    Abstract: {{abstract}}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
	<|im_end|>
    <|im_start|>assistant
    """

In [80]:
def answer_prompt(sys_prompt: str, hypothesis: str, abstract: str, chain_of_thought: str) -> str:
    return f"""
    <|im_start|>system
    {sys_prompt}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {hypothesis}
    Abstract: {abstract}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
    {chain_of_thought}
    
    Classify the given abstract as either 0 (Not relevant) or 1 (Relevant) based on your reasoning above and any information in the abstract and hypothesis.
    Answer: 
    <|im_end|>
    <|im_start|>assistant
    """

In [172]:
# For a batched input generate output with flattened dimensions
def gen(batches: list[str], model: any, sampling_config: vllm.SamplingParams) -> list[str]:
    outputs = []
    for batch in batches:
        generated = model.generate(batch, sampling_params = sampling_config)
        outputs.extend([output.outputs[0].text for output in generated])
    return outputs

In [145]:
def get_batch(inp: list, batch_size: int) -> list:
    return [inp[i * batch_size:(i + 1) * batch_size] for i in range((len(inp) + batch_size - 1) // batch_size )]

In [192]:
# Redefined reshape function to work with ragged string arrays
def reshape(inp: list, shape: list) -> list:
    assert(len(inp) == sum(shape))
    output = []
    running_length = 0;
    for length in shape:
        output.append(inp[running_length: running_length + length])
        running_length = length
        
    return output
        

In [75]:
# Total of 65 abstracts here 
with open("/home/ubuntu/kmGPT/config.json") as file:
	job_config = json.load(file)
 
b_terms_pmids = km_output.ab_pmid_intersection.map(lambda pmid_list: pmid_list.strip('][').split(', '))
# Grab only the abstract from each list of pmids in the TSV
abstracts = [process_abstracts_data(job_config, pmid_list)[0] for pmid_list in b_terms_pmids] # Fetch abstracts from each b_term's PMID list

# There should only be one a_term, so it's safe to grab the first index
a_term = km_output.a_term.unique().tolist()[0].split("&")[0]
b_terms = km_output.b_term.unique().tolist()

In [81]:
sys_prompt = "You are an incredibly brilliant biomedical researcher who has spent their lifetime reading all the papers in PubMed. You are focused on uplifting other researchers in dire need to evaluate suggested hypotheses given abstracts in PubMed. The sole purpose of your existence is to help uncover hidden connections between the work of existing papers, examining the fully connected relationship between papers while maintaining a strict standard of truth."
hypotheses = [getHypothesis(a_term, b_term) for b_term in b_terms]

### 2) Model Inference
#### Techniques Used
1. Zero-Shot CoT Prompting
2. Constrained generation
3. Prompt chaining
4. GPU Batching
5. Paged Attention!
6. Special sampling


### 2.1) Chain of Thought Generation

In [2]:
mistral = vllm.LLM(model="Mistral-7B-OpenOrca", max_model_len=16832)

INFO 03-06 17:27:26 llm_engine.py:87] Initializing an LLM engine with config: model='Mistral-7B-OpenOrca', tokenizer='Mistral-7B-OpenOrca', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=16832, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 03-06 17:27:40 llm_engine.py:357] # GPU blocks: 2054, # CPU blocks: 2048
INFO 03-06 17:27:41 model_runner.py:684] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 03-06 17:27:41 model_runner.py:688] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 03-06 17:27:45 model_runner.py:756] Graph capturing finished in 4 secs.


In [157]:
cot_prompts = [cot_prompt(sys_prompt, hypotheses[i], abstract) for i, abstract_list in enumerate(abstracts) for abstract in abstract_list]
cot_batches = get_batch(cot_prompts, config.batch_size)

In [173]:
# %%time
sampling_cot = vllm.SamplingParams(
			temperature=config.temperature, 
			top_k = config.top_k, top_p=config.top_p, 
			max_tokens=config.max_tokens, 
			repetition_penalty=config.repetition_penalty)
cot_outputs = gen(cot_batches, mistral, sampling_cot)

Processed prompts:   0%|          | 0/32 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s]
Processed prompts: 100%|██████████| 32/32 [00:23<00:00,  1.38it/s]
Processed prompts: 100%|██████████| 1/1 [00:09<00:00,  9.14s/it]


### 2.2) Answer Generation

In [174]:
tokenizer_data = build_vllm_token_enforcer_tokenizer_data(mistral)
logits_processor = build_vllm_logits_processor(tokenizer_data, RegexParser(r"0|1"))

In [175]:
answer_prompts = []
total_idx = 0
for i, abstract_list in enumerate(abstracts):
    for j, abstract in enumerate(abstract_list):
        answer_prompts.append(answer_prompt(sys_prompt, hypotheses[i], abstract, cot_outputs[total_idx + j]))
    total_idx += len(abstract_list)
answer_batches = get_batch(answer_prompts, config.batch_size)


In [202]:
sampling_answer = vllm.SamplingParams(
			temperature=config.temperature, 
			top_k = config.top_k, top_p=config.top_p, 
			max_tokens=config.max_tokens, 
			repetition_penalty=config.repetition_penalty,
   			logits_processors=[logits_processor])
answers = gen(answer_batches, mistral, sampling_answer)

Processed prompts:   0%|          | 0/32 [00:00<?, ?it/s]

Processed prompts: 100%|██████████| 32/32 [00:09<00:00,  3.46it/s]
Processed prompts: 100%|██████████| 32/32 [00:08<00:00,  3.88it/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  3.17it/s]


In [203]:
answers = [eval(answer) for answer in answers]
shape = [len(abstract_list) for abstract_list in abstracts]
answers = reshape(answers, shape)

In [205]:
answers

[[0, 0, 1, 1, 0, 0, 0, 0, 0, 1],
 [0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
 [0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
 [0, 0, 1, 0, 0, 1, 0, 0, 0, 1],
 [0, 0, 1, 0, 0, 1, 0, 0],
 [0, 1, 0, 0],
 [0, 0, 0],
 [1, 0, 0, 0, 0, 0, 1, 0, 0, 1]]

## 3) Post Processing

### 3.1) Generate Chain of Thought TSV

In [235]:
cot_tsv = km_output.copy(deep = True)
cot_tsv["scores"] = answers
cot_tsv["chain_of_thought"] = reshape(cot_outputs, shape)

In [283]:
cot_tsv.to_csv("chain_of_thought.tsv", sep='\t')

### 3.2) Generate Output TSV

In [251]:
filtered_tsv = km_output.copy(deep = True)


In [284]:
filtered_abstracts = []
for i, abstract_list in enumerate(abstracts):
	for j, score in enumerate(answers[i]):
		filtered = []
		if score == 1:
			filtered.append(abstract_list[j])
	filtered_abstracts.append(filtered)

In [285]:
filtered_tsv["ab_pmid_intersection"] = filtered_abstracts
filtered_tsv.to_csv("filtered_output.tsv", sep="\t")