In [1]:
import os
from tqdm import tqdm
import pandas as pd
import csv
import numpy as np
import concurrent
import time
import json
from common_string import common_lenient_performance

from Llamasgard import CodeLlama
from openai import OpenAI, AzureOpenAI
import json

## Set Up

### Functions

In [2]:
def perform_extraction(system_content, prompt, text, temperature):
    llm=CodeLlama(system=system_content, temperature=temperature, max_new_tokens=2048)
    response = llm(prompt=prompt.format(text))
    return response

def perform_cleanup(extraction, openai_api):
    client = OpenAI(api_key=openai_api)
    
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": ""},
            {
                "role": "user",
                "content": """The following text is an extraction of adverse event terms from a drug label. Please remove any preamble or postamble from the list and turn the list of ADEs into a comma separated list. 
The text: {}""".format(extraction)
            }
        ],
        model="gpt-3.5-turbo-16k",
        temperature=0,
    )
    term = chat_completion.choices[0].message.content
    return term

In [3]:
# function for extracting 
def extract_ade_terms(config,system_content, prompt, text, temperature):
  extraction = perform_extraction(system_content, prompt, text, temperature)
  extraction = perform_cleanup(extraction, config['OpenAI']['openai_api_key'])
  return extraction


In [4]:
def evaluation_subtype(manual_ades, gpt_output, drug, section='adverse reactions', subtype = 'all', lenient=False):
    '''
    For a given drug, evaluate the performance of GPT on a given subtype of ADEs. 
    '''
    
    drug_df = manual_ades.query("(drug_name == '{}') & (section_name == '{}')".format(drug, section))
    if subtype == 'exact-meddra': drug_df = drug_df[drug_df.meddra_exact_term == 1]
    if subtype == 'non-meddra': drug_df = drug_df[drug_df.meddra_exact_term == 0]
    if subtype == 'negated': drug_df = drug_df[drug_df.negated_term == 1]
    if subtype == 'discontinuous': drug_df = drug_df[drug_df.discontinuous_term == 1]

    
    manual = set(drug_df['reaction_string'].to_list())
    gpt_drug = (gpt_output[
        (gpt_output['drug_name'] == drug)
        &
        (gpt_output['section_name'] == "adverse reactions")
        ]["gpt_output"].astype(str)
        .str.lower()
        .str.replace('\n-', ', ')
        .str.split(",").tolist())
    
    try:
        gpt_drug = [x.strip() for x in gpt_drug[0] if x]
        gpt_drug = set(gpt_drug)
    except:
        return [drug, subtype, len(manual), len(gpt_drug), np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]
        
    if not lenient:    
        #overall
        TP = len(manual.intersection(gpt_drug))
        FP = len(gpt_drug.difference(manual))
        FN = len(manual.difference(gpt_drug))
        if TP == 0 and FP == 0:
            precision = np.NAN
        else:
            precision = TP/(TP+FP)
        if TP == 0 and FN == 0:
            recall = np.NAN
        else:
            recall = TP/(TP+FN)
        if precision != 0 and recall != 0:
            f1 = (2 * precision * recall)/(precision + recall)# 2*TP/(2*TP+FP+FN)
        else:
            f1 = np.NAN
    else:
        [TP, FP, FN, precision, recall, f1] = common_lenient_performance(gpt_drug, manual)
    
    if subtype != 'all':
            # these can't be computed for the subtypes
            precision = np.nan
            f1 = np.nan
            FP = np.nan
    
    return [drug, section, subtype, len(manual), len(gpt_drug), TP, FP, FN, precision, recall, f1]


In [5]:
def evaluation(manual_ades, gpt_output, lenient=False, limit = 1000):
    drugs = gpt_output['drug_name'].unique()
    drugs_set = set()
    results = []
    for drug in tqdm(drugs):
            results.append(evaluation_subtype(manual_ades, gpt_output, drug, lenient))        
    results = pd.DataFrame(results, columns=['drug_name', 'exclude', 'n_manual', 'n_gpt', 'tp', 'fp', 'fn', 'precision', 'recall', 'f1'])
    return results

In [6]:
def evaluation_granular(manual_ades, gpt_output, limit = 1000, lenient=False):
    drugs = gpt_output['drug_name'].unique()
    drugs_set = set()
    results = []
    for drug in tqdm(drugs):
        drugs_set.add(drug)
        if len(drugs_set) > limit:
            break
        
        for section in ['adverse reactions', 'warnings and precautions','boxed warnings']:
            for subtype in ['all', 'exact-medra', 'non-meddra', 'negated', 'discontinuous']:
                results.append(evaluation_subtype(manual_ades, gpt_output, drug, section, subtype, lenient))

    results = pd.DataFrame(results, columns=['drug_name', 'section', 'ade_type', 'n_manual', 'n_gpt', 'tp', 'fp', 'fn', 'precision', 'recall', 'f1'])
    return results

### Variables

In [7]:
drug_file = 'data/train_drug_label_text.csv'
manual_file = 'data/train_drug_label_text_manual_ades.csv'
my_max = 10000

In [8]:
drugs = pd.read_csv(drug_file)
manual_ades = pd.read_csv(manual_file)
set_type = drug_file.split('/')[1].split('_')[0] # assuming file follows format "train_..." or "test...."

## Run Llama

In [20]:
outputs = {}

In [21]:
config = json.load(open('./config.json'))

gpt_model = 'code-llama-34b'

In [25]:
nruns = 1
temperature = 0

system_options = {
    "no-system-prompt": "",
    "pharmexpert-v0": "You are an expert in pharmacology.",
    "pharmexpert-v1": "You are an expert in medical natural language processing, adverse drug reactions, pharmacology, and clinical trials."
}

prompt_options = {
    "fatal-prompt-v2": """
Extract all adverse reactions as they appear, including all synonyms.
mentioned in the text and provide them as a comma-separated list.
If a fatal event is listed add 'death' to the list.
The text is :'{}' 
"""
}

system_name = "pharmexpert-v0"
system_content = system_options[system_name]

prompt_name = "fatal-prompt-v2"
prompt = prompt_options[prompt_name]

gpt_params = [f"temp{temperature}"]

output_file_basename = '{}_{}_{}_{}_{}'.format(gpt_model, prompt_name, system_name, '-'.join(gpt_params), set_type)
output_file_basename

'code-llama-34b_fatal-prompt-v2_pharmexpert-v0_temp0_train'

In [26]:
# run Llama
for i in range(nruns):
    run_key = "{}_run{}".format(output_file_basename, i)
    print(run_key)
    if run_key in outputs:
        print(f"Run {run_key} already started will pick up from where it was left off.")
    elif os.path.exists('results/{}.csv'.format(run_key)):
        gpt_output = pd.read_csv('results/{}.csv'.format(run_key))
        outputs[run_key] = gpt_output
        print(f"Run {run_key} started, loading from disk and pick up from where it was left off.")
    
    start = time.time()
    results = list()
    for _, row in tqdm(drugs.iterrows(), total=drugs.shape[0]):
        name, section = row['drug_name'], row['section_name']

        if run_key in outputs:
            prev_run_results = outputs[run_key].query(f"drug_name == '{name}'").query(f"section_name == '{section}'")
            if prev_run_results.shape[0]==1:
                results.append([name, section, prev_run_results.gpt_output.values[0]])
                continue
        
        text = row['section_text'][:15000]
        try:
            gpt_out = extract_ade_terms(config, system_content, prompt, text, temperature)
            results.append([name, section, gpt_out])
        except Exception as err:
            print(f"Encountered an exception for row: {name} {section}. Error message below:")
            print(err)
            continue
            
    gpt_output = pd.DataFrame(
        [r for r in results if r is not None],
        columns=['drug_name', 'section_name', 'gpt_output']
    )
    end = time.time()
    
    if gpt_output.shape[0] > 0:
        outputs[run_key] = gpt_output
        gpt_output.to_csv('results/{}.csv'.format(run_key))
    
    print(f"Run: {run_key}, time elapsed: {end-start}s.")

code-llama-34b_fatal-prompt-v2_pharmexpert-v0_temp0_train_run0


 79%|███████▊  | 188/239 [00:47<00:07,  6.84it/s]

Encountered an exception for row: KYPROLIS adverse reactions. Error message below:
('Connection aborted.', ConnectionResetError(54, 'Connection reset by peer'))
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1564ffee0>: Failed to establish a new connection: [Errno 61] Connection refused'))
Encountered an exception for row: MULTAQ adverse reactions. Error message below:
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1573111e0>: Failed to establish a new connection: [Errno 61] Connection refused'))
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1573137f0>: Failed to establish a new connection: [Errno 61] Connection refused'))
HTTPConn

100%|██████████| 239/239 [00:47<00:00,  5.03it/s]

HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1564ffee0>: Failed to establish a new connection: [Errno 61] Connection refused'))
Encountered an exception for row: ONFI adverse reactions. Error message below:
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x156572800>: Failed to establish a new connection: [Errno 61] Connection refused'))
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predict (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x156570190>: Failed to establish a new connection: [Errno 61] Connection refused'))
Encountered an exception for row: TECFIDERA adverse reactions. Error message below:
HTTPConnectionPool(host='127.0.0.1', port=8000): Max retries exceeded with url: /predic




## Evaluation

In [24]:
for eval_method in ('strict', 'lenient'):
    for run_key, output in outputs.items():
        granular_save_filename = 'results/{}_{}_granular.csv'.format(run_key, eval_method)
        overall_save_filename = 'results/{}_{}_overall.csv'.format(run_key, eval_method)
        
        print(run_key)
        
        results_granular = evaluation_granular(manual_ades, output, lenient=(eval_method=='lenient'))
        overall_results = results_granular.groupby(['section','ade_type'])[['tp', 'fp', 'fn']].sum(min_count = 1).reset_index()
        overall_results['micro_precision'] = overall_results['tp']/(overall_results['tp']+overall_results['fp'])
        overall_results['micro_recall'] = overall_results['tp']/(overall_results['tp']+overall_results['fn'])
        overall_results['micro_f1'] = (2 * overall_results['micro_precision'] * overall_results['micro_recall'])/(overall_results['micro_precision'] + overall_results['micro_recall']) # 2*tp_total/(2*tp_total+fp_total+fn_total)
        
        macro_results = results_granular.groupby(['section', 'ade_type'])[['precision', 'recall', 'f1']].mean(numeric_only=True).reset_index()
        overall_results['macro_precision'] = macro_results['precision']
        overall_results['macro_recall'] = macro_results['recall']
        overall_results['macro_f1'] = macro_results['f1']

        allsections_results = results_granular.groupby(['ade_type'])[['tp', 'fp', 'fn']].sum(min_count = 1).reset_index().query("ade_type == 'all'")
        allsections_results['micro_precision'] = allsections_results['tp']/(allsections_results['tp']+allsections_results['fp'])
        allsections_results['micro_recall'] = allsections_results['tp']/(allsections_results['tp']+allsections_results['fn'])
        allsections_results['micro_f1'] = (2 * allsections_results['micro_precision'] * allsections_results['micro_recall'])/(allsections_results['micro_precision'] + overall_results['micro_recall']) # 2*tp_total/(2*tp_total+fp_total+fn_total)
        
        allsections_macro_results = results_granular.groupby(['ade_type'])[['precision', 'recall', 'f1']].mean(numeric_only=True).reset_index().query("ade_type == 'all'")
        allsections_results['macro_precision'] = allsections_macro_results['precision']
        allsections_results['macro_recall'] = allsections_macro_results['recall']
        allsections_results['macro_f1'] = allsections_macro_results['f1']
        allsections_results['section'] = ['all']
        
        overall_results = pd.concat([overall_results, allsections_results])
        overall_results.to_csv(overall_save_filename)
        results_granular.to_csv(granular_save_filename)

code-llama-34b_fatal-prompt-v2_pharmexpert-v1_temp0_train_run0


100%|██████████| 101/101 [00:01<00:00, 51.39it/s]


code-llama-34b_fatal-prompt-v2_pharmexpert-v1_temp0_train_run0


100%|██████████| 101/101 [00:15<00:00,  6.72it/s]
