<a href="https://colab.research.google.com/github/sktexaslearn/MIMIC-NLP/blob/main/mimic_nlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries
!pip uninstall -y spacy scispacy en-core-web-sm en-core-sci-sm
!pip install spacy==3.4.4
!python -m spacy download en_core_web_sm-3.4.0 --direct
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
!pip install bigframes
!pip install gensim
!pip install matplotlib scikit-learn
!pip install transformers

# Import libraries
import bigframes.pandas as bf
import spacy
import scispacy
from gensim.models import Word2Vec
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
import pandas as pd

# data with age calculation
df = bf.read_gbq("""
SELECT
    n.text,
    p.drug,
    a.subject_id,
    a.hadm_id,
    DATE_DIFF(a.admittime, pat.dob, YEAR) AS age,  -- Calculate age from DOB
    pat.gender,
    d.icd9_code,
    DATE_DIFF(a.dischtime, a.admittime, DAY) AS los
FROM `physionet-data.mimiciii_notes.noteevents` n
JOIN `physionet-data.mimiciii_clinical.prescriptions` p
    ON n.subject_id = p.subject_id
JOIN `physionet-data.mimiciii_clinical.admissions` a
    ON n.hadm_id = a.hadm_id
JOIN `physionet-data.mimiciii_clinical.patients` pat
    ON n.subject_id = pat.subject_id
JOIN `physionet-data.mimiciii_clinical.diagnoses_icd` d
    ON n.hadm_id = d.hadm_id
WHERE n.category = 'Discharge summary'
AND p.drug IS NOT NULL

""")

# Extract data
query_job = df.head(50)
sample_notes = query_job['text'].tolist()
prescribed_drugs = query_job['drug'].tolist()
ages = query_job['age'].tolist()
genders = query_job['gender'].tolist()
icd9_codes = query_job['icd9_code'].tolist()
los = query_job['los'].tolist()
print(f"Number of discharge summaries: {len(sample_notes)}")
print("Sample note excerpt:", sample_notes[0][:200])
print("Sample drug:", prescribed_drugs[0], "Age:", ages[0], "Gender:", genders[0], "LOS:", los[0])

# Load SpaCy and SciSpaCy models
nlp_spacy = spacy.load("en_core_web_sm")
nlp_scispacy = spacy.load("en_core_sci_sm")

# Function to extract ADE pairs with severity
def extract_ade_pairs(text, prescribed_drugs, nlp_scispacy):
    doc = nlp_scispacy(text.lower())
    entities = [ent.text.lower() for ent in doc.ents]
    drug_mentions = [d.lower() for d in prescribed_drugs if d.lower() in entities]
    ade_terms = [
        "rash", "bleeding", "nausea", "dizziness", "pain", "fever", "swelling",
        "vomiting", "diarrhea", "headache", "fatigue", "allergy", "infection",
        "cough", "shortness", "breath", "itching", "confusion", "weakness"
    ]
    ade_mentions = [e for e in entities if any(term in e for term in ade_terms)]

    # Severity classification
    severity = "mild"  # Default
    if any(word in text.lower() for word in ["severe", "critical", "emergency", "life-threatening"]):
        severity = "severe"
    elif any(word in text.lower() for word in ["moderate", "significant", "persistent"]):
        severity = "moderate"

    return list(set(drug_mentions)), list(set(ade_mentions)), severity

# ADE data
ade_data = []
for i, (note, drug, age, gender, icd9, los_val) in enumerate(zip(sample_notes, prescribed_drugs, ages, genders, icd9_codes, los)):
    drugs, ades, severity = extract_ade_pairs(note, [drug], nlp_scispacy)
    if drugs and ades:
        ade_data.append({
            "drug": drugs[0], "ade": ades[0], "severity": severity,
            "age": age, "gender": gender, "icd9": icd9, "los": los_val
        })
print("Sample ADE data:", ade_data[:5] if ade_data else "No ADE pairs found")

# notes for Word2Vec
def preprocess_text(text):
    doc = nlp_scispacy(text)
    return [token.text.lower() for token in doc if not token.is_stop and not token.is_punct]

processed_notes = [preprocess_text(note) for note in sample_notes]

# Train Word2Vec model
w2v_model = Word2Vec(sentences=processed_notes, vector_size=100, window=5, min_count=2, workers=2)
print("Word2Vec vocabulary size:", len(w2v_model.wv))

# t-SNE on ADE terms
relevant_terms = set()
for entry in ade_data:
    relevant_terms.add(entry["drug"])
    relevant_terms.add(entry["ade"])
relevant_terms = [term for term in relevant_terms if term in w2v_model.wv]
if not relevant_terms:
    print("No ADE terms for t-SNE, using top 50 vocab terms")
    relevant_terms = list(w2v_model.wv.key_to_index)[:50]
embeddings = [w2v_model.wv[term] for term in relevant_terms]
embeddings = np.array(embeddings)

perplexity = min(5, len(embeddings) - 1)
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
tsne_results = tsne.fit_transform(embeddings)

plt.figure(figsize=(10, 6))
plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c='blue', alpha=0.5)
for i, term in enumerate(relevant_terms):
    plt.annotate(term, (tsne_results[i, 0], tsne_results[i, 1]), fontsize=9)
plt.title("t-SNE Visualization of ADE-Related Terms")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.savefig("tsne_plot_ade.png")
plt.show()

# ClinicalBERT
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
clinicalbert_tokens = tokenizer(sample_notes[0][:512], return_tensors="pt", truncation=True, padding=True)
tokens = tokenizer.convert_ids_to_tokens(clinicalbert_tokens["input_ids"][0])
print("ClinicalBERT Tokens:", tokens[:20])

# Output
if ade_data:
    ade_df = pd.DataFrame(ade_data)
    print("\nADE Summary Table:")
    print(ade_df[["drug", "ade", "severity", "age", "gender", "los"]].to_string(index=False))

    # Demographic Impact
    print("\nAverage Age by Severity:")
    print(ade_df.groupby("severity")["age"].mean().round(1))
    print("\nADE Frequency by Gender:")
    print(ade_df["gender"].value_counts())

    # Outcome Analysis
    ade_los = ade_df["los"].mean()
    non_ade_los = np.mean([l for i, l in enumerate(los) if i not in [j for j, _ in enumerate(ade_data)]])
    print(f"\nAverage LOS - ADE Patients: {ade_los:.1f} days, Non-ADE Patients: {non_ade_los:.1f} days")

    # Bar Chart
    plt.figure(figsize=(8, 5))
    ade_df["severity"].value_counts().plot(kind="bar", color="skyblue")
    plt.title("Distribution of ADE Severity")
    plt.xlabel("Severity")
    plt.ylabel("Frequency")
    plt.savefig("severity_bar.png")
    plt.show()

# results
query_job.to_csv("ade_notes.csv", index=False)
with open("ade_data.txt", "w") as f:
    f.write("ADE Data:\n" + str(ade_data))