In [1]:
# Import packages
import json
from huggingface_hub import login
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, pipeline
import transformers
import random
import torch
import time
import re
from tqdm import tqdm
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Read in csv file
df = pd.read_csv("../Data/subject-info-cleaned.csv")
df.head()

Unnamed: 0.1,Unnamed: 0,Patient ID,Follow-up period from enrollment (days),days_4years,Exit of the study,Cause of death,Age,Gender (male=1),Weight (kg),Height (cm),...,Angiotensin-II receptor blocker (yes=1),Anticoagulants/antitrombotics (yes=1),Betablockers (yes=1),Digoxin (yes=1),Loop diuretics (yes=1),Spironolactone (yes=1),Statins (yes=1),Hidralazina (yes=1),ACE inhibitor (yes=1),Nitrovasodilator (yes=1)
0,1,P0001,2065,1460,0,0,58,1,83,163,...,0,1,1,1,1,0,0,0,1,0
1,2,P0002,2045,1460,0,0,58,1,74,160,...,1,1,1,0,0,0,1,0,0,0
2,3,P0003,2044,1460,0,0,69,1,83,174,...,1,1,1,1,1,0,0,0,0,0
3,4,P0004,2044,1460,0,0,56,0,84,165,...,1,1,1,0,1,1,0,0,0,0
4,5,P0005,2043,1460,0,0,70,1,97,183,...,0,1,1,0,1,0,1,0,1,1


In [5]:
# Add ECG Impressions to dataframe
impressions = ["Ventricular Extrasystole", "Ventricular Tachycardia", "Non-sustained ventricular tachycardia (CH>10)", 
               "Paroxysmal supraventricular tachyarrhythmia", "Bradycardia"]

# Create ECG reports
class ECGReport:
    def __init__(self, ventricular_extrasystole, ventricular_tachycardia, non_sustained_ventricular_tachycardia, paroxysmal_supraventricular_tachyarrhythmia, bradycardia):
        self.ventricular_extrasystole = ventricular_extrasystole
        self.ventricular_tachycardia = ventricular_tachycardia
        self.non_sustained_ventricular_tachycardia = non_sustained_ventricular_tachycardia
        self.paroxysmal_supraventricular_tachyarrhythmia = paroxysmal_supraventricular_tachyarrhythmia
        self.bradycardia = bradycardia
        
    def interpret_ventricular_extrasystole(self):
        ventricular_extrasystole_dict = {
            0: "No",
            1: "Monomorphic",
            2: "Polymorphic",
            3: "Couplets"
        }
        return ventricular_extrasystole_dict.get(self.ventricular_extrasystole, "Unknown ventricular extrasystole code")
    
    def interpret_ventricular_tachycardia(self):
        ventricular_tachycardia_dict = {
            0: "No",
            1: "Non-sustained VT",
            2: "Sustained VT", 
            3: "Torsade de Points"
        }
        return ventricular_tachycardia_dict.get(self.ventricular_tachycardia, "Unknown ventricular tachycardia code")
    
    def interpret_non_sustained_ventricular_tachycardia(self):
        non_sustained_ventricular_tachycardia_dict = {
            0: "No",
            1: "Yes"
        }
        return non_sustained_ventricular_tachycardia_dict.get(self.non_sustained_ventricular_tachycardia, "Unknown non sustained ventricular tachycardia code")
    
    def interpret_paroxysmal_supraventricular_tachyarrhythmia(self):
        paroxysmal_supraventricular_tachyarrhythmia_dict = {
            0: "No", 
            1: "TPSV", 
            2: "Parosysmal AF", 
            3: "Paroxismal flutter", 
            4: "Others"
        }
        return paroxysmal_supraventricular_tachyarrhythmia_dict.get(self.paroxysmal_supraventricular_tachyarrhythmia, "Unknown paroxysmal supraventricular tachyarrhythmia code")
    
    def interpret_bradycardia(self):
        bradycardia_dict = {
            0: "No",
            1: "Sinus Node Dysfunction", 
            2: "First-degree Atrioventricular block (AVB)",
            3: "Second-degree AVB - type I",
            4: "Second-degree AVB - type II", 
            5: "Third-degree AVB", 
            6: "Paroxysmal AVB"
        }
        return bradycardia_dict.get(self.bradycardia, "Unknown bradycardia code")
    
    def generate_report(self):
        return f"""ECG Impression:
        - Ventricular Extrasystole: {self.interpret_ventricular_extrasystole()}
        - Ventricular Tachycardia: {self.interpret_ventricular_tachycardia()}
        - Non-sustained ventricular tachycardia (CH>10): {self.interpret_non_sustained_ventricular_tachycardia()}
        - Paroxysmal supraventricular tachyarrhythmia: {self.interpret_paroxysmal_supraventricular_tachyarrhythmia()}
        - Bradycardia: {self.interpret_bradycardia()}
            """
            
# Generate ECG impressions for all patients
df['ECG_impressions'] = df.apply(lambda row: ECGReport(
    row["Ventricular Extrasystole"], row["Ventricular Tachycardia"], row["Non-sustained ventricular tachycardia (CH>10)"], row["Paroxysmal supraventricular tachyarrhythmia"], row["Bradycardia"]).generate_report(), axis = 1)

In [6]:
# Test dictionary 
def generate_dictionary(row):
    # Create a dictionary to store non-missing values
    patient_data = {col: row[col] for col in df.columns if pd.notna(row[col])}
    return patient_data
generate_dictionary(df.iloc[0])

{'Unnamed: 0': 1,
 'Patient ID': 'P0001',
 'Follow-up period from enrollment (days)': 2065,
 'days_4years': 1460,
 'Exit of the study': 0,
 'Cause of death': 0,
 'Age': '58',
 'Gender (male=1)': 1,
 'Weight (kg)': 83,
 'Height (cm)': 163,
 'Body Mass Index (Kg/m2)': 312,
 'NYHA class': 3,
 'Diastolic blood  pressure (mmHg)': 75,
 'Systolic blood pressure (mmHg)': 110,
 'HF etiology - Diagnosis': 1,
 'Diabetes (yes=1)': 0,
 'History of dyslipemia (yes=1)': 0,
 'Peripheral vascular disease (yes=1)': 0,
 'History of hypertension (yes=1)': 0,
 'Prior Myocardial Infarction (yes=1)': 0,
 'Prior implantable device': 0,
 'Prior Revascularization': 0,
 'Syncope': 0,
 'daily smoking (cigarretes/day)': 20,
 'smoke-free time (years)': 20,
 'cigarettes /year': 160600,
 'alcohol consumption (standard units)': 0,
 'Albumin (g/L)': 424.0,
 'ALT or GPT (IU/L)': 10,
 'AST or GOT (IU/L)': 20,
 'Normalized Troponin': '1',
 'Total Cholesterol (mmol/L)': 54,
 'Creatinine (?mol/L)': 106,
 'Gamma-glutamil tra

In [7]:
# Test prompt 
def generate_prompt(row):
    # Create a dictionary to store non-missing values
    patient_data = {col: row[col] for col in df.columns if pd.notna(row[col])}

    # Start the prompt
    prompt = "Generate a structured clinical note based on the following data:\n\n"

    # Add demographic information 
    if "Age" in patient_data:
        prompt += f"Age: {patient_data['Age']}\n"
    if "Gender (male=1)" in patient_data:
        if patient_data['Gender (male=1)'] == 1:
            prompt += f"Gender: Male \n"
        elif patient_data['Gender (male=1)'] == 0:
            prompt += f"Gender: Female \n"
    if "Weight (kg)" in patient_data:
        prompt += f"Weight: {patient_data['Weight (kg)']} kg\n"
    if "Height (cm)" in patient_data:
        prompt += f"Height: {patient_data['Height (cm)']} cm\n"

    # Add clinical features
    if "NYHA class" in patient_data:
        if patient_data['NYHA class'] == 2:
            prompt += f"NYHA Class: II\n"
        elif patient_data['NYHA class'] == 3:
            prompt += f"NYHA Class: III\n"
    if ("Systolic blood pressure (mmHg)" in patient_data) and ("Diastolic blood  pressure (mmHg)" in patient_data):
        prompt += f"Blood Pressure: {patient_data['Systolic blood pressure (mmHg)']}/{patient_data['Diastolic blood  pressure (mmHg)']} mmHg\n"

    # Past medical history
    past_medical_conditions = []
    for condition in ["HF etiology - Diagnosis", "Diabetes (yes=1)", "History of dyslipemia (yes=1)", "Peripheral vascular disease (yes=1)",
                     "History of hypertension (yes=1)", "Prior Myocardial Infarction (yes=1)"]:
        if (condition in patient_data) and (condition == "HF etiology - Diagnosis"):
            if patient_data['HF etiology - Diagnosis'] == 1:
                past_medical_conditions.append("Idiopathic dilated cardiomyopathy")
                # prompt += f"HF Etiology: Idiopathic dilated cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 2:
                past_medical_conditions.append("Ischemic dilated cardiomyopathy")
                # prompt += f"HF Etiology: Ischemic dilated cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 3:
                past_medical_conditions.append("Enolic dilated cardiomyopathy")
                # prompt += f"HF Etiology: Enolic dilated cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 4:
                past_medical_conditions.append("Valvular cardiomyopathy")
                # prompt += f"HF Etiology: Valvular cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 5:
                past_medical_conditions.append("Toxic dilated cardiomyopathy")
                #prompt += f"HF Etiology: Toxic dilated cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 6:
                past_medical_conditions.append("Post-myocardial dilated cardiomyopathy")
                # prompt += f"HF Etiology: Post-myocardial dilated cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 7:
                past_medical_conditions.append("Hypertropic cardiomyopathy")
                # prompt += f"HF Etiology: Hypertropic cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 8:
                past_medical_conditions.append("Hypertensive cardiomyopathy")
                # prompt += f"HF Etiology: Hypertensive cardiomyopathy\n"
            if patient_data['HF etiology - Diagnosis'] == 9:
                past_medical_conditions.append("Other HF etiology")
                # prompt += f"HF Etiology: Other\n"
        elif (condition in patient_data) and (condition == "Diabetes (yes=1)"):
            if patient_data['Diabetes (yes=1)'] == 1:
                past_medical_conditions.append("Diabetes")
        elif (condition in patient_data) and (condition == "History of dyslipemia (yes=1)"):
            if patient_data['History of dyslipemia (yes=1)'] == 1:
                past_medical_conditions.append("Dyslipemia")
        elif (condition in patient_data) and (condition == "Peripheral vascular disease (yes=1)"):
            if patient_data['Peripheral vascular disease (yes=1)'] == 1:
                past_medical_conditions.append("Peripheral vascular disease")
        elif (condition in patient_data) and (condition == "History of hypertension (yes=1)"):
            if patient_data['History of hypertension (yes=1)'] == 1:
                past_medical_conditions.append("Hypertension")
        elif (condition in patient_data) and (condition == "Prior Myocardial Infarction (yes=1)"):
            if patient_data['Prior Myocardial Infarction (yes=1)'] == 1:
                past_medical_conditions.append("Myocardial Infarction")
    if past_medical_conditions:
        prompt += "Past Medical History: " + ", ".join(past_medical_conditions) + "\n"
    else:
        prompt += "Past Medical History: None reported.\n"

    # Lab results
    lab_tests = ['Albumin (g/L)', 'ALT or GPT (IU/L)', 'AST or GOT (IU/L)', 'Total Cholesterol (mmol/L)', 'Creatinine (?mol/L)',
                 'Gamma-glutamil transpeptidase (IU/L)', 'Glucose (mmol/L)', 'Hemoglobin (g/L)', 'HDL (mmol/L)', 
                 'Potassium (mEq/L)', 'LDL (mmol/L)', 'Sodium (mEq/L)', 'Pro-BNP (ng/L)',  'Protein (g/L)', 'T3 (pg/dL)', 
                 'T4 (ng/L)', 'Troponin (ng/mL)', 'TSH (mIU/L)', 'Urea (mg/dL)']
    for test in lab_tests:
        # Special case for units of creatining
        if (test in patient_data) and (test == 'Creatinine (?mol/L)'):
            prompt += f"Creatinine (mmol/L): {patient_data[test]}\n"
        elif test in patient_data:
            prompt += f"{test}: {patient_data[test]}\n"
    
    # LVEF
    if "LVEF (%)" in patient_data:
        # prompt += f"LVEF (%): {patient_data["LVEF (%)"]}\n"
        prompt += f"LVEF (%): {patient_data['LVEF (%)']}\n"
        
    # Medication
    current_medications = []
    medications = {
    'Calcium channel blocker (yes=1)': "Calcium Channel Blocker",
    'Diabetes medication (yes=1)': "Diabetes Medication",
    'Amiodarone (yes=1)': "Amiodarone",
    'Angiotensin-II receptor blocker (yes=1)': "Angiotensin II Receptor Blocker",
    'Anticoagulants/antitrombotics (yes=1)': "Anticoagulants/Antithrombotics",
    'Betablockers (yes=1)': "Beta Blockers",
    'Digoxin (yes=1)': "Digoxin",
    'Loop diuretics (yes=1)': "Loop Diuretics",
    'Spironolactone (yes=1)': "Spironolactone",
    'Statins (yes=1)': "Statins",
    'Hidralazina (yes=1)': "Hydralazine",
    'ACE inhibitor (yes=1)': "ACE Inhibitor",
    'Nitrovasodilator (yes=1)': "Nitrovasodilator"
    }
    for key, value in medications.items():
        if (key in patient_data) and (patient_data[key]) == 1:
            current_medications.append(value)
    if current_medications:
        prompt += "Medications: " + ", ".join(current_medications) + "\n"
    else:
        prompt += "Medications: None reported.\n"

    # Holter ECG rhythm
    if "Holter  rhythm" in patient_data:
        if patient_data["Holter  rhythm"] == 0:
            prompt += "Holter rhythm: sinus \n"
        elif patient_data["Holter  rhythm"] == 1:
            prompt += "Holter rhythm: permanent atrial fibrillation \n"
        elif patient_data["Holter  rhythm"] == 2:
            prompt += "Holter rhythm: atrial flutter \n"
        elif patient_data["Holter  rhythm"] == 3:
            prompt += "Holter rhythm: pacemaker \n"
    
    # Holter ECG Features (Impressions only)
    prompt += patient_data['ECG_impressions']
    return prompt
    
print(generate_prompt(df.iloc[1]))


Generate a structured clinical note based on the following data:

Age: 58
Gender: Male 
Weight: 74 kg
Height: 160 cm
NYHA Class: II
Blood Pressure: 130/80 mmHg
Past Medical History: Ischemic dilated cardiomyopathy, Dyslipemia, Myocardial Infarction
Albumin (g/L): 404.0
ALT or GPT (IU/L): 20
AST or GOT (IU/L): 20
Total Cholesterol (mmol/L): 618
Creatinine (mmol/L): 121
Gamma-glutamil transpeptidase (IU/L): 44.0
Glucose (mmol/L): 56.0
Hemoglobin (g/L): 126.0
HDL (mmol/L): 0,98
Potassium (mEq/L): 46.0
LDL (mmol/L): 4,06
Sodium (mEq/L): 140.0
Pro-BNP (ng/L): 570.0
Protein (g/L): 75.0
T3 (pg/dL): 0,04
T4 (ng/L): 12.0
Troponin (ng/mL): 0,01
TSH (mIU/L): 3,27
Urea (mg/dL): 1047
LVEF (%): 35
Medications: Angiotensin II Receptor Blocker, Beta Blockers, Statins
Holter rhythm: sinus 
ECG Impression:
        - Ventricular Extrasystole: Monomorphic
        - Ventricular Tachycardia: No
        - Non-sustained ventricular tachycardia (CH>10): No
        - Paroxysmal supraventricular tachyarrhythmia:

In [8]:
# Create prompt dataframe
df_prompts = df[[df.columns[1]]]
df_prompts['Prompts'] = None
for i in range(len(df_prompts)):
    df_prompts.loc[i, 'Prompts'] = generate_prompt(df.iloc[i])
df_prompts.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_prompts['Prompts'] = None


Unnamed: 0,Patient ID,Prompts
0,P0001,Generate a structured clinical note based on t...
1,P0002,Generate a structured clinical note based on t...
2,P0003,Generate a structured clinical note based on t...
3,P0004,Generate a structured clinical note based on t...
4,P0005,Generate a structured clinical note based on t...


In [9]:
print(df_prompts.iloc[0,1])

Generate a structured clinical note based on the following data:

Age: 58
Gender: Male 
Weight: 83 kg
Height: 163 cm
NYHA Class: III
Blood Pressure: 110/75 mmHg
Past Medical History: Idiopathic dilated cardiomyopathy
Albumin (g/L): 424.0
ALT or GPT (IU/L): 10
AST or GOT (IU/L): 20
Total Cholesterol (mmol/L): 54
Creatinine (mmol/L): 106
Gamma-glutamil transpeptidase (IU/L): 20.0
Glucose (mmol/L): 57.0
Hemoglobin (g/L): 132.0
HDL (mmol/L): 1,29
Potassium (mEq/L): 46.0
LDL (mmol/L): 3,36
Sodium (mEq/L): 141.0
Pro-BNP (ng/L): 1834.0
Protein (g/L): 69.0
T3 (pg/dL): 0,05
T4 (ng/L): 15.0
Troponin (ng/mL): 0,01
TSH (mIU/L): 3,02
Urea (mg/dL): 712
LVEF (%): 35
Medications: Beta Blockers, Digoxin, Loop Diuretics, ACE Inhibitor
Holter rhythm: permanent atrial fibrillation 
ECG Impression:
        - Ventricular Extrasystole: Polymorphic
        - Ventricular Tachycardia: Non-sustained VT
        - Non-sustained ventricular tachycardia (CH>10): Yes
        - Paroxysmal supraventricular tachyarrhyth

In [10]:
# Save results
df_prompts.to_csv("../Data/subject-info-cleaned-with-prompts_BioMistral.csv")