In [49]:
import torch
from transformers import LongformerTokenizer, LongformerModel
import pandas as pd
import numpy as np
import faiss

# Load Clinical-Longformer model and tokenizer
longformer_model_name = "yikuan8/Clinical-Longformer"
tokenizer = LongformerTokenizer.from_pretrained(longformer_model_name)
embedding_model = LongformerModel.from_pretrained(longformer_model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model.to(device)

Some weights of LongformerModel were not initialized from the model checkpoint at yikuan8/Clinical-Longformer and are newly initialized: ['longformer.pooler.dense.bias', 'longformer.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
def generate_visit_embeddings(texts, batch_size=8):
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(device)
        with torch.no_grad():
            outputs = embedding_model(**inputs)
        batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        all_embeddings.append(batch_embeddings)
    return np.vstack(all_embeddings)

In [4]:
data_df = pd.read_csv("data/discharge_journal_df.csv")

In [5]:
all_texts = data_df['text'].tolist()
visit_embeddings = generate_visit_embeddings(all_texts)

Input ids are automatically padded to be a multiple of `config.attention_window`: 512


In [6]:
patient_embeddings = {}
for subject_id, group in data_df.groupby('subject_id'):
    indices = group.index.tolist()
    patient_embedding = np.mean(visit_embeddings[indices], axis=0)
    patient_embeddings[subject_id] = patient_embedding

In [7]:
# Convert patient embeddings dictionary to array for indexing
patient_ids = list(patient_embeddings.keys())
embeddings_array = np.array(list(patient_embeddings.values()))

# Create FAISS index
index = faiss.IndexFlatL2(embeddings_array.shape[1])
index.add(embeddings_array)

In [8]:
df = pd.read_csv("data/training_text.csv")

# Combine the specified columns into a single text block for each row
df['masked_text'] = df.apply(lambda row: '\n'.join([
    f"Admission Type:\n{row['Admission Type']}",
    f"Chief Complaint:\n{row['Chief Complaint']}",
    f"Major Surgical or Invasive Procedure:\n{row['Major Surgical or Invasive Procedure']}",
    f"History of Present Illness:\n{row['History of Present Illness']}",
    f"Past Medical History:\n{row['Past Medical History']}",
    f"Social History:\n{row['Social History']}",
    f"Physical Exam:\n{row['Physical Exam']}",
    f"Brief Hospital Course:\n{row['Brief Hospital Course']}"
]), axis=1)

df = df[['masked_text']]
# Save the new DataFrame to a CSV file
output_csv = 'data/output_with_masked_text.csv'
df.to_csv(output_csv, index=False)

In [64]:
# Example new visit data and generating its embedding
new_visit_data = df['masked_text'][4]
new_visit_embedding = generate_visit_embeddings([new_visit_data])

# Find similar patients
distances, indices = index.search(new_visit_embedding, k=2)
print("Distances:", distances)
print("Indices:", indices)

# Retrieve patient data based on indices
retrieved_patient_data_df = data_df[data_df['subject_id'].isin([patient_ids[i] for i in indices[0]])]
print("Retrieved Patient Data:\n", retrieved_patient_data_df)

Distances: [[0.36186546 0.3762812 ]]
Indices: [[6696 3696]]
Retrieved Patient Data:
               note_id  subject_id   hadm_id   charttime  \
8089    10260254-DS-2    10260254  28861371  2147-12-15   
14870  10472450-DS-20    10472450  24626805  2140-10-06   
14871  10472450-DS-21    10472450  24553761  2141-04-12   
14872  10472450-DS-22    10472450  23142912  2141-04-26   

                                                    text  
8089   Subject ID: 10260254, HAdm ID: 28861371, Chart...  
14870  Subject ID: 10472450, HAdm ID: 24626805, Chart...  
14871  Subject ID: 10472450, HAdm ID: 24553761, Chart...  
14872  Subject ID: 10472450, HAdm ID: 23142912, Chart...  


In [51]:
from transformers import AutoTokenizer, LongT5ForConditionalGeneration
# Load LongT5Model for summarization
longt5_model_name = "google/long-t5-tglobal-base"
longt5_tokenizer = AutoTokenizer.from_pretrained(longt5_model_name)
longt5_model = LongT5ForConditionalGeneration.from_pretrained(longt5_model_name)



In [56]:
torch.cuda.empty_cache()

In [65]:
longt5_model.to(device)
def summarize_text(text, max_length=512):
    inputs = longt5_tokenizer.encode(text, return_tensors="pt", max_length=8192, truncation=True).to(device)
    summary_ids = longt5_model.generate(inputs, max_length=max_length, min_length=40, length_penalty=2.0, num_beams=2)
    summary = longt5_tokenizer.decode(summary_ids[0].cpu().detach(), skip_special_tokens=True)
    return summary

def summarize_retrieved_data(retrieved_data):
    summaries = []
    for text in retrieved_data['text']:
        summary = summarize_text(text)
        summaries.append(summary)
    return " ".join(summaries)

# Summarize retrieved patient data
summarized_text = summarize_retrieved_data(retrieved_patient_data_df)
print("Summarized Text:\n", summarized_text)



Summarized Text:


In [66]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

phi_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct",trust_remote_code=True)
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")

messages = [
    {"role": "user", "content": "Recommend: "+summarized_text}
]
pipe = pipeline(
    "text-generation",
    model=phi_model,
    tokenizer=phi_tokenizer,
)

generation_args = {
    "max_new_tokens": 500,
    "return_full_text": False,
    "temperature": 0.0,
    "do_sample": False
}

output = pipe(messages, **generation_args)
print(output[0]['generated_text'])

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


 Chief Complaint: Shortness of breath

Major Surgical or Invasive Procedure: None

History of Present Illness: A 55-year-old male with a past medical history of nonischemic cardiomyopathy (ejection fraction 10%), bioprosthetic mitral valve replacement, tachy-brady syndrome status post DDD pacemaker on the left side, and a history of an implantable cardioverter-defibrillator (ICD) shock delivered 2 months ago. The patient has a history of alcoholic cirrhosis, end-stage renal disease on dialysis, and presents with shortness of breath, orthopnea, and weight gain.

EKG: Atrial fibrillation with left bundle branch block (LBBB)

Labs/Studies:

- Troponin T: 0.10 ng/mL (normal range: 0-0.04 ng/mL)

- Prothrombin time (PTT): 38.5 seconds (normal range: 25-35 seconds)

- International normalized ratio (INR): 2.0 (normal range: 0.8-1.2)

- Lactate: 1.6 mmol/L (normal range: 0.5-2.2 mmol/L)

- Alanine aminotransferase (ALT): 46 U/L (normal range: 7-56 U/L)

- Aspartate aminotransferase (AST): 41 