In [2]:
import pandas as pd
import re
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
from tqdm import tqdm

In [3]:
transfers_df = pd.read_csv("/Users/jacobsussman/Desktop/mimic-iv-clinical-database-demo-2.2/hosp/transfers.csv.gz")
radiology_notes = pd.read_csv("/Users/jacobsussman/Downloads/mimic-iv-note-deidentified-free-text-clinical-notes-2/note/radiology.csv.gz")

In [4]:
icu_units = [
    "Trauma SICU (TSICU)", "Medical Intensive Care Unit (MICU)",
    "Surgical Intensive Care Unit (SICU)", "Medical/Surgical Intensive Care Unit (MICU/SICU)",
    "PACU", "Neuro Surgical Intensive Care Unit (Neuro SICU)"
]

transfers_df["intime"] = pd.to_datetime(transfers_df["intime"], errors="coerce")
icu_transfers = transfers_df[transfers_df["careunit"].isin(icu_units)].copy()

first_icu_transfer = (
    icu_transfers.groupby("subject_id", as_index=False)["intime"]
    .min()
    .rename(columns={"intime": "first_icu_time"})
)

def clean_text(text):
    text = re.sub(r"\[\*\*.*?\*\*\]", " ", text)
    text = re.sub(r"[^a-zA-Z0-9.,;:()\-/% ]", " ", text)
    text = re.sub(r"\s+", " ", text)
    return text.strip().lower()

radiology_notes["charttime"] = pd.to_datetime(radiology_notes["charttime"], errors="coerce")
radiology_notes["clean_text"] = radiology_notes["text"].apply(clean_text)

notes_with_icu = radiology_notes.merge(first_icu_transfer, on="subject_id", how="left")

pre_icu_notes = notes_with_icu[
    (notes_with_icu["first_icu_time"].isna()) |
    (notes_with_icu["charttime"] < notes_with_icu["first_icu_time"])
].copy()

filtered_notes = (
    pre_icu_notes.groupby("subject_id")["clean_text"]
    .apply(lambda x: " ".join(x))
    .reset_index()
)

print(filtered_notes.head())


   subject_id                                         clean_text
0    10000032  examination: chest (pa and lat) indication: wi...
1    10000084  examination: chest (pa and lat) indication: hi...
2    10000102  chest pa and lateral. comparison: none. histor...
3    10000108  history: male with left buccal abscess and nec...
4    10000117  bilateral digital screening mammogram with cad...


In [4]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model.to("cuda" if torch.cuda.is_available() else "cpu")
bert_model.eval()



config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]



pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [5]:
def get_embedding(text, tokenizer, model, max_length=512):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].cpu().numpy().flatten()

In [7]:
embeddings = []
for text in tqdm(adm_notes["clean_text"]):
    embeddings.append(get_embedding(text, tokenizer, bert_model))

adm_notes["embedding"] = embeddings

  0%|                                    | 39/309670 [00:24<54:39:37,  1.57it/s]


KeyboardInterrupt: 