In [38]:
# Import packages
import json
from huggingface_hub import login
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline
import transformers
import random
import torch
import time
import re
from tqdm import tqdm
import pandas as pd

# Set GPUs
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"

print("Available GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

# Set seeds
random.seed(0)
torch.manual_seed(0)


Available GPUs: 2
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB


<torch._C.Generator at 0x7feb2515a650>

In [40]:
# Log into Huggingface
with open("../../huggingface_token.txt", "r") as file:
    access_token = file.read().strip()
login(access_token)

# Load Huggingface Model
# model_name = "GanjinZero/biomistral-7B"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token, low_cpu_mem_usage=True,
#                     torch_dtype=torch.float16, device_map='auto')

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("PharMolix/BioMedGPT-LM-7B")
model = AutoModelForCausalLM.from_pretrained("PharMolix/BioMedGPT-LM-7B", device_map = 'auto')

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████| 3/3 [00:07<00:00,  2.38s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


In [50]:
def get_message(note):
    
    system = 'You are a medical assistant. Your tasks are to generate a clinical note and create an Assessment and Plan section. Additionally, you will also determine the fate of the patient. The patient either survives for the next few years or succumbs to either sudden cardiac death or pump failure death. Only suggest death if there is strong evidence for it. Provide your confidence in survival, sudden cardiac death, and pump failure death such that the confidence percentages add up to 100 percent and format these results in a list. Please output a clinical note that has a section for demographics, medical history, lab results, LVEF, medication, and ECG impressions. In the end, put the Assessment and Plan section along with a prediction. Provide reasoning for the prediction.'

    # patient_info = f"Here is the patient data: \n{note}"

    messages = "<<SYS>>\n" + system + "\n<</SYS>> \n" + "[INST]\n" + note + "\n [/INST]"

    return messages

def extract_assistant_response(response):
    parts = response.split("\n [/INST]", 1)
    return parts[1].strip() if len(parts) > 1 else response

In [51]:
# Load in csv file with prompts
df = pd.read_csv("../Data/subject-info-cleaned-with-prompts.csv")
print(df['Prompts'][0])

Generate a structured clinical note based on the following data:

Age: 58
Gender: Male 
Weight: 83 kg
Height: 163 cm
NYHA Class: III
Blood Pressure: 110/75 mmHg
Past Medical History: Idiopathic dilated cardiomyopathy
Albumin (g/L): 424.0
ALT or GPT (IU/L): 10
AST or GOT (IU/L): 20
Total Cholesterol (mmol/L): 54
Creatinine (mmol/L): 106
Gamma-glutamil transpeptidase (IU/L): 20.0
Glucose (mmol/L): 57.0
Hemoglobin (g/L): 132.0
HDL (mmol/L): 1,29
Potassium (mEq/L): 46.0
LDL (mmol/L): 3,36
Sodium (mEq/L): 141.0
Pro-BNP (ng/L): 1834.0
Protein (g/L): 69.0
T3 (pg/dL): 0,05
T4 (ng/L): 15.0
Troponin (ng/mL): 0,01
TSH (mIU/L): 3,02
Urea (mg/dL): 712
LVEF (%): 35
Medications: Beta Blockers, Digoxin, Loop Diuretics, ACE Inhibitor
Holter rhythm: permanent atrial fibrillation 
ECG Impression:
        - Ventricular Extrasystole: Polymorphic
        - Ventricular Tachycardia: Non-sustained VT
        - Non-sustained ventricular tachycardia (CH>10): Yes
        - Paroxysmal supraventricular tachyarrhyth

In [52]:
message = get_message(df['Prompts'][0])
print(message)

<<SYS>>
You are a medical assistant. Your tasks are to generate a clinical note and create an Assessment and Plan section. Additionally, you will also determine the fate of the patient. The patient either survives for the next few years or succumbs to either sudden cardiac death or pump failure death. Only suggest death if there is strong evidence for it. Provide your confidence in survival, sudden cardiac death, and pump failure death such that the confidence percentages add up to 100 percent and format these results in a list. Please output a clinical note that has a section for demographics, medical history, lab results, LVEF, medication, and ECG impressions. In the end, put the Assessment and Plan section along with a prediction. Provide reasoning for the prediction.
<</SYS>> 
[INST]
Generate a structured clinical note based on the following data:

Age: 58
Gender: Male 
Weight: 83 kg
Height: 163 cm
NYHA Class: III
Blood Pressure: 110/75 mmHg
Past Medical History: Idiopathic dilated

In [None]:
# Put message into LLM
# input_text = tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = True)
input_text = message
inputs = tokenizer(input_text, return_tensors = "pt").to(model.device)
output = model.generate(**inputs, max_new_tokens = 500, temperature = 0.0)



In [33]:
# Get result
result = tokenizer.decode(output[0], skip_special_tokens = True)
result = result.replace("**", "")
result = extract_assistant_response(result)
print(result)

Clinical Note:
The patient is a 58-year-old male with a past medical history of idiopathic dilated cardiomyopathy. He presented to the clinic with complaints of shortness of breath and fatigue. The patient's vital signs were within normal limits. Physical examination revealed a regular heart rate and rhythm with no murmurs, rubs, or gallops. The patient's lungs were clear to auscultation bilaterally. The patient's abdomen was soft and nontender with normal bowel sounds. The patient's extremities were warm and well perfused.
The patient's laboratory results were significant for a low potassium level of 3.2 mmol/L, a high creatinine level of 1.4 mg/dL, and a high pro-BNP level of 1834 pg/mL. The patient's ECG showed a normal sinus rhythm with a heart rate of 70 beats per minute, a left axis deviation, and a left ventricular hypertrophy pattern. The patient's echocardiogram showed a left ventricular ejection fraction of 35%.
The patient's medical history, physical examination, laboratory 

In [32]:
test_result = result.replace("**", "")
print(extract_assistant_response(test_result))

Clinical Note:
The patient is a 58-year-old male with a past medical history of idiopathic dilated cardiomyopathy. He presented to the clinic with complaints of shortness of breath and fatigue. The patient's vital signs were within normal limits. Physical examination revealed a regular heart rate and rhythm with no murmurs, rubs, or gallops. The patient's lungs were clear to auscultation bilaterally. The patient's abdomen was soft and nontender with normal bowel sounds. The patient's extremities were warm and well perfused.
The patient's laboratory results were significant for a low potassium level of 3.2 mmol/L, a high creatinine level of 1.4 mg/dL, and a high pro-BNP level of 1834 pg/mL. The patient's ECG showed a normal sinus rhythm with a heart rate of 70 beats per minute, a left axis deviation, and a left ventricular hypertrophy pattern. The patient's echocardiogram showed a left ventricular ejection fraction of 35%.
The patient's medical history, physical examination, laboratory 

In [34]:
# Test
df_test = df[[df.columns[1]]]
df_test['Reports'] = None


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_test['Reports'] = None


In [37]:
for i in tqdm(range(5), desc = "Test"):
    # Get message
    message = get_message(df['Prompts'][i])

    # Put message into LLM
    # input_text = tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = True)
    input_text = message
    inputs = tokenizer(input_text, return_tensors = "pt").to(model.device)
    output = model.generate(**inputs, max_new_tokens = 1000)

    # Get result
    result = tokenizer.decode(output[0], skip_special_tokens = True)
    result = result.replace("**", "")
    result = extract_assistant_response(result)

    # Store result
    df_test.loc[i, 'Reports'] = result
    

Test:   0%|                                                                                           | 0/5 [03:58<?, ?it/s]


KeyboardInterrupt: 

In [40]:
print(df_test['Reports'][1])

Clinical Note

Demographics

- Patient Name: [Not provided]
- Date of Birth: [Not provided]
- Age: 58
- Gender: Male
- Weight: 74 kg
- Height: 160 cm

Medical History

- Ischemic dilated cardiomyopathy
- Dyslipemia
- Myocardial Infarction
- NYHA Class II

Lab Results

- Albumin: 404.0 g/L
- ALT: 20 IU/L
- AST: 20 IU/L
- Total Cholesterol: 618 mmol/L
- Creatinine: 121 mmol/L
- Gamma-glutamil transpeptidase: 44.0 IU/L
- Glucose: 56.0 mmol/L
- Hemoglobin: 126.0 g/L
- HDL: 0.98 mmol/L
- Potassium: 4.6 mEq/L
- LDL: 4.06 mmol/L
- Sodium: 140.0 mEq/L
- Pro-BNP: 570.0 ng/L
- Protein: 75.0 g/L
- T3: 0.04 pg/dL
- T4: 12.0 ng/L
- Troponin: 0.01 ng/mL
- TSH: 3.27 mIU/L
- Urea: 104.7 mg/dL

Medications

- Angiotensin II Receptor Blocker
- Beta Blockers
- Statins

ECG Impressions

- Ventricular Extrasystole: Monomorphic
- Ventricular Tachycardia: No
- Non-sustained ventricular tachycardia (CH>10): No
- Paroxysmal supraventricular tachyarrhythmia: No
- Bradycardia: No

LVEF

- Left Ventricular Ejecti