In [1]:
pip install torch transformers pandas tqdm numpy

Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import pandas as pd

In [None]:
# Load ClinicalBERT
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [4]:
def get_embedding(text):
    # Tokenize and truncate
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Use CLS token representation as embedding
    cls_embedding = outputs.last_hidden_state[:, 0, :]  # shape: (1, 768)
    return cls_embedding.squeeze().cpu().numpy()        # shape: (768,)

In [3]:
df=pd.read_csv("Discharge_Summaries.csv")
df

Unnamed: 0.1,Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT
0,0,1112,192293,Admission Date: [**2161-8-22**] ...
1,1,57199,135740,Admission Date: [**2139-10-29**] ...
2,2,24573,184883,Admission Date: [**2128-1-30**] ...
3,3,15747,184674,Admission Date: [**2142-11-7**] ...
4,4,56027,106560,Admission Date: [**2189-6-20**] ...
...,...,...,...,...
59647,59647,8566,122833,"Name: [**Known lastname 8238**], [**Known fir..."
59648,59648,1034,133565,"Name: [**Known lastname 17760**], [**Known fi..."
59649,59649,91388,175262,"Name: [**Known lastname 7062**],[**Known firs..."
59650,59650,13436,195728,"Name: [**Known lastname **], [**Known firstna..."


In [9]:
sample_text = df["TEXT"].iloc[0]
embedding = get_embedding(sample_text)
print(embedding.shape)  # Should be (768,)

(768,)


In [10]:
def get_embedding_chunked(text, max_length=512, stride=256):
    tokens = tokenizer(text, return_tensors="pt", return_overflowing_tokens=True,
                       truncation=True, max_length=max_length, stride=stride)
    
    all_embeddings = []
    for i in range(len(tokens["input_ids"])):
        input_batch = {k: v[i:i+1].to(device) for k, v in tokens.items() if isinstance(v, torch.Tensor)}
        with torch.no_grad():
            output = model(**input_batch)
        cls_emb = output.last_hidden_state[:, 0, :]  # CLS token
        all_embeddings.append(cls_emb.cpu())
    
    return torch.mean(torch.cat(all_embeddings, dim=0), dim=0).numpy()