In [1]:
# !pip install -r requirements.txt

In [1]:
import numpy as np
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import set_seed
import torch
from typing import Callable
import random
import os
from src.get_pubmed_text import process_abstracts_data
import json

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [2]:
# Configuration variables for this whole notebook
class config:
    seed = 42
    model = "Mistral-7B-OpenOrca"
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    top_k = 20
    top_p = 0.95
    do_sample = True
    num_return_sequences = 1
    max_new_tokens = 500
    temperature = 0.8
    repetition_penalty = 1.2
    penalty_alpha=0.6


In [3]:
# !sudo apt-get install git-lfs
# !git lfs install
# !git clone https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca

In [4]:
# km_output = pd.read_csv("./km_output.tsv", sep = "\t")
# km_output
km_output = pd.read_csv("./data.csv")
km_output

Unnamed: 0.1,Unnamed: 0,pmid,abstract,label
0,0,23956253,PMID: 23956253 Text: The first aim was to crit...,0
1,1,23444397,PMID: 23444397 Text: Niacin has potentially fa...,0
2,2,28886926,"PMID: 28886926 Text: In 2016, the American Col...",0
3,3,27701660,PMID: 27701660 Text: Low-density lipoprotein c...,1
4,4,19095139,PMID: 19095139 Text: This secondary analysis f...,0
5,5,21095263,PMID: 21095263 Text: Lowering low-density lipo...,0


In [5]:
with open("/home/ubuntu/kmGPT/config.json") as file:
	job_config = json.load(file)

In [6]:
# Total of 65 abstracts here 

# b_terms_pmids = km_output.ab_pmid_intersection.map(lambda pmid_list: pmid_list.strip('][').split(', '))
# # Grab only the abstract from each list of pmids in the TSV
# abstracts = [process_abstracts_data(job_config, pmid_list)[0] for pmid_list in b_terms_pmids] # Fetch abstracts from each b_term's PMID list
# # There should only be one a_term, so it's safe to grab the first index
# a_term = km_output.a_term.unique().tolist()[0].split("&")[0]
# b_terms = km_output.b_term.unique().tolist()
# b_terms_pmids = km_output.ab_pmid_intersection.map(lambda pmid_list: pmid_list.strip('][').split(', '))
# # Grab only the abstract from each list of pmids in the TSV
# abstracts = [process_abstracts_data(job_config, pmid_list)[0] for pmid_list in b_terms_pmids] # Fetch abstracts from each b_term's PMID list
# # There should only be one a_term, so it's safe to grab the first index
# a_term = km_output.a_term.unique().tolist()[0].split("&")[0]
# b_terms = km_output.b_term.unique().tolist()

In [78]:
hyp = "ezetimibe may effectively alleviate or target key pathogenic mechanisms of diabetes potentially offering therapeutic benefits or slowing disease progression."
sys_prompt = "You are an incredibly brilliant biomedical researcher who has spent their lifetime reading all the papers in PubMed. You are focused on uplifting other researchers in dire need to evaluate suggested hypotheses given abstracts in PubMed. The sole purpose of your existence is to help uncover hidden connections between the work of existing papers, examining the fully connected relationship between papers while maintaining a strict standard of truth."

In [79]:
def retrieveZeroShotCoTPrompt(hyp: str, abstract: str) -> str:
  zero_shot_prompt = f"""
    Hypothesis: {hyp}
    Abstract: {abstract}

    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis.
    A relevant abstract should either support the given hypothesis or have evidence to refute the hypothesis.
    A relevant abstract must directly comment on the hypothesis.

    Let us think through this step by step.
  """
  return zero_shot_prompt

In [80]:
output = km_output.copy(deep = True)
chain_of_thought = km_output.copy(deep=True)


In [81]:
output

Unnamed: 0.1,Unnamed: 0,pmid,abstract,label
0,0,23956253,PMID: 23956253 Text: The first aim was to crit...,0
1,1,23444397,PMID: 23444397 Text: Niacin has potentially fa...,0
2,2,28886926,"PMID: 28886926 Text: In 2016, the American Col...",0
3,3,27701660,PMID: 27701660 Text: Low-density lipoprotein c...,1
4,4,19095139,PMID: 19095139 Text: This secondary analysis f...,0
5,5,21095263,PMID: 21095263 Text: Lowering low-density lipo...,0


# Outlines Test
### Techniques Used
1. Zero-Shot CoT Prompting
2. llama.cpp GPU inference
3. Constrained generation
4. Prompt chaining
### Unfortunately no batching since batching is pretty difficult with constrained generation


In [82]:
import outlines
from llama_cpp import Llama
from pydantic import BaseModel, field_validator, Field

In [83]:
@outlines.prompt
def cot_prompt(sys_prompt, hypothesis, abstract):
    """
    <|im_start|>system
    {{ sys_prompt }}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {{ hypothesis }}
    Abstract: {{abstract}}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
	<|im_end|>
    <|im_start|>assistant
    """

In [84]:
@outlines.prompt
def answer_prompt(sys_prompt, hypothesis, abstract, chain_of_thought):
    """
    <|im_start|>system
    {{ sys_prompt }}
    <|im_end|>
    <|im_start|>user
    Hypothesis: {{ hypothesis }}
    Abstract: {{abstract}}
    
    Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. A relevant abstract must directly comment on the hypothesis and either support the given hypothesis or have evidence to refute the hypothesis.

    Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Let's work this out in a step by step way to be sure we have the right answer.
    {{chain_of_thought}}
    
    Classify the given abstract as either 0 (Not relevant) or 1 (Relevant) based on your reasoning above and any information in the abstract and hypothesis.
    Answer: 
    <|im_end|>
    <|im_start|>assistant
    """

In [85]:
model = outlines.models.llamacpp(model_name="/home/ubuntu/kmGPT/mistral-7b-openorca.Q5_K_M.gguf", device = "cuda", model_kwargs = {"n_gpu_layers":-1, "n_ctx":4096, "n_threads": 30})

llama_model_loader: loaded meta data with 20 key-value pairs and 291 tensors from /home/ubuntu/kmGPT/mistral-7b-openorca.Q5_K_M.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = open-orca_mistral-7b-openorca
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 l

In [86]:
# prompts = [[zero_shot_cot_prompt(sys_prompt, hyp, abstract) for abstract in abstract_list] for abstract_list in abstracts]
cot_prompts = [cot_prompt(sys_prompt, hyp, abstract) for abstract in km_output["abstract"]]

In [87]:
%time
cot_generator = outlines.generate.text(model, sampler=outlines.samplers.MultinomialSampler(temperature=0.8, top_k=20, top_p=0.95))
cot_outputs = cot_generator(cot_prompts)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.01 µs



llama_print_timings:        load time =     197.89 ms
llama_print_timings:      sample time =      81.25 ms /   349 runs   (    0.23 ms per token,  4295.65 tokens per second)
llama_print_timings: prompt eval time =     352.63 ms /   790 tokens (    0.45 ms per token,  2240.28 tokens per second)
llama_print_timings:        eval time =    5198.25 ms /   348 runs   (   14.94 ms per token,    66.95 tokens per second)
llama_print_timings:       total time =    6252.79 ms /  1138 tokens

llama_print_timings:        load time =     197.89 ms
llama_print_timings:      sample time =      87.05 ms /   374 runs   (    0.23 ms per token,  4296.18 tokens per second)
llama_print_timings: prompt eval time =     409.53 ms /   920 tokens (    0.45 ms per token,  2246.51 tokens per second)
llama_print_timings:        eval time =    5640.85 ms /   373 runs   (   15.12 ms per token,    66.12 tokens per second)
llama_print_timings:       total time =    6809.70 ms /  1293 tokens

llama_print_timings:     

In [88]:
%time
answer_prompts = [answer_prompt(sys_prompt, hyp, km_output["abstract"][i], cot_outputs[i]) for i in range(len(km_output["abstract"]))]
answer_generator = outlines.generate.choice(model, ["0", "1"])
answers = answer_generator(answer_prompts)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.29 µs



llama_print_timings:        load time =     197.89 ms
llama_print_timings:      sample time =       0.44 ms /     2 runs   (    0.22 ms per token,  4535.15 tokens per second)
llama_print_timings: prompt eval time =     554.06 ms /  1180 tokens (    0.47 ms per token,  2129.72 tokens per second)
llama_print_timings:        eval time =      16.60 ms /     1 runs   (   16.60 ms per token,    60.23 tokens per second)
llama_print_timings:       total time =     577.46 ms /  1181 tokens

llama_print_timings:        load time =     197.89 ms
llama_print_timings:      sample time =       0.45 ms /     2 runs   (    0.22 ms per token,  4454.34 tokens per second)
llama_print_timings: prompt eval time =     621.12 ms /  1335 tokens (    0.47 ms per token,  2149.33 tokens per second)
llama_print_timings:        eval time =      15.84 ms /     1 runs   (   15.84 ms per token,    63.14 tokens per second)
llama_print_timings:       total time =     643.88 ms /  1336 tokens

llama_print_timings:     

In [91]:
answers

['0', '0', '0', '1', '1', '0']

In [90]:
cot_outputs

["\n To determine if this abstract is relevant for evaluating the hypothesis, let's first understand the hypothesis and break it down into its components:\n\nHypothesis: Ezetimibe may effectively alleviate or target key pathogenic mechanisms of diabetes potentially offering therapeutic benefits or slowing disease progression.\n\nNow, let's analyze the abstract for any direct mention or evidence that supports or refutes this hypothesis:\n\nAbstract Summary: The primary focus of the abstract revolves around the diagnosis, screening, and treatment of familial hypercholesterolemia (FH) and its associated risks of coronary heart disease (CHD). It provides specific recommendations for screening, cholesterol targets, and treatment options, including ezetimibe, for both children and adults with FH.\n\nOur thought process:\n1. We look for any mention or discussion of diabetes in the abstract.\n2. We search for evidence related to ezetimibe's potential role in alleviating pathogenic mechanisms o

In [None]:
with system():
    lm = mistral + sys_prompt

with user():
    lm += prompt

with assistant():
    lm += gen(max_tokens = 500, temperature = config.temperature, name = "chain_of_thought")

with user():
    lm += "Give a score of either 0: (Not relevant) or 1: (Relevant) for the above abstract. Answer: " + select([0, 1], name = "answer")

KeyboardInterrupt: 