In [None]:
import pandas as pd
import json
import time
import os
import re
from tqdm import tqdm
from openai import OpenAI, RateLimitError, BadRequestError, APIError
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import sys

# ==========================================
# 1. CONFIGURATION
# ==========================================
INPUT_CSV = "NOTEEVENTS.csv"
OUTPUT_FILE = "mimic_dialogue_soap.jsonl"
CHECKPOINT_FILE = "progress.checkpoint"

MAX_SAMPLES = 100   # set None for full run
CHUNK_SIZE = 1000

TARGET_CATEGORIES = ["Progress Note", "Discharge summary"]

EXTRACTOR_MODEL = "gpt-4.1-mini"
SIMULATOR_MODEL = "gpt-4o-mini"

# Retry configuration
MAX_SOAP_RETRIES = 3
MAX_DIALOGUE_RETRIES = 3

# ==========================================
# 2. CLIENT SETUP
# ==========================================
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# ==========================================
# 3. PROMPTS
# ==========================================
SYSTEM_PROMPT_EXTRACTOR = """
You are an expert Clinical Documentation Improvement (CDI) specialist.

Extract factual clinical information from the note into SOAP format.

Return STRICT JSON with exactly these 4 keys (no others):
- Subjective
- Objective
- Assessment
- Plan

Rules:
- Do NOT summarize or infer beyond the text
- Copy facts verbatim where possible
- If a section is missing, return "Not documented"
- Preserve medical abbreviations and terminology
- Include vital signs and lab values with units
- Each value must be a string (not array or object)
"""

SYSTEM_PROMPT_SIMULATOR = """
You are generating a realistic doctor‚Äìpatient conversation for medical training.

Rules:
1. Patient describes symptoms from Subjective section
2. Doctor asks clarifying questions based on Objective findings
3. Doctor explains Assessment and Plan at the end
4. Use 6-12 conversation turns total
5. EVERY line must start with exactly "Doctor:" or "Patient:" followed by a space
6. Make dialogue natural with pauses, acknowledgments
7. End with doctor summarizing next steps clearly

Example format:
Doctor: Good morning, how are you feeling today?
Patient: I've been having chest pain since yesterday.
Doctor: Can you describe the pain for me?
... (more turns)
"""

# ==========================================
# 4. HELPER FUNCTIONS
# ==========================================
def clean_clinical_text(text):
    """Clean and sanitize clinical text."""
    if not isinstance(text, str):
        return ""

    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text)

    # Remove common MIMIC redaction patterns
    text = re.sub(r'\[\*\*.*?\*\*\]', '[REDACTED]', text)

    return text.strip()

def validate_soap_structure(soap_dict):
    """Validate SOAP JSON structure and content."""
    if not isinstance(soap_dict, dict):
        return False, "Not a dictionary"

    # Check for required keys (case-insensitive)
    required_keys = {"subjective", "objective", "assessment", "plan"}
    soap_keys = {k.lower() for k in soap_dict.keys()}

    missing_keys = required_keys - soap_keys
    if missing_keys:
        return False, f"Missing keys: {missing_keys}"

    # Check each value is a non-empty string
    key_map = {k.lower(): k for k in soap_dict.keys()}
    empty_patterns = ["", "n/a", "none", "not documented", "not available", "na", "[]", "{}"]

    for req_key in required_keys:
        original_key = key_map[req_key]
        value = soap_dict[original_key]

        # Must be string
        if not isinstance(value, str):
            return False, f"{original_key} is not a string"

        value_clean = value.strip().lower()
        if value_clean in empty_patterns:
            return False, f"{original_key} is empty or placeholder"

    # Check total content length
    total_length = sum(len(str(soap_dict[key_map[k]])) for k in required_keys)
    if total_length < 200:
        return False, f"Insufficient content length: {total_length}"

    return True, "Valid"

def validate_dialogue_structure(dialogue_text):
    """Validate dialogue format and structure."""
    if not dialogue_text or not isinstance(dialogue_text, str):
        return False, "No dialogue text"

    lines = dialogue_text.strip().split('\n')
    if len(lines) < 4:  # Minimum 2 turns each
        return False, f"Too few lines: {len(lines)}"

    doctor_count = 0
    patient_count = 0
    malformed_lines = []

    for i, line in enumerate(lines):
        line = line.strip()
        if not line:
            continue

        if line.startswith("Doctor:"):
            doctor_count += 1
            # Check there's content after "Doctor:"
            if len(line) <= len("Doctor:"):
                malformed_lines.append(f"Line {i}: Empty doctor line")
        elif line.startswith("Patient:"):
            patient_count += 1
            if len(line) <= len("Patient:"):
                malformed_lines.append(f"Line {i}: Empty patient line")
        else:
            malformed_lines.append(f"Line {i}: Doesn't start with Doctor:/Patient:")

    if doctor_count < 2:
        return False, f"Insufficient doctor turns: {doctor_count}"
    if patient_count < 2:
        return False, f"Insufficient patient turns: {patient_count}"
    if malformed_lines:
        return False, f"Malformed lines: {malformed_lines[:3]}"

    # Check dialogue has reasonable length
    if len(dialogue_text) < 150:
        return False, f"Dialogue too short: {len(dialogue_text)} chars"

    return True, f"Valid: {doctor_count} doctor, {patient_count} patient turns"

def extract_relevant_text_chunk(text, max_chars=4000):
    """Extract middle portion of text to avoid headers/footers."""
    if len(text) <= max_chars:
        return text

    # Take middle section, avoiding beginning (headers) and end (signatures)
    start_idx = max(0, len(text) // 2 - max_chars // 2)
    return text[start_idx:start_idx + max_chars]

def load_checkpoint():
    """Load processing checkpoint if exists."""
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, "r") as f:
                return int(f.read().strip())
        except:
            return 0
    return 0

def save_checkpoint(processed_count):
    """Save processing checkpoint."""
    with open(CHECKPOINT_FILE, "w") as f:
        f.write(str(processed_count))

# ==========================================
# 5. LLM HELPERS WITH RETRY & VALIDATION
# ==========================================
@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(multiplier=1, min=2, max=30),
    retry=retry_if_exception_type((RateLimitError, APIError)),
    reraise=True
)
def call_llm_with_retry(system_prompt, user_prompt, model, json_mode=False):
    """Call LLM with retry for API errors."""
    try:
        response = client.responses.create(
            model=model,
            input=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0.6,
            response_format={"type": "json_object"} if json_mode else None,
            max_tokens=2000 if json_mode else 1500
        )
        return response.output_text
    except BadRequestError as e:
        # Don't retry on bad requests (content filter, invalid input)
        print(f"‚ùå Bad request (won't retry): {e}")
        return None
    except Exception as e:
        print(f"‚ö†Ô∏è API error: {e}")
        raise  # Let tenacity handle retryable errors

def extract_soap_from_note(note_text, max_retries=MAX_SOAP_RETRIES):
    """Extract SOAP with validation and retry on malformed responses."""
    for attempt in range(max_retries):
        try:
            # Call LLM
            soap_str = call_llm_with_retry(
                SYSTEM_PROMPT_EXTRACTOR,
                f"Extract SOAP from this clinical note:\n\n{note_text}",
                model=EXTRACTOR_MODEL,
                json_mode=True
            )

            if not soap_str:
                if attempt < max_retries - 1:
                    print(f"‚ö†Ô∏è Empty SOAP response, retry {attempt + 1}/{max_retries}")
                    continue
                return None, "Empty response"

            # Parse JSON
            soap = json.loads(soap_str)

            # Validate structure
            is_valid, message = validate_soap_structure(soap)

            if is_valid:
                # Standardize keys
                key_map = {k.lower(): k for k in soap.keys()}
                soap_standard = {
                    "Subjective": soap[key_map["subjective"]],
                    "Objective": soap[key_map["objective"]],
                    "Assessment": soap[key_map["assessment"]],
                    "Plan": soap[key_map["plan"]]
                }
                return soap_standard, None
            else:
                print(f"‚ö†Ô∏è Invalid SOAP (attempt {attempt + 1}/{max_retries}): {message}")

                # If last attempt, try to fix common issues
                if attempt == max_retries - 1:
                    # Attempt to salvage by checking for any usable data
                    key_map = {k.lower(): k for k in soap.keys()}
                    salvageable_keys = {"subjective", "objective", "assessment", "plan"} & set(key_map.keys())
                    if len(salvageable_keys) >= 2:
                        print(f"‚ö†Ô∏è Salvaging partial SOAP with keys: {salvageable_keys}")
                        soap_standard = {}
                        for key in ["Subjective", "Objective", "Assessment", "Plan"]:
                            if key.lower() in key_map:
                                soap_standard[key] = soap[key_map[key.lower()]]
                            else:
                                soap_standard[key] = "Not documented"
                        return soap_standard, "Partially salvaged"

                # Add progressive guidance on retry
                guidance = ""
                if attempt == 1:
                    guidance = " Remember: return JSON with exactly 4 keys: Subjective, Objective, Assessment, Plan."
                elif attempt == 2:
                    guidance = " Each value must be a non-empty string. No extra keys."

                if attempt < max_retries - 1:
                    # Modify prompt slightly to guide better response
                    enhanced_prompt = SYSTEM_PROMPT_EXTRACTOR + guidance
                    soap_str = call_llm_with_retry(
                        enhanced_prompt,
                        f"Extract SOAP from this clinical note. IMPORTANT: Return JSON with exactly 4 keys (Subjective, Objective, Assessment, Plan). Each value must be a non-empty string.\n\n{note_text}",
                        model=EXTRACTOR_MODEL,
                        json_mode=True
                    )

        except json.JSONDecodeError as e:
            print(f"‚ö†Ô∏è JSON decode error (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt < max_retries - 1:
                time.sleep(1)
                continue
            return None, f"JSON decode error: {e}"
        except Exception as e:
            print(f"‚ö†Ô∏è Unexpected error in SOAP extraction (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff
                continue
            return None, f"Unexpected error: {e}"

    return None, f"Failed after {max_retries} attempts"

def generate_dialogue_from_soap(soap_dict, max_retries=MAX_DIALOGUE_RETRIES):
    """Generate dialogue with validation and retry on malformed responses."""
    soap_json_str = json.dumps(soap_dict, indent=2)

    for attempt in range(max_retries):
        try:
            # Call LLM
            dialogue = call_llm_with_retry(
                SYSTEM_PROMPT_SIMULATOR,
                f"Create a doctor-patient dialogue based on these clinical facts:\n{soap_json_str}",
                model=SIMULATOR_MODEL
            )

            if not dialogue:
                if attempt < max_retries - 1:
                    print(f"‚ö†Ô∏è Empty dialogue response, retry {attempt + 1}/{max_retries}")
                    continue
                return None, "Empty response"

            # Validate dialogue structure
            is_valid, message = validate_dialogue_structure(dialogue)

            if is_valid:
                return dialogue, None
            else:
                print(f"‚ö†Ô∏è Invalid dialogue (attempt {attempt + 1}/{max_retries}): {message}")

                # Provide more explicit guidance on retry
                guidance = "\n\nIMPORTANT FORMAT: Every line MUST start with exactly 'Doctor:' or 'Patient:' followed by a space. No other prefixes allowed."

                if attempt < max_retries - 1:
                    enhanced_prompt = SYSTEM_PROMPT_SIMULATOR + guidance
                    dialogue = call_llm_with_retry(
                        enhanced_prompt,
                        f"Create a doctor-patient dialogue based on these clinical facts. CRITICAL: Every line must start with 'Doctor:' or 'Patient:' followed by a space.\n\n{soap_json_str}",
                        model=SIMULATOR_MODEL
                    )

        except Exception as e:
            print(f"‚ö†Ô∏è Unexpected error in dialogue generation (attempt {attempt + 1}/{max_retries}): {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
                continue

    return None, f"Failed after {max_retries} attempts"

# ==========================================
# 6. CORE PIPELINE
# ==========================================
def process_mimic():
    """Main processing pipeline."""

    # Validate input file and columns
    if not os.path.exists(INPUT_CSV):
        print(f"‚ùå {INPUT_CSV} not found")
        return None

    # Check required columns
    try:
        df_sample = pd.read_csv(INPUT_CSV, nrows=1)
        required_cols = ["CATEGORY", "TEXT"]
        if not all(col in df_sample.columns for col in required_cols):
            print(f"‚ùå Missing required columns. Found: {df_sample.columns.tolist()}")
            print(f"   Required: {required_cols}")
            return None
    except Exception as e:
        print(f"‚ùå Error reading CSV: {e}")
        return None

    # Initialize stats
    stats = {
        "total_processed": 0,
        "skipped_short": 0,
        "soap_fail": 0,
        "soap_retry_success": 0,
        "dialogue_fail": 0,
        "dialogue_retry_success": 0,
        "saved": 0,
        "soap_validation_errors": {},
        "dialogue_validation_errors": {}
    }

    # Load checkpoint
    checkpoint = load_checkpoint()
    if checkpoint > 0:
        print(f"üîÑ Resuming from checkpoint: {checkpoint} samples processed")

    # Process data
    try:
        with open(OUTPUT_FILE, "a" if checkpoint > 0 else "w", encoding="utf-8") as fout:
            chunk_iterator = pd.read_csv(INPUT_CSV, chunksize=CHUNK_SIZE)

            # Skip approximate number of chunks based on checkpoint
            # (This is approximate since we don't know exact distribution)
            chunks_to_skip = checkpoint // (CHUNK_SIZE // 20)  # Conservative estimate
            for _ in range(chunks_to_skip):
                next(chunk_iterator, None)

            for chunk_idx, chunk in enumerate(chunk_iterator):
                print(f"\nüìÇ Processing chunk {chunk_idx + 1 + chunks_to_skip}...")

                # Filter by category
                df = chunk[chunk["CATEGORY"].isin(TARGET_CATEGORIES)].copy()

                if df.empty:
                    continue

                for _, row in tqdm(df.iterrows(), total=len(df), desc="Notes"):
                    # Stop if MAX_SAMPLES reached
                    if MAX_SAMPLES and stats["saved"] >= MAX_SAMPLES:
                        print(f"\n‚úÖ Reached MAX_SAMPLES ({MAX_SAMPLES})")
                        return stats

                    # Skip if we haven't reached checkpoint yet
                    if stats["saved"] < checkpoint:
                        # Just count but don't process
                        stats["total_processed"] += 1
                        if len(str(row["TEXT"])) < 300:
                            stats["skipped_short"] += 1
                        continue

                    stats["total_processed"] += 1

                    # Get and clean text
                    raw_text = str(row["TEXT"])
                    raw_text = clean_clinical_text(raw_text)

                    # Skip short notes
                    if len(raw_text) < 300:
                        stats["skipped_short"] += 1
                        continue

                    # Extract relevant portion
                    note_chunk = extract_relevant_text_chunk(raw_text, max_chars=4000)

                    # ---- STEP A: RAW NOTE ‚Üí SOAP (with retry) ----
                    soap, error_msg = extract_soap_from_note(note_chunk, max_retries=MAX_SOAP_RETRIES)

                    if not soap:
                        stats["soap_fail"] += 1
                        if error_msg:
                            stats["soap_validation_errors"][error_msg] = stats["soap_validation_errors"].get(error_msg, 0) + 1
                        continue

                    if error_msg == "Partially salvaged":
                        stats["soap_retry_success"] += 1

                    # ---- STEP B: SOAP ‚Üí DIALOGUE (with retry) ----
                    dialogue, dialogue_error = generate_dialogue_from_soap(soap, max_retries=MAX_DIALOGUE_RETRIES)

                    if not dialogue:
                        stats["dialogue_fail"] += 1
                        if dialogue_error:
                            stats["dialogue_validation_errors"][dialogue_error] = stats["dialogue_validation_errors"].get(dialogue_error, 0) + 1
                        continue

                    if "retry" in dialogue_error.lower():
                        stats["dialogue_retry_success"] += 1

                    # ---- STEP C: FORMAT DATASET ENTRY ----
                    target = (
                        f"Subjective: {soap['Subjective']}\n"
                        f"Objective: {soap['Objective']}\n"
                        f"Assessment: {soap['Assessment']}\n"
                        f"Plan: {soap['Plan']}"
                    )

                    entry = {
                        "input": f"generate soap note:\n{dialogue}",
                        "output": target,
                        "metadata": {
                            "category": row["CATEGORY"],
                            "row_id": int(row["ROW_ID"]) if "ROW_ID" in row else None,
                            "note_length": len(raw_text),
                            "chunk_length": len(note_chunk),
                            "soap_quality": "full" if error_msg is None else "partial",
                            "dialogue_quality": "full" if dialogue_error is None else "retry_success",
                            "processing_timestamp": pd.Timestamp.now().isoformat()
                        }
                    }

                    # Write to file
                    fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
                    fout.flush()

                    stats["saved"] += 1

                    # Save checkpoint every 5 samples
                    if stats["saved"] % 5 == 0:
                        save_checkpoint(stats["saved"])

                # Free memory
                del df

    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Process interrupted by user")
        if stats["saved"] > 0:
            save_checkpoint(stats["saved"])
            print(f"üíæ Progress saved: {stats['saved']} samples")
        return stats
    except Exception as e:
        print(f"‚ùå Unexpected error: {e}")
        import traceback
        traceback.print_exc()
        return stats

    print("\n‚úÖ Processing completed!")
    return stats

# ==========================================
# 7. MAIN EXECUTION
# ==========================================
if __name__ == "__main__":
    print("üöÄ Starting MIMIC-III to Dialogue-SOAP conversion")
    print(f"Output file: {OUTPUT_FILE}")
    print(f"Max samples: {MAX_SAMPLES or 'All'}")
    print(f"Target categories: {TARGET_CATEGORIES}")
    print(f"SOAP retries: {MAX_SOAP_RETRIES}")
    print(f"Dialogue retries: {MAX_DIALOGUE_RETRIES}")

    start_time = time.time()
    final_stats = process_mimic()

    # Clean up checkpoint file if completed
    if final_stats and "saved" in final_stats:
        if MAX_SAMPLES and final_stats["saved"] >= MAX_SAMPLES:
            if os.path.exists(CHECKPOINT_FILE):
                os.remove(CHECKPOINT_FILE)
                print("üßπ Checkpoint file removed (completed)")

    # Print statistics
    if final_stats:
        print("\nüìä FINAL STATISTICS")
        print("=" * 50)
        print(f"{'Metric':30} {'Count':>10} {'%':>8}")
        print("-" * 50)

        total = final_stats.get("total_processed", 0)

        metrics = [
            ("Total processed", final_stats.get("total_processed", 0)),
            ("Skipped (short)", final_stats.get("skipped_short", 0)),
            ("SOAP failed", final_stats.get("soap_fail", 0)),
            ("SOAP retry success", final_stats.get("soap_retry_success", 0)),
            ("Dialogue failed", final_stats.get("dialogue_fail", 0)),
            ("Dialogue retry success", final_stats.get("dialogue_retry_success", 0)),
            ("Successfully saved", final_stats.get("saved", 0))
        ]

        for name, value in metrics:
            if total > 0 and name not in ["Total processed", "Successfully saved"]:
                percentage = (value / total) * 100
                print(f"{name:30} {value:10} {percentage:7.1f}%")
            else:
                print(f"{name:30} {value:10}")

        if total > 0:
            success_rate = (final_stats.get("saved", 0) / total) * 100
            print("-" * 50)
            print(f"{'Success rate':30} {'':10} {success_rate:7.1f}%")

            # Print common validation errors
            print("\nüîç TOP VALIDATION ERRORS:")
            soap_errors = final_stats.get("soap_validation_errors", {})
            if soap_errors:
                print("SOAP errors:")
                for error, count in sorted(soap_errors.items(), key=lambda x: x[1], reverse=True)[:3]:
                    print(f"  - {error}: {count}")

            dialogue_errors = final_stats.get("dialogue_validation_errors", {})
            if dialogue_errors:
                print("Dialogue errors:")
                for error, count in sorted(dialogue_errors.items(), key=lambda x: x[1], reverse=True)[:3]:
                    print(f"  - {error}: {count}")

    elapsed = time.time() - start_time
    saved = final_stats.get("saved", 0) if final_stats else 0
    print(f"\n‚è±Ô∏è  Total time: {elapsed:.1f} seconds")
    if saved > 0:
        print(f"   Avg time per sample: {elapsed/saved:.1f}s")
        print(f"   Samples per hour: {(saved/elapsed)*3600:.1f}")