In [6]:
!pip install sentencepiece protobuf datasets transformers trl textstat peft bitsandbytes nltk --quiet
!pip install -U bitsandbytes accelerate --quiet

In [7]:

with open("hf.token", "r") as f:
    hftoken = f.read().strip()  

## import

In [8]:
# Standard library imports
import csv
import json
import os
import random
import re
import sys
import time
from collections import Counter
from collections import defaultdict

# Third-party data and ML libraries
import pandas as pd
import torch
import nltk
from nltk.tokenize import word_tokenize

# Hugging Face ecosystem
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    pipeline,
    logging,
)

# Fine-tuning and optimization
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

# Text analysis and readability metrics
from textstat import (
    flesch_kincaid_grade, 
    flesch_reading_ease,
    smog_index, 
    gunning_fog, 
    dale_chall_readability_score,
    text_standard, 
    syllable_count
)

# Optional: Uncomment if needed
from huggingface_hub import login
login(token=hftoken)  # Move token to environment variable

  from .autonotebook import tqdm as notebook_tqdm


## load model

In [9]:
# Model Configuration
# model_name = "NousResearch/Llama-2-7b-chat-hf"
# Recommended upgrade - Llama 3.1 8B
model_name = "meta-llama/Llama-3.1-8B-Instruct"

# # Or try Mistral 7B
# model_name = "mistralai/Mistral-7B-Instruct-v0.3"

device_map = {"": 0}

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Load model with quantization
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     device_map=device_map,
#     quantization_config=bnb_config
# )

# full precision
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    torch_dtype=torch.float16  # Use half precision instead
)

# # 8 precision
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     device_map=device_map,
#     load_in_8bit=True
# )

Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.42s/it]


## data

In [None]:

def calculate_readability_scores(text, term=None):
    """Calculate readability scores for a piece of text, optionally removing a term"""
    scoring_text = text
    
    # Remove the term and its variations before scoring if provided
    if term:
        # Create variations of the term to remove (capitalized, lowercase)
        term_variations = [term, term.lower(), term.capitalize()]
        
        for variation in term_variations:
            scoring_text = scoring_text.replace(variation, "")
        
        # Clean up any double spaces created
        scoring_text = " ".join(scoring_text.split())
    
    fk_grade = flesch_kincaid_grade(scoring_text)
    readability = flesch_reading_ease(scoring_text)
    return {
        "fk_grade": fk_grade,
        "readability": readability,
    } 
def txt_to_dict(file_path):
    data_dict = {}
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for i in range(0, len(lines) - 1, 2):
            key = lines[i].strip()    # Odd line are key
            value = lines[i + 1].strip()  # Even line are value
            data_dict[key] = value

    return data_dict


txt_file_path = 'formaldef.txt' 
formaldic = txt_to_dict(txt_file_path)
len(formaldic)

meddict={}
for k,v in formaldic.items():
    meddict[k.split('Listen to pronunciation')[0].split('(')[0]]=v

In [12]:
# load data
file_name = "Discharge Sum_Sample_MIMIC.csv"
df = pd.read_csv(file_name)

# promt, single term


chatgpt baseline.

A33
A33 is a protein found on the surface of certain cells in the intestines and some types of cancer. Doctors can target it to help find or treat these cancers. It acts like a flag that helps identify where the cancer cells are.

A6
A6 is a small part of a protein that can help stop cancer from spreading. It works by blocking signals that tell cancer cells to move. Scientists are studying it to see if it can be used as a treatment.

AAP
AAP stands for American Academy of Pediatrics. It’s a group of doctors who focus on keeping children healthy. They create guidelines to help parents and doctors care for kids.

AAT deficiency
AAT deficiency is a condition where the body doesn't make enough of a protein that protects the lungs. Without it, the lungs can get damaged more easily, especially by smoking or pollution. It can also affect the liver in some people.

Abarelix
Abarelix is a medicine that lowers certain hormones in the body. It is mainly used to treat prostate cancer by stopping the cancer from growing. It works by blocking signals from the brain that tell the body to make these hormones.

# prompt, paragraph


In [None]:
# Modular Medical Text Explanation System
# Returns clean dict with configurable steps: extract, generate, clean

import re
import torch
from collections import defaultdict

class MedicalTextExplainer:
    def __init__(self, model, tokenizer, meddict):
        self.model = model
        self.tokenizer = tokenizer
        self.meddict = meddict
        
        # Improved generation config for better outputs
        self.generation_config = {
            "max_new_tokens": 999,              # Increased for fuller explanations
            "temperature": 0.3,
            "top_p": 0.85,
            "repetition_penalty": 1.15,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
            "eos_token_id": tokenizer.eos_token_id,  # Ensure proper stopping
            "early_stopping": True,              # Stop when EOS token is generated
        }
        
        # Define available steps
        self.available_steps = ['determine_audience', 'extract', 'generate', 'clean'] # 增加新步骤
        self.step_functions = {
            'determine_audience': self._step_determine_audience, # 增加新映射
            'extract': self._step_extract_terms,
            'generate': self._step_generate_explanation,
            'clean': self._step_clean_response
        }
    def create_audience_prompt(self, original_text):
            """
            Creates a simple, direct prompt to classify the audience.
            """
            prompt = f"""
        [INST] You are an expert medical text classifier. Read the following medical text and determine the appropriate audience for a summary letter.

        **Medical Text:**
        "{original_text}"

        **Instructions:**
        Based on the text, who is the audience for the explanation letter?
        - If the text describes a patient recovering, Discharge Condition says much improved, or having a positive or follow ongoing treatment plan, the audience is the **patient**.
        - If the text mentions "died", "passed away," "deceased," or describes a fatal outcome such "comfort care" "hospice care" "pallliative care", "palliative extubate", the audience is the **patient's family**.

        Respond with a single word ONLY: **patient** or **family**.
        [/INST]
        """
            return prompt
    

    def _step_determine_audience(self, result):
        """Step 0: Determine the correct audience (patient or family)."""
        text = result['original_text']
        
        # 1. 创建分类任务的Prompt
        audience_prompt = self.create_audience_prompt(text)
        
        # 2. 调用模型进行分类 (使用一个轻量的生成配置)
        try:
            inputs = self.tokenizer(audience_prompt, return_tensors="pt").to(self.model.device)
            input_length = inputs["input_ids"].shape[1]
            
            # 使用一个非常短的生成配置，因为我们只需要一个词
            audience_config = {
                "max_new_tokens": 5,
                "temperature": 0.01, # 温度极低，追求最确定的答案
                "do_sample": False, # 不进行采样
                "pad_token_id": self.tokenizer.eos_token_id
            }

            with torch.no_grad():
                outputs = self.model.generate(inputs["input_ids"], **audience_config)
            
            response_tokens = outputs[0][input_length:]
            raw_label = self.tokenizer.decode(response_tokens, skip_special_tokens=True).lower().strip()
            
            # 3. 清理和验证标签
            if "family" in raw_label:
                audience = "family"
            else:
                audience = "patient" # 默认为 patient
                
        except Exception as e:
            print(f"   ⚠️ Audience determination failed: {e}. Defaulting to 'patient'.")
            audience = "patient"

        result['determined_audience'] = audience
        print(f"   🎯 Determined audience: '{audience}'")
        return result


    def explain_medical_text(self, text, steps=['extract', 'generate', 'clean']):
        """
        Modular main function that runs configurable steps.
        
        Args:
            text (str): Input medical text
            steps (list): List of steps to run. Options: ['extract', 'generate', 'clean']
                         - 'extract': Extract medical terms from text
                         - 'generate': Generate explanation from model  
                         - 'clean': Clean the raw output
        
        Returns:
            dict: Contains results from each step that was run
        """
        
        # Validate steps
        invalid_steps = [step for step in steps if step not in self.available_steps]
        if invalid_steps:
            raise ValueError(f"Invalid steps: {invalid_steps}. Available steps: {self.available_steps}")
        
        # Initialize result with original text
        result = {
            'original_text': text,
            'steps_run': steps.copy(),
            'available_steps': self.available_steps.copy()
        }
        
        # Run each step in sequence, passing result between steps
        for step in steps:
            print(f"🔄 Running step: {step}")
            result = self.step_functions[step](result)
        
        print(f"✅ Completed {len(steps)} steps: {steps}")
        return result
    
    def _step_extract_terms(self, result):
        """Step 1: Extract medical terms from text."""
        text = result['original_text']
        found_terms = self.extract_medical_terms(text)
        
        result.update({
            'found_terms': found_terms,
            'terms_count': len(found_terms)
        })
        
        print(f"   📋 Found {len(found_terms)} medical terms")
        return result
    
    
    def _step_generate_explanation(self, result):
        """Step 2: Generate explanation using model."""
        text = result['original_text']
        if 'determined_audience' not in result:
            raise ValueError("Audience not determined. Please run 'determine_audience' step before 'generate'.")
        
        audience = result['determined_audience']
        # Use found terms if extract step was run, otherwise extract them now
        if 'found_terms' in result:
            found_terms = result['found_terms']
        else:
            print("   ⚠️  Terms not extracted yet, extracting now...")
            found_terms = self.extract_medical_terms(text)
            result['found_terms'] = found_terms
        
        # Create prompt and generate
        prompt = self.create_prompt(text, found_terms, audience)
        raw_output = self._generate_raw_response(prompt)
        
        result.update({
            'prompt': prompt,
            'raw_output': raw_output,
            'prompt_length': len(prompt),
        })
        
        print(f"   🤖 Generated {len(raw_output)} character response")
        return result
    
    
    def _step_clean_response(self, result):
        """Step 3: Clean the raw model output."""
        
        # Check if we have raw output to clean
        if 'raw_output' not in result:
            raise ValueError("No raw_output found. Must run 'generate' step before 'clean' step.")
        
        raw_output = result['raw_output']
        
        cleaned_output = self.model_clean_response(raw_output, result['determined_audience'])
        
        result.update({
            'cleaned_output': cleaned_output,
            'cleaned_length': len(cleaned_output)
        })
        
        print(f"   🧹 Cleaned to {len(cleaned_output)} characters")
        return result
    
    def extract_medical_terms(self, text):
        """
        Enhanced medical term extraction to catch more terms from complex texts.
        """
        found_terms = {}
        
        # Strategy 1: Single words (improved regex for medical terms)
        words = re.findall(r'\b[A-Za-z]+(?:[-\'][A-Za-z]+)*\b', text)
        for word in words:
            definition = self.find_term_in_dict(word)
            if definition:
                found_terms[word] = definition
        
        # Strategy 2: Multi-word terms (expanded range for complex medical phrases)
        for n in range(2, 6):  # Increased from 5 to 6 for longer medical phrases
            n_grams = self.get_n_grams(text, n)
            for phrase in n_grams:
                definition = self.find_term_in_dict(phrase)
                if definition:
                    found_terms[phrase] = definition
        
        # Strategy 3: Medical abbreviations (enhanced pattern)
        abbreviations = re.findall(r'\b[A-Z]{2,8}\b', text)  # Increased from 6 to 8
        for abbrev in abbreviations:
            definition = self.find_term_in_dict(abbrev)
            if definition:
                found_terms[abbrev] = definition
        
        # Strategy 4: Medical procedures and conditions with specific patterns
        medical_patterns = [
            r'\b\w+oscopy\b',          # bronchoscopy, endoscopy, etc.
            r'\b\w+ectomy\b',          # appendectomy, etc.
            r'\b\w+itis\b',            # bronchitis, arthritis, etc.
            r'\b\w+osis\b',            # fibrosis, stenosis, etc.
            r'\b\w+emia\b',            # anemia, septicemia, etc.
            r'\b\w+pathy\b',           # myopathy, neuropathy, etc.
            r'\b\w+malacia\b',         # tracheomalacia, etc.
        ]
        
        for pattern in medical_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                definition = self.find_term_in_dict(match)
                if definition:
                    found_terms[match] = definition
        
        # Strategy 5: Medication names (often end in specific suffixes)
        medication_patterns = [
            r'\b\w+cillin\b',          # penicillin, amoxicillin, etc.
            r'\b\w+mycin\b',           # streptomycin, etc.
            r'\b\w+floxacin\b',        # levofloxacin, ciprofloxacin, etc.
            r'\b\w+sone\b',            # prednisone, cortisone, etc.
            r'\b\w+pam\b',             # lorazepam, etc.
        ]
        
        for pattern in medication_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                definition = self.find_term_in_dict(match)
                if definition:
                    found_terms[match] = definition
        
        return found_terms
    
    
    def get_n_grams(self, text, n):
        """Generate n-grams from text."""
        words = re.findall(r'\b[A-Za-z]+\b', text.lower())
        n_grams = []
        for i in range(len(words) - n + 1):
            phrase = ' '.join(words[i:i+n])
            n_grams.append(phrase)
        return n_grams
    
    def find_term_in_dict(self, term):
        """Find term in medical dictionary."""
        search_formats = [
            term, term.lower(), term.upper(), term.title(), term.capitalize()
        ]
        
        for search_term in search_formats:
            if search_term in self.meddict:
                return self.meddict[search_term]
        
        # Partial matching
        for key in self.meddict.keys():
            if key.lower() == term.lower():
                return self.meddict[key]
        
        return None
    
    def create_prompt(self, original_text, found_terms, determined_audience):
        """
        Create the exact prompt that will be sent to the model.
        Enhanced with better instructions for patient-focused explanations.
        """
        
        # Format found terms for inclusion in prompt
        terms_section = ""
        if found_terms:
            terms_section = "Found medical terms and their definitions:\n"
            for term, definition in found_terms.items():
                terms_section += f"- {term}: {definition}\n"
        else:
            terms_section = "No medical terms found in dictionary.\n"
        
        if determined_audience == 'family':
            audience_instruction = "The determined audience for this letter is the **patient's family**. You must address them directly as 'you' and refer to the patient in the third person (e.g., 'your loved one,' 'he/she')."
        else: # patient
            audience_instruction = "The determined audience for this letter is the **patient**. You must address them directly as 'you' throughout the entire letter."

        prompt = f"""
    <s>[INST] You are a medical doctor and educator who writes clear, caring, and situationally appropriate medical explanations. Your persona is consistently that of a medical professional.

    Original medical text:
    "{original_text}"

    {terms_section}

    ### Your Task:
    Your primary goal is to generate a single, complete, and polished letter.

    1.  **Audience Directive:** {audience_instruction} **This is a strict rule.** Never mix the audience.
    2.  Explain all medical content in simple, everyday language, defining every term clearly.
    3.  Organize the letter logically and strive for genuine, varied language. Avoid overusing clichés.
    4.  Maintain your professional persona as a doctor. Use "we" for the medical team. Never use personal terms like "my mother."

    ### Strict Output Formatting:
    * Produce ONLY the letter text.
    * NO meta-commentary, notes, or alternative versions.
    * NO technical markers (`<s>`, `</INST>`, etc.).
    * Start with a natural, professional salutation.

    Please provide the patient-focused explanation now:

    [/INST]
    """
        return prompt
    
    def _generate_raw_response(self, prompt):
        """
        Generate raw response from model (separated from cleaning).
        """
        try:
            # Tokenize input
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.model.device)
            input_length = inputs["input_ids"].shape[1]
            
            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    **self.generation_config
                )
            
            # Extract only the NEW tokens (response), not the input prompt
            response_tokens = outputs[0][input_length:]  # Skip input tokens
            raw_output = self.tokenizer.decode(response_tokens, skip_special_tokens=False)
            
            return raw_output
            
        except Exception as e:
            return f"Error generating response: {str(e)}"
    
    def model_clean_response(self, raw_response, aduience):
        """
        Enhanced cleaning with better instructions for patient-focused, structured output.
        """
        
        # If response is very short or error, return as-is
        if len(raw_response.strip()) < 10 or raw_response.startswith("❌"):
            return raw_response
        
        cleaning_prompt = f"""

        
[INST] You are a compassionate and professional editor. Your task is to revise the following draft of a letter from a doctor to a patient's family.

**Draft Letter:**
"{raw_response}"

**Your Instructions:**
1.  **Fix Audience Inconsistency:** Ensure the entire letter is addressed consistently to a single audience. In this case, the audience is the {aduience}.
2.  **Improve Tone and Word Choice:** Revise the text to make it sound more genuine and less formulaic. Replace clichés (like "despite our best efforts," "with heavy hearts") with more heartfelt and personal language where appropriate.
3.  **Ensure Professional Persona:** Correct any text that breaks the doctor's persona (e.g., remove any personal anecdotes or incorrect points of view).
4.  **Final Output:** Provide only the final, polished letter. Do not include any notes or explanations about your edits.

Please provide the revised letter now:
[/INST]
"""
        
        try:
            inputs = self.tokenizer(cleaning_prompt, return_tensors="pt", truncation=True).to(self.model.device)
            input_length = inputs["input_ids"].shape[1]
            
            # Use enhanced generation config for cleaning
            cleaning_config = {
                "max_new_tokens": 600,  # Increased for more complete cleaning
                "temperature": 0.1,
                "top_p": 0.9,
                "repetition_penalty": 1.1,
                "do_sample": True,
                "pad_token_id": self.tokenizer.eos_token_id,
                "eos_token_id": self.tokenizer.eos_token_id,
                "early_stopping": True,
            }
            
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    **cleaning_config
                )
            
            # Extract only the cleaned response
            response_tokens = outputs[0][input_length:]
            cleaned = self.tokenizer.decode(response_tokens, skip_special_tokens=True)
            

            
            return cleaned
            
        except Exception as e:
     
            return raw_response
    

In [40]:
explainer = MedicalTextExplainer(model, tokenizer, meddict)

In [41]:
df.iloc[1].TEXT

'Admission Date:  [**2119-5-4**]              Discharge Date:   [**2119-5-25**]\n\n\nService: CARDIOTHORACIC\n\nAllergies:\nAmlodipine\n\nAttending:[**Last Name (NamePattern1) 1561**]\nChief Complaint:\n81 yo F smoker w/ COPD, severe TBM, s/p tracheobronchoplasty [**5-5**]\ns/p perc trach [**5-13**]\n\nMajor Surgical or Invasive Procedure:\nbronchoscopy 3/31,4/2,3,[**6-12**], [**5-17**], [**5-19**]\ns/p trachealplasty [**5-5**]\npercutaneous tracheostomy [**5-13**] after failed extubation\ndown size trach on [**5-25**] to size 6 cuffless\n\n\nHistory of Present Illness:\nThis 81 year old woman has a history of COPD. Over the past five\n\nyears she has had progressive difficulties with her breathing.\nIn\n[**2118-6-4**] she was admitted to [**Hospital1 18**] for respiratory failure\ndue\nto a COPD exacerbation. Due to persistent hypoxemia, she\nrequired\nintubation and a eventual bronchoscopy on [**2118-6-9**] revealed marked\n\nnarrowing of the airways on expiration consistent with\ntr

In [42]:
allres=[]

In [43]:
df.columns

Index(['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CATEGORY',
       'DESCRIPTION', 'TEXT'],
      dtype='object')

In [44]:
for index, row in df.iterrows():
    text=row['TEXT'] 
    result = explainer.explain_medical_text(text, steps=['determine_audience','extract', 'generate','clean'])
    allres.append(result)

🔄 Running step: determine_audience




   🎯 Determined audience: 'patient'
🔄 Running step: extract
   📋 Found 131 medical terms
🔄 Running step: generate




   🤖 Generated 3069 character response
🔄 Running step: clean
   🧹 Cleaned to 3054 characters
✅ Completed 4 steps: ['determine_audience', 'extract', 'generate', 'clean']
🔄 Running step: determine_audience
   🎯 Determined audience: 'patient'
🔄 Running step: extract
   📋 Found 111 medical terms
🔄 Running step: generate
   🤖 Generated 3566 character response
🔄 Running step: clean
   🧹 Cleaned to 2874 characters
✅ Completed 4 steps: ['determine_audience', 'extract', 'generate', 'clean']
🔄 Running step: determine_audience
   🎯 Determined audience: 'patient'
🔄 Running step: extract
   📋 Found 153 medical terms
🔄 Running step: generate
   🤖 Generated 4641 character response
🔄 Running step: clean
   🧹 Cleaned to 3157 characters
✅ Completed 4 steps: ['determine_audience', 'extract', 'generate', 'clean']
🔄 Running step: determine_audience
   🎯 Determined audience: 'family'
🔄 Running step: extract
   📋 Found 167 medical terms
🔄 Running step: generate
   🤖 Generated 2897 character response
🔄 Runnin

In [45]:
results_df = pd.DataFrame(allres)
results_df.columns

Index(['original_text', 'steps_run', 'available_steps', 'determined_audience',
       'found_terms', 'terms_count', 'prompt', 'raw_output', 'prompt_length',
       'cleaned_output', 'cleaned_length'],
      dtype='object')

In [46]:


base_filename = 'output'
extension = '.csv'
output_filename = base_filename + extension
counter = 1

# Check if the file exists and find a new name if it does
while os.path.exists(output_filename):
    output_filename = f"{base_filename}_{counter}{extension}"
    counter += 1

# Save the DataFrame to the new, unique filename
results_df.to_csv(output_filename, index=False)

print(f"DataFrame saved to '{output_filename}'")

DataFrame saved to 'output_4.csv'


In [47]:
results_df.to_csv('output.csv')