In [None]:
import pandas as pd
import json
import re
import spacy
from collections import Counter

# Load spaCy medical model
print("Loading spaCy medical model...")
try:
    nlp = spacy.load("en_core_sci_md")
    print("Medical spaCy model loaded successfully!")
except OSError:
    print("Medical spaCy model not found. Please install it using:")
    print("pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_md-0.5.1.tar.gz")
    # For fallback, use regular English model
    nlp = spacy.load("en_core_web_sm")

# Load your dataset
print("Loading dataset...")
df = pd.read_csv('/dataset_summary.csv')

# Define urgency classification rules
URGENCY_RULES = {
    'high': [
        # Cardiac emergencies
        r'\b(severe|crushing|sharp) chest pain\b',
        r'\bheart attack\b',
        r'\bcardiac arrest\b',
        r'\bshortness of breath\b',
        r'\bdifficulty breathing\b',
        r'\btrouble breathing\b',
        r'\blightheaded\b',
        r'\bfainting\b',
        r'\bloss of consciousness\b',
        r'\bpassing out\b',
        r'\bradiating pain\b',
        r'\bpain spreading\b',
        r'\bsevere pain\b',
        r'\b10/10 pain\b',
        r'\b9/10 pain\b',
        r'\b8/10 pain\b',
        r'\bworst pain\b',

        # Other emergencies
        r'\bsevere bleeding\b',
        r'\buncontrolled bleeding\b',
        r'\bstroke\b',
        r'\bnumbness\b',
        r'\bweakness\b',
        r'\bparalysis\b',
        r'\bvision loss\b',
        r'\bsevere headache\b',
        r'\bseizure\b'
    ],
    'medium': [
        r'\bchest pain\b',
        r'\bfever\b',
        r'\bpersistent cough\b',
        r'\bworsening symptoms\b',
        r'\babdominal pain\b',
        r'\bvomiting\b',
        r'\bdiarrhea\b',
        r'\bdehydration\b',
        r'\bmoderate pain\b',
        r'\b7/10 pain\b',
        r'\b6/10 pain\b',
        r'\binfection\b',
        r'\binflammatory\b',
        r'\bswelling\b',
        r'\bredness\b',
        r'\bdizziness\b',
        r'\bnausea\b'
    ],
    'low': [
        r'\broutine\b',
        r'\bfollow.up\b',
        r'\bcheck.up\b',
        r'\bmanagement\b',
        r'\breview\b',
        r'\bmild pain\b',
        r'\bchronic condition\b',
        r'\bstable\b',
        r'\bpreventive care\b',
        r'\bvaccination\b',
        r'\bscreening\b',
        r'\bcold symptoms\b',
        r'\bmild cough\b',
        r'\brunny nose\b',
        r'\b1/10 pain\b',
        r'\b2/10 pain\b',
        r'\b3/10 pain\b'
    ]
}

# Specialty mapping based on symptoms and keywords
SPECIALTY_KEYWORDS = {
    'Cardiology': [
        'chest', 'heart', 'cardiac', 'breathing', 'palpitations',
        'blood pressure', 'hypertension', 'cholesterol', 'ecg'
    ],
    'Gastroenterology': [
        'stomach', 'abdominal', 'vomiting', 'diarrhea', 'nausea',
        'bowel', 'digestive', 'constipation', 'indigestion'
    ],
    'Musculoskeletal': [
        'pain', 'joint', 'muscle', 'elbow', 'shoulder', 'knee',
        'back', 'swelling', 'tendon', 'ligament', 'fracture'
    ],
    'Dermatology': [
        'rash', 'skin', 'itching', 'redness', 'lesion', 'acne',
        'eczema', 'dermatitis', 'psoriasis'
    ],
    'Respiratory': [
        'cough', 'breathing', 'wheezing', 'lungs', 'respiratory',
        'asthma', 'pneumonia', 'bronchitis'
    ],
    'General Medicine': [
        'fever', 'fatigue', 'general', 'routine', 'check.up'
    ]
}

def extract_key_information(text):
    """Extract key medical information using spaCy"""
    doc = nlp(text.lower())

    # Extract medical entities
    symptoms = []
    conditions = []
    severity_indicators = []

    for ent in doc.ents:
        if ent.label_ in ["DISEASE", "SYMPTOM", "SIGN"]:
            symptoms.append(ent.text)
        elif ent.label_ in ["PROBLEM", "CONDITION"]:
            conditions.append(ent.text)

    # Extract severity words
    severity_words = ['severe', 'mild', 'moderate', 'sharp', 'chronic', 'acute', 'worsening']
    for token in doc:
        if token.text in severity_words:
            severity_indicators.append(token.text)

    # Extract pain levels
    pain_levels = re.findall(r'(\d+)/10 pain', text.lower())
    severity_indicators.extend([f"{level}/10 pain" for level in pain_levels])

    return {
        'symptoms': list(set(symptoms)),
        'conditions': list(set(conditions)),
        'severity_indicators': list(set(severity_indicators))
    }

def classify_urgency(text, extracted_info):
    """Classify urgency based on rules and extracted information"""
    text_lower = text.lower()

    # Check high urgency rules
    for pattern in URGENCY_RULES['high']:
        if re.search(pattern, text_lower):
            return 'high'

    # Check medium urgency rules
    for pattern in URGENCY_RULES['medium']:
        if re.search(pattern, text_lower):
            return 'medium'

    # Check low urgency rules
    for pattern in URGENCY_RULES['low']:
        if re.search(pattern, text_lower):
            return 'low'

    # Default based on severity indicators
    if any(word in text_lower for word in ['severe', 'emergency', 'urgent']):
        return 'high'
    elif any(word in text_lower for word in ['moderate', 'worsening']):
        return 'medium'
    else:
        return 'low'

def predict_specialty(text, extracted_info):
    """Predict medical specialty based on keywords"""
    text_lower = text.lower()
    specialty_scores = {}

    for specialty, keywords in SPECIALTY_KEYWORDS.items():
        score = 0
        for keyword in keywords:
            # Count occurrences of keyword
            score += len(re.findall(r'\b' + re.escape(keyword) + r'\b', text_lower))
        specialty_scores[specialty] = score

    # Return specialty with highest score
    predicted_specialty = max(specialty_scores, key=specialty_scores.get)

    # If no strong match, use General Medicine
    if specialty_scores[predicted_specialty] == 0:
        return 'General Medicine'

    return predicted_specialty

def create_concise_summary(text, extracted_info, urgency, specialty):
    """Create a concise summary in the required format"""
    symptoms = extracted_info['symptoms']
    severity = extracted_info['severity_indicators']

    # Build the summary text
    if symptoms:
        main_symptoms = symptoms[:3]  # Take first 3 symptoms
        symptom_text = ", ".join(main_symptoms)

        if severity:
            severity_text = severity[0]  # Take the most prominent severity indicator
            summary = f"Patient experiencing {severity_text} {symptom_text}"
        else:
            summary = f"Patient experiencing {symptom_text}"
    else:
        # Fallback: use keywords from text
        if 'chest pain' in text.lower():
            summary = "Patient experiencing chest pain"
        elif 'fever' in text.lower() and 'cough' in text.lower():
            summary = "Patient with fever and persistent cough"
        elif 'pain' in text.lower():
            summary = "Patient reporting pain symptoms"
        else:
            summary = "Patient with medical symptoms requiring attention"

    # Add duration if mentioned
    duration_pattern = r'(\d+\s*(?:hour|day|week|month)s?)'
    duration_match = re.search(duration_pattern, text.lower())
    if duration_match:
        summary += f" for {duration_match.group(1)}"

    # Add urgency context
    if urgency == 'high':
        summary += " with emergency symptoms"
    elif urgency == 'medium':
        summary += " with concerning symptoms"
    else:
        summary += " for evaluation"

    return summary

def process_medical_conversations(df):
    """Process all medical conversations and create the training dataset"""
    training_data = []

    for idx, row in df.iterrows():
        text = row['transcription']
        original_specialty = row['specialty']

        print(f"Processing {row['filename']} - {original_specialty}")

        # Extract information using spaCy
        extracted_info = extract_key_information(text)

        # Classify urgency
        urgency = classify_urgency(text, extracted_info)

        # Predict specialty
        specialty = predict_specialty(text, extracted_info)

        # Create concise summary
        concise_text = create_concise_summary(text, extracted_info, urgency, specialty)

        # Create the training example
        training_example = {
            "text": concise_text,
            "metadata": {
                "specialty": specialty,
                "urgency": urgency
            }
        }

        training_data.append(training_example)

        # Print sample for verification
        if idx < 3:  # Show first 3 examples
            print(f"Sample {idx + 1}:")
            print(f"  Text: {concise_text}")
            print(f"  Specialty: {specialty}, Urgency: {urgency}")
            print()

    return training_data

# Process the dataset
print("Processing medical conversations...")
training_dataset = process_medical_conversations(df)

# Save the dataset
output_file = '/spacy_medical_training_data.json'
with open(output_file, 'w') as f:
    json.dump(training_dataset, f, indent=2)

print(f"\nDataset saved to: {output_file}")
print(f"Total training examples created: {len(training_dataset)}")

# Create a sample preview file with first 10 examples
sample_output = training_dataset[:10]
sample_file = '/sample_training_data.json'
with open(sample_file, 'w') as f:
    json.dump(sample_output, f, indent=2)

print(f"Sample preview saved to: {sample_file}")

# Print statistics
urgency_counts = Counter([item['metadata']['urgency'] for item in training_dataset])
specialty_counts = Counter([item['metadata']['specialty'] for item in training_dataset])

print("\nDataset Statistics:")
print(f"Urgency distribution: {dict(urgency_counts)}")
print(f"Specialty distribution: {dict(specialty_counts)}")

# Display first 5 examples
print("\nFirst 5 training examples:")
for i, example in enumerate(training_dataset[:5]):
    print(f"{i+1}. Text: {example['text']}")
    print(f"   Metadata: {example['metadata']}")
    print()

Loading spaCy medical model...
Medical spaCy model not found. Please install it using:
pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_md-0.5.1.tar.gz
Loading dataset...
Processing medical conversations...
Processing CAR0001.mp3 - Cardiology
Sample 1:
  Text: Patient experiencing chest pain with emergency symptoms
  Specialty: Musculoskeletal, Urgency: high

Processing MSK0005.mp3 - Musculoskeletal
Sample 2:
  Text: Patient experiencing chest pain with emergency symptoms
  Specialty: Musculoskeletal, Urgency: high

Processing GAS0004.mp3 - Gastroenterology
Sample 3:
  Text: Patient reporting pain symptoms with concerning symptoms
  Specialty: Gastroenterology, Urgency: medium

Processing DER0001.mp3 - Dermatology
Processing MSK0004.mp3 - Musculoskeletal
Processing MSK0010.mp3 - Musculoskeletal
Processing GAS0007.mp3 - Gastroenterology
Processing RES0006.mp3 - Respiratory
Processing MSK0001.mp3 - Musculoskeletal
Processing MSK0006.mp3 - Musculo

In [None]:
import pandas as pd
import json
import re
import spacy
from collections import Counter
import random

# Load spaCy medical model
print("Loading spaCy medical model...")
try:
    nlp = spacy.load("en_core_sci_md")
    print("Medical spaCy model loaded successfully!")
except OSError:
    print("Medical spaCy model not found. Using basic English model...")
    nlp = spacy.load("en_core_web_sm")

# Load your dataset
print("Loading dataset...")
df = pd.read_csv('/dataset_summary.csv')

# Enhanced urgency classification rules
URGENCY_RULES = {
    'high': [
        # Cardiac emergencies
        r'\b(severe|crushing|sharp) chest pain\b',
        r'\bheart attack\b',
        r'\bcardiac arrest\b',
        r'\bshortness of breath\b',
        r'\bdifficulty breathing\b',
        r'\btrouble breathing\b',
        r'\blightheaded\b',
        r'\bfainting\b',
        r'\bloss of consciousness\b',
        r'\bpassing out\b',
        r'\bradiating pain\b',
        r'\bpain spreading\b',
        r'\bsevere pain\b',
        r'\b10/10 pain\b',
        r'\b9/10 pain\b',
        r'\b8/10 pain\b',
        r'\bworst pain\b',

        # Other emergencies
        r'\bsevere bleeding\b',
        r'\buncontrolled bleeding\b',
        r'\bstroke\b',
        r'\bnumbness\b',
        r'\bweakness\b',
        r'\bparalysis\b',
        r'\bvision loss\b',
        r'\bsevere headache\b',
        r'\bseizure\b',
        r'\bunresponsive\b',
        r'\bcardiac\b',
        r'\bemergency\b',
        r'\bcritical\b'
    ],
    'medium': [
        r'\bchest pain\b',
        r'\bfever\b',
        r'\bpersistent cough\b',
        r'\bworsening symptoms\b',
        r'\babdominal pain\b',
        r'\bvomiting\b',
        r'\bdiarrhea\b',
        r'\bdehydration\b',
        r'\bmoderate pain\b',
        r'\b7/10 pain\b',
        r'\b6/10 pain\b',
        r'\b5/10 pain\b',
        r'\binfection\b',
        r'\binflammatory\b',
        r'\bswelling\b',
        r'\bredness\b',
        r'\bdizziness\b',
        r'\bnausea\b',
        r'\bheadache\b',
        r'\bfatigue\b',
        r'\bweakness\b',
        r'\bconcern\b',
        r'\bworsening\b'
    ],
    'low': [
        r'\broutine\b',
        r'\bfollow.up\b',
        r'\bcheck.up\b',
        r'\bmanagement\b',
        r'\breview\b',
        r'\bmild pain\b',
        r'\bchronic condition\b',
        r'\bstable\b',
        r'\bpreventive care\b',
        r'\bvaccination\b',
        r'\bscreening\b',
        r'\bcold symptoms\b',
        r'\bmild cough\b',
        r'\brunny nose\b',
        r'\b1/10 pain\b',
        r'\b2/10 pain\b',
        r'\b3/10 pain\b',
        r'\b4/10 pain\b',
        r'\bannual\b',
        r'\bphysical\b',
        r'\bexam\b',
        r'\basymptomatic\b',
        r'\bwellness\b'
    ]
}

# Enhanced specialty mapping
SPECIALTY_KEYWORDS = {
    'Cardiology': [
        'chest', 'heart', 'cardiac', 'breathing', 'palpitations',
        'blood pressure', 'hypertension', 'cholesterol', 'ecg',
        'heartbeat', 'arrhythmia', 'angina', 'cardiovascular'
    ],
    'Gastroenterology': [
        'stomach', 'abdominal', 'vomiting', 'diarrhea', 'nausea',
        'bowel', 'digestive', 'constipation', 'indigestion',
        'acid reflux', 'gerd', 'ibs', 'gallbladder', 'liver'
    ],
    'Musculoskeletal': [
        'pain', 'joint', 'muscle', 'elbow', 'shoulder', 'knee',
        'back', 'swelling', 'tendon', 'ligament', 'fracture',
        'arthritis', 'osteoporosis', 'sprain', 'strain'
    ],
    'Dermatology': [
        'rash', 'skin', 'itching', 'redness', 'lesion', 'acne',
        'eczema', 'dermatitis', 'psoriasis', 'hives', 'blister',
        'mole', 'skin cancer', 'sunburn'
    ],
    'Respiratory': [
        'cough', 'breathing', 'wheezing', 'lungs', 'respiratory',
        'asthma', 'pneumonia', 'bronchitis', 'copd', 'shortness',
        'oxygen', 'inhaler'
    ],
    'General Medicine': [
        'fever', 'fatigue', 'general', 'routine', 'check.up',
        'wellness', 'physical', 'annual', 'preventive'
    ]
}

def create_synthetic_medical_cases():
    """Create additional synthetic medical cases for better balance"""
    synthetic_cases = []

    # LOW URGENCY CASES
    low_cases = [
        # General Medicine
        "Patient presents for routine annual physical examination and preventive health screening",
        "Follow-up visit for stable hypertension management and medication review",
        "Wellness check for asymptomatic patient with no current complaints",
        "Routine diabetes management with stable blood glucose levels",
        "Preventive care consultation and vaccination update",
        "Annual health maintenance visit with normal vital signs",
        "Medication refill request for chronic stable condition",
        "General health consultation for lifestyle modifications",

        # Musculoskeletal - Low
        "Mild occasional back stiffness after long periods of sitting",
        "Minor joint discomfort that resolves with rest and over-the-counter pain relief",
        "Chronic stable arthritis with well-controlled symptoms",
        "Follow-up for previous muscle strain with significant improvement",

        # Respiratory - Low
        "Mild seasonal allergies with occasional sneezing and runny nose",
        "Stable asthma with infrequent inhaler use and good control",
        "Routine pulmonary function test for monitoring purposes",

        # Cardiology - Low
        "Stable blood pressure readings during routine monitoring",
        "Follow-up for well-controlled hyperlipidemia on medication",
        "Routine cardiac assessment with normal findings"
    ]

    # MEDIUM URGENCY CASES
    medium_cases = [
        # General Medicine
        "Moderate fever of 101¬∞F with body aches and fatigue for 2 days",
        "Persistent cough with yellow phlegm and mild chest discomfort",
        "Worsening headache with sensitivity to light but no neurological symptoms",
        "Abdominal pain with nausea and decreased appetite for 24 hours",
        "Urinary symptoms with burning sensation and increased frequency",
        "Skin infection with localized redness, swelling and mild pain",
        "Moderate dehydration after gastrointestinal illness",

        # Musculoskeletal - Medium
        "Moderate back pain limiting daily activities but no neurological deficits",
        "Joint pain with swelling and stiffness affecting mobility",
        "Muscle strain with moderate pain and functional limitation",
        "Worsening arthritis symptoms with increased pain levels",

        # Respiratory - Medium
        "Bronchitis symptoms with productive cough and mild shortness of breath",
        "Asthma exacerbation with increased inhaler use and wheezing",
        "Sinus infection with facial pressure and colored discharge",

        # Cardiology - Medium
        "Palpitations with mild dizziness but no chest pain or fainting",
        "Elevated blood pressure readings with mild headache",
        "Chest discomfort with anxiety but normal cardiac workup",

        # Gastroenterology - Medium
        "Gastroenteritis with vomiting and diarrhea for 12 hours",
        "Moderate abdominal pain with bloating and gas",
        "Food poisoning symptoms with nausea and stomach cramps"
    ]

    # HIGH URGENCY CASES (additional to existing ones)
    high_cases = [
        "Severe crushing chest pain radiating to left arm with sweating and nausea",
        "Sudden onset of severe shortness of breath with blue lips and confusion",
        "Patient collapsed and unresponsive with no pulse or breathing",
        "Severe headache with vision loss and difficulty speaking",
        "Uncontrolled bleeding from deep laceration with signs of shock",
        "Severe allergic reaction with swelling and breathing difficulty",
        "Stroke symptoms with facial droop and arm weakness",
        "Severe abdominal pain with rigidity and fever suggesting appendicitis"
    ]

    # Add low urgency cases
    for case in low_cases:
        synthetic_cases.append({
            'text': case,
            'metadata': {
                'specialty': predict_specialty_synthetic(case),
                'urgency': 'low'
            }
        })

    # Add medium urgency cases
    for case in medium_cases:
        synthetic_cases.append({
            'text': case,
            'metadata': {
                'specialty': predict_specialty_synthetic(case),
                'urgency': 'medium'
            }
        })

    # Add high urgency cases
    for case in high_cases:
        synthetic_cases.append({
            'text': case,
            'metadata': {
                'specialty': predict_specialty_synthetic(case),
                'urgency': 'high'
            }
        })

    return synthetic_cases

def predict_specialty_synthetic(text):
    """Predict specialty for synthetic cases"""
    text_lower = text.lower()
    specialty_scores = {}

    for specialty, keywords in SPECIALTY_KEYWORDS.items():
        score = 0
        for keyword in keywords:
            if keyword in text_lower:
                score += 1
        specialty_scores[specialty] = score

    predicted_specialty = max(specialty_scores, key=specialty_scores.get)
    return predicted_specialty if specialty_scores[predicted_specialty] > 0 else 'General Medicine'

def extract_key_information(text):
    """Extract key medical information using spaCy"""
    doc = nlp(text.lower())

    symptoms = []
    conditions = []
    severity_indicators = []

    for ent in doc.ents:
        if ent.label_ in ["DISEASE", "SYMPTOM", "SIGN", "PROBLEM", "CONDITION"]:
            symptoms.append(ent.text)

    # Extract severity words
    severity_words = ['severe', 'mild', 'moderate', 'sharp', 'chronic', 'acute', 'worsening', 'stable']
    for token in doc:
        if token.text in severity_words:
            severity_indicators.append(token.text)

    # Extract pain levels
    pain_levels = re.findall(r'(\d+)/10 pain', text.lower())
    severity_indicators.extend([f"{level}/10 pain" for level in pain_levels])

    return {
        'symptoms': list(set(symptoms)),
        'conditions': list(set(conditions)),
        'severity_indicators': list(set(severity_indicators))
    }

def classify_urgency(text, extracted_info):
    """Classify urgency based on rules and extracted information"""
    text_lower = text.lower()

    # Check high urgency rules
    for pattern in URGENCY_RULES['high']:
        if re.search(pattern, text_lower):
            return 'high'

    # Check medium urgency rules
    for pattern in URGENCY_RULES['medium']:
        if re.search(pattern, text_lower):
            return 'medium'

    # Check low urgency rules
    for pattern in URGENCY_RULES['low']:
        if re.search(pattern, text_lower):
            return 'low'

    # Default based on severity indicators
    if any(word in text_lower for word in ['severe', 'emergency', 'urgent', 'critical']):
        return 'high'
    elif any(word in text_lower for word in ['moderate', 'worsening', 'persistent']):
        return 'medium'
    else:
        return 'low'

def predict_specialty(text, extracted_info):
    """Predict medical specialty based on keywords"""
    text_lower = text.lower()
    specialty_scores = {}

    for specialty, keywords in SPECIALTY_KEYWORDS.items():
        score = 0
        for keyword in keywords:
            score += len(re.findall(r'\b' + re.escape(keyword) + r'\b', text_lower))
        specialty_scores[specialty] = score

    predicted_specialty = max(specialty_scores, key=specialty_scores.get)

    if specialty_scores[predicted_specialty] == 0:
        return 'General Medicine'

    return predicted_specialty

def create_concise_summary(text, extracted_info, urgency, specialty):
    """Create a concise summary in the required format"""
    symptoms = extracted_info['symptoms']
    severity = extracted_info['severity_indicators']

    # Build the summary text
    if symptoms:
        main_symptoms = symptoms[:3]
        symptom_text = ", ".join(main_symptoms)

        if severity:
            severity_text = severity[0]
            summary = f"Patient experiencing {severity_text} {symptom_text}"
        else:
            summary = f"Patient experiencing {symptom_text}"
    else:
        # Enhanced fallback with better context
        if any(word in text.lower() for word in ['chest pain', 'heart', 'cardiac']):
            summary = "Patient with cardiac symptoms"
        elif any(word in text.lower() for word in ['fever', 'cough', 'breathing']):
            summary = "Patient with respiratory symptoms"
        elif any(word in text.lower() for word in ['pain', 'swelling', 'joint']):
            summary = "Patient with musculoskeletal symptoms"
        elif any(word in text.lower() for word in ['rash', 'skin', 'itching']):
            summary = "Patient with dermatological symptoms"
        elif any(word in text.lower() for word in ['vomiting', 'diarrhea', 'abdominal']):
            summary = "Patient with gastrointestinal symptoms"
        else:
            summary = "Patient requiring medical evaluation"

    # Add duration if mentioned
    duration_pattern = r'(\d+\s*(?:hour|day|week|month)s?)'
    duration_match = re.search(duration_pattern, text.lower())
    if duration_match:
        summary += f" for {duration_match.group(1)}"

    # Add urgency context
    if urgency == 'high':
        summary += " with emergency symptoms requiring immediate attention"
    elif urgency == 'medium':
        summary += " with concerning symptoms requiring evaluation"
    else:
        summary += " for routine assessment"

    return summary

def process_medical_conversations(df):
    """Process all medical conversations and create the training dataset"""
    training_data = []

    for idx, row in df.iterrows():
        text = row['transcription']
        original_specialty = row['specialty']

        # Extract information using spaCy
        extracted_info = extract_key_information(text)

        # Classify urgency
        urgency = classify_urgency(text, extracted_info)

        # Predict specialty
        specialty = predict_specialty(text, extracted_info)

        # Create concise summary
        concise_text = create_concise_summary(text, extracted_info, urgency, specialty)

        # Create the training example
        training_example = {
            "text": concise_text,
            "metadata": {
                "specialty": specialty,
                "urgency": urgency
            }
        }

        training_data.append(training_example)

    return training_data

# Process the dataset
print("Processing medical conversations...")
training_dataset = process_medical_conversations(df)

# Add synthetic cases for better balance
print("Adding synthetic medical cases for better dataset balance...")
synthetic_cases = create_synthetic_medical_cases()
training_dataset.extend(synthetic_cases)

# Shuffle the dataset for better training
random.shuffle(training_dataset)

# Save the enhanced dataset
output_file = '/enhanced_medical_training_data.json'
with open(output_file, 'w') as f:
    json.dump(training_dataset, f, indent=2)

print(f"\nEnhanced dataset saved to: {output_file}")
print(f"Total training examples created: {len(training_dataset)}")

# Print detailed statistics
urgency_counts = Counter([item['metadata']['urgency'] for item in training_dataset])
specialty_counts = Counter([item['metadata']['specialty'] for item in training_dataset])

print("\nüìä ENHANCED DATASET STATISTICS:")
print("=" * 50)
print(f"Urgency distribution:")
for urgency, count in urgency_counts.items():
    percentage = (count / len(training_dataset)) * 100
    print(f"  {urgency.upper()}: {count} cases ({percentage:.1f}%)")

print(f"\nSpecialty distribution:")
for specialty, count in specialty_counts.items():
    percentage = (count / len(training_dataset)) * 100
    print(f"  {specialty}: {count} cases ({percentage:.1f}%)")

# Display examples from each urgency level
print("\nüìù SAMPLE CASES FROM EACH URGENCY LEVEL:")
print("=" * 50)

# Show 2 examples from each urgency level
for urgency_level in ['low', 'medium', 'high']:
    print(f"\n{urgency_level.upper()} URGENCY EXAMPLES:")
    urgency_cases = [item for item in training_dataset if item['metadata']['urgency'] == urgency_level]
    for i, case in enumerate(urgency_cases[:2]):
        print(f"  {i+1}. {case['text']}")
        print(f"     Specialty: {case['metadata']['specialty']}")

# Create a balanced dataset report
print(f"\nüéØ DATASET BALANCE REPORT:")
print("=" * 50)
total_cases = len(training_dataset)
print(f"Total cases: {total_cases}")
print(f"Low urgency: {urgency_counts['low']} ({urgency_counts['low']/total_cases*100:.1f}%)")
print(f"Medium urgency: {urgency_counts['medium']} ({urgency_counts['medium']/total_cases*100:.1f}%)")
print(f"High urgency: {urgency_counts['high']} ({urgency_counts['high']/total_cases*100:.1f}%)")

# Save sample file
sample_output = training_dataset[:15]
sample_file = '/enhanced_sample_training_data.json'
with open(sample_file, 'w') as f:
    json.dump(sample_output, f, indent=2)

print(f"\nSample preview saved to: {sample_file}")
print("\n‚úÖ Enhanced dataset ready for model training!")

Loading spaCy medical model...
Medical spaCy model not found. Using basic English model...
Loading dataset...
Processing medical conversations...
Adding synthetic medical cases for better dataset balance...

Enhanced dataset saved to: /enhanced_medical_training_data.json
Total training examples created: 76

üìä ENHANCED DATASET STATISTICS:
Urgency distribution:
  HIGH: 33 cases (43.4%)
  LOW: 18 cases (23.7%)
  MEDIUM: 25 cases (32.9%)

Specialty distribution:
  Cardiology: 15 cases (19.7%)
  Musculoskeletal: 25 cases (32.9%)
  Respiratory: 9 cases (11.8%)
  General Medicine: 18 cases (23.7%)
  Gastroenterology: 9 cases (11.8%)

üìù SAMPLE CASES FROM EACH URGENCY LEVEL:

LOW URGENCY EXAMPLES:
  1. Minor joint discomfort that resolves with rest and over-the-counter pain relief
     Specialty: Musculoskeletal
  2. Medication refill request for chronic stable condition
     Specialty: General Medicine

MEDIUM URGENCY EXAMPLES:
  1. Gastroenteritis with vomiting and diarrhea for 12 hours

In [None]:
pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m84.1/84.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [None]:
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup, TrainingArguments, Trainer
)
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report, confusion_matrix
import pandas as pd
import numpy as np
import json
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os
import random
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from tqdm.auto import tqdm
from datasets import Dataset
import evaluate

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class EnhancedMedicalBERTTrainer:
    def __init__(self):
        self.model_name = 'emilyalsentzer/Bio_ClinicalBERT'
        self.tokenizer = None
        self.model = None
        self.urgency_labels = ['low', 'medium', 'high']
        self.class_weights = None

    def load_and_preprocess_data(self, data_path):
        """Load and preprocess data with enhanced balancing"""
        print("Loading training data...")
        with open(data_path, 'r') as f:
            data = json.load(f)

        print(f"Total samples: {len(data)}")

        # Analyze data distribution
        urgencies = [item['metadata']['urgency'] for item in data]
        urgency_counts = Counter(urgencies)
        print(f"Urgency distribution: {dict(urgency_counts)}")

        # Calculate class weights for imbalance handling
        self.calculate_class_weights(urgencies)

        # Enhanced stratified split
        train_data, val_data, test_data = self.enhanced_stratified_split(data)

        print(f"\nData splits:")
        print(f"Train: {len(train_data)}")
        print(f"Validation: {len(val_data)}")
        print(f"Test: {len(test_data)}")

        # Print distribution for each split
        for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
            split_urgencies = [item['metadata']['urgency'] for item in split_data]
            print(f"{split_name} distribution: {dict(Counter(split_urgencies))}")

        return train_data, val_data, test_data

    def calculate_class_weights(self, urgencies):
        """Calculate class weights for handling imbalance"""
        urgency_counts = Counter(urgencies)
        total_samples = len(urgencies)

        # Inverse frequency weighting
        weights = {}
        for label in self.urgency_labels:
            if label in urgency_counts:
                weights[label] = total_samples / (len(self.urgency_labels) * urgency_counts[label])
            else:
                weights[label] = 1.0  # Default weight if class missing

        # Convert to tensor for training
        self.class_weights = torch.tensor([weights[label] for label in self.urgency_labels]).float().to(device)
        print(f"Class weights: {weights}")

    def enhanced_stratified_split(self, data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
        """Enhanced stratified split ensuring all classes in all splits"""
        # Group by urgency
        urgency_groups = {}
        for item in data:
            urgency = item['metadata']['urgency']
            if urgency not in urgency_groups:
                urgency_groups[urgency] = []
            urgency_groups[urgency].append(item)

        train_data, val_data, test_data = [], [], []

        for urgency, items in urgency_groups.items():
            # Shuffle items
            random.shuffle(items)
            n_items = len(items)

            # Calculate split indices
            train_end = int(train_ratio * n_items)
            val_end = train_end + int(val_ratio * n_items)

            # Ensure at least 1 sample in each split for small classes
            if n_items < 3:
                train_data.extend(items)
                continue

            train_data.extend(items[:train_end])
            val_data.extend(items[train_end:val_end])
            test_data.extend(items[val_end:])

        return train_data, val_data, test_data

    def create_balanced_dataloader(self, data, batch_size=8):
        """Create balanced dataloader with weighted sampling"""
        texts = [item['text'] for item in data]
        urgencies = [item['metadata']['urgency'] for item in data]
        labels = [self.urgency_labels.index(urg) for urg in urgencies]

        # Tokenize
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=256,
            return_tensors='pt'
        )

        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            torch.tensor(labels)
        )

        # Calculate sample weights for balanced sampling
        class_counts = Counter(labels)
        sample_weights = [1.0 / class_counts[label] for label in labels]
        sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

        return DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    def prepare_standard_dataloader(self, data, batch_size=8):
        """Prepare standard dataloader without balancing"""
        if self.tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        texts = [item['text'] for item in data]
        urgencies = [item['metadata']['urgency'] for item in data]
        labels = [self.urgency_labels.index(urg) for urg in urgencies]

        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=256,
            return_tensors='pt'
        )

        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            torch.tensor(labels)
        )

        return DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def train_with_advanced_techniques(self, train_loader, val_loader, test_loader, output_dir):
        """Train with advanced techniques for better class handling"""
        print("Initializing Bio+Clinical BERT...")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=3,
            id2label={0: 'low', 1: 'medium', 2: 'high'},
            label2id={'low': 0, 'medium': 1, 'high': 2},
            attention_probs_dropout_prob=0.1,
            hidden_dropout_prob=0.1
        )
        self.model.to(device)
        print("‚úÖ Bio+Clinical BERT model loaded")

        # Advanced training parameters
        epochs = 15
        learning_rate = 2e-5
        warmup_ratio = 0.1

        # Optimizer with different learning rates for different layers
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters()
                          if not any(nd in n for nd in no_decay) and 'classifier' not in n],
                'weight_decay': 0.01,
                'lr': learning_rate
            },
            {
                'params': [p for n, p in self.model.named_parameters()
                          if any(nd in n for nd in no_decay) and 'classifier' not in n],
                'weight_decay': 0.0,
                'lr': learning_rate
            },
            {
                'params': [p for n, p in self.model.named_parameters() if 'classifier' in n],
                'weight_decay': 0.01,
                'lr': learning_rate * 2  # Higher LR for classifier
            }
        ]

        optimizer = AdamW(optimizer_grouped_parameters)
        total_steps = len(train_loader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(warmup_ratio * total_steps),
            num_training_steps=total_steps
        )

        # Training loop with early stopping
        training_losses = []
        val_accuracies = []
        val_f1_scores = []
        best_f1 = 0
        patience = 5
        patience_counter = 0

        print("Starting advanced training...")
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            total_train_loss = 0
            all_predictions = []
            all_true_labels = []

            progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
            for batch in progress_bar:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                self.model.zero_grad()
                outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)

                # Custom loss with class weights
                loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
                logits = outputs.logits
                loss = loss_fct(logits.view(-1, 3), labels.view(-1))

                total_train_loss += loss.item()
                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

                # Collect predictions for training metrics
                predictions = torch.argmax(logits, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_true_labels.extend(labels.cpu().numpy())

                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })

            avg_train_loss = total_train_loss / len(train_loader)
            training_losses.append(avg_train_loss)

            # Calculate training metrics
            train_accuracy = accuracy_score(all_true_labels, all_predictions)
            train_f1 = precision_recall_fscore_support(all_true_labels, all_predictions, average='weighted')[2]

            # Validation phase
            val_metrics = self.comprehensive_evaluation(val_loader)
            val_accuracy = val_metrics['accuracy']
            val_f1 = val_metrics['f1_weighted']
            val_accuracies.append(val_accuracy)
            val_f1_scores.append(val_f1)

            print(f'\nEpoch {epoch+1}:')
            print(f'  Train Loss: {avg_train_loss:.4f}')
            print(f'  Train Accuracy: {train_accuracy:.4f}, Train F1: {train_f1:.4f}')
            print(f'  Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}')
            print(f'  Val Precision: {val_metrics["precision_weighted"]:.4f}')
            print(f'  Val Recall: {val_metrics["recall_weighted"]:.4f}')

            # Print per-class metrics
            print(f'  Per-class F1: Low={val_metrics["f1_per_class"][0]:.4f}, '
                  f'Medium={val_metrics["f1_per_class"][1]:.4f}, '
                  f'High={val_metrics["f1_per_class"][2]:.4f}')

            # Early stopping based on F1 score
            if val_f1 > best_f1:
                best_f1 = val_f1
                patience_counter = 0
                self.model.save_pretrained(output_dir)
                self.tokenizer.save_pretrained(output_dir)
                print(f'  ‚úÖ Saved best model (F1: {val_f1:.4f})')
            else:
                patience_counter += 1
                print(f'  ‚è≥ No improvement ({patience_counter}/{patience})')

            if patience_counter >= patience:
                print(f'  üõë Early stopping at epoch {epoch+1}')
                break

        # Final evaluation on test set
        print("\nüî¨ Final Evaluation on Test Set:")
        test_metrics = self.comprehensive_evaluation(test_loader)
        self.plot_comprehensive_results(training_losses, val_accuracies, val_f1_scores, test_metrics)

        return training_losses, val_accuracies, val_f1_scores, test_metrics

    def comprehensive_evaluation(self, dataloader):
        """Comprehensive evaluation with multiple metrics"""
        self.model.eval()
        predictions, true_labels = [], []

        with torch.no_grad():
            for batch in dataloader:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                batch_predictions = torch.argmax(logits, dim=1)

                predictions.extend(batch_predictions.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        # Calculate comprehensive metrics
        accuracy = accuracy_score(true_labels, predictions)
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
            true_labels, predictions, average='weighted', zero_division=0
        )
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            true_labels, predictions, average='macro', zero_division=0
        )

        # Per-class metrics
        precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
            true_labels, predictions, average=None, zero_division=0
        )

        # Confusion matrix
        cm = confusion_matrix(true_labels, predictions, labels=[0, 1, 2])

        return {
            'accuracy': accuracy,
            'precision_weighted': precision_weighted,
            'recall_weighted': recall_weighted,
            'f1_weighted': f1_weighted,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'precision_per_class': precision_per_class.tolist(),
            'recall_per_class': recall_per_class.tolist(),
            'f1_per_class': f1_per_class.tolist(),
            'support_per_class': support_per_class.tolist(),
            'confusion_matrix': cm.tolist(),
            'predictions': predictions,
            'true_labels': true_labels
        }

    def plot_comprehensive_results(self, train_losses, val_accuracies, val_f1_scores, test_metrics):
        """Plot comprehensive training results"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

        # Plot 1: Training loss
        ax1.plot(train_losses, 'b-', linewidth=2, label='Training Loss')
        ax1.set_title('Training Loss Over Epochs')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.grid(True, alpha=0.3)
        ax1.legend()

        # Plot 2: Validation metrics
        epochs = range(1, len(val_accuracies) + 1)
        ax2.plot(epochs, val_accuracies, 'g-', linewidth=2, label='Validation Accuracy')
        ax2.plot(epochs, val_f1_scores, 'r-', linewidth=2, label='Validation F1')
        ax2.set_title('Validation Metrics Over Epochs')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Score')
        ax2.grid(True, alpha=0.3)
        ax2.legend()

        # Plot 3: Confusion matrix
        cm = np.array(test_metrics['confusion_matrix'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax3,
                   xticklabels=self.urgency_labels,
                   yticklabels=self.urgency_labels)
        ax3.set_title('Test Set Confusion Matrix')
        ax3.set_xlabel('Predicted')
        ax3.set_ylabel('Actual')

        # Plot 4: Per-class F1 scores
        classes = self.urgency_labels
        f1_scores = test_metrics['f1_per_class']
        colors = ['green', 'orange', 'red']
        bars = ax4.bar(classes, f1_scores, color=colors, alpha=0.7)
        ax4.set_title('Per-class F1 Scores on Test Set')
        ax4.set_ylabel('F1 Score')
        ax4.set_ylim(0, 1)

        # Add value labels on bars
        for bar, score in zip(bars, f1_scores):
            ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{score:.3f}', ha='center', va='bottom')

        plt.tight_layout()
        plt.savefig('/comprehensive_training_results.png', dpi=300, bbox_inches='tight')
        plt.close()

    def analyze_class_performance(self, test_metrics):
        """Analyze performance for each class in detail"""
        print("\nüìä DETAILED CLASS PERFORMANCE ANALYSIS:")
        print("=" * 60)

        classes = self.urgency_labels
        precision = test_metrics['precision_per_class']
        recall = test_metrics['recall_per_class']
        f1 = test_metrics['f1_per_class']
        support = test_metrics['support_per_class']

        for i, class_name in enumerate(classes):
            print(f"\n{class_name.upper()} Urgency:")
            print(f"  Precision: {precision[i]:.4f}")
            print(f"  Recall:    {recall[i]:.4f}")
            print(f"  F1-Score:  {f1[i]:.4f}")
            print(f"  Support:   {support[i]} samples")

            # Performance interpretation
            if f1[i] >= 0.8:
                status = "‚úÖ EXCELLENT"
            elif f1[i] >= 0.7:
                status = "‚ö†Ô∏è GOOD"
            elif f1[i] >= 0.6:
                status = "üî∂ FAIR"
            else:
                status = "‚ùå NEEDS IMPROVEMENT"
            print(f"  Status:    {status}")

def main():
    print("üöÄ ENHANCED MEDICAL BERT TRAINING FOR ALL URGENCY CLASSES")
    print("=" * 70)

    # Initialize enhanced trainer
    trainer = EnhancedMedicalBERTTrainer()

    # Load enhanced dataset
    data_path = '/enhanced_medical_training_data.json'
    train_data, val_data, test_data = trainer.load_and_preprocess_data(data_path)

    # Prepare tokenizer
    trainer.tokenizer = AutoTokenizer.from_pretrained(trainer.model_name)

    # Create dataloaders
    print("\nüìö Preparing dataloaders...")
    train_loader = trainer.create_balanced_dataloader(train_data, batch_size=8)
    val_loader = trainer.prepare_standard_dataloader(val_data, batch_size=8)
    test_loader = trainer.prepare_standard_dataloader(test_data, batch_size=8)

    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

    # Train model
    output_dir = '/enhanced_medical_bert_model'
    os.makedirs(output_dir, exist_ok=True)

    print(f"\nüè• Starting Enhanced Training...")
    train_losses, val_accuracies, val_f1_scores, test_metrics = trainer.train_with_advanced_techniques(
        train_loader, val_loader, test_loader, output_dir
    )

    # Detailed analysis
    trainer.analyze_class_performance(test_metrics)

    # Generate comprehensive report
    generate_enhanced_report(test_metrics, output_dir)

    print("\n" + "=" * 70)
    print("üéØ ENHANCED TRAINING COMPLETED SUCCESSFULLY!")
    print("=" * 70)

def generate_enhanced_report(test_metrics, model_path):
    """Generate comprehensive performance report"""
    report = {
        'timestamp': str(pd.Timestamp.now()),
        'model': 'Enhanced Bio+Clinical BERT',
        'test_performance': {
            'accuracy': test_metrics['accuracy'],
            'precision_weighted': test_metrics['precision_weighted'],
            'recall_weighted': test_metrics['recall_weighted'],
            'f1_weighted': test_metrics['f1_weighted'],
            'precision_macro': test_metrics['precision_macro'],
            'recall_macro': test_metrics['recall_macro'],
            'f1_macro': test_metrics['f1_macro'],
        },
        'per_class_performance': {
            'low': {
                'precision': test_metrics['precision_per_class'][0],
                'recall': test_metrics['recall_per_class'][0],
                'f1': test_metrics['f1_per_class'][0],
                'support': test_metrics['support_per_class'][0]
            },
            'medium': {
                'precision': test_metrics['precision_per_class'][1],
                'recall': test_metrics['recall_per_class'][1],
                'f1': test_metrics['f1_per_class'][1],
                'support': test_metrics['support_per_class'][1]
            },
            'high': {
                'precision': test_metrics['precision_per_class'][2],
                'recall': test_metrics['recall_per_class'][2],
                'f1': test_metrics['f1_per_class'][2],
                'support': test_metrics['support_per_class'][2]
            }
        },
        'confusion_matrix': test_metrics['confusion_matrix'],
        'performance_interpretation': {
            'overall_quality': 'EXCELLENT' if test_metrics['f1_weighted'] > 0.85 else
                              'GOOD' if test_metrics['f1_weighted'] > 0.75 else
                              'FAIR' if test_metrics['f1_weighted'] > 0.65 else 'POOR',
            'class_balance_quality': 'BALANCED' if min(test_metrics['f1_per_class']) > 0.7 else
                                   'MODERATE' if min(test_metrics['f1_per_class']) > 0.6 else 'IMBALANCED',
            'recommendation': 'Ready for deployment' if test_metrics['f1_weighted'] > 0.8 and min(test_metrics['f1_per_class']) > 0.7 else
                            'Suitable for prototype' if test_metrics['f1_weighted'] > 0.7 else
                            'Needs improvement'
        }
    }

    with open('/enhanced_medical_bert_report.json', 'w') as f:
        json.dump(report, f, indent=2)

    print("\nüìä ENHANCED PERFORMANCE REPORT:")
    print("=" * 50)
    print(f"Overall F1-Score: {test_metrics['f1_weighted']:.4f}")
    print(f"Overall Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Macro F1-Score: {test_metrics['f1_macro']:.4f}")

    print(f"\nPer-class Performance:")
    print(f"  Low:    F1={test_metrics['f1_per_class'][0]:.4f}")
    print(f"  Medium: F1={test_metrics['f1_per_class'][1]:.4f}")
    print(f"  High:   F1={test_metrics['f1_per_class'][2]:.4f}")

    print(f"\nStatus: {report['performance_interpretation']['overall_quality']}")
    print(f"Recommendation: {report['performance_interpretation']['recommendation']}")

if __name__ == "__main__":
    main()

Using device: cpu
üöÄ ENHANCED MEDICAL BERT TRAINING FOR ALL URGENCY CLASSES
Loading training data...
Total samples: 76
Urgency distribution: {'high': 33, 'low': 18, 'medium': 25}
Class weights: {'low': 1.4074074074074074, 'medium': 1.0133333333333334, 'high': 0.7676767676767676}

Data splits:
Train: 52
Validation: 9
Test: 15
Train distribution: {'high': 23, 'low': 12, 'medium': 17}
Validation distribution: {'high': 4, 'low': 2, 'medium': 3}
Test distribution: {'high': 6, 'low': 4, 'medium': 5}


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]


üìö Preparing dataloaders...
Train batches: 7, Val batches: 2, Test batches: 2

üè• Starting Enhanced Training...
Initializing Bio+Clinical BERT...


pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Bio+Clinical BERT model loaded
Starting advanced training...


Epoch 1/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 1:
  Train Loss: 1.1017
  Train Accuracy: 0.4423, Train F1: 0.4223
  Val Accuracy: 0.5556, Val F1: 0.4343
  Val Precision: 0.3651
  Val Recall: 0.5556
  Per-class F1: Low=0.5000, Medium=0.0000, High=0.7273
  ‚úÖ Saved best model (F1: 0.4343)


Epoch 2/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 2:
  Train Loss: 0.8673
  Train Accuracy: 0.6538, Train F1: 0.6127
  Val Accuracy: 0.6667, Val F1: 0.6380
  Val Precision: 0.8095
  Val Recall: 0.6667
  Per-class F1: Low=0.6667, Medium=0.5000, High=0.7273
  ‚úÖ Saved best model (F1: 0.6380)


Epoch 3/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 3:
  Train Loss: 0.8240
  Train Accuracy: 0.6346, Train F1: 0.6109
  Val Accuracy: 0.7778, Val F1: 0.7704
  Val Precision: 0.8519
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.8000, High=0.8000
  ‚úÖ Saved best model (F1: 0.7704)


Epoch 4/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 4:
  Train Loss: 0.6122
  Train Accuracy: 0.9231, Train F1: 0.9232
  Val Accuracy: 0.7778, Val F1: 0.7654
  Val Precision: 0.8000
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.6667, High=0.8889
  ‚è≥ No improvement (1/5)


Epoch 5/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 5:
  Train Loss: 0.4748
  Train Accuracy: 0.9615, Train F1: 0.9609
  Val Accuracy: 0.7778, Val F1: 0.7654
  Val Precision: 0.8000
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.6667, High=0.8889
  ‚è≥ No improvement (2/5)


Epoch 6/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 6:
  Train Loss: 0.4662
  Train Accuracy: 0.8269, Train F1: 0.8245
  Val Accuracy: 0.7778, Val F1: 0.7654
  Val Precision: 0.8000
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.6667, High=0.8889
  ‚è≥ No improvement (3/5)


Epoch 7/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 7:
  Train Loss: 0.3508
  Train Accuracy: 0.8654, Train F1: 0.8649
  Val Accuracy: 0.7778, Val F1: 0.7704
  Val Precision: 0.8519
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.8000, High=0.8000
  ‚è≥ No improvement (4/5)


Epoch 8/15 [Train]:   0%|          | 0/7 [00:00<?, ?it/s]


Epoch 8:
  Train Loss: 0.2763
  Train Accuracy: 0.9615, Train F1: 0.9615
  Val Accuracy: 0.7778, Val F1: 0.7704
  Val Precision: 0.8519
  Val Recall: 0.7778
  Per-class F1: Low=0.6667, Medium=0.8000, High=0.8000
  ‚è≥ No improvement (5/5)
  üõë Early stopping at epoch 8

üî¨ Final Evaluation on Test Set:

üìä DETAILED CLASS PERFORMANCE ANALYSIS:

LOW Urgency:
  Precision: 0.7500
  Recall:    0.7500
  F1-Score:  0.7500
  Support:   4 samples
  Status:    ‚ö†Ô∏è GOOD

MEDIUM Urgency:
  Precision: 0.3333
  Recall:    0.4000
  F1-Score:  0.3636
  Support:   5 samples
  Status:    ‚ùå NEEDS IMPROVEMENT

HIGH Urgency:
  Precision: 0.6000
  Recall:    0.5000
  F1-Score:  0.5455
  Support:   6 samples
  Status:    ‚ùå NEEDS IMPROVEMENT

üìä ENHANCED PERFORMANCE REPORT:
Overall F1-Score: 0.5394
Overall Accuracy: 0.5333
Macro F1-Score: 0.5530

Per-class Performance:
  Low:    F1=0.7500
  Medium: F1=0.3636
  High:   F1=0.5455

Status: POOR
Recommendation: Needs improvement

üéØ ENHANCED TRA

In [None]:
# verification_script.py
import json

# Load and verify the generated dataset
def verify_dataset(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)

    print(f"Verifying dataset: {file_path}")
    print(f"Total examples: {len(data)}")

    # Check required format
    required_keys = ['text', 'metadata']
    required_metadata_keys = ['specialty', 'urgency']

    valid_count = 0
    for i, item in enumerate(data):
        # Check main keys
        if not all(key in item for key in required_keys):
            print(f"‚ùå Example {i}: Missing required keys")
            continue

        # Check metadata keys
        if not all(key in item['metadata'] for key in required_metadata_keys):
            print(f"‚ùå Example {i}: Missing metadata keys")
            continue

        # Check data types
        if not isinstance(item['text'], str) or len(item['text']) == 0:
            print(f"‚ùå Example {i}: Invalid text")
            continue

        if not isinstance(item['metadata']['specialty'], str) or not isinstance(item['metadata']['urgency'], str):
            print(f"‚ùå Example {i}: Invalid metadata types")
            continue

        valid_count += 1

    print(f"‚úÖ Valid examples: {valid_count}/{len(data)}")

    # Show samples
    print("\nSample valid examples:")
    for i, item in enumerate(data[:3]):
        print(f"{i+1}. {item}")

# Verify the main dataset
verify_dataset('/spacy_medical_training_data.json')

Verifying dataset: /spacy_medical_training_data.json
Total examples: 30
‚úÖ Valid examples: 30/30

Sample valid examples:
1. {'text': 'Patient experiencing chest pain with emergency symptoms', 'metadata': {'specialty': 'Musculoskeletal', 'urgency': 'high'}}
2. {'text': 'Patient experiencing chest pain with emergency symptoms', 'metadata': {'specialty': 'Musculoskeletal', 'urgency': 'high'}}
3. {'text': 'Patient reporting pain symptoms with concerning symptoms', 'metadata': {'specialty': 'Gastroenterology', 'urgency': 'medium'}}


In [None]:
# medical_bert_training_fixed.py
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report
import pandas as pd
import numpy as np
import json
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os
import random
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def convert_to_serializable(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj

class MedicalBERTTrainer:
    def __init__(self):
        # Using Bio+Clinical BERT specifically trained on medical texts
        self.model_name = 'emilyalsentzer/Bio_ClinicalBERT'
        self.tokenizer = None
        self.model = None
        self.urgency_labels = ['low', 'medium', 'high']

    def load_and_preprocess_data(self, data_path):
        """Load and preprocess the medical training data"""
        print("Loading training data...")
        with open(data_path, 'r') as f:
            data = json.load(f)

        print(f"Total samples: {len(data)}")

        # Analyze data distribution
        urgencies = [item['metadata']['urgency'] for item in data]
        specialties = [item['metadata']['specialty'] for item in data]

        print(f"Urgencies: {dict(Counter(urgencies))}")
        print(f"Specialties: {dict(Counter(specialties))}")

        # Stratified split to ensure all classes are represented
        train_data, val_data = self.stratified_split(data)

        print(f"Train: {len(train_data)}, Validation: {len(val_data)}")
        print(f"Train distribution: {dict(Counter([item['metadata']['urgency'] for item in train_data]))}")
        print(f"Val distribution: {dict(Counter([item['metadata']['urgency'] for item in val_data]))}")

        return train_data, val_data

    def stratified_split(self, data, val_ratio=0.2):
        """Stratified split to maintain class distribution"""
        # Group by urgency
        urgency_groups = {}
        for item in data:
            urgency = item['metadata']['urgency']
            if urgency not in urgency_groups:
                urgency_groups[urgency] = []
            urgency_groups[urgency].append(item)

        train_data, val_data = [], []

        for urgency, items in urgency_groups.items():
            # Shuffle items
            random.shuffle(items)
            # Calculate split index
            split_idx = max(1, int(len(items) * (1 - val_ratio)))  # Ensure at least 1 in val
            train_data.extend(items[:split_idx])
            val_data.extend(items[split_idx:])

        return train_data, val_data

    def augment_medical_data(self, data):
        """Medical-specific data augmentation"""
        augmented_data = data.copy()

        # Check current distribution
        urgency_counts = Counter([item['metadata']['urgency'] for item in data])
        print(f"Original distribution: {urgency_counts}")

        # Medical-specific augmentation for each urgency level
        for urgency_level in self.urgency_labels:
            current_count = urgency_counts.get(urgency_level, 0)
            samples = [item for item in data if item['metadata']['urgency'] == urgency_level]

            # Target minimum samples per class
            target_min = 8
            if current_count < target_min and samples:
                needed = target_min - current_count
                for i in range(needed):
                    if len(augmented_data) >= 40:  # Max total samples
                        break
                    sample = random.choice(samples)
                    new_sample = self.create_medical_augmented_sample(sample, urgency_level)
                    augmented_data.append(new_sample)

        print(f"After augmentation - Total: {len(augmented_data)}")
        print(f"Augmented distribution: {dict(Counter([item['metadata']['urgency'] for item in augmented_data]))}")
        return augmented_data

    def create_medical_augmented_sample(self, sample, urgency_level):
        """Create medically relevant augmented samples"""
        text = sample['text']
        specialty = sample['metadata']['specialty']

        # Medical-specific augmentation patterns
        if urgency_level == 'high':
            augmentations = [
                f"Emergency presentation: {text}",
                f"Critical condition with {text}",
                f"Urgent medical attention required for {text}",
                text.replace('pain', 'severe acute pain') if 'pain' in text.lower() else f"Acute {text}"
            ]
        elif urgency_level == 'medium':
            augmentations = [
                f"Patient presents with {text}",
                f"Clinical evaluation for {text}",
                f"Medical consultation regarding {text}",
                text.replace('mild', 'moderate') if 'mild' in text.lower() else f"Moderate {text}"
            ]
        else:  # low
            augmentations = [
                f"Routine medical follow-up: {text}",
                f"Preventive care consultation for {text}",
                f"Stable condition with {text}",
                text.replace('severe', 'mild') if 'severe' in text.lower() else f"Mild {text}"
            ]

        new_text = random.choice(augmentations)

        return {
            'text': new_text[:200],  # Reasonable length for medical texts
            'metadata': {
                'specialty': specialty,
                'urgency': urgency_level
            }
        }

    def prepare_medical_dataloader(self, data, batch_size=4):
        """Prepare DataLoader for medical BERT training"""
        if self.tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            print("‚úÖ Bio+Clinical BERT tokenizer loaded")

        texts = [item['text'] for item in data]
        urgencies = [item['metadata']['urgency'] for item in data]

        # Convert to numerical labels
        labels = [self.urgency_labels.index(urg) for urg in urgencies]

        # Tokenize with medical BERT
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=256,
            return_tensors='pt'
        )

        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            torch.tensor(labels)
        )

        return DataLoader(dataset, batch_size=batch_size, shuffle=True)

    def train_medical_bert(self, train_loader, val_loader, output_dir):
        """Train the Bio+Clinical BERT model"""
        print("Initializing Bio+Clinical BERT...")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=3,
            id2label={0: 'low', 1: 'medium', 2: 'high'},
            label2id={'low': 0, 'medium': 1, 'high': 2}
        )
        self.model.to(device)
        print("‚úÖ Bio+Clinical BERT model loaded and ready for training")

        # Training parameters optimized for medical BERT
        epochs = 10
        learning_rate = 1e-5

        # Optimizer
        optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        total_steps = len(train_loader) * epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )

        # Training loop
        print("Starting Bio+Clinical BERT training...")
        training_losses = []
        val_accuracies = []
        best_accuracy = 0

        for epoch in range(epochs):
            # Training
            self.model.train()
            total_train_loss = 0

            progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Medical BERT]')
            for batch in progress_bar:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                self.model.zero_grad()

                outputs = self.model(
                    input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss
                total_train_loss += loss.item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })

            avg_train_loss = total_train_loss / len(train_loader)
            training_losses.append(avg_train_loss)

            # Validation
            val_accuracy = self.evaluate_medical_model(val_loader)
            val_accuracies.append(val_accuracy)

            print(f'Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Accuracy = {val_accuracy:.4f}')

            # Save best model
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                self.model.save_pretrained(output_dir)
                self.tokenizer.save_pretrained(output_dir)
                print(f"‚úÖ Saved best model (accuracy: {val_accuracy:.4f})")

        # Plot training history
        self.plot_medical_training_history(training_losses, val_accuracies)

        return training_losses, val_accuracies

    def evaluate_medical_model(self, val_loader):
        """Evaluate model on validation set"""
        self.model.eval()
        predictions, true_labels = [], []

        with torch.no_grad():
            for batch in val_loader:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                predictions.extend(torch.argmax(logits, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        return accuracy_score(true_labels, predictions)

    def plot_medical_training_history(self, train_losses, val_accuracies):
        """Plot training history for medical model"""
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(train_losses, 'b-', label='Training Loss', linewidth=2)
        plt.title('Bio+Clinical BERT - Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True, alpha=0.3)
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
        plt.title('Bio+Clinical BERT - Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.grid(True, alpha=0.3)
        plt.legend()

        plt.savefig('/medical_bert_training_history.png',
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()

    def comprehensive_medical_evaluation(self, val_loader):
        """Comprehensive medical-specific evaluation"""
        self.model.eval()
        predictions, true_labels = [], []

        with torch.no_grad():
            for batch in val_loader:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                predictions.extend(torch.argmax(logits, dim=1).cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        # Handle available classes
        unique_labels = np.unique(true_labels)
        available_labels = [self.urgency_labels[i] for i in unique_labels]

        # Calculate comprehensive metrics
        accuracy = float(accuracy_score(true_labels, predictions))
        precision, recall, f1, _ = precision_recall_fscore_support(
            true_labels, predictions, average='weighted', zero_division=0
        )

        # Convert to Python native types
        precision = float(precision)
        recall = float(recall)
        f1 = float(f1)

        # Create classification report with only available classes
        target_names = [self.urgency_labels[i] for i in unique_labels]
        class_report = classification_report(
            true_labels, predictions,
            target_names=target_names,
            output_dict=True,
            zero_division=0
        )

        # Convert class_report to serializable
        class_report = convert_to_serializable(class_report)

        # Confusion matrix
        cm = np.zeros((3, 3))
        for true, pred in zip(true_labels, predictions):
            cm[true][pred] += 1

        # Plot medical-themed confusion matrix
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='.0f', cmap='RdYlGn_r',
                   xticklabels=self.urgency_labels,
                   yticklabels=self.urgency_labels,
                   cbar_kws={'label': 'Number of Cases'})
        plt.title('Medical Urgency Classification - Confusion Matrix\n(Bio+Clinical BERT)',
                 fontsize=14, fontweight='bold', pad=20)
        plt.ylabel('True Medical Urgency', fontweight='bold')
        plt.xlabel('Predicted Medical Urgency', fontweight='bold')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig('/medical_bert_confusion_matrix.png',
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'available_classes': available_labels,
            'class_report': class_report,
            'confusion_matrix': cm.tolist()
        }

class MedicalBERTPredictor:
    def __init__(self, model_path='/medical_bert_model'):
        try:
            self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model.to(device)
            self.model.eval()
            self.urgency_labels = ['low', 'medium', 'high']
            print("‚úÖ Bio+Clinical BERT model loaded successfully!")
        except Exception as e:
            print(f"‚ùå Error loading Bio+Clinical BERT model: {e}")
            self.model = None
            self.tokenizer = None

    def predict_medical_urgency(self, text, confidence_threshold=0.6):
        """Predict medical urgency using Bio+Clinical BERT"""
        if self.model is None or self.tokenizer is None:
            return self.medical_rule_based_prediction(text)

        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=256,
            padding=True
        )

        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=1).item()
            confidence = probabilities[0][predicted_class].item()

        result = {
            'urgency': self.urgency_labels[predicted_class],
            'confidence': confidence,
            'probabilities': {
                'low': float(probabilities[0][0]),
                'medium': float(probabilities[0][1]),
                'high': float(probabilities[0][2])
            },
            'model': 'Bio+Clinical BERT'
        }

        if confidence < confidence_threshold:
            result['low_confidence'] = True

        return result

    def medical_rule_based_prediction(self, text):
        """Medical rule-based fallback with clinical terminology"""
        text_lower = text.lower()

        emergency_terms = [
            'chest pain', 'shortness of breath', 'difficulty breathing',
            'stroke', 'heart attack', 'severe pain', 'crushing pain'
        ]

        urgent_terms = [
            'fever', 'high fever', 'persistent cough', 'abdominal pain',
            'vomiting', 'infection', 'moderate pain'
        ]

        routine_terms = [
            'routine', 'follow-up', 'check-up', 'mild pain', 'chronic', 'stable'
        ]

        emergency_count = sum(1 for term in emergency_terms if term in text_lower)
        urgent_count = sum(1 for term in urgent_terms if term in text_lower)
        routine_count = sum(1 for term in routine_terms if term in text_lower)

        if emergency_count > 0:
            return {'urgency': 'high', 'confidence': 0.90, 'fallback': True}
        elif urgent_count > 0:
            return {'urgency': 'medium', 'confidence': 0.80, 'fallback': True}
        elif routine_count > 0:
            return {'urgency': 'low', 'confidence': 0.85, 'fallback': True}
        else:
            return {'urgency': 'medium', 'confidence': 0.70, 'fallback': True}

def main():
    print("üöÄ STARTING BIO+CLINICAL BERT MEDICAL CLASSIFIER TRAINING")
    print("=" * 70)

    # Initialize medical BERT trainer
    trainer = MedicalBERTTrainer()

    # Load data
    data_path = '/spacy_medical_training_data.json'
    train_data, val_data = trainer.load_and_preprocess_data(data_path)

    # Apply medical-specific augmentation
    print("\nü©∫ Applying medical-specific data augmentation...")
    train_data_augmented = trainer.augment_medical_data(train_data)

    # Prepare medical data loaders
    print("üìö Preparing medical data loaders...")
    train_loader = trainer.prepare_medical_dataloader(train_data_augmented, batch_size=4)
    val_loader = trainer.prepare_medical_dataloader(val_data, batch_size=4)

    print(f"   Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

    # Train medical BERT model
    output_dir = '/medical_bert_model'
    os.makedirs(output_dir, exist_ok=True)

    print(f"\nüè• Training Bio+Clinical BERT for Medical Urgency Classification...")
    train_losses, val_accuracies = trainer.train_medical_bert(train_loader, val_loader, output_dir)

    # Comprehensive medical evaluation
    print("\nüî¨ Running comprehensive medical evaluation...")
    metrics = trainer.comprehensive_medical_evaluation(val_loader)

    # Generate medical performance report
    generate_medical_report(metrics, output_dir)

    # Test the medical model
    test_medical_model()

def generate_medical_report(metrics, model_path):
    """Generate comprehensive medical performance report"""
    # Convert all metrics to serializable types
    metrics = convert_to_serializable(metrics)

    report = {
        'timestamp': str(pd.Timestamp.now()),
        'model': 'Bio+Clinical BERT',
        'model_description': 'BERT model pre-trained on biomedical and clinical texts',
        'performance_metrics': {
            'accuracy': metrics['accuracy'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'f1_score': metrics['f1'],
        },
        'available_classes': metrics['available_classes'],
        'class_performance': metrics['class_report'],
        'confusion_matrix': metrics['confusion_matrix'],
        'medical_suitability': {
            'domain': 'Medical NLP - Urgency Classification',
            'training_data': 'Clinical conversations and medical texts',
            'suitable_for': ['Triage systems', 'Medical alert systems', 'Clinical decision support'],
            'limitations': ['Small training dataset', 'Requires medical validation']
        },
        'requirements_validation': {
            'profiling_accuracy': {
                'target': 'F1-score > 0.85',
                'achieved': metrics['f1'],
                'status': 'PASS' if metrics['f1'] > 0.85 else 'NEEDS_IMPROVEMENT'
            },
            'clinical_reliability': {
                'high_risk_recall': 'Critical for patient safety',
                'false_positive_rate': 'Should be minimized',
                'status': 'MEDICAL_VALIDATION_REQUIRED'
            }
        }
    }

    with open('/medical_bert_performance_report.json', 'w') as f:
        json.dump(report, f, indent=2)

    print("\nüìä MEDICAL BERT PERFORMANCE REPORT:")
    print("=" * 50)
    print(f"üè• Model: Bio+Clinical BERT")
    print(f"üìà Accuracy: {metrics['accuracy']:.4f}")
    print(f"üéØ F1-Score: {metrics['f1']:.4f}")
    print(f"üìç Precision: {metrics['precision']:.4f}")
    print(f"üîç Recall: {metrics['recall']:.4f}")
    print(f"üìã Available classes: {metrics['available_classes']}")

    if metrics['f1'] > 0.85:
        print("‚úÖ Profiling Accuracy: EXCELLENT (F1 > 0.85)")
    elif metrics['f1'] > 0.75:
        print("‚ö†Ô∏è Profiling Accuracy: GOOD (F1 > 0.75) - Suitable for prototype")
    else:
        print("üî¥ Profiling Accuracy: NEEDS IMPROVEMENT - Collect more medical data")

    print(f"\nüí° Medical Application Ready!")
    print("   Use for: Medical triage, Alert systems, Clinical support")

def test_medical_model():
    """Test the medical BERT model with clinical cases"""
    print("\nüß™ TESTING BIO+CLINICAL BERT MODEL")
    print("=" * 60)

    predictor = MedicalBERTPredictor()

    # Clinical test cases
    clinical_test_cases = [
        "Patient presents with acute chest pain radiating to left arm",
        "Fever of 102¬∞F with productive cough for 3 days",
        "Routine diabetes follow-up for medication adjustment",
        "Severe headache with photophobia and neck stiffness",
        "Mild seasonal allergies with sneezing"
    ]

    print("Clinical Case Predictions:")
    print("-" * 60)

    for i, case in enumerate(clinical_test_cases, 1):
        result = predictor.predict_medical_urgency(case)

        print(f"\n{i}. {case}")
        print(f"   üè• Urgency: {result['urgency'].upper()} (confidence: {result['confidence']:.3f})")

        if not result.get('fallback', False):
            print(f"   üìä Model: {result.get('model', 'Bio+Clinical BERT')}")
            print(f"   üìà Probabilities - Low: {result['probabilities']['low']:.3f}, "
                  f"Medium: {result['probabilities']['medium']:.3f}, "
                  f"High: {result['probabilities']['high']:.3f}")
        else:
            print("   ‚ö†Ô∏è Using medical rule-based fallback")

if __name__ == "__main__":
    main()

    print("\n" + "=" * 70)
    print("üéØ BIO+CLINICAL BERT TRAINING COMPLETED SUCCESSFULLY!")
    print("=" * 70)
    print("\nüìÅ Medical Model saved in: /kaggle/working/medical_bert_model/")
    print("üìä Medical Report: /kaggle/working/medical_bert_performance_report.json")
    print("üìà Training History: /kaggle/working/medical_bert_training_history.png")
    print("üîç Confusion Matrix: /kaggle/working/medical_bert_confusion_matrix.png")
    print("\n‚úÖ Medical NLP Model is ready for clinical applications!")

Using device: cpu
üöÄ STARTING BIO+CLINICAL BERT MEDICAL CLASSIFIER TRAINING
Loading training data...
Total samples: 30
Urgencies: {'high': 25, 'medium': 5}
Specialties: {'Musculoskeletal': 16, 'Gastroenterology': 4, 'Cardiology': 5, 'Respiratory': 5}
Train: 24, Validation: 6
Train distribution: {'high': 20, 'medium': 4}
Val distribution: {'high': 5, 'medium': 1}

ü©∫ Applying medical-specific data augmentation...
Original distribution: Counter({'high': 20, 'medium': 4})
After augmentation - Total: 28
Augmented distribution: {'high': 20, 'medium': 8}
üìö Preparing medical data loaders...
‚úÖ Bio+Clinical BERT tokenizer loaded
   Train batches: 7, Val batches: 2

üè• Training Bio+Clinical BERT for Medical Urgency Classification...
Initializing Bio+Clinical BERT...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Bio+Clinical BERT model loaded and ready for training
Starting Bio+Clinical BERT training...


Epoch 1/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 1: Train Loss = 1.3528, Val Accuracy = 0.1667
‚úÖ Saved best model (accuracy: 0.1667)


Epoch 2/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 2: Train Loss = 1.0064, Val Accuracy = 0.5000
‚úÖ Saved best model (accuracy: 0.5000)


Epoch 3/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 3: Train Loss = 0.8720, Val Accuracy = 0.8333
‚úÖ Saved best model (accuracy: 0.8333)


Epoch 4/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 4: Train Loss = 0.7703, Val Accuracy = 0.8333


Epoch 5/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 5: Train Loss = 0.6522, Val Accuracy = 0.8333


Epoch 6/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 6: Train Loss = 0.5870, Val Accuracy = 0.8333


Epoch 7/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 7: Train Loss = 0.5284, Val Accuracy = 0.8333


Epoch 8/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 8: Train Loss = 0.5102, Val Accuracy = 0.8333


Epoch 9/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch 9: Train Loss = 0.4522, Val Accuracy = 1.0000
‚úÖ Saved best model (accuracy: 1.0000)


Epoch 10/10 [Medical BERT]:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# Mount your Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create the target folder if it doesn't exist
!mkdir -p /content/drive/MyDrive/robi

# Copy the entire model directory to Drive
!cp -r medical_bert_model /content/drive/MyDrive/robi/

print("‚úÖ Model successfully copied to Google Drive at: /MyDrive/robi/medical_bert_model")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
cp: cannot stat 'medical_bert_model': No such file or directory
‚úÖ Model successfully copied to Google Drive at: /MyDrive/robi/medical_bert_model


In [None]:
# ============================================
# üì¶ Zip and Download Trained Model Folder
# ============================================

import shutil
from google.colab import files

# Change this if your model folder name or path is different
MODEL_DIR = "/medical_bert_model"
ZIP_PATH = f"{MODEL_DIR}.zip"

# 1Ô∏è‚É£ Zip the entire folder
shutil.make_archive(MODEL_DIR, 'zip', MODEL_DIR)
print(f"‚úÖ Zipped folder created at: {ZIP_PATH}")

# 2Ô∏è‚É£ Download the zip file to your local machine
files.download(ZIP_PATH)


‚úÖ Zipped folder created at: /medical_bert_model.zip


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
import pandas as pd
import json
import re
import spacy
from collections import Counter
import random

# Load spaCy medical model
print("Loading spaCy medical model...")
try:
    nlp = spacy.load("en_core_sci_md")
    print("Medical spaCy model loaded successfully!")
except OSError:
    print("Medical spaCy model not found. Using basic English model...")
    nlp = spacy.load("en_core_web_sm")

# Load your dataset
print("Loading dataset...")
df = pd.read_csv('/dataset_summary.csv')

# Enhanced urgency classification rules
URGENCY_RULES = {
    'high': [
        # Cardiac emergencies
        r'\b(severe|crushing|sharp) chest pain\b',
        r'\bheart attack\b',
        r'\bcardiac arrest\b',
        r'\bshortness of breath\b',
        r'\bdifficulty breathing\b',
        r'\btrouble breathing\b',
        r'\blightheaded\b',
        r'\bfainting\b',
        r'\bloss of consciousness\b',
        r'\bpassing out\b',
        r'\bradiating pain\b',
        r'\bpain spreading\b',
        r'\bsevere pain\b',
        r'\b10/10 pain\b',
        r'\b9/10 pain\b',
        r'\b8/10 pain\b',
        r'\bworst pain\b',

        # Other emergencies
        r'\bsevere bleeding\b',
        r'\buncontrolled bleeding\b',
        r'\bstroke\b',
        r'\bnumbness\b',
        r'\bweakness\b',
        r'\bparalysis\b',
        r'\bvision loss\b',
        r'\bsevere headache\b',
        r'\bseizure\b',
        r'\bunresponsive\b',
        r'\bcardiac\b',
        r'\bemergency\b',
        r'\bcritical\b'
    ],
    'medium': [
        r'\bchest pain\b',
        r'\bfever\b',
        r'\bpersistent cough\b',
        r'\bworsening symptoms\b',
        r'\babdominal pain\b',
        r'\bvomiting\b',
        r'\bdiarrhea\b',
        r'\bdehydration\b',
        r'\bmoderate pain\b',
        r'\b7/10 pain\b',
        r'\b6/10 pain\b',
        r'\b5/10 pain\b',
        r'\binfection\b',
        r'\binflammatory\b',
        r'\bswelling\b',
        r'\bredness\b',
        r'\bdizziness\b',
        r'\bnausea\b',
        r'\bheadache\b',
        r'\bfatigue\b',
        r'\bweakness\b',
        r'\bconcern\b',
        r'\bworsening\b'
    ],
    'low': [
        r'\broutine\b',
        r'\bfollow.up\b',
        r'\bcheck.up\b',
        r'\bmanagement\b',
        r'\breview\b',
        r'\bmild pain\b',
        r'\bchronic condition\b',
        r'\bstable\b',
        r'\bpreventive care\b',
        r'\bvaccination\b',
        r'\bscreening\b',
        r'\bcold symptoms\b',
        r'\bmild cough\b',
        r'\brunny nose\b',
        r'\b1/10 pain\b',
        r'\b2/10 pain\b',
        r'\b3/10 pain\b',
        r'\b4/10 pain\b',
        r'\bannual\b',
        r'\bphysical\b',
        r'\bexam\b',
        r'\basymptomatic\b',
        r'\bwellness\b'
    ]
}

# Enhanced specialty mapping
SPECIALTY_KEYWORDS = {
    'Cardiology': [
        'chest', 'heart', 'cardiac', 'breathing', 'palpitations',
        'blood pressure', 'hypertension', 'cholesterol', 'ecg',
        'heartbeat', 'arrhythmia', 'angina', 'cardiovascular'
    ],
    'Gastroenterology': [
        'stomach', 'abdominal', 'vomiting', 'diarrhea', 'nausea',
        'bowel', 'digestive', 'constipation', 'indigestion',
        'acid reflux', 'gerd', 'ibs', 'gallbladder', 'liver'
    ],
    'Musculoskeletal': [
        'pain', 'joint', 'muscle', 'elbow', 'shoulder', 'knee',
        'back', 'swelling', 'tendon', 'ligament', 'fracture',
        'arthritis', 'osteoporosis', 'sprain', 'strain'
    ],
    'Dermatology': [
        'rash', 'skin', 'itching', 'redness', 'lesion', 'acne',
        'eczema', 'dermatitis', 'psoriasis', 'hives', 'blister',
        'mole', 'skin cancer', 'sunburn'
    ],
    'Respiratory': [
        'cough', 'breathing', 'wheezing', 'lungs', 'respiratory',
        'asthma', 'pneumonia', 'bronchitis', 'copd', 'shortness',
        'oxygen', 'inhaler'
    ],
    'General Medicine': [
        'fever', 'fatigue', 'general', 'routine', 'check.up',
        'wellness', 'physical', 'annual', 'preventive'
    ]
}

def analyze_dataset_balance(df):
    """Analyze the current dataset distribution"""
    print("Analyzing dataset balance...")

    # Sample analysis of original dataset
    urgency_distribution = {'high': 0, 'medium': 0, 'low': 0}
    specialty_distribution = {}

    for idx, row in df.iterrows():
        text = row['transcription'].lower()

        # Simple urgency classification for analysis
        if any(re.search(pattern, text) for pattern in URGENCY_RULES['high']):
            urgency_distribution['high'] += 1
        elif any(re.search(pattern, text) for pattern in URGENCY_RULES['medium']):
            urgency_distribution['medium'] += 1
        else:
            urgency_distribution['low'] += 1

        # Specialty distribution
        specialty = row.get('specialty', 'Unknown')
        specialty_distribution[specialty] = specialty_distribution.get(specialty, 0) + 1

    total = len(df)
    print(f"Original dataset distribution:")
    print(f"High urgency: {urgency_distribution['high']} ({urgency_distribution['high']/total*100:.1f}%)")
    print(f"Medium urgency: {urgency_distribution['medium']} ({urgency_distribution['medium']/total*100:.1f}%)")
    print(f"Low urgency: {urgency_distribution['low']} ({urgency_distribution['low']/total*100:.1f}%)")

    return urgency_distribution

def create_balanced_synthetic_cases(target_count_per_level=500):
    """Create balanced synthetic medical cases across all urgency levels"""
    synthetic_cases = []

    # LOW URGENCY CASES - Expanded
    low_cases = [
        # General Medicine
        "Patient presents for routine annual physical examination and preventive health screening",
        "Follow-up visit for stable hypertension management and medication review",
        "Wellness check for asymptomatic patient with no current complaints",
        "Routine diabetes management with stable blood glucose levels",
        "Preventive care consultation and vaccination update",
        "Annual health maintenance visit with normal vital signs",
        "Medication refill request for chronic stable condition",
        "General health consultation for lifestyle modifications",
        "Routine blood work and laboratory test review",
        "Health insurance physical examination for employment",
        "Pre-operative clearance for elective surgery",
        "Travel medicine consultation and vaccination requirements",
        "Sports physical examination for school athletics",
        "Routine cholesterol screening and lipid profile review",
        "Annual vision and hearing screening test",
        "Well child check-up with developmental assessment",
        "Routine prenatal visit with normal pregnancy progression",
        "Post-operative follow-up with good wound healing",
        "Medication adjustment for well-controlled condition",
        "Routine bone density screening for osteoporosis",

        # Musculoskeletal - Low
        "Mild occasional back stiffness after long periods of sitting",
        "Minor joint discomfort that resolves with rest and over-the-counter pain relief",
        "Chronic stable arthritis with well-controlled symptoms",
        "Follow-up for previous muscle strain with significant improvement",
        "Mild muscle soreness after exercise or physical activity",
        "Stable orthopedic condition with minimal functional limitation",
        "Routine physical therapy follow-up with good progress",
        "Minor sprain with improving mobility and reduced swelling",
        "Chronic neck pain with stable characteristics",
        "Mild carpal tunnel symptoms with intermittent numbness",

        # Respiratory - Low
        "Mild seasonal allergies with occasional sneezing and runny nose",
        "Stable asthma with infrequent inhaler use and good control",
        "Routine pulmonary function test for monitoring purposes",
        "Mild cold symptoms with clear nasal discharge",
        "Environmental allergy management with good response to medication",
        "Routine spirometry testing for occupational health",
        "Mild post-nasal drip with occasional throat clearing",
        "Stable COPD with minimal daily symptoms",

        # Cardiology - Low
        "Stable blood pressure readings during routine monitoring",
        "Follow-up for well-controlled hyperlipidemia on medication",
        "Routine cardiac assessment with normal findings",
        "Stable palpitations without associated symptoms",
        "Routine EKG with normal sinus rhythm",
        "Medication review for stable cardiac condition",
        "Lifestyle counseling for cardiovascular health",
        "Routine pacemaker check with normal device function",

        # Gastroenterology - Low
        "Mild occasional heartburn responsive to antacids",
        "Stable irritable bowel syndrome with minimal symptoms",
        "Routine colon cancer screening consultation",
        "Mild constipation managed with dietary changes",
        "Stable inflammatory bowel disease in remission",
        "Routine liver enzyme monitoring",
        "Dietary consultation for mild digestive issues",

        # Dermatology - Low
        "Mild dry skin with occasional itching",
        "Stable eczema with good response to moisturizers",
        "Routine skin cancer screening examination",
        "Mild acne with occasional breakouts",
        "Benign mole monitoring with no concerning changes",
        "Mild dandruff managed with medicated shampoo",
        "Seasonal skin dryness with no rash or infection"
    ]

    # MEDIUM URGENCY CASES - Expanded
    medium_cases = [
        # General Medicine
        "Moderate fever of 101¬∞F with body aches and fatigue for 2 days",
        "Persistent cough with yellow phlegm and mild chest discomfort",
        "Worsening headache with sensitivity to light but no neurological symptoms",
        "Abdominal pain with nausea and decreased appetite for 24 hours",
        "Urinary symptoms with burning sensation and increased frequency",
        "Skin infection with localized redness, swelling and mild pain",
        "Moderate dehydration after gastrointestinal illness",
        "Ear pain with mild hearing loss and pressure sensation",
        "Sore throat with difficulty swallowing and swollen glands",
        "Sinus pressure with colored nasal discharge and facial pain",
        "Eye redness with discharge and mild vision blurring",
        "Flu-like symptoms with chills and generalized body aches",
        "Allergic reaction with hives and mild swelling",
        "Insect bite with localized redness and moderate itching",
        "Mild to moderate anxiety with physical symptoms",
        "Sleep disturbances with daytime fatigue and irritability",
        "Moderate stress with tension headaches and muscle tightness",
        "Weight changes with appetite fluctuations",
        "Moderate fatigue affecting daily activities",
        "Recurrent mild infections requiring evaluation",

        # Musculoskeletal - Medium
        "Moderate back pain limiting daily activities but no neurological deficits",
        "Joint pain with swelling and stiffness affecting mobility",
        "Muscle strain with moderate pain and functional limitation",
        "Worsening arthritis symptoms with increased pain levels",
        "Tendonitis with localized pain and movement restriction",
        "Moderate sprain with swelling and difficulty bearing weight",
        "Overuse injury from repetitive activities",
        "Moderate sciatica with radiating leg pain",
        "Rotator cuff strain with shoulder movement limitation",
        "Knee pain with swelling and difficulty climbing stairs",

        # Respiratory - Medium
        "Bronchitis symptoms with productive cough and mild shortness of breath",
        "Asthma exacerbation with increased inhaler use and wheezing",
        "Sinus infection with facial pressure and colored discharge",
        "Moderate allergic rhinitis affecting sleep and daily activities",
        "Pneumonia symptoms with fever and productive cough",
        "COPD exacerbation with increased shortness of breath",
        "Pleuritic chest pain with breathing discomfort",
        "Moderate croup with barking cough and stridor",

        # Cardiology - Medium
        "Palpitations with mild dizziness but no chest pain or fainting",
        "Elevated blood pressure readings with mild headache",
        "Chest discomfort with anxiety but normal cardiac workup",
        "Moderate edema with shortness of breath on exertion",
        "Syncopal episode with quick recovery and no injury",
        "Cardiac medication side effects requiring adjustment",
        "Moderate tachycardia with associated anxiety",

        # Gastroenterology - Medium
        "Gastroenteritis with vomiting and diarrhea for 12 hours",
        "Moderate abdominal pain with bloating and gas",
        "Food poisoning symptoms with nausea and stomach cramps",
        "Moderate GERD with nighttime symptoms affecting sleep",
        "Diverticulitis flare with localized abdominal pain",
        "Gallbladder symptoms with right upper quadrant discomfort",
        "Moderate constipation with abdominal distension",
        "Hemorrhoid flare with bleeding and discomfort",

        # Dermatology - Medium
        "Moderate rash with itching and spreading lesions",
        "Skin infection requiring antibiotic treatment",
        "Moderate eczema flare with significant itching",
        "Allergic contact dermatitis with blistering",
        "Moderate psoriasis flare with scaling and redness",
        "Suspicious mole with recent changes in appearance",
        "Moderate sunburn with blistering and pain",
        "Fungal infection not responding to over-the-counter treatment"
    ]

    # HIGH URGENCY CASES - Expanded
    high_cases = [
        # Cardiac emergencies
        "Severe crushing chest pain radiating to left arm with sweating and nausea",
        "Sudden onset of severe shortness of breath with blue lips and confusion",
        "Patient collapsed and unresponsive with no pulse or breathing",
        "Severe palpitations with chest pain and near fainting",
        "Cardiac arrest requiring immediate CPR and defibrillation",
        "Severe hypertension with blurred vision and headache",
        "Acute myocardial infarction with ST elevation on EKG",
        "Unstable angina with worsening chest pain at rest",
        "Severe aortic dissection with tearing chest pain",
        "Cardiogenic shock with low blood pressure and confusion",

        # Respiratory emergencies
        "Severe asthma attack not responding to inhaler with wheezing",
        "Acute respiratory distress with inability to speak in full sentences",
        "Severe pneumonia with high fever and respiratory failure",
        "Pulmonary embolism with sudden pleuritic chest pain",
        "Severe COPD exacerbation with oxygen saturation dropping",
        "Anaphylaxis with swelling and breathing difficulty",
        "Severe croup with stridor and respiratory distress",
        "Tension pneumothorax with tracheal deviation",

        # Neurological emergencies
        "Stroke symptoms with facial droop and arm weakness",
        "Severe headache with vision loss and difficulty speaking",
        "Seizure lasting more than 5 minutes or multiple seizures",
        "Severe head injury with loss of consciousness",
        "Meningitis symptoms with stiff neck and photophobia",
        "Spinal cord injury with paralysis or numbness",
        "Guillain-Barr√© syndrome with rapidly progressive weakness",
        "Status epilepticus with continuous seizure activity",

        # Abdominal emergencies
        "Severe abdominal pain with rigidity and fever suggesting appendicitis",
        "Ruptured abdominal aortic aneurysm with back pain and hypotension",
        "Severe gastrointestinal bleeding with bloody vomiting",
        "Bowel obstruction with vomiting and abdominal distension",
        "Ectopic pregnancy with abdominal pain and vaginal bleeding",
        "Severe pancreatitis with intense abdominal pain and vomiting",
        "Perforated ulcer with sudden severe abdominal pain",
        "Testicular torsion with severe pain and swelling",

        # Trauma and surgical emergencies
        "Major trauma with multiple injuries and unstable vital signs",
        "Severe burns covering large body surface area",
        "Compound fracture with bone protruding through skin",
        "Severe head injury with decreasing level of consciousness",
        "Stab wound or gunshot wound to torso or head",
        "Severe electrical injury with cardiac complications",
        "Near-drowning with respiratory compromise",
        "Severe animal bite with arterial bleeding",

        # Other emergencies
        "Uncontrolled bleeding from deep laceration with signs of shock",
        "Severe allergic reaction with swelling and breathing difficulty",
        "Diabetic ketoacidosis with confusion and dehydration",
        "Severe sepsis with fever, low blood pressure, and confusion",
        "Heat stroke with high body temperature and altered mental status",
        "Hypothermia with slow heart rate and confusion",
        "Severe drug overdose with respiratory depression",
        "Suicidal ideation with plan and intent"
    ]

    # Balance the cases to have approximately equal numbers
    cases_by_urgency = {
        'low': low_cases,
        'medium': medium_cases,
        'high': high_cases
    }

    # Calculate how many of each we need to reach target
    for urgency_level, cases in cases_by_urgency.items():
        # If we have fewer cases than target, we'll use all available
        num_cases = min(len(cases), target_count_per_level)
        selected_cases = random.sample(cases, num_cases)

        for case in selected_cases:
            synthetic_cases.append({
                'text': case,
                'metadata': {
                    'specialty': predict_specialty_synthetic(case),
                    'urgency': urgency_level
                }
            })

    return synthetic_cases

def predict_specialty_synthetic(text):
    """Predict specialty for synthetic cases"""
    text_lower = text.lower()
    specialty_scores = {}

    for specialty, keywords in SPECIALTY_KEYWORDS.items():
        score = 0
        for keyword in keywords:
            if keyword in text_lower:
                score += 1
        specialty_scores[specialty] = score

    predicted_specialty = max(specialty_scores, key=specialty_scores.get)
    return predicted_specialty if specialty_scores[predicted_specialty] > 0 else 'General Medicine'

def extract_key_information(text):
    """Extract key medical information using spaCy"""
    doc = nlp(text.lower())

    symptoms = []
    conditions = []
    severity_indicators = []

    for ent in doc.ents:
        if ent.label_ in ["DISEASE", "SYMPTOM", "SIGN", "PROBLEM", "CONDITION"]:
            symptoms.append(ent.text)

    # Extract severity words
    severity_words = ['severe', 'mild', 'moderate', 'sharp', 'chronic', 'acute', 'worsening', 'stable']
    for token in doc:
        if token.text in severity_words:
            severity_indicators.append(token.text)

    # Extract pain levels
    pain_levels = re.findall(r'(\d+)/10 pain', text.lower())
    severity_indicators.extend([f"{level}/10 pain" for level in pain_levels])

    return {
        'symptoms': list(set(symptoms)),
        'conditions': list(set(conditions)),
        'severity_indicators': list(set(severity_indicators))
    }

def classify_urgency(text, extracted_info):
    """Classify urgency based on rules and extracted information"""
    text_lower = text.lower()

    # Check high urgency rules
    for pattern in URGENCY_RULES['high']:
        if re.search(pattern, text_lower):
            return 'high'

    # Check medium urgency rules
    for pattern in URGENCY_RULES['medium']:
        if re.search(pattern, text_lower):
            return 'medium'

    # Check low urgency rules
    for pattern in URGENCY_RULES['low']:
        if re.search(pattern, text_lower):
            return 'low'

    # Default based on severity indicators
    if any(word in text_lower for word in ['severe', 'emergency', 'urgent', 'critical']):
        return 'high'
    elif any(word in text_lower for word in ['moderate', 'worsening', 'persistent']):
        return 'medium'
    else:
        return 'low'

def predict_specialty(text, extracted_info):
    """Predict medical specialty based on keywords"""
    text_lower = text.lower()
    specialty_scores = {}

    for specialty, keywords in SPECIALTY_KEYWORDS.items():
        score = 0
        for keyword in keywords:
            score += len(re.findall(r'\b' + re.escape(keyword) + r'\b', text_lower))
        specialty_scores[specialty] = score

    predicted_specialty = max(specialty_scores, key=specialty_scores.get)

    if specialty_scores[predicted_specialty] == 0:
        return 'General Medicine'

    return predicted_specialty

def create_concise_summary(text, extracted_info, urgency, specialty):
    """Create a concise summary in the required format"""
    symptoms = extracted_info['symptoms']
    severity = extracted_info['severity_indicators']

    # Build the summary text
    if symptoms:
        main_symptoms = symptoms[:3]
        symptom_text = ", ".join(main_symptoms)

        if severity:
            severity_text = severity[0]
            summary = f"Patient experiencing {severity_text} {symptom_text}"
        else:
            summary = f"Patient experiencing {symptom_text}"
    else:
        # Enhanced fallback with better context
        if any(word in text.lower() for word in ['chest pain', 'heart', 'cardiac']):
            summary = "Patient with cardiac symptoms"
        elif any(word in text.lower() for word in ['fever', 'cough', 'breathing']):
            summary = "Patient with respiratory symptoms"
        elif any(word in text.lower() for word in ['pain', 'swelling', 'joint']):
            summary = "Patient with musculoskeletal symptoms"
        elif any(word in text.lower() for word in ['rash', 'skin', 'itching']):
            summary = "Patient with dermatological symptoms"
        elif any(word in text.lower() for word in ['vomiting', 'diarrhea', 'abdominal']):
            summary = "Patient with gastrointestinal symptoms"
        else:
            summary = "Patient requiring medical evaluation"

    # Add duration if mentioned
    duration_pattern = r'(\d+\s*(?:hour|day|week|month)s?)'
    duration_match = re.search(duration_pattern, text.lower())
    if duration_match:
        summary += f" for {duration_match.group(1)}"

    # Add urgency context
    if urgency == 'high':
        summary += " with emergency symptoms requiring immediate attention"
    elif urgency == 'medium':
        summary += " with concerning symptoms requiring evaluation"
    else:
        summary += " for routine assessment"

    return summary

def process_medical_conversations(df):
    """Process all medical conversations and create the training dataset"""
    training_data = []

    for idx, row in df.iterrows():
        text = row['transcription']
        original_specialty = row['specialty']

        # Extract information using spaCy
        extracted_info = extract_key_information(text)

        # Classify urgency
        urgency = classify_urgency(text, extracted_info)

        # Predict specialty
        specialty = predict_specialty(text, extracted_info)

        # Create concise summary
        concise_text = create_concise_summary(text, extracted_info, urgency, specialty)

        # Create the training example
        training_example = {
            "text": concise_text,
            "metadata": {
                "specialty": specialty,
                "urgency": urgency
            }
        }

        training_data.append(training_example)

    return training_data

# First, analyze the current dataset balance
print("Analyzing current dataset distribution...")
original_urgency_dist = analyze_dataset_balance(df)

# Process the original dataset
print("\nProcessing original medical conversations...")
training_dataset = process_medical_conversations(df)

# Calculate target counts for balanced dataset
total_original_cases = len(training_dataset)
target_count_per_level = max(
    total_original_cases // 3,  # Aim for roughly equal distribution
    500  # Minimum target per level
)

print(f"\nTarget cases per urgency level: {target_count_per_level}")

# Add balanced synthetic cases
print("Adding balanced synthetic medical cases...")
synthetic_cases = create_balanced_synthetic_cases(target_count_per_level)
training_dataset.extend(synthetic_cases)

# Final balancing: ensure equal distribution
print("Performing final dataset balancing...")
final_dataset = []
urgency_groups = {'low': [], 'medium': [], 'high': []}

# Group by urgency
for case in training_dataset:
    urgency = case['metadata']['urgency']
    urgency_groups[urgency].append(case)

# Find the minimum group size to balance
min_group_size = min(len(urgency_groups['low']), len(urgency_groups['medium']), len(urgency_groups['high']))

# Sample equally from each group
for urgency_level in ['low', 'medium', 'high']:
    sampled_cases = random.sample(urgency_groups[urgency_level], min_group_size)
    final_dataset.extend(sampled_cases)

# Shuffle the balanced dataset
random.shuffle(final_dataset)

# Save the balanced dataset
output_file = '/balanced_medical_training_data.json'
with open(output_file, 'w') as f:
    json.dump(final_dataset, f, indent=2)

print(f"\nBalanced dataset saved to: {output_file}")
print(f"Total training examples created: {len(final_dataset)}")

# Print detailed statistics
urgency_counts = Counter([item['metadata']['urgency'] for item in final_dataset])
specialty_counts = Counter([item['metadata']['specialty'] for item in final_dataset])

print("\nüìä BALANCED DATASET STATISTICS:")
print("=" * 50)
print(f"Urgency distribution:")
for urgency, count in urgency_counts.items():
    percentage = (count / len(final_dataset)) * 100
    print(f"  {urgency.upper()}: {count} cases ({percentage:.1f}%)")

print(f"\nSpecialty distribution:")
for specialty, count in specialty_counts.most_common():
    percentage = (count / len(final_dataset)) * 100
    print(f"  {specialty}: {count} cases ({percentage:.1f}%)")

# Display examples from each urgency level
print("\nüìù SAMPLE CASES FROM EACH URGENCY LEVEL:")
print("=" * 50)

# Show 2 examples from each urgency level
for urgency_level in ['low', 'medium', 'high']:
    print(f"\n{urgency_level.upper()} URGENCY EXAMPLES:")
    urgency_cases = [item for item in final_dataset if item['metadata']['urgency'] == urgency_level]
    for i, case in enumerate(urgency_cases[:2]):
        print(f"  {i+1}. {case['text']}")
        print(f"     Specialty: {case['metadata']['specialty']}")

# Create a balanced dataset report
print(f"\nüéØ FINAL BALANCE REPORT:")
print("=" * 50)
total_cases = len(final_dataset)
print(f"Total cases: {total_cases}")
print(f"Low urgency: {urgency_counts['low']} ({urgency_counts['low']/total_cases*100:.1f}%)")
print(f"Medium urgency: {urgency_counts['medium']} ({urgency_counts['medium']/total_cases*100:.1f}%)")
print(f"High urgency: {urgency_counts['high']} ({urgency_counts['high']/total_cases*100:.1f}%)")

# Calculate balance quality
balance_ratio = min(urgency_counts.values()) / max(urgency_counts.values())
print(f"Balance quality ratio: {balance_ratio:.3f} (1.0 = perfect balance)")

# Save sample file
sample_output = final_dataset[:20]
sample_file = '/balanced_sample_training_data.json'
with open(sample_file, 'w') as f:
    json.dump(sample_output, f, indent=2)

print(f"\nSample preview saved to: {sample_file}")
print("\n‚úÖ Balanced dataset ready for model training!")

Loading spaCy medical model...
Medical spaCy model not found. Using basic English model...
Loading dataset...
Analyzing current dataset distribution...
Analyzing dataset balance...
Original dataset distribution:
High urgency: 25 (83.3%)
Medium urgency: 5 (16.7%)
Low urgency: 0 (0.0%)

Processing original medical conversations...

Target cases per urgency level: 500
Adding balanced synthetic medical cases...
Performing final dataset balancing...

Balanced dataset saved to: /balanced_medical_training_data.json
Total training examples created: 180

üìä BALANCED DATASET STATISTICS:
Urgency distribution:
  HIGH: 60 cases (33.3%)
  MEDIUM: 60 cases (33.3%)
  LOW: 60 cases (33.3%)

Specialty distribution:
  General Medicine: 64 cases (35.6%)
  Musculoskeletal: 35 cases (19.4%)
  Cardiology: 28 cases (15.6%)
  Gastroenterology: 21 cases (11.7%)
  Respiratory: 17 cases (9.4%)
  Dermatology: 15 cases (8.3%)

üìù SAMPLE CASES FROM EACH URGENCY LEVEL:

LOW URGENCY EXAMPLES:
  1. Environmental al

In [None]:
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup, TrainingArguments, Trainer
)
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, classification_report, confusion_matrix
import pandas as pd
import numpy as np
import json
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os
import random
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from tqdm.auto import tqdm
from datasets import Dataset
import evaluate
from sklearn.utils.class_weight import compute_class_weight

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class AdvancedMedicalBERTTrainer:
    def __init__(self):
        self.model_name = 'emilyalsentzer/Bio_ClinicalBERT'
        self.tokenizer = None
        self.model = None
        self.urgency_labels = ['low', 'medium', 'high']
        self.class_weights = None
        self.best_metrics = {}

    def load_and_preprocess_data(self, data_path):
        """Load and preprocess the balanced dataset"""
        print("Loading balanced training data...")
        with open(data_path, 'r') as f:
            data = json.load(f)

        print(f"Total samples: {len(data)}")

        # Analyze data distribution
        urgencies = [item['metadata']['urgency'] for item in data]
        urgency_counts = Counter(urgencies)
        print(f"Urgency distribution: {dict(urgency_counts)}")

        # Calculate enhanced class weights
        self.calculate_enhanced_class_weights(urgencies)

        # Create larger splits for better evaluation
        train_data, val_data, test_data = self.robust_stratified_split(data)

        print(f"\nData splits:")
        print(f"Train: {len(train_data)}")
        print(f"Validation: {len(val_data)}")
        print(f"Test: {len(test_data)}")

        # Print distribution for each split
        for split_name, split_data in [('Train', train_data), ('Validation', val_data), ('Test', test_data)]:
            split_urgencies = [item['metadata']['urgency'] for item in split_data]
            split_counts = Counter(split_urgencies)
            print(f"{split_name} distribution: {dict(split_counts)}")

        return train_data, val_data, test_data

    def calculate_enhanced_class_weights(self, urgencies):
        """Calculate enhanced class weights with smoothing"""
        label_indices = [self.urgency_labels.index(urg) for urg in urgencies]

        # Compute class weights with balanced strategy
        weights = compute_class_weight(
            class_weight='balanced',
            classes=np.array([0, 1, 2]),
            y=label_indices
        )

        # Apply smoothing to prevent extreme weights
        weights = np.clip(weights, 0.5, 2.0)

        self.class_weights = torch.tensor(weights).float().to(device)
        print(f"Enhanced class weights: {dict(zip(self.urgency_labels, weights))}")

    def robust_stratified_split(self, data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
        """Robust stratified split ensuring sufficient samples per class"""
        # Group by urgency
        urgency_groups = {}
        for item in data:
            urgency = item['metadata']['urgency']
            if urgency not in urgency_groups:
                urgency_groups[urgency] = []
            urgency_groups[urgency].append(item)

        train_data, val_data, test_data = [], [], []

        for urgency, items in urgency_groups.items():
            # Shuffle items
            random.shuffle(items)
            n_items = len(items)

            # Ensure minimum samples per split
            min_samples_per_split = 10

            # Calculate split indices
            train_end = int(train_ratio * n_items)
            val_end = train_end + int(val_ratio * n_items)

            # Ensure we have enough samples in each split
            if n_items < min_samples_per_split * 3:
                # If too few samples, use all for training
                train_data.extend(items)
                continue

            train_data.extend(items[:train_end])
            val_data.extend(items[train_end:val_end])
            test_data.extend(items[val_end:])

        # Final shuffle
        random.shuffle(train_data)
        random.shuffle(val_data)
        random.shuffle(test_data)

        return train_data, val_data, test_data

    def create_enhanced_dataloader(self, data, batch_size=16, training=True):
        """Create enhanced dataloader with better balancing"""
        if self.tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        texts = [item['text'] for item in data]
        urgencies = [item['metadata']['urgency'] for item in data]
        labels = [self.urgency_labels.index(urg) for urg in urgencies]

        # Enhanced tokenization with medical context
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=512,  # Increased for medical context
            return_tensors='pt',
            add_special_tokens=True
        )

        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            torch.tensor(labels)
        )

        if training:
            # Use weighted sampling for training
            class_counts = Counter(labels)
            sample_weights = [1.0 / class_counts[label] for label in labels]
            sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
            return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
        else:
            # No sampling for validation/test
            return DataLoader(dataset, batch_size=batch_size, shuffle=False)

    def initialize_enhanced_model(self):
        """Initialize model with enhanced configuration"""
        print("Initializing Enhanced Bio+Clinical BERT...")

        model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=3,
            id2label={0: 'low', 1: 'medium', 2: 'high'},
            label2id={'low': 0, 'medium': 1, 'high': 2},
            attention_probs_dropout_prob=0.2,  # Increased dropout
            hidden_dropout_prob=0.2,           # Increased dropout
            hidden_size=768,
            num_attention_heads=12,
            num_hidden_layers=12
        )

        # Enhanced classifier head
        model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(model.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 3)
        )

        model.to(device)
        print("‚úÖ Enhanced Bio+Clinical BERT model loaded with improved classifier")
        return model

    def train_with_focal_loss(self, train_loader, val_loader, test_loader, output_dir):
        """Train with focal loss to handle class imbalance"""
        print("Initializing model with focal loss...")
        self.model = self.initialize_enhanced_model()

        # Focal Loss for handling class imbalance
        class FocalLoss(nn.Module):
            def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
                super(FocalLoss, self).__init__()
                self.alpha = alpha
                self.gamma = gamma
                self.reduction = reduction

            def forward(self, inputs, targets):
                ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
                pt = torch.exp(-ce_loss)
                focal_loss = (1 - pt) ** self.gamma * ce_loss

                if self.alpha is not None:
                    alpha_t = self.alpha[targets]
                    focal_loss = alpha_t * focal_loss

                if self.reduction == 'mean':
                    return focal_loss.mean()
                elif self.reduction == 'sum':
                    return focal_loss.sum()
                else:
                    return focal_loss

        # Use focal loss with class weights
        criterion = FocalLoss(alpha=self.class_weights, gamma=2.0)

        # Enhanced optimizer configuration
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters()
                          if not any(nd in n for nd in no_decay) and 'classifier' not in n],
                'weight_decay': 0.01,
                'lr': 1e-5  # Lower LR for base model
            },
            {
                'params': [p for n, p in self.model.named_parameters()
                          if any(nd in n for nd in no_decay) and 'classifier' not in n],
                'weight_decay': 0.0,
                'lr': 1e-5
            },
            {
                'params': [p for n, p in self.model.named_parameters() if 'classifier' in n],
                'weight_decay': 0.01,
                'lr': 2e-4  # Higher LR for classifier
            }
        ]

        optimizer = AdamW(optimizer_grouped_parameters)

        # Enhanced training parameters
        epochs = 20
        warmup_ratio = 0.1
        total_steps = len(train_loader) * epochs

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(warmup_ratio * total_steps),
            num_training_steps=total_steps
        )

        # Training metrics tracking
        training_losses = []
        val_accuracies = []
        val_f1_scores = []
        best_f1_macro = 0
        patience = 7
        patience_counter = 0

        print("Starting enhanced training with focal loss...")
        for epoch in range(epochs):
            # Training phase
            self.model.train()
            total_train_loss = 0
            all_predictions = []
            all_true_labels = []

            progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
            for batch in progress_bar:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                self.model.zero_grad()
                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                # Use focal loss
                loss = criterion(logits, labels)
                total_train_loss += loss.item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

                # Collect predictions
                predictions = torch.argmax(logits, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_true_labels.extend(labels.cpu().numpy())

                progress_bar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                })

            avg_train_loss = total_train_loss / len(train_loader)
            training_losses.append(avg_train_loss)

            # Calculate training metrics
            train_accuracy = accuracy_score(all_true_labels, all_predictions)
            train_precision, train_recall, train_f1, _ = precision_recall_fscore_support(
                all_true_labels, all_predictions, average='weighted', zero_division=0
            )

            # Validation phase
            val_metrics = self.comprehensive_evaluation(val_loader)
            val_accuracy = val_metrics['accuracy']
            val_f1_macro = val_metrics['f1_macro']  # Use macro F1 for early stopping
            val_f1_weighted = val_metrics['f1_weighted']

            val_accuracies.append(val_accuracy)
            val_f1_scores.append(val_f1_weighted)

            print(f'\nEpoch {epoch+1}:')
            print(f'  Train Loss: {avg_train_loss:.4f}')
            print(f'  Train Accuracy: {train_accuracy:.4f}, Train F1: {train_f1:.4f}')
            print(f'  Val Accuracy: {val_accuracy:.4f}, Val F1 (macro): {val_f1_macro:.4f}')
            print(f'  Val F1 (weighted): {val_f1_weighted:.4f}')

            # Print detailed per-class metrics
            f1_per_class = val_metrics['f1_per_class']
            print(f'  Per-class F1: Low={f1_per_class[0]:.4f}, Medium={f1_per_class[1]:.4f}, High={f1_per_class[2]:.4f}')

            # Enhanced early stopping based on macro F1 and minimum class performance
            current_f1_macro = val_f1_macro
            min_class_f1 = min(f1_per_class)

            if current_f1_macro > best_f1_macro and min_class_f1 > 0.5:
                best_f1_macro = current_f1_macro
                patience_counter = 0
                self.model.save_pretrained(output_dir)
                self.tokenizer.save_pretrained(output_dir)
                self.best_metrics = val_metrics
                print(f'  ‚úÖ Saved best model (Macro F1: {current_f1_macro:.4f}, Min Class F1: {min_class_f1:.4f})')
            else:
                patience_counter += 1
                print(f'  ‚è≥ No improvement ({patience_counter}/{patience})')

            if patience_counter >= patience:
                print(f'  üõë Early stopping at epoch {epoch+1}')
                break

        # Load best model for final evaluation
        self.model = AutoModelForSequenceClassification.from_pretrained(output_dir)
        self.model.to(device)

        # Final evaluation on test set
        print("\nüî¨ Final Evaluation on Test Set:")
        test_metrics = self.comprehensive_evaluation(test_loader)

        # Plot results
        self.plot_enhanced_results(training_losses, val_accuracies, val_f1_scores, test_metrics)

        return training_losses, val_accuracies, val_f1_scores, test_metrics

    def comprehensive_evaluation(self, dataloader):
        """Comprehensive evaluation with multiple metrics"""
        self.model.eval()
        predictions, true_labels = [], []

        with torch.no_grad():
            for batch in dataloader:
                batch = tuple(t.to(device) for t in batch)
                input_ids, attention_mask, labels = batch

                outputs = self.model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                batch_predictions = torch.argmax(logits, dim=1)

                predictions.extend(batch_predictions.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        # Calculate comprehensive metrics
        accuracy = accuracy_score(true_labels, predictions)

        # Multiple averaging strategies
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
            true_labels, predictions, average='weighted', zero_division=0
        )
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            true_labels, predictions, average='macro', zero_division=0
        )

        # Per-class metrics
        precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
            true_labels, predictions, average=None, zero_division=0
        )

        # Confusion matrix
        cm = confusion_matrix(true_labels, predictions, labels=[0, 1, 2])

        # Additional metrics
        classification_rep = classification_report(true_labels, predictions,
                                                 target_names=self.urgency_labels,
                                                 output_dict=True)

        return {
            'accuracy': accuracy,
            'precision_weighted': precision_weighted,
            'recall_weighted': recall_weighted,
            'f1_weighted': f1_weighted,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'precision_per_class': precision_per_class.tolist(),
            'recall_per_class': recall_per_class.tolist(),
            'f1_per_class': f1_per_class.tolist(),
            'support_per_class': support_per_class.tolist(),
            'confusion_matrix': cm.tolist(),
            'classification_report': classification_rep,
            'predictions': predictions,
            'true_labels': true_labels
        }

    def plot_enhanced_results(self, train_losses, val_accuracies, val_f1_scores, test_metrics):
        """Plot enhanced training results"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

        # Plot 1: Training loss
        ax1.plot(train_losses, 'b-', linewidth=2, label='Training Loss')
        ax1.set_title('Training Loss Over Epochs', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.grid(True, alpha=0.3)
        ax1.legend()

        # Plot 2: Validation metrics
        epochs = range(1, len(val_accuracies) + 1)
        ax2.plot(epochs, val_accuracies, 'g-', linewidth=2, label='Validation Accuracy')
        ax2.plot(epochs, val_f1_scores, 'r-', linewidth=2, label='Validation F1 (Weighted)')
        ax2.set_title('Validation Metrics Over Epochs', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Score')
        ax2.grid(True, alpha=0.3)
        ax2.legend()

        # Plot 3: Confusion matrix
        cm = np.array(test_metrics['confusion_matrix'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax3,
                   xticklabels=self.urgency_labels,
                   yticklabels=self.urgency_labels,
                   annot_kws={"size": 14})
        ax3.set_title('Test Set Confusion Matrix', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Predicted')
        ax3.set_ylabel('Actual')

        # Plot 4: Per-class F1 scores comparison
        classes = self.urgency_labels
        f1_scores = test_metrics['f1_per_class']
        colors = ['#2ecc71', '#f39c12', '#e74c3c']  # Green, Orange, Red

        bars = ax4.bar(classes, f1_scores, color=colors, alpha=0.8, edgecolor='black')
        ax4.set_title('Per-class F1 Scores on Test Set', fontsize=14, fontweight='bold')
        ax4.set_ylabel('F1 Score')
        ax4.set_ylim(0, 1)
        ax4.grid(True, alpha=0.3, axis='y')

        # Add value labels on bars
        for bar, score in zip(bars, f1_scores):
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width()/2, height + 0.01,
                    f'{score:.3f}', ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()
        plt.savefig('/enhanced_medical_bert_results.png', dpi=300, bbox_inches='tight')
        plt.close()

    def analyze_class_performance(self, test_metrics):
        """Enhanced class performance analysis"""
        print("\nüìä ENHANCED CLASS PERFORMANCE ANALYSIS:")
        print("=" * 70)

        classes = self.urgency_labels
        precision = test_metrics['precision_per_class']
        recall = test_metrics['recall_per_class']
        f1 = test_metrics['f1_per_class']
        support = test_metrics['support_per_class']

        for i, class_name in enumerate(classes):
            print(f"\n{class_name.upper()} URGENCY:")
            print(f"  Precision: {precision[i]:.4f}")
            print(f"  Recall:    {recall[i]:.4f}")
            print(f"  F1-Score:  {f1[i]:.4f}")
            print(f"  Support:   {support[i]} samples")

            # Enhanced performance interpretation
            if f1[i] >= 0.85:
                status = "‚úÖ EXCELLENT"
                explanation = "Very reliable predictions"
            elif f1[i] >= 0.75:
                status = "‚úÖ VERY GOOD"
                explanation = "Good performance, suitable for deployment"
            elif f1[i] >= 0.65:
                status = "‚ö†Ô∏è GOOD"
                explanation = "Acceptable performance"
            elif f1[i] >= 0.55:
                status = "üî∂ FAIR"
                explanation = "Needs minor improvements"
            else:
                status = "‚ùå NEEDS IMPROVEMENT"
                explanation = "Significant improvements needed"

            print(f"  Status:    {status}")
            print(f"  Assessment: {explanation}")

        # Overall assessment
        min_f1 = min(f1)
        avg_f1 = sum(f1) / len(f1)

        print(f"\nüìà OVERALL ASSESSMENT:")
        print(f"  Minimum F1: {min_f1:.4f}")
        print(f"  Average F1: {avg_f1:.4f}")
        print(f"  Balanced: {'Yes' if min_f1 > 0.65 else 'Partially' if min_f1 > 0.55 else 'No'}")

def main():
    print("üöÄ ADVANCED MEDICAL BERT TRAINING WITH BALANCED DATASET")
    print("=" * 70)

    # Initialize advanced trainer
    trainer = AdvancedMedicalBERTTrainer()

    # Load balanced dataset
    data_path = '/balanced_medical_training_data.json'  # Use the new balanced dataset
    train_data, val_data, test_data = trainer.load_and_preprocess_data(data_path)

    # Prepare tokenizer
    trainer.tokenizer = AutoTokenizer.from_pretrained(trainer.model_name)

    # Create enhanced dataloaders
    print("\nüìö Preparing enhanced dataloaders...")
    train_loader = trainer.create_enhanced_dataloader(train_data, batch_size=16, training=True)
    val_loader = trainer.create_enhanced_dataloader(val_data, batch_size=16, training=False)
    test_loader = trainer.create_enhanced_dataloader(test_data, batch_size=16, training=False)

    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

    # Train model with enhanced techniques
    output_dir = '/advanced_medical_bert_model'
    os.makedirs(output_dir, exist_ok=True)

    print(f"\nüè• Starting Advanced Training with Focal Loss...")
    train_losses, val_accuracies, val_f1_scores, test_metrics = trainer.train_with_focal_loss(
        train_loader, val_loader, test_loader, output_dir
    )

    # Enhanced analysis
    trainer.analyze_class_performance(test_metrics)

    # Generate comprehensive report
    generate_advanced_report(test_metrics, output_dir)

    print("\n" + "=" * 70)
    print("üéØ ADVANCED TRAINING COMPLETED SUCCESSFULLY!")
    print("=" * 70)

def generate_advanced_report(test_metrics, model_path):
    """Generate comprehensive performance report"""

    # Calculate quality metrics
    f1_scores = test_metrics['f1_per_class']
    min_f1 = min(f1_scores)
    avg_f1 = sum(f1_scores) / len(f1_scores)

    if min_f1 >= 0.7 and avg_f1 >= 0.75:
        overall_quality = "EXCELLENT"
        recommendation = "Ready for production deployment"
    elif min_f1 >= 0.6 and avg_f1 >= 0.7:
        overall_quality = "VERY GOOD"
        recommendation = "Suitable for deployment with monitoring"
    elif min_f1 >= 0.55 and avg_f1 >= 0.65:
        overall_quality = "GOOD"
        recommendation = "Suitable for prototype deployment"
    else:
        overall_quality = "NEEDS IMPROVEMENT"
        recommendation = "Requires further model optimization"

    report = {
        'timestamp': str(pd.Timestamp.now()),
        'model': 'Advanced Bio+Clinical BERT with Focal Loss',
        'test_performance': {
            'accuracy': float(test_metrics['accuracy']),
            'precision_weighted': float(test_metrics['precision_weighted']),
            'recall_weighted': float(test_metrics['recall_weighted']),
            'f1_weighted': float(test_metrics['f1_weighted']),
            'precision_macro': float(test_metrics['precision_macro']),
            'recall_macro': float(test_metrics['recall_macro']),
            'f1_macro': float(test_metrics['f1_macro']),
        },
        'per_class_performance': {
            'low': {
                'precision': float(test_metrics['precision_per_class'][0]),
                'recall': float(test_metrics['recall_per_class'][0]),
                'f1': float(test_metrics['f1_per_class'][0]),
                'support': int(test_metrics['support_per_class'][0])
            },
            'medium': {
                'precision': float(test_metrics['precision_per_class'][1]),
                'recall': float(test_metrics['recall_per_class'][1]),
                'f1': float(test_metrics['f1_per_class'][1]),
                'support': int(test_metrics['support_per_class'][1])
            },
            'high': {
                'precision': float(test_metrics['precision_per_class'][2]),
                'recall': float(test_metrics['recall_per_class'][2]),
                'f1': float(test_metrics['f1_per_class'][2]),
                'support': int(test_metrics['support_per_class'][2])
            }
        },
        'performance_interpretation': {
            'overall_quality': overall_quality,
            'class_balance_quality': 'BALANCED' if min_f1 > 0.65 else 'MODERATE' if min_f1 > 0.55 else 'IMBALANCED',
            'minimum_class_f1': float(min_f1),
            'average_class_f1': float(avg_f1),
            'recommendation': recommendation,
            'deployment_ready': min_f1 >= 0.6 and avg_f1 >= 0.7
        }
    }

    with open('/advanced_medical_bert_report.json', 'w') as f:
        json.dump(report, f, indent=2)

    print("\nüìä ADVANCED PERFORMANCE REPORT:")
    print("=" * 50)
    print(f"Overall F1-Score (Weighted): {test_metrics['f1_weighted']:.4f}")
    print(f"Overall F1-Score (Macro): {test_metrics['f1_macro']:.4f}")
    print(f"Overall Accuracy: {test_metrics['accuracy']:.4f}")

    print(f"\nPer-class Performance:")
    print(f"  Low:    F1={test_metrics['f1_per_class'][0]:.4f}")
    print(f"  Medium: F1={test_metrics['f1_per_class'][1]:.4f}")
    print(f"  High:   F1={test_metrics['f1_per_class'][2]:.4f}")

    print(f"\nStatus: {overall_quality}")
    print(f"Deployment Ready: {'Yes' if report['performance_interpretation']['deployment_ready'] else 'No'}")
    print(f"Recommendation: {recommendation}")

if __name__ == "__main__":
    main()

Using device: cpu
üöÄ ADVANCED MEDICAL BERT TRAINING WITH BALANCED DATASET
Loading balanced training data...
Total samples: 180
Urgency distribution: {'high': 60, 'medium': 60, 'low': 60}
Enhanced class weights: {'low': np.float64(1.0), 'medium': np.float64(1.0), 'high': np.float64(1.0)}

Data splits:
Train: 144
Validation: 18
Test: 18
Train distribution: {'high': 48, 'low': 48, 'medium': 48}
Validation distribution: {'low': 6, 'high': 6, 'medium': 6}
Test distribution: {'low': 6, 'medium': 6, 'high': 6}

üìö Preparing enhanced dataloaders...
Train batches: 9, Val batches: 2, Test batches: 2

üè• Starting Advanced Training with Focal Loss...
Initializing model with focal loss...
Initializing Enhanced Bio+Clinical BERT...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ Enhanced Bio+Clinical BERT model loaded with improved classifier
Starting enhanced training with focal loss...


Epoch 1/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 1:
  Train Loss: 0.4918
  Train Accuracy: 0.3333, Train F1: 0.2891
  Val Accuracy: 0.3333, Val F1 (macro): 0.1667
  Val F1 (weighted): 0.1667
  Per-class F1: Low=0.0000, Medium=0.5000, High=0.0000
  ‚è≥ No improvement (1/7)


Epoch 2/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 2:
  Train Loss: 0.4784
  Train Accuracy: 0.4236, Train F1: 0.4048
  Val Accuracy: 0.5556, Val F1 (macro): 0.4444
  Val F1 (weighted): 0.4444
  Per-class F1: Low=0.6667, Medium=0.0000, High=0.6667
  ‚è≥ No improvement (2/7)


Epoch 3/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 3:
  Train Loss: 0.3998
  Train Accuracy: 0.5417, Train F1: 0.4871
  Val Accuracy: 0.6667, Val F1 (macro): 0.6016
  Val F1 (weighted): 0.6016
  Per-class F1: Low=0.7692, Medium=0.2857, High=0.7500
  ‚è≥ No improvement (3/7)


Epoch 4/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 4:
  Train Loss: 0.3413
  Train Accuracy: 0.5833, Train F1: 0.4918
  Val Accuracy: 0.5556, Val F1 (macro): 0.4667
  Val F1 (weighted): 0.4667
  Per-class F1: Low=0.8000, Medium=0.0000, High=0.6000
  ‚è≥ No improvement (4/7)


Epoch 5/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 5:
  Train Loss: 0.2865
  Train Accuracy: 0.6875, Train F1: 0.6578
  Val Accuracy: 0.6667, Val F1 (macro): 0.6217
  Val F1 (weighted): 0.6217
  Per-class F1: Low=0.9091, Medium=0.2500, High=0.7059
  ‚è≥ No improvement (5/7)


Epoch 6/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 6:
  Train Loss: 0.2307
  Train Accuracy: 0.6944, Train F1: 0.6815
  Val Accuracy: 0.8333, Val F1 (macro): 0.8222
  Val F1 (weighted): 0.8222
  Per-class F1: Low=1.0000, Medium=0.6667, High=0.8000
  ‚úÖ Saved best model (Macro F1: 0.8222, Min Class F1: 0.6667)


Epoch 7/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 7:
  Train Loss: 0.1588
  Train Accuracy: 0.7847, Train F1: 0.7762
  Val Accuracy: 0.6667, Val F1 (macro): 0.6217
  Val F1 (weighted): 0.6217
  Per-class F1: Low=0.9091, Medium=0.2500, High=0.7059
  ‚è≥ No improvement (1/7)


Epoch 8/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 8:
  Train Loss: 0.1389
  Train Accuracy: 0.7917, Train F1: 0.7856
  Val Accuracy: 0.8333, Val F1 (macro): 0.8222
  Val F1 (weighted): 0.8222
  Per-class F1: Low=1.0000, Medium=0.6667, High=0.8000
  ‚è≥ No improvement (2/7)


Epoch 9/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 9:
  Train Loss: 0.1189
  Train Accuracy: 0.8194, Train F1: 0.8154
  Val Accuracy: 0.8889, Val F1 (macro): 0.8857
  Val F1 (weighted): 0.8857
  Per-class F1: Low=1.0000, Medium=0.8000, High=0.8571
  ‚úÖ Saved best model (Macro F1: 0.8857, Min Class F1: 0.8000)


Epoch 10/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 10:
  Train Loss: 0.1060
  Train Accuracy: 0.8194, Train F1: 0.8182
  Val Accuracy: 0.8889, Val F1 (macro): 0.8857
  Val F1 (weighted): 0.8857
  Per-class F1: Low=1.0000, Medium=0.8000, High=0.8571
  ‚è≥ No improvement (1/7)


Epoch 11/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 11:
  Train Loss: 0.0774
  Train Accuracy: 0.9167, Train F1: 0.9178
  Val Accuracy: 0.8333, Val F1 (macro): 0.8222
  Val F1 (weighted): 0.8222
  Per-class F1: Low=1.0000, Medium=0.6667, High=0.8000
  ‚è≥ No improvement (2/7)


Epoch 12/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 12:
  Train Loss: 0.0803
  Train Accuracy: 0.8681, Train F1: 0.8612
  Val Accuracy: 0.8333, Val F1 (macro): 0.8222
  Val F1 (weighted): 0.8222
  Per-class F1: Low=1.0000, Medium=0.6667, High=0.8000
  ‚è≥ No improvement (3/7)


Epoch 13/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 13:
  Train Loss: 0.0809
  Train Accuracy: 0.9375, Train F1: 0.9364
  Val Accuracy: 0.7778, Val F1 (macro): 0.7410
  Val F1 (weighted): 0.7410
  Per-class F1: Low=0.9231, Medium=0.5000, High=0.8000
  ‚è≥ No improvement (4/7)


Epoch 14/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 14:
  Train Loss: 0.0529
  Train Accuracy: 0.9375, Train F1: 0.9359
  Val Accuracy: 0.8333, Val F1 (macro): 0.8222
  Val F1 (weighted): 0.8222
  Per-class F1: Low=1.0000, Medium=0.6667, High=0.8000
  ‚è≥ No improvement (5/7)


Epoch 15/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]


Epoch 15:
  Train Loss: 0.0683
  Train Accuracy: 0.9236, Train F1: 0.9238
  Val Accuracy: 0.8889, Val F1 (macro): 0.8857
  Val F1 (weighted): 0.8857
  Per-class F1: Low=1.0000, Medium=0.8000, High=0.8571
  ‚è≥ No improvement (6/7)


Epoch 16/20 [Train]:   0%|          | 0/9 [00:00<?, ?it/s]

Some weights of the model checkpoint at /advanced_medical_bert_model were not used when initializing BertForSequenceClassification: ['classifier.1.bias', 'classifier.1.weight', 'classifier.4.bias', 'classifier.4.weight', 'classifier.7.bias', 'classifier.7.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /advanced_medical_bert_model and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-st


Epoch 16:
  Train Loss: 0.0386
  Train Accuracy: 0.9444, Train F1: 0.9445
  Val Accuracy: 0.8889, Val F1 (macro): 0.8857
  Val F1 (weighted): 0.8857
  Per-class F1: Low=1.0000, Medium=0.8000, High=0.8571
  ‚è≥ No improvement (7/7)
  üõë Early stopping at epoch 16

üî¨ Final Evaluation on Test Set:

üìä ENHANCED CLASS PERFORMANCE ANALYSIS:

LOW URGENCY:
  Precision: 1.0000
  Recall:    0.1667
  F1-Score:  0.2857
  Support:   6 samples
  Status:    ‚ùå NEEDS IMPROVEMENT
  Assessment: Significant improvements needed

MEDIUM URGENCY:
  Precision: 0.0000
  Recall:    0.0000
  F1-Score:  0.0000
  Support:   6 samples
  Status:    ‚ùå NEEDS IMPROVEMENT
  Assessment: Significant improvements needed

HIGH URGENCY:
  Precision: 0.2667
  Recall:    0.6667
  F1-Score:  0.3810
  Support:   6 samples
  Status:    ‚ùå NEEDS IMPROVEMENT
  Assessment: Significant improvements needed

üìà OVERALL ASSESSMENT:
  Minimum F1: 0.0000
  Average F1: 0.2222
  Balanced: No

üìä ADVANCED PERFORMANCE REPORT: