In [1]:
from aiohttp.web_exceptions import HTTPResetContent
from sympy.physics.units import temperature
# Title: Biomedical NER and SDoH Extraction from PubMed Abstracts with LLM Verification

# --------------------------- 1. Setup and Imports ---------------------------
!pip install transformers torch nltk seqeval spacy tqdm matplotlib sentence-transformers numpy==1.26.4
!python -m spacy download en_core_web_sm

import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, AutoTokenizer, AutoModelForCausalLM
import nltk
from tqdm import tqdm
import spacy
from collections import Counter
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util
import os
from nltk.tokenize import sent_tokenize

# Ensure NLTK punkt data is downloaded
nltk.download('punkt')


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: Invalid requirement: 'numpy=1.26.4': Expected end or semicolon (after name and no valid version specifier)
    numpy=1.26.4
         ^
Hint: = is not a valid operator. Did you mean == ?[0m[31m
[0mCollecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /Users/Tim/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [12]:

# --------------------------- 2. Load Data and Models ---------------------------

# Update these paths if needed
input_file = '../MS_SDoH_pubmed_abstracts_20241127.json'
output_file = '../Processed_MS_SDoH_pubmed_abstracts_with_entities.json'

with open(input_file, 'r', encoding='utf-8') as f:
    articles = json.load(f)

# Load BioBERT NER model and tokenizer
model_name = "d4data/biomedical-ner-all"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    ner_model = AutoModelForTokenClassification.from_pretrained(model_name)
    print("BioBERT NER model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading model '{model_name}': {e}")
    import sys
    sys.exit()

# Create NER pipeline
device = 0 if torch.cuda.is_available() else -1
ner_pipeline = pipeline("ner", model=ner_model, tokenizer=tokenizer, aggregation_strategy="simple", device=device)
print(f"Using device: {'GPU' if device == 0 else 'CPU'}")

# Test the NER pipeline on a sample sentence
test_sentence = "The patient reported no recurrence of palpitations at follow-up 6 months after the ablation."
print("\nTesting NER pipeline with a sample sentence...")
try:
    test_entities = ner_pipeline(test_sentence)
    print("Test sentence entities:")
    print(json.dumps(test_entities, indent=4))
except Exception as e:
    print(f"Error during NER pipeline test: {e}")

BioBERT NER model and tokenizer loaded successfully.
Using device: CPU

Testing NER pipeline with a sample sentence...
Test sentence entities:
Error during NER pipeline test: Object of type float32 is not JSON serializable


In [None]:

# --------------------------- 3. Perform NER on All Articles ---------------------------
all_entities = []
for article in tqdm(articles, desc="Processing abstracts"):
    abstract = article.get('abstract', '')
    if abstract:
        entities = ner_pipeline(abstract)
        article['entities'] = entities
        all_entities.extend(entities)
    else:
        article['entities'] = []

# Count entity types
entity_labels = [entity['entity_group'] for entity in all_entities]
entity_counts = Counter(entity_labels)

print("Entity Counts:")
for label, count in entity_counts.items():
    print(f"{label}: {count}")

# Optional: Plot entity counts
if len(entity_counts) > 0:
    labels, counts = zip(*entity_counts.items())
    plt.bar(labels, counts)
    plt.xlabel('Entity Type')
    plt.ylabel('Count')
    plt.title('Entity Counts by Type')
    plt.show()

In [11]:

# --------------------------- 4. Define the Refined SDoH Ontology with Exact Matches ---------------------------
sdoh_ontology = {
    "socioeconomic_factors": {
        "poverty_low_income": [
            "poverty", 
            "living in poverty",
            "low income",
            "low-income population",
            "financial hardship",
            "low socioeconomic status"
        ],
        "high_income_wealthy": [
            "high income",
            "high-income population",
            "wealthy communities",
            "high ses"
        ],
        "lower_education": [
            "low education",
            "low educational attainment",
            "limited schooling"
        ],
        "higher_education": [
            "high education",
            "college-educated",
            "high educational level"
        ],
        "lower_social_class": [
            "lower social class",
            "lower class",
            "working class"
        ],
        "upper_social_class": [
            "upper social class",
            "middle class",
            "upper class"
        ]
    },
    "lifestyle_behavioral_factors": {
        "nutrition_poor": [
            "poor nutrition",
            "unhealthy dietary habits",
            "unhealthy diet",
            "bad nutrition"
        ],
        "nutrition_high_fat": [
            "high-fat diet",
            "high fat diet",
            "high fat dietary intake"
        ],
        "nutrition_balanced": [
            "nutrition",
            "good nutrition",
            "balanced diet",
            "nutrient-rich diet",
            "healthy diet",
            "balanced nutrition"
        ],
        "stress_high": [
            "high stress",
            "high stress levels",
            "chronic stress"
        ],
        "stress_low": [
            "low stress",
            "low stress levels",
            "reduced stress environment"
        ],
        "physical_activity_high": [
            "physical activity",
            "regular exercise",
            "physically active lifestyle"
        ],
        "physical_activity_low": [
            "sedentary lifestyle",
            "sedentary behavior"
        ],
        "smoking_tobacco_use": [
            "smoking",
            "smoker",
            "heavy smoking",
            "current smoker",
            "former smoker",
            "tobacco use"
        ],
        "alcohol_use": [
            "alcohol use",
            "alcohol consumption",
            "excessive alcohol consumption",
            "moderate drinking",
            "alcohol misuse"
        ],
        "substance_abuse": [
            "substance abuse",
            "illicit drug use",
            "substance dependence",
            "drug misuse"
        ],
        "general_unhealthy_lifestyle": [
            "unhealthy lifestyle",
            "unhealthy life"
        ],
        "general_healthy_lifestyle": [
            "healthy lifestyle",
            "health-promoting behaviors"
        ]
    },
    "environmental_factors": {
        "housing": [
            "housing",
            "poor housing conditions",
            "overcrowded housing",
            "stable housing"
        ],
        "neighborhood": [
            "neighborhood",
            "low-resource neighborhood",
            "unsafe neighborhood",
            "affluent neighborhood"
        ],
        "pollution_high": [
            "pollution",
            "high pollution exposure",
            "environmental toxins"
        ],
        "urban_environment": [
            "urban",
            "urban environment",
            "city dwelling"
        ],
        "rural_environment": [
            "rural",
            "rural setting",
            "remote area"
        ],
        "high_latitude": [
            "high latitude",
            "high-latitude region"
        ],
        "low_latitude": [
            "low latitude",
            "low-latitude region"
        ],
        "transportation_limited": [
            "transportation",
            "limited transportation access",
            "inadequate public transit"
        ],
        "food_insecurity": [
            "food insecurity",
            "food insecure",
            "limited healthy food access"
        ],
        "general_environment": [
            "environment",
            "environmental factors",
            "local environmental conditions"
        ]
    },
    "cultural_social_factors": {
        "demographic_identity_factors": [
            "ethnicity",
            "race",
            "gender",
            "ethnic minority group",
            "racial minority",
            "diverse racial background",
            "female sex",
            "male sex",
            "gender minorities"
        ],
        "social_support_cultural_context": [
            "social support",
            "cultural beliefs",
            "community",
            "family",
            "strong social support network",
            "lack of social support",
            "traditional cultural norms",
            "tight-knit community",
            "dysfunctional family environment"
        ]
    },
    "health_system_related_factors": {
        "access_to_care_limited": [
            "no access to care",
            "limited access to healthcare",
            "barriers to care"
        ],
        "access_to_care_adequate": [
            "access to care",
            "adequate access to healthcare",
            "improved healthcare availability"
        ],
        "health_disparities_inequality": [
            "health disparities",
            "inequality",
            "health inequities",
            "unequal healthcare access",
            "healthcare inequality"
        ],
        "health_literacy_low": [
            "low health literacy",
            "limited health literacy"
        ],
        "health_literacy_high": [
            "health literacy",
            "high health literacy",
            "understanding health information"
        ]
    },
    "violence_safety_factors": {
        "violence": [
            "violence",
            "exposure to violence",
            "interpersonal violence",
            "unsafe environment"
        ]
    }
}

# --------------------------- 5. Prepare Embeddings for SDoH Ontology ---------------------------
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

flattened_ontology = {}
for main_cat, subcats in sdoh_ontology.items():
    for subcat, phrases in subcats.items():
        category_name = f"{main_cat}::{subcat}"
        flattened_ontology[category_name] = [p.lower() for p in phrases]

ontology_embeddings = {}
for cat_name, phrases in flattened_ontology.items():
    cat_embeddings = embedding_model.encode(phrases, convert_to_tensor=True)
    ontology_embeddings[cat_name] = cat_embeddings

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:

# --------------------------- LLM Initialization ---------------------------
# Ensure you have the appropriate model and permission
import torch
from transformers import pipeline

# Your Hugging Face access token
hf_token = "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"  # Replace with your actual token

model_id = "meta-llama/Llama-3.2-1B"

# Initialize the LLM pipeline with proper authentication and parameters
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.float16, # Use torch.bfloat16 if supported by your hardware
    device_map="cpu",
    token=hf_token,           # Correct parameter for authentication
    trust_remote_code=True    # Enable execution of remote code if necessary
)

# Example usage
prompt = "Once upon a time"
response = pipe(prompt, max_length=50, truncation=True, temperature=0.4)
print(response)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


[{'generated_text': 'Once upon a time, a young man named Robert was born into a family of farmers. His father was a hardworking man who loved his family and his farm. His mother was a gentle woman who took care of the family and the farm.'}]


In [19]:

import logging

# --------------------------- Logging Configuration ---------------------------
logging.basicConfig(
    level=logging.DEBUG,  # Set to DEBUG for detailed logging
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        # Uncomment the following line to log to a file
        # logging.FileHandler('sdoh_extraction_enhanced.log')
    ]
)

logger = logging.getLogger(__name__)
logger.info(f"Beginning extraction and mapping of SDoH mentions from {len(articles)} articles.")

all_sdoh_mentions = []

LOW_THRESHOLD=0.5
HIGH_THRESHOLD=0.7

for idx, article in enumerate(tqdm(articles, desc="Processing Articles")):
    logger.debug(f"Processing Article {idx + 1}/{len(articles)}")
    abstract = article.get('abstract', '')
    sdoh_mentions = []
    
    if abstract:
        logger.debug("Abstract found. Proceeding with noun phrase extraction.")
        
        # Extract noun phrases using SpaCy
        try:
            doc = nlp(abstract)
            noun_phrases = [chunk.text.strip() for chunk in doc.noun_chunks]
            logger.debug(f"Extracted {len(noun_phrases)} noun phrases from abstract.")
        except Exception as e:
            logger.error(f"Error during noun phrase extraction: {e}")
            noun_phrases = []
        
        # Tokenize sentences using NLTK
        try:
            sentences = nltk.sent_tokenize(abstract)
            logger.debug(f"Tokenized abstract into {len(sentences)} sentences.")
        except Exception as e:
            logger.error(f"Error during sentence tokenization: {e}")
            sentences = []
        
        for phrase_idx, phrase in enumerate(noun_phrases):
            logger.debug(f"Processing Phrase {phrase_idx + 1}/{len(noun_phrases)}: '{phrase}'")
            phrase_text = phrase.lower()
            
            # Encode the phrase to get its embedding
            try:
                phrase_embedding = embedding_model.encode(phrase_text, convert_to_tensor=True)
                logger.debug(f"Encoded phrase '{phrase_text}' into embedding.")
            except Exception as e:
                logger.error(f"Error encoding phrase '{phrase_text}': {e}")
                continue  # Skip this phrase if encoding fails
            
            best_cat = None
            best_score = -1.0
            
            # Compare phrase with each ontology category's embeddings
            for cat_idx, (cat_name, cat_embeddings) in enumerate(ontology_embeddings.items(), start=1):
                try:
                    cosine_scores = util.cos_sim(phrase_embedding, cat_embeddings).squeeze(0)
                    max_score, _ = torch.max(cosine_scores, dim=0)
                    max_score = max_score.item()
                    logger.debug(f"Category '{cat_name}' (Index {cat_idx}): Cosine Score = {max_score:.4f}")
                    
                    if max_score > best_score:
                        best_score = max_score
                        best_cat = cat_name
                        logger.debug(f"New best category for phrase '{phrase}': '{best_cat}' with score {best_score:.4f}")
                except Exception as e:
                    logger.error(f"Error computing cosine similarity for category '{cat_name}': {e}")
                    continue  # Skip this category if similarity computation fails
            
            if best_cat is None:
                logger.warning(f"No suitable category found for phrase '{phrase}'. Skipping.")
                continue  # Skip if no category is found
            
            logger.debug(f"Best category for phrase '{phrase}': '{best_cat}' with similarity {best_score:.4f}")
            
            if best_score < LOW_THRESHOLD:
                logger.info(f"Phrase '{phrase}' similarity {best_score:.4f} below LOW_THRESHOLD {LOW_THRESHOLD}. Discarding.")
                continue  # Discard the phrase
            elif best_score > HIGH_THRESHOLD:
                logger.info(f"Phrase '{phrase}' assigned to category '{best_cat}' with high similarity {best_score:.4f}.")
                sdoh_mentions.append({
                    'phrase': phrase,
                    'category': best_cat,
                    'similarity': best_score
                })
                all_sdoh_mentions.append(best_cat)
            else:
                logger.info(f"Phrase '{phrase}' similarity {best_score:.4f} is between LOW_THRESHOLD and HIGH_THRESHOLD. Initiating LLM verification.")
                
                # Find the sentence containing the phrase
                phrase_sent = next((s for s in sentences if phrase in s), abstract)
                logger.debug(f"Selected sentence for LLM verification: '{phrase_sent}'")
                
                # Create the prompt for the LLM
                prompt = (
                    f"Context: {phrase_sent}\n\n"
                    f"Phrase: '{phrase}'\n"
                    f"Proposed Category: '{best_cat}'\n\n"
                    "The phrase above is proposed to be classified under the given category of social determinants of health. "
                    "Is it reasonable and semantically correct to classify this phrase under that category? "
                    "Answer 'yes' if it is correct and 'no' if not."
                )
                
                logger.debug(f"LLM Prompt: {prompt}")
                
                # Generate response from LLM
                try:
                    llm_response = pipe(
                        prompt,
                        max_new_tokens=5,
                        temperature=0.4,
                        top_k=50,
                        top_p=0.95,
                        return_full_text=False
                    )[0]['generated_text'].strip().lower()
                    logger.debug(f"LLM Response: '{llm_response}'")
                except Exception as e:
                    logger.error(f"Error during LLM generation for phrase '{phrase}': {e}")
                    continue  # Skip LLM verification if generation fails
                
                # Interpret LLM response
                if "yes" in llm_response and "no" not in llm_response:
                    logger.info(f"LLM verified phrase '{phrase}' as correctly categorized under '{best_cat}'.")
                    sdoh_mentions.append({
                        'phrase': phrase,
                        'category': best_cat,
                        'similarity': best_score,
                        'verified_by_llm': True
                    })
                    all_sdoh_mentions.append(best_cat)
                else:
                    logger.info(f"LLM did not verify phrase '{phrase}' under category '{best_cat}'. Skipping.")
                    continue  # Do not include the phrase if not verified by LLM
    else:
        logger.warning(f"Article {idx + 1} does not contain an abstract. Skipping SDoH extraction.")
    
    # Assign the extracted SDoH mentions to the article
    article['sdoh_mentions'] = sdoh_mentions
    logger.debug(f"Assigned {len(sdoh_mentions)} SDoH mentions to Article {idx + 1}.")

# --------------------------- Counting SDoH Mentions ---------------------------
logger.info("Counting all SDoH mentions across processed articles.")
sdoh_counts = Counter(all_sdoh_mentions)

logger.info("\nSDoH Mentions Counts:")
for sdoh, count in sdoh_counts.items():
    logger.info(f"{sdoh}: {count}")

# --------------------------- Plotting Top SDoH Mentions ---------------------------
if sdoh_counts:
    logger.info("Plotting top 10 SDoH mentions.")
    try:
        top_sdoh = sdoh_counts.most_common(10)
        sdoh_labels, sdoh_counts_values = zip(*top_sdoh)
        plt.figure(figsize=(12, 8))
        plt.bar(sdoh_labels, sdoh_counts_values, color='skyblue')
        plt.xlabel('SDoH Category')
        plt.ylabel('Count')
        plt.title('Top 10 SDoH Mentions')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
        logger.info("Plot displayed successfully.")
    except Exception as e:
        logger.error(f"Error during plotting SDoH mentions: {e}")
else:
    logger.info("No SDoH mentions found to plot.")

2024-12-12 16:54:12,830 - INFO - Beginning extraction and mapping of SDoH mentions from 523 articles.
Processing Articles:   0%|          | 0/523 [00:00<?, ?it/s]2024-12-12 16:54:12,837 - DEBUG - Processing Article 1/523
2024-12-12 16:54:12,838 - DEBUG - Abstract found. Proceeding with noun phrase extraction.
2024-12-12 16:54:12,855 - DEBUG - Extracted 24 noun phrases from abstract.
2024-12-12 16:54:12,856 - DEBUG - Tokenized abstract into 4 sentences.
2024-12-12 16:54:12,856 - DEBUG - Processing Phrase 1/24: 'Autoimmune diseases'

Batches: 100%|██████████| 1/1 [00:00<00:00, 10.54it/s]
2024-12-12 16:54:12,962 - DEBUG - Encoded phrase 'autoimmune diseases' into embedding.
2024-12-12 16:54:12,983 - DEBUG - Category 'socioeconomic_factors::poverty_low_income' (Index 1): Cosine Score = 0.0941
2024-12-12 16:54:12,984 - DEBUG - New best category for phrase 'Autoimmune diseases': 'socioeconomic_factors::poverty_low_income' with score 0.0941
2024-12-12 16:54:12,987 - DEBUG - Category 'socioeco

KeyboardInterrupt: 

In [None]:

# --------------------------- 7. Save Processed Data ---------------------------
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(articles, f, ensure_ascii=False, indent=4)

print(f"Processed data with mapped entities saved to {output_file}")