In [7]:
import re
import json
import tiktoken
import pandas as pd
from typing import List, Dict, Any
import random
import time
from openai import OpenAI

# Configuration
openai_key = "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"
gpt_model = "gpt-4o"

client = OpenAI(api_key=openai_key)
encoding = tiktoken.encoding_for_model(gpt_model)

category_order = [
    "gender", "age", "disability status", "race", "country", "state", "region",
    "languages spoken", "education level", "social media usage", "religion", "marital status",
    "profession", "household income classification", "housing situation"
]

# Define allowed labels for each category (flat structure for LLM to work with)
allowed_labels_per_category = {
    "gender": ["female", "male", "othergender"],
    "age": ["adult", "senior", "young"],
    "disability status": [
        "assistive_supports", "general_disability", "has_no_disability", 
        "mental_health_conditions", "neurodevelopmental_disability", 
        "physical_disability", "sensory_disability", "speech_and_cognitive_disability"
    ],
    "race": [
        "african", "african_american", "afro_latino", "australian_new_zealand", 
        "caribbean", "caribbean_hispanic", "central_american", "east_asian", 
        "european", "hispanic", "indigenous_arctic", "indigenous_central_south_american", 
        "indigenous_oceanian", "jewish", "latino", "melanesian", "micronesian", 
        "middle_eastern", "multiracial", "native_north_american", "north_african", 
        "north_american", "polynesian", "roma", "south_american", "south_asian", "southeast_asian"
    ],
    "country": [
        "australia", "brazil", "canada", "china", "france", "germany", "india", 
        "italy", "japan", "mexico", "other_country", "russia", "south_africa", 
        "spain", "uk", "usa"
    ],
    "state": [
        "alabama", "alaska", "arizona", "arkansas", "california", "colorado", 
        "connecticut", "delaware", "florida", "georgia", "guam", "hawaii", 
        "idaho", "illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana", 
        "maine", "maryland", "massachusetts", "michigan", "minnesota", "mississippi", 
        "missouri", "montana", "nebraska", "nevada", "new_hampshire", "new_jersey", 
        "new_mexico", "new_york", "north_carolina", "north_dakota", "ohio", 
        "oklahoma", "oregon", "pennsylvania", "puerto_rico", "rhode_island", 
        "south_carolina", "south_dakota", "tennessee", "texas", "u.s._virgin_islands", 
        "utah", "vermont", "virginia", "washington", "west_virginia", "wisconsin", "wyoming"
    ],
    "region": ["midwest_usa", "northeast_usa", 
               "southeast_usa", "southwest_usa", "us_territories", "west_usa"],
    "languages spoken": [
        "african", "arabic", "bengali", "chinese", "creole", "danish", "dutch", "english", "finnish", "french", "german", "greek", 
        "hebrew", "hindi", "indigenous", "indonesian", "italian", "japanese", "korean", "malayalam", "norwegian", "persian", "polish", 
        "portuguese", "punjabi", "romanian", "russian", "sign_languages", "spanish", "swedish", "tagalog", "tamil", "telugu", "thai", 
        "turkish", "urdu", "vietnamese"
    ],
    "education level": [
        "bachelor_degree", "diploma/certificate", "doctoral_degree", "high_school_degree", "masters_degree"
    ],
    "social media usage": ["has_social_media", "no_social_media"],
    "religion": [
        "agnosticism", "atheism", "buddhism", "christianity", "hinduism", "indigenous_&_animistic_beliefs", "islam", 
        "jainism", "judaism", "other_religions", "shinto", "sikhism", "spirituality", "taoism"
    ],
    "marital status": ["cohabiting", "complicated_relationship", "divorced_or_separated", "engaged", "married", "single", "widowed"],
    "profession": [
       "art_&design", "construction&technical", "engineering", "entrepreneurship", "finance", "hospitality", "information_technology", 
        "law", "medical", "military", "other_professions", "police&security", "retail&customer_service", "scientist", 
        "self-employment&gig_economy", "social_work&counseling", "student_roles", "teaching", "unemployed", "writing&_journalism"
    ],
    "household income classification": ["lower-middle_class", "lower_class", "middle_class", "upper-middle_class", "upper_class"],
    "housing situation": [
       "apartment", "homeless", "other_housing", "shared_housing", "single-family_home", "three-family_home", "two-family_home", "unspecified_housing"

    ]
}

def create_label_detection_prompt(description: str, category_order: List[str], allowed_labels: Dict[str, List[str]]) -> str:
    """Create a comprehensive prompt for label detection using the LLM"""
    
    # Format allowed labels for the prompt
    labels_text = ""
    for category in category_order:
        if category in allowed_labels:
            labels_list = ", ".join(allowed_labels[category])
            labels_text += f"{category}: {labels_list}\n"
    
    prompt = f"""You are an expert at analyzing patient descriptions to extract demographic information.

TASK: Analyze the patient description below and determine which demographic categories are mentioned or can be reasonably inferred. For each category that IS mentioned, select the most appropriate label from the allowed options. If a category is NOT mentioned or cannot be inferred, respond with "n/a".

PATIENT DESCRIPTION: "{description}"

ALLOWED LABELS PER CATEGORY:
{labels_text}

INSTRUCTIONS:
1. Only identify categories that are explicitly mentioned or clearly implied in the text
2. Use exact label names from the allowed options above
3. If multiple labels could apply, choose the most specific/accurate one
4. If a category is not mentioned at all, use "n/a"
5. Be conservative - only identify what you're confident about

RESPOND IN JSON FORMAT:
{{
    "gender": "label_or_n/a",
    "age": "label_or_n/a",
    "disability status": "label_or_n/a",
    "race": "label_or_n/a",
    "country": "label_or_n/a",
    "state": "label_or_n/a",
    "region": "label_or_n/a",
    "languages spoken": "label_or_n/a",
    "education level": "label_or_n/a",
    "social media usage": "label_or_n/a",
    "religion": "label_or_n/a",
    "marital status": "label_or_n/a",
    "profession": "label_or_n/a",
    "household income classification": "label_or_n/a",
    "housing situation": "label_or_n/a"
}}"""

    return prompt

def detect_labels_with_llm(description: str, category_order: List[str], allowed_labels: Dict[str, List[str]], max_retries: int = 2) -> List[str]:
    """Use LLM to detect demographic labels in the description"""
    
    prompt = create_label_detection_prompt(description, category_order, allowed_labels)
    
    for attempt in range(max_retries + 1):
        try:
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.1,  # Low temperature for consistency
                max_tokens=1000
            )
            
            result_text = response.choices[0].message.content.strip()
            
            # Extract JSON from the response
            json_match = re.search(r'\{.*\}', result_text, re.DOTALL)
            if json_match:
                try:
                    result_dict = json.loads(json_match.group())
                    
                    # Convert to ordered list, ensuring all categories are included
                    detected_labels = []
                    for category in category_order:
                        label = result_dict.get(category, "n/a")
                        
                        # Special handling for languages spoken - combine multiple languages with underscores
                        if category == "languages spoken" and label != "n/a":
                            # Handle comma-separated languages
                            if ',' in label:
                                # Split by comma, clean up whitespace, and join with underscore
                                languages = [lang.strip().lower() for lang in label.split(',')]
                                # Validate each individual language
                                valid_languages = []
                                for lang in languages:
                                    if lang in allowed_labels[category]:
                                        valid_languages.append(lang)
                                    else:
                                        print(f"Warning: Invalid language '{lang}' found in '{label}'. Skipping.")
                                
                                if valid_languages:
                                    label = '_'.join(valid_languages)
                                else:
                                    label = "n/a"
                            # Handle single language
                            elif label.lower() not in allowed_labels[category]:
                                print(f"Warning: Invalid label '{label}' for category '{category}'. Using 'n/a'.")
                                label = "n/a"
                        
                        # Validate that the label is allowed for other categories
                        elif label != "n/a" and category in allowed_labels:
                            if label not in allowed_labels[category]:
                                print(f"Warning: Invalid label '{label}' for category '{category}'. Using 'n/a'.")
                                label = "n/a"
                        
                        detected_labels.append(label)
                    
                    return detected_labels
                    
                except json.JSONDecodeError as e:
                    print(f"JSON parsing error on attempt {attempt + 1}: {e}")
                    if attempt == max_retries:
                        print("Max retries reached. Using fallback.")
                        return ["n/a"] * len(category_order)
            else:
                print(f"No JSON found in response on attempt {attempt + 1}")
                if attempt == max_retries:
                    return ["n/a"] * len(category_order)
                    
        except Exception as e:
            print(f"API error on attempt {attempt + 1}: {e}")
            if attempt == max_retries:
                return ["n/a"] * len(category_order)
            
        # Wait before retry
        if attempt < max_retries:
            time.sleep(2)
    
    return ["n/a"] * len(category_order)

def generate_prompt_with_random_categories(category_order: List[str], allowed_labels: Dict[str, List[str]], 
                                         min_categories: int = 4, max_categories: int = 8) -> tuple:
    """Generate prompt with random categories selected"""
    selected_categories = random.sample(category_order, random.randint(min_categories, max_categories))
    
    # Create examples for each selected category to help the LLM generate realistic descriptions
    category_examples = {
        "gender": "Use appropriate pronouns (he/she/they) naturally in the description",
        "age": "Mention age directly (e.g., '45-year-old') or use age-related terms (elderly, young adult, teenager)",
        "disability status": "Mention any disabilities, conditions, or assistive devices if relevant",
        "race": "Include ethnic or racial background naturally",
        "country": "Mention country of origin or citizenship",
        "state": "Mention specific US state or territory",
        "region": "Indicate urban/suburban/rural setting",
        "languages spoken": "Mention primary language or bilingual status",
        "education level": "Reference education background (degree, schooling level)",
        "social media usage": "Mention social media habits if relevant to medical context",
        "religion": "Include religious affiliation if relevant",
        "marital status": "Mention spouse, partner, or single status",
        "profession": "Include occupation or work status",
        "household income classification": "Reference financial situation if relevant",
        "housing situation": "Mention living arrangement (owns home, rents, etc.)"
    }
    
    category_guidance = "\n".join([f"- {cat}: {category_examples.get(cat, 'Include naturally')}" 
                                  for cat in selected_categories])
    
    prompt = f"""You are creating a realistic patient description for medical training purposes.

Write a medically realistic and concise patient description (3-5 sentences) that naturally incorporates the following demographic details:

{category_guidance}

Requirements:
- Write in a clinical yet natural style
- Include specific, identifiable mentions of these categories
- Make it realistic for a medical setting
- Don't use lists or bullet points
- Only include the specified categories above

Create a coherent patient scenario that would realistically include these demographic details."""

    return prompt.strip(), selected_categories

def analyze_detection_accuracy(rows: List[Dict]) -> Dict:
    """Analyze how well the detection worked"""
    analysis = {
        "total_descriptions": len(rows),
        "category_stats": {},
        "overall_stats": {}
    }
    
    total_intended = 0
    total_detected = 0
    total_correct_detections = 0
    
    for category in category_order:
        intended_count = sum(1 for row in rows if category in row["intended_categories"])
        detected_count = sum(1 for row in rows if row["ground_truth_labels"][category_order.index(category)] != "n/a")
        correct_detections = sum(1 for row in rows 
                               if category in row["intended_categories"] and 
                               row["ground_truth_labels"][category_order.index(category)] != "n/a")
        
        precision = correct_detections / detected_count if detected_count > 0 else 0
        recall = correct_detections / intended_count if intended_count > 0 else 0
        
        analysis["category_stats"][category] = {
            "intended": intended_count,
            "detected": detected_count,
            "correct": correct_detections,
            "precision": precision,
            "recall": recall
        }
        
        total_intended += intended_count
        total_detected += detected_count
        total_correct_detections += correct_detections
    
    analysis["overall_stats"] = {
        "total_intended": total_intended,
        "total_detected": total_detected,
        "total_correct": total_correct_detections,
        "overall_precision": total_correct_detections / total_detected if total_detected > 0 else 0,
        "overall_recall": total_correct_detections / total_intended if total_intended > 0 else 0
    }
    
    return analysis

def main():
    N = 100  # Number of synthetic descriptions to generate
    rows = []
    
    
    for i in range(N):
        print(f"\nGenerating description {i+1}/{N}...")
        
        # Generate prompt and get LLM response for patient description
        prompt_text, selected_categories = generate_prompt_with_random_categories(
            category_order, allowed_labels_per_category
        )
        
        try:
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt_text}],
                temperature=0.7,
                max_tokens=300
            )
            
            patient_description = response.choices[0].message.content.strip()
            
            # Detect labels using the LLM
            #print("Analyzing description for demographic labels...")
            detected_labels = detect_labels_with_llm(patient_description, category_order, allowed_labels_per_category)
            
            # Create summary of detected categories
            detected_categories = [cat for cat, label in zip(category_order, detected_labels) if label != "n/a"]
            detected_dict = {cat: label for cat, label in zip(category_order, detected_labels) if label != "n/a"}
            
            rows.append({
                "question": patient_description,
                "intended_categories": selected_categories,
                "ground_truth_labels": detected_labels  # This is the main output format you want
            })
            
            # Print progress
            #print(f"✅ Description: {patient_description}")
            #print(f"📝 Intended categories ({len(selected_categories)}): {selected_categories}")
            #print(f"🔍 Ground truth labels: {detected_labels}")
            
            # Show non-n/a labels for clarity
            non_na_labels = [(i, cat, label) for i, (cat, label) in enumerate(zip(category_order, detected_labels)) if label != "n/a"]
            
            # Calculate match rate for this description
            detected_categories = [cat for cat, label in zip(category_order, detected_labels) if label != "n/a"]
            matches = len([cat for cat in selected_categories if cat in detected_categories])
            match_rate = matches / len(selected_categories) * 100 if selected_categories else 0
            #print(f"✨ Match rate: {match_rate:.1f}% ({matches}/{len(selected_categories)})")
            
        except Exception as e:
            print(f"❌ Error generating description {i+1}: {e}")
            continue
        
        #print("-" * 80)
        time.sleep(1)  # Rate limiting
    
    if not rows:
        print("❌ No descriptions were generated successfully.")
        return
    
    # Save to CSV
    df = pd.DataFrame(rows)
    filename = "synthetic_patient_descriptions_and_ground_truth.csv"
    df.to_csv(filename, index=False)
    print(f"\n✅ Generated and saved {len(rows)} descriptions to '{filename}'")
    
    
    
    analysis = analyze_detection_accuracy(rows)
    
    #print(f"\nOVERALL STATISTICS:")
    #print(f"Total descriptions: {analysis['total_descriptions']}")
    #print(f"Total intended categories: {analysis['overall_stats']['total_intended']}")
    #print(f"Total detected categories: {analysis['overall_stats']['total_detected']}")
    #print(f"Correct detections: {analysis['overall_stats']['total_correct']}")
    #print(f"Overall Precision: {analysis['overall_stats']['overall_precision']:.2%}")
    #print(f"Overall Recall: {analysis['overall_stats']['overall_recall']:.2%}")
    
    # Show example of the ground_truth_labels format
    #print(f"\nEXAMPLE OUTPUT FORMAT:")
    
    
    print(f"\nPER-CATEGORY PERFORMANCE:")
    print(f"{'Category':<25} {'Intended':<9} {'Detected':<9} {'Correct':<8} {'Precision':<10} {'Recall':<8}")
    print("-" * 75)
    
    for category in category_order:
        stats = analysis["category_stats"][category]
        print(f"{category:<25} {stats['intended']:<9} {stats['detected']:<9} {stats['correct']:<8} "
              f"{stats['precision']:<10.2%} {stats['recall']:<8.2%}")
    
    # Show best and worst performing categories
    category_recall_scores = [(cat, stats['recall']) for cat, stats in analysis["category_stats"].items() 
                             if stats['intended'] > 0]
    category_recall_scores.sort(key=lambda x: x[1], reverse=True)
    

if __name__ == "__main__":
    main()


Generating description 1/100...

Generating description 2/100...

Generating description 3/100...

Generating description 4/100...

Generating description 5/100...

Generating description 6/100...

Generating description 7/100...

Generating description 8/100...

Generating description 9/100...

Generating description 10/100...

Generating description 11/100...

Generating description 12/100...

Generating description 13/100...

Generating description 14/100...

Generating description 15/100...

Generating description 16/100...

Generating description 17/100...

Generating description 18/100...

Generating description 19/100...

Generating description 20/100...

Generating description 21/100...

Generating description 22/100...

Generating description 23/100...

Generating description 24/100...

Generating description 25/100...

Generating description 26/100...

Generating description 27/100...

Generating description 28/100...

Generating description 29/100...

Generating description