In [None]:
import torch
torch._dynamo.config.disable = True

In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import os
import gc
import psutil
import time
import json
import re

# --- Configuration ---
INPUT_CSV_PATH = "training_dataset_distilled_cleaned.csv"
OUTPUT_CSV_PATH = "Datasets-note-validation.csv"
BASE_MODEL = "google/medgemma-27b-text-it"
BATCH_SAVE_SIZE = 20  # 每處理多少筆資料就保存一次
MAX_ROWS_TO_PROCESS = 2000  # Total: 14541
MEMORY_CLEANUP_FREQUENCY = 10  # 每處理多少筆就進行記憶體清理
OOM_RETRY_ATTEMPTS = 3  # OOM時的重試次數
OOM_COOLDOWN_SECONDS = 5  # OOM後的冷卻時間

# --- JSON Validation Function ---
def validate_json_format(text: str) -> str:
    """
    驗證JSON格式是否符合要求，如果不符合則在開頭添加ERROR
    
    Required format:
    {
        "relevant_text": [
            "string1",
            "string2",
            ...
        ]
    }
    """
    # 移除前後空白
    text = text.strip()
    
    # 檢查是否以ERROR開頭，如果是則直接返回
    if text.startswith("ERROR"):
        return text
    
    try:
        # 嘗試解析JSON
        parsed_json = json.loads(text)
        
        # 檢查是否為字典
        if not isinstance(parsed_json, dict):
            return f"ERROR: Root must be a JSON object, got {type(parsed_json).__name__}: {text[:100]}..."
        
        # 檢查是否只有一個key且為"relevant_text"
        if len(parsed_json.keys()) != 1:
            return f"ERROR: JSON must have exactly one key 'relevant_text', got {len(parsed_json.keys())} keys: {text[:100]}..."
        
        if "relevant_text" not in parsed_json:
            return f"ERROR: Missing required key 'relevant_text': {text[:100]}..."
        
        # 檢查relevant_text的值是否為list
        relevant_text = parsed_json["relevant_text"]
        if not isinstance(relevant_text, list):
            return f"ERROR: 'relevant_text' must be an array, got {type(relevant_text).__name__}: {text[:100]}..."
        
        # 檢查list中的所有元素是否為字符串
        for i, item in enumerate(relevant_text):
            if not isinstance(item, str):
                return f"ERROR: All items in 'relevant_text' must be strings, item {i} is {type(item).__name__}: {text[:100]}..."
        
        # 格式正確，返回原文
        return text
        
    except json.JSONDecodeError as e:
        return f"ERROR: Invalid JSON format - {str(e)}: {text[:100]}..."
    except Exception as e:
        return f"ERROR: Unexpected validation error - {str(e)}: {text[:100]}..."

def clean_model_output(raw_output: str) -> str:
    """
    清理模型輸出，嘗試提取有效的JSON部分
    """
    # 移除前後空白
    raw_output = raw_output.strip()
    
    # 如果已經是ERROR開頭，直接返回
    if raw_output.startswith("ERROR"):
        return raw_output
        
    # 移除 markdown 代碼塊標記
    raw_output = raw_output.replace('```json', '').replace('```', '').strip()
    
    # 尋找第一個 { 和最後一個 }
    first_brace = raw_output.find('{')
    last_brace = raw_output.rfind('}')
    
    if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
        # 提取JSON部分
        json_part = raw_output[first_brace:last_brace + 1]
        return json_part
    
    # 如果找不到有效的JSON結構，返回ERROR
    return f"ERROR: No valid JSON structure found in model output: {raw_output[:100]}..."

# --- Memory Management Utilities ---
def get_gpu_memory_usage():
    """獲取GPU記憶體使用情況"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3, torch.cuda.memory_reserved() / 1024**3
    return 0, 0

def aggressive_memory_cleanup():
    """積極的記憶體清理"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

def print_memory_stats(stage: str):
    """打印記憶體使用統計"""
    if torch.cuda.is_available():
        allocated, reserved = get_gpu_memory_usage()
        # print(f"[{stage}] GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
    
    # CPU記憶體
    process = psutil.Process()
    # cpu_memory = process.memory_info().rss / 1024**3
    # print(f"[{stage}] CPU Memory: {cpu_memory:.2f}GB")

# --- 1. Model and Tokenizer Loading ---
print(f"Loading base model: {BASE_MODEL}...")
try:
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL, 
        torch_dtype=torch.bfloat16, 
        attn_implementation='eager', 
        device_map="auto",
        # low_cpu_mem_usage=True  # 減少CPU記憶體使用
    )
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
    
    # 設置padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print("Model and tokenizer loaded successfully.")
    print_memory_stats("After Model Loading")
    
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure you have enough VRAM and the necessary libraries (accelerate, bitsandbytes) are installed.")
    exit()

# --- 2. Distillation Prompt Template ---
def get_distillation_prompt(raw_xml: str, human_summary: str) -> str:
    """
    Creates the final prompt for the MedGemma model, focusing on editing for factuality
    while preserving human-like style.
    """
    return """<start_of_turn>user
You are a **Clinical Data Traceability Auditor**. Your sole function is to pinpoint the **direct and explicit source text** from a `<PatientEncounter>` document that directly corresponds to specific claims made in the `<Discharge_Summary>`.

You will be provided with two documents: `<PatientEncounter>` and `<Discharge_Summary>`.

**Core Principle: Directness and Verifiability**

Your primary goal is to establish a clear, verifiable link. For each piece of information in the summary, find its *source* in the encounter.

*   **Example:**
    *   If the summary says "bilateral pneumonia", your target is the exact text confirming this, like `X光顯示肺炎`.
    *   If the summary mentions "leukocytosis (27100)", your targets are `Leukocytes [#/volume] in Blood` and `27100`.

**Output Requirements:**

1.  **Format:** Your output MUST be a single JSON object.
2.  **Structure:** The JSON object must contain one key: `"relevant_text"`. The value for this key must be an array of strings.
3.  **Content:** Each string in the array must be an **exact, verbatim quote** from the `<PatientEncounter>` that serves as the **direct source** for a claim in the summary.
4.  **Special Handling for Lab Results:** For laboratory results found in `<Item>` tags, you must extract the `name` attribute's value as one string, and the tag's inner text (the lab value) as the *next* string in the array. **Do not** combine them.
    *   **Correct:** `"Lactate [Mass/volume] in Serum or Plasma"`, `"93.6"`
    *   **Incorrect:** `"Lactate [Mass/volume] in Serum or Plasma: 93.6"`
5.  **Strict Formatting:** Your entire response MUST begin directly with JSON format `{` and end with `}`. Do not include `json` markers, code fences (```), or any explanatory text.
"""+f"""

{raw_xml}

<Discharge_Summary>
{human_summary}
</Discharge_Summary>

<end_of_turn>
<start_of_turn>model
"""
# *   **AVOID extracting general background information**: The patient's detailed cancer history or past medical conditions should only be extracted if the *specific condition* is explicitly mentioned in the `<Discharge_Summary>`.
    
# --- 3. Core Distillation Function with Memory Management ---
def distill_summary(input_text: str, output_text: str, model, tokenizer) -> str:
    """
    Generates a distilled summary using the provided model with improved memory management.
    """
    if not isinstance(input_text, str) or not isinstance(output_text, str):
        return "ERROR: Invalid input type."
    
    for attempt in range(OOM_RETRY_ATTEMPTS):
        try:
            # 在每次生成前清理記憶體
            if attempt > 0:
                print(f"  Retry attempt {attempt + 1}/{OOM_RETRY_ATTEMPTS}")
                aggressive_memory_cleanup()
                time.sleep(OOM_COOLDOWN_SECONDS)
            
            prompt = get_distillation_prompt(input_text, output_text)
            
            # 使用no_grad來節省記憶體
            with torch.no_grad():
                inputs = tokenizer(
                    prompt, 
                    return_tensors="pt",
                    truncation=True,
                    # max_length=4096,  # 限制輸入長度
                    padding=False
                ).to(model.device)
                
                input_ids = inputs["input_ids"]
                
                # 調整生成參數以節省記憶體
                outputs = model.generate(
                    input_ids, 
                    max_new_tokens=4096,  # 減少最大生成長度
                    temperature=0.5,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id,
                    # use_cache=True,  # 使用KV cache
                    # early_stopping=True  # 早期停止
                )

                input_length = input_ids.shape[1]
                newly_generated_tokens = outputs[0, input_length:]
                raw_distilled_text = tokenizer.decode(newly_generated_tokens, skip_special_tokens=True)
                
                # 立即清理這次推理的記憶體
                del inputs, input_ids, outputs, newly_generated_tokens
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
                # 清理和驗證輸出格式
                cleaned_output = clean_model_output(raw_distilled_text.strip())
                validated_output = validate_json_format(cleaned_output)
                
                return validated_output
                
        except torch.cuda.OutOfMemoryError as e:
            print(f"  CUDA OOM Error on attempt {attempt + 1}: {e}")
            aggressive_memory_cleanup()
            
            if attempt == OOM_RETRY_ATTEMPTS - 1:
                return f"ERROR: CUDA OOM after {OOM_RETRY_ATTEMPTS} attempts - {str(e)[:100]}..."
            
        except Exception as e:
            print(f"  Unexpected error on attempt {attempt + 1}: {e}")
            if attempt == OOM_RETRY_ATTEMPTS - 1:
                return f"ERROR: {str(e)[:200]}..."

# --- 4. Batch Save Function ---
def save_batch_to_csv(distilled_data: list, output_path: str, is_first_batch: bool = False):
    """
    保存批次資料到CSV檔案
    """
    if not distilled_data:
        return
    
    batch_df = pd.DataFrame(distilled_data)
    
    # 如果是第一批次或檔案不存在，寫入標題列
    write_header = is_first_batch or not os.path.exists(output_path)
    
    # 使用 'a' 模式附加到現有檔案，或 'w' 模式建立新檔案
    mode = 'w' if is_first_batch else 'a'
    
    batch_df.to_csv(
        output_path, 
        mode=mode,
        header=write_header,
        index=False, 
        encoding='utf-8'
    )
    
    print(f"✓ 已保存 {len(distilled_data)} 筆資料到 {output_path}")
    
    # 清理DataFrame記憶體
    del batch_df
    gc.collect()

# --- 5. Resume Function ---
def get_resume_index(output_path: str) -> int:
    """
    檢查輸出檔案，確定從哪個索引開始繼續處理
    """
    if not os.path.exists(output_path):
        return 0
    
    try:
        existing_df = pd.read_csv(output_path)
        resume_index = len(existing_df)
        print(f"發現現有進度檔案，將從第 {resume_index + 1} 筆開始繼續處理")
        del existing_df  # 清理記憶體
        gc.collect()
        return resume_index
    except Exception as e:
        print(f"讀取現有檔案時發生錯誤: {e}")
        print("將從頭開始處理")
        return 0

# --- 6. Main Execution Block ---
if __name__ == "__main__":
    print(f"Reading data from {INPUT_CSV_PATH}...")
    try:
        df = pd.read_csv(INPUT_CSV_PATH, nrows=MAX_ROWS_TO_PROCESS)
        
        # Ensure the required columns exist
        if 'input_text' not in df.columns or 'distilled_output' not in df.columns:
            print("Error: CSV must contain 'input_text' and 'distilled_output' columns.")
            exit()
            
        print(f"Loaded {len(df)} rows for processing")
        
    except FileNotFoundError:
        print(f"Error: The file {INPUT_CSV_PATH} was not found.")
        exit()

    # 檢查是否有現有進度可以繼續
    start_index = get_resume_index(OUTPUT_CSV_PATH)
    
    # 如果需要從中間開始，跳過已處理的資料
    if start_index > 0:
        df = df.iloc[start_index:].reset_index(drop=True)
        print(f"跳過前 {start_index} 筆已處理的資料，剩餘 {len(df)} 筆待處理")

    # Create a list to hold the current batch results
    current_batch = []
    total_processed = start_index
    error_count = 0
    json_validation_error_count = 0

    print("Starting distillation process...")
    print_memory_stats("Process Start")
    
    # Using tqdm for a progress bar
    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Distilling Summaries"):
        input_text = row['input_text']
        output_text = row['distilled_output']
        
        # 定期進行記憶體清理
        if total_processed % MEMORY_CLEANUP_FREQUENCY == 0 and total_processed > 0:
            # print(f"\n[記憶體清理] 處理了 {total_processed} 筆資料")
            # print_memory_stats("Before Cleanup")
            aggressive_memory_cleanup()
            # print_memory_stats("After Cleanup")
        
        try:
            # Generate the new, factually grounded summary
            # print(f"\n處理第 {total_processed + 1} 筆資料...")
            distilled_output = distill_summary(input_text, output_text, model, tokenizer)
            
            # 檢查是否為錯誤輸出
            if distilled_output.startswith("ERROR:"):
                error_count += 1
                if "JSON" in distilled_output or "format" in distilled_output.lower():
                    json_validation_error_count += 1
                print(f"⚠️  處理失敗: {distilled_output}")
            
            # Store the results in current batch
            current_batch.append({
                'input_text': input_text,
                'distilled_output': output_text,
                'validation': distilled_output
            })
            
            total_processed += 1
            
            # 每處理 BATCH_SAVE_SIZE 筆就保存一次
            if len(current_batch) >= BATCH_SAVE_SIZE:
                print(f"\n=== 保存批次 (已處理 {total_processed} 筆) ===")
                is_first_batch = (total_processed == BATCH_SAVE_SIZE and start_index == 0)
                save_batch_to_csv(current_batch, OUTPUT_CSV_PATH, is_first_batch)
                current_batch = []  # 清空當前批次
                aggressive_memory_cleanup()  # 批次保存後清理記憶體
                print_memory_stats("After Batch Save")
                print(f"總錯誤數量: {error_count} (其中JSON格式錯誤: {json_validation_error_count})")

        except Exception as e:
            print(f"處理第 {total_processed + 1} 筆資料時發生未預期錯誤: {e}")
            error_count += 1
            
            # Add a placeholder for the failed row
            current_batch.append({
                'input_text': input_text,
                'distilled_output': output_text,
                'validation': f"ERROR: Unexpected - {e}"
            })
            total_processed += 1
            
            # 即使出錯也要保存批次
            if len(current_batch) >= BATCH_SAVE_SIZE:
                is_first_batch = (total_processed == BATCH_SAVE_SIZE and start_index == 0)
                save_batch_to_csv(current_batch, OUTPUT_CSV_PATH, is_first_batch)
                current_batch = []
                aggressive_memory_cleanup()

    # 保存剩餘的資料（最後一批可能不足 BATCH_SAVE_SIZE 筆）
    if current_batch:
        print(f"\n=== 保存最後批次 ({len(current_batch)} 筆) ===")
        is_first_batch = (start_index == 0 and total_processed <= BATCH_SAVE_SIZE)
        save_batch_to_csv(current_batch, OUTPUT_CSV_PATH, is_first_batch)
        aggressive_memory_cleanup()
    
    print("\n" + "="*50)
    print("Distillation process complete.")
    print(f"Final dataset saved to {OUTPUT_CSV_PATH}")
    print(f"Total processed records: {total_processed}")
    print(f"Total errors: {error_count}")
    print(f"JSON validation errors: {json_validation_error_count}")
    print(f"Success rate: {((total_processed - error_count) / total_processed * 100):.2f}%")
    print(f"JSON validation success rate: {((total_processed - json_validation_error_count) / total_processed * 100):.2f}%")
    print_memory_stats("Process Complete")

In [None]:
import pandas as pd

# --- 設定 ---
# 你剛剛生成的蒸餾後檔案路徑
DISTILLED_CSV_PATH = "kaggle_20250801_validation.csv"
# 你想查看的資料筆數
NUM_RECORDS_TO_SHOW = 12

# --- 主程式 ---
print(f"正在讀取檔案: {DISTILLED_CSV_PATH}")

try:
    # 讀取 CSV 檔案
    df = pd.read_csv(DISTILLED_CSV_PATH)

    print(f"檔案讀取成功，共 {len(df)} 筆資料。")
    print(f"以下顯示前 {NUM_RECORDS_TO_SHOW} 筆資料的詳細對比：\n")

    # 使用 .head() 來選取前 N 筆資料並進行迭代
    for index, row in df.head(NUM_RECORDS_TO_SHOW).iterrows():
        print(f"==================  資料索引 {index}  ==================")
        # print(row['input_text'])
        
        # 印出原始的、可能帶有幻覺的摘要
        # print("\n【原始摘要 (Original Summary)】:")
        # print(row['distilled_output'])
        
        # 印出經過 MedGemma 蒸餾後的、基於事實的摘要
        print("\n【蒸餾後摘要 (Distilled Summary)】:")
        print(row['validation'])
        
        print("\n" + "=" * 60 + "\n")

except FileNotFoundError:
    print(f"錯誤：找不到檔案 '{DISTILLED_CSV_PATH}'。請確認檔案名稱和路徑是否正確。")
except KeyError as e:
    print(f"錯誤：CSV 檔案中找不到欄位 {e}。請確認你的 CSV 檔案包含 'original_output' 和 'distilled_output' 欄位。")
except Exception as e:
    print(f"發生未知錯誤: {e}")

In [None]:
#!/usr/bin/env python3
"""
Dataset Cleaner for Emergency Discharge Summaries
Removes rows where distilled_output starts with "ERROR"
"""

import json
import pandas as pd
import os
from typing import Union, List, Dict

def clean_json_dataset(input_file: str, output_file: str = None) -> int:
    """
    Clean JSON dataset by removing rows with ERROR in distilled_output
    
    Args:
        input_file: Path to input JSON file
        output_file: Path to output JSON file (optional, defaults to input_file_cleaned.json)
    
    Returns:
        Number of rows removed
    """
    if output_file is None:
        name, ext = os.path.splitext(input_file)
        output_file = f"{name}_cleaned{ext}"
    
    # Load the dataset
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    original_count = len(data)
    print(f"Original dataset size: {original_count} rows")
    
    # Filter out rows where distilled_output starts with "ERROR"
    cleaned_data = []
    removed_count = 0
    
    for item in data:
        distilled_output = item.get('validation', '')
        
        # Check if distilled_output starts with "ERROR" (case-insensitive)
        if isinstance(distilled_output, str) and distilled_output.strip().upper().startswith('ERROR'):
            removed_count += 1
            print(f"Removing row with ERROR: {distilled_output[:100]}...")
        else:
            cleaned_data.append(item)
    
    # Save cleaned dataset
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(cleaned_data, f, indent=2, ensure_ascii=False)
    
    print(f"Cleaned dataset size: {len(cleaned_data)} rows")
    print(f"Removed {removed_count} rows with ERROR")
    print(f"Cleaned dataset saved to: {output_file}")
    
    return removed_count

def clean_csv_dataset(input_file: str, output_file: str = None) -> int:
    """
    Clean CSV dataset by removing rows with ERROR in distilled_output
    
    Args:
        input_file: Path to input CSV file
        output_file: Path to output CSV file (optional, defaults to input_file_cleaned.csv)
    
    Returns:
        Number of rows removed
    """
    if output_file is None:
        name, ext = os.path.splitext(input_file)
        output_file = f"{name}_cleaned{ext}"
    
    # Load the dataset
    df = pd.read_csv(input_file)
    original_count = len(df)
    print(f"Original dataset size: {original_count} rows")
    
    # Check if distilled_output column exists
    if 'validation' not in df.columns:
        print("Warning: 'validation' column not found!")
        print(f"Available columns: {list(df.columns)}")
        return 0
    
    # Find rows where distilled_output starts with "ERROR"
    error_mask = df['validation'].astype(str).str.strip().str.upper().str.startswith('ERROR')
    removed_count = error_mask.sum()
    
    # Show some examples of what's being removed
    if removed_count > 0:
        print(f"\nExamples of rows being removed:")
        error_rows = df[error_mask]['validation'].head(3)
        for i, row in enumerate(error_rows):
            print(f"  {i+1}. {str(row)[:100]}...")
    
    # Remove rows with ERROR
    cleaned_df = df[~error_mask].copy()
    
    # Save cleaned dataset
    cleaned_df.to_csv(output_file, index=False)
    
    print(f"\nCleaned dataset size: {len(cleaned_df)} rows")
    print(f"Removed {removed_count} rows with ERROR")
    print(f"Cleaned dataset saved to: {output_file}")
    
    return removed_count

def clean_dataset_auto(input_file: str, output_file: str = None) -> int:
    """
    Automatically detect file format and clean dataset
    
    Args:
        input_file: Path to input file (JSON or CSV)
        output_file: Path to output file (optional)
    
    Returns:
        Number of rows removed
    """
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    
    # Detect file format
    _, ext = os.path.splitext(input_file)
    ext = ext.lower()
    
    if ext == '.json':
        return clean_json_dataset(input_file, output_file)
    elif ext == '.csv':
        return clean_csv_dataset(input_file, output_file)
    else:
        raise ValueError(f"Unsupported file format: {ext}. Please use .json or .csv files.")

def preview_errors(input_file: str, max_preview: int = 100) -> None:
    """
    Preview rows that would be removed without actually cleaning
    
    Args:
        input_file: Path to input file
        max_preview: Maximum number of error rows to preview
    """
    _, ext = os.path.splitext(input_file)
    ext = ext.lower()
    
    print(f"Previewing ERROR rows in {input_file}:")
    
    if ext == '.json':
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        error_count = 0
        for i, item in enumerate(data):
            distilled_output = item.get('validation', '')
            if isinstance(distilled_output, str) and distilled_output.strip().upper().startswith('ERROR'):
                error_count += 1
                if error_count <= max_preview:
                    print(f"  Row {i}: {distilled_output[:100]}...")
        
        print(f"\nTotal ERROR rows found: {error_count}")
        
    elif ext == '.csv':
        df = pd.read_csv(input_file)
        if 'validation' not in df.columns:
            print("Warning: 'validation' column not found!")
            return
        
        error_mask = df['validation'].astype(str).str.strip().str.upper().str.startswith('ERROR')
        error_rows = df[error_mask]
        
        print(f"Total ERROR rows found: {len(error_rows)}")
        
        for i, (idx, row) in enumerate(error_rows.head(max_preview).iterrows()):
            print(f"  Row {idx}: {str(row['validation'])[:100]}...")

# Example usage
if __name__ == "__main__":
    # Example usage - replace with your file path
    input_file = "kaggle_20250801_validation.csv"  # or "your_dataset.csv"
    
    try:
        # Preview what will be removed (optional)
        print("=== PREVIEW MODE ===")
        preview_errors(input_file)
        
        print("\n" + "="*50)
        
        # Clean the dataset
        print("=== CLEANING DATASET ===")
        removed_count = clean_dataset_auto(input_file)
        
        if removed_count > 0:
            print(f"\n✅ Successfully cleaned dataset! Removed {removed_count} ERROR rows.")
        else:
            print("\n✅ Dataset is already clean - no ERROR rows found.")
            
    except FileNotFoundError:
        print("❌ Please update the 'input_file' variable with your actual file path.")
    except Exception as e:
        print(f"❌ Error: {str(e)}")

# Quick function for direct use
def quick_clean(file_path: str) -> None:
    """
    Quick one-line function to clean a dataset
    
    Usage:
        quick_clean("my_dataset.json")
    """
    try:
        removed = clean_dataset_auto(file_path)
        print(f"✅ Quick clean completed! Removed {removed} ERROR rows.")
    except Exception as e:
        print(f"❌ Quick clean failed: {str(e)}")