In [4]:
# Import packages
import json
from huggingface_hub import login
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline
import transformers
import random
import torch
import time
import re
from tqdm import tqdm
import pandas as pd

# Set seeds
random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fb6aff4a650>

In [2]:
# Log into Huggingface
with open("../../huggingface_token.txt", "r") as file:
    access_token = file.read().strip()
login(access_token)

# Load Huggingface Model
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token, low_cpu_mem_usage=True,
                    torch_dtype=torch.float16, device_map='auto')

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [01:13<00:00, 18.50s/it]


In [109]:
def get_message(note):
    system = 'You are a medical assistant. Your tasks are to generate a clinical note and create an Assessment and Plan section. Additionally, you will also determine the fate of the patient. The patient either survives for the next few years or succumbs to either sudden cardiac death or pump failure death. Only suggest death if there is strong evidence for it. Provide your confidence in survival, sudden cardiac death, and pump failure death such that the confidence percentages add up to 100 percent and format these results in a list. Please output a clinical note that has a section for demographics, medical history, lab results, LVEF, medication, and ECG impressions. In the end, put the Assessment and Plan section along with a prediction. Provide reasoning for the prediction.'

    prompt = f"Here is the patient data: \n{note}"

    messages = [
		{"role": "system", "content": system},
		{"role": "user", "content": prompt}
	]

    return messages

def extract_assistant_response(response):
    parts = response.split("assistant\n\n", 1)
    return parts[1].strip() if len(parts) > 1 else response

In [110]:
# Load in csv file with prompts
df = pd.read_csv("../Data/subject-info-cleaned-with-prompts.csv")
df['Prompts'][0]

'Generate a structured clinical note based on the following data:\n\nAge: 58\nGender: Male \nWeight: 74 kg\nHeight: 160 cm\nNYHA Class: II\nBlood Pressure: 130/80 mmHg\nPast Medical History: Ischemic dilated cardiomyopathy, Dyslipemia, Myocardial Infarction\nAlbumin (g/L): 404.0\nALT or GPT (IU/L): 20\nAST or GOT (IU/L): 20\nTotal Cholesterol (mmol/L): 618\nCreatinine (mmol/L): 121\nGamma-glutamil transpeptidase (IU/L): 44.0\nGlucose (mmol/L): 56\nHemoglobin (g/L): 126.0\nHDL (mmol/L): 0,98\nPotassium (mEq/L): 46.0\nLDL (mmol/L): 4,06\nSodium (mEq/L): 140.0\nPro-BNP (ng/L): 570.0\nProtein (g/L): 75.0\nT3 (pg/dL): 0,04\nT4 (ng/L): 12.0\nTroponin (ng/mL): 0,01\nTSH (mIU/L): 3,27\nUrea (mg/dL): 1047\nLVEF (%): 35\nMedications: Angiotensin II Receptor Blocker, Beta Blockers, Statins\nECG Impression:\n        - Ventricular Extrasystole: Monomorphic\n        - Ventricular Tachycardia: No\n        - Non-sustained ventricular tachycardia (CH>10): No\n        - Paroxysmal supraventricular tachy

In [111]:
message = get_message(df['Prompts'][0])
message

[{'role': 'system',
  'content': 'You are a medical assistant. Your tasks are to generate a clinical note and create an Assessment and Plan section. Additionally, you will also determine the fate of the patient. The patient either survives for the next few years or succumbs to either sudden cardiac death or pump failure death. Only suggest death if there is strong evidence for it. Provide your confidence in survival, sudden cardiac death, and pump failure death such that the confidence percentages add up to 100 percent and format these results in a list. Please output a clinical note that has a section for demographics, medical history, lab results, LVEF, medication, and ECG impressions. In the end, put the Assessment and Plan section along with a prediction. Provide reasoning for the prediction.'},
 {'role': 'user',
  'content': 'Here is the patient data: \nGenerate a structured clinical note based on the following data:\n\nAge: 58\nGender: Male \nWeight: 74 kg\nHeight: 160 cm\nNYHA C

In [112]:
# Put message into LLM
input_text = tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = True)
inputs = tokenizer(input_text, return_tensors = "pt").to(model.device)
output = model.generate(**inputs, max_new_tokens = 1000)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


In [113]:
# Get result
result

"Clinical Note\n\nDemographics\n\n* Patient Name: (Not provided)\n* Age: 58 years\n* Sex: Male\n* Weight: 74 kg\n* Height: 160 cm\n\nMedical History\n\n* Ischemic dilated cardiomyopathy\n* Dyslipemia\n* Myocardial Infarction\n* NYHA Class II\n\nLab Results\n\n* Albumin: 404.0 g/L\n* ALT: 20 IU/L\n* AST: 20 IU/L\n* Total Cholesterol: 618 mmol/L\n* Creatinine: 121 mmol/L\n* GGT: 44.0 IU/L\n* Glucose: 56 mmol/L\n* Hemoglobin: 126.0 g/L\n* HDL: 0.98 mmol/L\n* Potassium: 4.6 mEq/L\n* LDL: 4.06 mmol/L\n* Sodium: 140.0 mEq/L\n* Pro-BNP: 570.0 ng/L\n* Protein: 75.0 g/L\n* T3: 0.04 pg/dL\n* T4: 12.0 ng/L\n* Troponin: 0.01 ng/mL\n* TSH: 3.27 mIU/L\n* Urea: 104.7 mg/dL\n\nLVEF\n\n* Left Ventricular Ejection Fraction (LVEF): 35%\n\nMedications\n\n* Angiotensin II Receptor Blocker\n* Beta Blockers\n* Statins\n\nECG Impression\n\n* Ventricular Extrasystole: Monomorphic\n* Ventricular Tachycardia: No\n* Non-sustained ventricular tachycardia (CH>10): No\n* Paroxysmal supraventricular tachyarrhythmia: 

In [114]:
test_result = result.replace("**", "")
print(extract_assistant_response(test_result))

Clinical Note

Demographics

* Patient Name: (Not provided)
* Age: 58 years
* Sex: Male
* Weight: 74 kg
* Height: 160 cm

Medical History

* Ischemic dilated cardiomyopathy
* Dyslipemia
* Myocardial Infarction
* NYHA Class II

Lab Results

* Albumin: 404.0 g/L
* ALT: 20 IU/L
* AST: 20 IU/L
* Total Cholesterol: 618 mmol/L
* Creatinine: 121 mmol/L
* GGT: 44.0 IU/L
* Glucose: 56 mmol/L
* Hemoglobin: 126.0 g/L
* HDL: 0.98 mmol/L
* Potassium: 4.6 mEq/L
* LDL: 4.06 mmol/L
* Sodium: 140.0 mEq/L
* Pro-BNP: 570.0 ng/L
* Protein: 75.0 g/L
* T3: 0.04 pg/dL
* T4: 12.0 ng/L
* Troponin: 0.01 ng/mL
* TSH: 3.27 mIU/L
* Urea: 104.7 mg/dL

LVEF

* Left Ventricular Ejection Fraction (LVEF): 35%

Medications

* Angiotensin II Receptor Blocker
* Beta Blockers
* Statins

ECG Impression

* Ventricular Extrasystole: Monomorphic
* Ventricular Tachycardia: No
* Non-sustained ventricular tachycardia (CH>10): No
* Paroxysmal supraventricular tachyarrhythmia: No
* Bradycardia: No

Assessment and Plan

Based on the

In [115]:
# Test
df_test = df[[df.columns[1]]]
df_test['Reports'] = None


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_test['Reports'] = None


In [127]:
for i in tqdm(range(5), desc = "Test"):
    # Get message
    message = get_message(df['Prompts'][i])

    # Put message into LLM
    input_text = tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = True)
    inputs = tokenizer(input_text, return_tensors = "pt").to(model.device)
    output = model.generate(**inputs, max_new_tokens = 1000)

    # Get result
    result = tokenizer.decode(output[0], skip_special_tokens = True)
    result = result.replace("**", "")
    result = extract_assistant_response(result)

    # Store result
    df_test.loc[i, 'Reports'] = result
    

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.                                                                                                              | 0/5 [00:00<?, ?it/s]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.                                                                                                      | 1/5 [00:54<03:36, 54.10s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.████▏                                                                                                 | 2/5 [01:38<02:26, 48.69s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.████████████████████████████████████▊                                                                 | 3/5 [02:22<01:32, 46.45s/it]
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.█████████████████████████████████████████████████████████████████████▍                                | 4/5 

In [126]:
print(df_test['Reports'][4])

Clinical Note

Demographics:

* Age: 58 years
* Gender: Male
* Weight: 74 kg
* Height: 160 cm

Medical History:

* Ischemic dilated cardiomyopathy
* Dyslipemia
* Myocardial Infarction

Lab Results:

* Albumin: 404.0 g/L
* ALT or GPT: 20 IU/L
* AST or GOT: 20 IU/L
* Total Cholesterol: 618 mmol/L
* Creatinine: 1.21 mmol/L
* Gamma-glutamil transpeptidase: 44.0 IU/L
* Glucose: 5.6 mmol/L
* Hemoglobin: 126.0 g/L
* HDL: 0.98 mmol/L
* Potassium: 4.6 mEq/L
* LDL: 4.06 mmol/L
* Sodium: 140.0 mEq/L
* Pro-BNP: 570.0 ng/L
* Protein: 75.0 g/L
* T3: 0.04 pg/dL
* T4: 12.0 ng/L
* Troponin: 0.01 ng/mL
* TSH: 3.27 mIU/L
* Urea: 104.7 mg/dL

LVEF (Left Ventricular Ejection Fraction):

* 35%

Medications:

* Angiotensin II Receptor Blocker
* Beta Blockers
* Statins

ECG Impressions:

* Ventricular Extrasystole: Monomorphic
* Ventricular Tachycardia: No
* Non-sustained ventricular tachycardia (CH>10): No
* Paroxysmal supraventricular tachyarrhythmia: No
* Bradycardia: No

Assessment and Plan:

Based on the