In [1]:
import os
import time
import json
import requests
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from evaluation_functions import evaluate

## Set Up

### Functions

In [15]:
def perform_extraction(model, prompt, text, temperature, max_length):
    # ssh -N -f -L localhost:5000:localhost:5000 username@10.19.2.120
    llamasgard_endpoint = "http://localhost:5000/predict"

    # model = codellama/CodeLlama-34b-Instruct-hf
    # Define the payload
    payload = {
        "input": prompt.format(text),
        "model_id": model,
        "parameters": {
            "temperature": temperature,
            "max_length": max_length
        }
    }
    
    #print(prompt.format(text))

    response = requests.post(llamasgard_endpoint, json=payload)

    # Check the response status code
    if response.status_code != 200:
        print("Error:", response.json())

    #print("\n\n" + str(response.json()) + "\n")
    return response.json().get('response')

def perform_cleanup(extraction, openai_api):
    client = OpenAI(api_key=openai_api)
    
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": ""},
            {
                "role": "user",
                "content": """The following text is an extraction of adverse event terms from a drug label. Please remove any preamble or postamble from the list and turn the list of ADEs into a comma separated list. 
The text: {}""".format(extraction)
            }
        ],
        model="gpt-3.5-turbo-16k",
        temperature=0,
    )
    term = chat_completion.choices[0].message.content
    return term

In [16]:
# row['section_text']
# perform_extraction("google/gemma-7b", system_content, prompt, row['section_text'], 0)


In [4]:
# function for extracting 
def extract_ade_terms(config, model, prompt, text, temperature, max_length):
  extraction = perform_extraction(model, prompt, text, temperature, max_length)
  if extraction is None:
    raise Exception(f"perform_extraction() return None for {model}")
  else:
    extraction = perform_cleanup(extraction, config['OpenAI']['openai_api_key'])
    return extraction


### Variables

In [7]:
drug_file = 'data/TAC2017/train_drug_label_text.csv'
manual_file = 'data/TAC2017/train_drug_label_text_manual_ades.csv'
my_max = 10000

In [23]:
drugs = pd.read_csv(drug_file)
manual_ades = pd.read_csv(manual_file)
set_type = drug_file.split('/')[2].split('_')[0] # assuming file follows format "train_..." or "test...."

all_sections = drugs.query("section_name != 'all-concat'").groupby('drug_name')['section_text'].apply(' '.join).reset_index()
all_sections.insert(1, "section_name", ["all-concat" for _ in range(all_sections.shape[0])])
drugs = pd.concat([drugs, all_sections])

## Run Llama

In [24]:
outputs = {}

In [25]:
config = json.load(open('./config.json'))

# gpt_model = 'code-llama-34b'
# model_id = "codellama/CodeLlama-34b-Instruct-hf"
# model_id = "google/gemma-7b"
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

model_name = model_id.split('/')[1]

max_length = 5000

In [26]:
nruns = 1
temperature = 0

system_options = {
    "no-system-prompt": "",
    "pharmexpert-v0": "You are an expert in pharmacology.",
    "pharmexpert-v1": "You are an expert in medical natural language processing, adverse drug reactions, pharmacology, and clinical trials."
}

prompt_options = {
    "fatal-prompt-v2": """
Extract all adverse reactions as they appear, including all synonyms.
mentioned in the text and provide them as a comma-separated list.
If a fatal event is listed add 'death' to the list.
The text is :'{}' 
"""
}

system_name = "pharmexpert-v0"
system_content = system_options[system_name]

user_prompt_name = "fatal-prompt-v2"
user_prompt = prompt_options[user_prompt_name]

gpt_params = [f"temp{temperature}"]

if model_id.split('/')[0] in ("codellama", "mistralai"):
    print("Modifying the prompt to include instruction tags.")
    prefix = ""
    prompt = f"<s>[INST] <<SYS>>\\n{system_content}\\n<</SYS>>\\n\\n{user_prompt}[/INST]{prefix}"
else:
    prompt = system_content + '\n' + user_prompt

output_file_basename = '{}_{}_{}_{}_{}'.format(model_name, user_prompt_name, system_name, '-'.join(gpt_params), set_type)
output_file_basename

Modifying the prompt to include instruction tags.


'Mixtral-8x7B-Instruct-v0.1_fatal-prompt-v2_pharmexpert-v0_temp0_train'

In [31]:
# run Llama
for i in range(nruns):
    run_key = "{}_run{}".format(output_file_basename, i)
    print(run_key)
    if run_key in outputs:
        print(f"Run {run_key} already started will pick up from where it was left off.")
    elif os.path.exists('results/extract/{}.csv'.format(run_key)):
        gpt_output = pd.read_csv('results/extract/{}.csv'.format(run_key))
        outputs[run_key] = gpt_output
        print(f"Run {run_key} started, loading from disk and pick up from where it was left off.")
    
    start = time.time()
    results = list()
    for _, row in tqdm(drugs.iterrows(), total=drugs.shape[0]):
        name, section = row['drug_name'], row['section_name']

        if run_key in outputs:
            prev_run_results = outputs[run_key].query(f"drug_name == '{name}'").query(f"section_name == '{section}'")
            if prev_run_results.shape[0]==1:
                results.append([name, section, prev_run_results.gpt_output.values[0]])
                continue
        
        text = row['section_text'][:15000]
        try:
            gpt_out = extract_ade_terms(config, model_id, prompt, text, temperature, max_length)
            results.append([name, section, gpt_out])    
        except Exception as err:
            print(f"Encountered an exception for row: {name} {section}. Error message below:")
            print(f"Failed for prompt: {prompt.format(text)}")
            raise err
            continue
    
    gpt_output = pd.DataFrame(
        [r for r in results if r is not None],
        columns=['drug_name', 'section_name', 'gpt_output']
    )
    end = time.time()
    
    if gpt_output.shape[0] > 0:
        outputs[run_key] = gpt_output
        gpt_output.to_csv('results/extract/{}.csv'.format(run_key))
    
    print(f"Run: {run_key}, time elapsed: {end-start}s.")

Mixtral-8x7B-Instruct-v0.1_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0
Run Mixtral-8x7B-Instruct-v0.1_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0 already started will pick up from where it was left off.


100%|██████████| 340/340 [45:24<00:00,  8.01s/it]  

Run: Mixtral-8x7B-Instruct-v0.1_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0, time elapsed: 2724.8942399024963s.





## Evaluation

In [28]:
evaluate(outputs, manual_ades, 'strict')
evaluate(outputs, manual_ades, 'lenient')

Running strict evaluation and saving results to disk.
code-llama-34b_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0


100%|██████████| 101/101 [00:02<00:00, 50.13it/s]


Running lenient evaluation and saving results to disk.
code-llama-34b_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0


100%|██████████| 101/101 [00:15<00:00,  6.50it/s]


In [14]:
text

