In [0]:
import os
import json
import requests
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType

# Get authentication token
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Working endpoint URL
LLAMA_ENDPOINT = "https://dbc-4b63042e-6da3.cloud.databricks.com/serving-endpoints/databricks-meta-llama-3-3-70b-instruct/invocations"

print(f"✅ Token obtained")
print(f"✅ Endpoint: {LLAMA_ENDPOINT}")

In [0]:
# DBTITLE 1,Test the Endpoint
# Quick test to verify access
test_payload = {
    "messages": [
        {"role": "user", "content": "Say 'Hello, I am working!'"}
    ],
    "max_tokens": 20,
    "temperature": 0.1
}

headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

try:
    response = requests.post(LLAMA_ENDPOINT, json=test_payload, headers=headers, timeout=30)
    response.raise_for_status()
    result = response.json()
    print("✅ Endpoint test successful!")
    print(f"Response: {result['choices'][0]['message']['content']}")
except Exception as e:
    print(f"❌ Endpoint test failed: {e}")

In [0]:
@udf(returnType=StringType())
def extract_entities_llama(text):
    """
    Extract clinical entities using Llama 3.3 70B
    
    Returns JSON with: chief_complaint, symptoms, diagnoses, medications, procedures, vital_signs
    """
    
    # Handle empty or very short text
    if not text or len(text) < 50:
        return json.dumps({
            "chief_complaint": "",
            "symptoms": [],
            "diagnoses": [],
            "medications": [],
            "procedures": [],
            "vital_signs": {},
            "extraction_confidence": "low"
        })
    
    # Truncate very long text (Llama context limit)
    text_truncated = text[:4000] if len(text) > 4000 else text
    
    # Craft extraction prompt
    prompt = f"""You are a medical AI assistant. Extract structured clinical information from this medical note.

MEDICAL NOTE:
{text_truncated}

Extract and return ONLY valid JSON in this exact format:
{{
  "chief_complaint": "primary reason for visit",
  "symptoms": [
    {{"name": "symptom name", "severity": "mild/moderate/severe", "duration": "how long"}}
  ],
  "diagnoses": [
    {{"diagnosis": "diagnosis name", "icd10": "code if mentioned", "status": "confirmed/suspected"}}
  ],
  "medications": [
    {{"name": "medication name", "dosage": "dosage", "frequency": "frequency", "route": "route"}}
  ],
  "procedures": [
    {{"procedure": "procedure name", "cpt": "code if mentioned"}}
  ],
  "vital_signs": {{
    "blood_pressure": "value",
    "temperature": "value with unit",
    "heart_rate": "value",
    "respiratory_rate": "value",
    "oxygen_saturation": "value"
  }},
  "allergies": ["allergy1", "allergy2"],
  "extraction_confidence": "high/medium/low"
}}

IMPORTANT:
- Extract ONLY information explicitly stated in the text
- Do NOT infer or assume information
- Use null for missing fields
- Return ONLY the JSON object, no explanations
- If a section is empty, use empty arrays [] or empty objects {{}}

JSON OUTPUT:"""

    # API payload
    payload = {
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "max_tokens": 2000,
        "temperature": 0.1  # Low temperature for factual extraction
    }
    
    headers = {
        "Authorization": f"Bearer {token}",
        "Content-Type": "application/json"
    }
    
    try:
        # Call Llama API
        response = requests.post(LLAMA_ENDPOINT, json=payload, headers=headers, timeout=90)
        response.raise_for_status()
        result = response.json()
        
        # Extract content from response
        content = result.get("choices", [{}])[0].get("message", {}).get("content", "{}")
        
        # Clean up markdown code blocks if present
        if content.startswith("```"):
            # Remove markdown formatting
            lines = content.split("\n")
            content = "\n".join([l for l in lines if not l.startswith("```")])
            # Remove 'json' tag if present
            if content.strip().startswith("json"):
                content = content.strip()[4:]
        
        # Validate it's proper JSON
        parsed = json.loads(content.strip())
        
        # Return cleaned JSON string
        return json.dumps(parsed)
        
    except requests.exceptions.Timeout:
        return json.dumps({"error": "API timeout", "extraction_confidence": "failed"})
    except requests.exceptions.RequestException as e:
        return json.dumps({"error": f"API error: {str(e)}", "extraction_confidence": "failed"})
    except json.JSONDecodeError as e:
        return json.dumps({"error": f"JSON parse error: {str(e)}", "raw_response": content[:200], "extraction_confidence": "failed"})
    except Exception as e:
        return json.dumps({"error": f"Unexpected error: {str(e)}", "extraction_confidence": "failed"})

print("✅ Extraction UDF defined")

In [0]:
# Read from bronze/raw zone
parsed_df = spark.table("healthcare_catalog.raw_zone.parsed_documents")

# Show sample
print(f"Total documents: {parsed_df.count()}")
display(parsed_df.select("path", "text").limit(3))

# COMMAND ----------

# DBTITLE 1,Extract Entities from All Documents
# Apply extraction UDF
print("Starting entity extraction...")

extracted = parsed_df.withColumn("entities", extract_entities_llama(col("text")))

# Show sample results
print("\n✅ Extraction complete! Sample results:")
display(extracted.select("path", "entities").limit(3))

In [0]:
# DBTITLE 1,Validate Extraction Quality
# Check for extraction errors
from pyspark.sql.functions import get_json_object

extraction_stats = extracted.select(
    (get_json_object(col("entities"), "$.error").isNotNull()).alias("has_error"),
    get_json_object(col("entities"), "$.extraction_confidence").alias("confidence")
).groupBy("has_error", "confidence").count()

print("Extraction Statistics:")
display(extraction_stats)

# COMMAND ----------

# DBTITLE 1,Filter Out Failed Extractions
# Keep only successful extractions
successful_extractions = extracted.filter(
    get_json_object(col("entities"), "$.error").isNull()
)

print(f"✅ Successful extractions: {successful_extractions.count()}")
print(f"❌ Failed extractions: {extracted.count() - successful_extractions.count()}")

In [0]:
# DBTITLE 1,Save to Silver Zone
# Save extracted entities to Delta table
successful_extractions.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("healthcare_catalog.silver_zone.extracted_entities")

print("✅ Saved to: healthcare_catalog.silver_zone.extracted_entities")

# Verify the save
result_count = spark.table("healthcare_catalog.silver_zone.extracted_entities").count()
print(f"✅ Verified: {result_count} records in silver zone")

# COMMAND ----------

# DBTITLE 1,View Sample Extracted Entities
# Display nicely formatted results
sample_entities = spark.sql("""
    SELECT 
        path,
        get_json_object(entities, '$.chief_complaint') as chief_complaint,
        get_json_object(entities, '$.symptoms') as symptoms,
        get_json_object(entities, '$.diagnoses') as diagnoses,
        get_json_object(entities, '$.medications') as medications,
        get_json_object(entities, '$.extraction_confidence') as confidence
    FROM healthcare_catalog.silver_zone.extracted_entities
    LIMIT 10
""")

display(sample_entities)

In [0]:
# Show final stats
spark.sql("""
    SELECT 
        COUNT(*) as total_documents,
        SUM(CASE WHEN get_json_object(entities, '$.extraction_confidence') = 'high' THEN 1 ELSE 0 END) as high_confidence,
        SUM(CASE WHEN get_json_object(entities, '$.extraction_confidence') = 'medium' THEN 1 ELSE 0 END) as medium_confidence,
        SUM(CASE WHEN get_json_object(entities, '$.extraction_confidence') = 'low' THEN 1 ELSE 0 END) as low_confidence
    FROM healthcare_catalog.silver_zone.extracted_entities
""").show()