# Create Embeddings of Summaries with Different Filter Criteria

In [1]:
import numpy as np
from sklearn.manifold import TSNE
from transformers import BertTokenizer, BertForTokenClassification
from datasets import load_dataset
import plotly.express as px
import re
from collections import Counter
import torch
import pickle
from rouge_score import rouge_scorer

In [2]:
# Load device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [3]:
# Incorporate test predictions so use test sets
mimic4_unfiltered_path = '/home/s_hegs02/mimic-iv-note-di/dataset/embeddings/embeddings_all_unprocessed_services_test.json'
mimic4_filtered_path = '/home/s_hegs02/mimic-iv-note-di/dataset/embeddings/embeddings_all_services_test.json'

mimic4_unfiltered = load_dataset('json', data_files=mimic4_unfiltered_path)['train']
mimic4_filtered = load_dataset('json', data_files=mimic4_filtered_path)['train']

print(f"Loaded {len(mimic4_unfiltered)} MIMIC-IV")
print(f"Loaded {len(mimic4_filtered)} MIMIC-IV (preprocessed)")

# Only select 10k examples
mimic4_unfiltered = mimic4_unfiltered.select(range(10000))
mimic4_filtered = mimic4_filtered.select(range(10000))

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loaded 10000 MIMIC-IV
Loaded 10000 MIMIC-IV (preprocessed)


In [4]:
# Load HF models
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased').to(device)

Downloading tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Get token distributions
def get_token_distributions(dataset):
    num_tokens = [len(tokenizer.tokenize(ex)) for ex in dataset['summary']]
    return num_tokens
    
# Print ratio of summaries longer than 512 tokens - cannot use the full summary for those
print(f"MIMIC-IV: {sum([1 for n in get_token_distributions(mimic4_unfiltered) if n > 512]) / len(mimic4_unfiltered)}")
print(f"MIMIC-IV (preprocessed): {sum([1 for n in get_token_distributions(mimic4_filtered) if n > 512]) / len(mimic4_filtered)}")

MIMIC-IV: 0.1119
MIMIC-IV (preprocessed): 0.0003


In [6]:
# Get cls embedding tokens for all summaries
def get_cls_embedding(batch):
    inputs = tokenizer(batch['summary'], padding=True, truncation=True, return_tensors='pt').to(device)
    outputs = model(**inputs, output_hidden_states=True)
    last_hidden_states = outputs.hidden_states[-1]
    cls_embedding = last_hidden_states[:,0,:]
    return {'cls_embedding': cls_embedding}

In [7]:
mimic4_unfiltered = mimic4_unfiltered.map(get_cls_embedding, batched=True, batch_size=32)  # Adapt batch_size to your GPU memory
mimic4_filtered = mimic4_filtered.map(get_cls_embedding, batched=True, batch_size=32)  # Adapt batch_size to your GPU memory

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [8]:
# Add tsne embeddings to dataset
def create_tsne_embeddings(dataset):
    cls_embeddings = np.array([ex for ex in dataset['cls_embedding']])
    embeddings = TSNE(n_components=2, random_state=1).fit_transform(cls_embeddings)
    return embeddings

In [9]:
mimic4_unfiltered_tsne = create_tsne_embeddings(mimic4_unfiltered)
mimic4_filtered_tsne = create_tsne_embeddings(mimic4_filtered)

In [10]:
# Normalize tsne embeddings
def normalize_tsne_embeddings(embeddings):
    return (embeddings - embeddings.min()) / (embeddings.max() - embeddings.min())

mimic4_unfiltered_tsne = normalize_tsne_embeddings(mimic4_unfiltered_tsne)
mimic4_filtered_tsne = normalize_tsne_embeddings(mimic4_filtered_tsne)

# Convert to lists to change ordering easily
mimic4_unfiltered_tsne = mimic4_unfiltered_tsne.tolist()
mimic4_filtered_tsne = mimic4_filtered_tsne.tolist()

In [11]:
# Extract sumamries from datasets
def extract_summaries(dataset):
    return [ex for ex in dataset['summary']]

mimic4_unfiltered_summaries = extract_summaries(mimic4_unfiltered)
mimic4_filtered_summaries = extract_summaries(mimic4_filtered)

In [84]:
# Add label classes for scatter plot to color dots
# Different options:
# 1. No labels (dummies)
# 2. Medical services (given in MIMIC notes)
# 3. Prediction results

In [85]:
# 1. No labels (dummies)
mimic4_unfiltered_classes = [1 for x in mimic4_unfiltered]
mimic4_filtered_classes = [1 for x in mimic4_filtered]

In [86]:
# 2. Medical services (given in MIMIC notes)
# We assume the medical services where (wrongly) stored in the text field of the summary
# Counts: ('MEDICINE', 195679), ('SURGERY', 46529), ('ORTHOPAEDICS', 18302), ('NEUROLOGY', 18056), ('CARDIOTHORACIC', 13786), ('NEUROSURGERY', 10638), ('OBSTETRICS/GYNECOLOGY', 9522), ('PSYCHIATRY', 7560), ('UROLOGY', 4947), ('PLASTIC', 3363), ('PODIATRY', 1315), ('OTOLARYNGOLOGY', 851), ('UNKNOWN', 653), ('OME', 362), ('EMERGENCY', 136), ('ANESTHESIOLOGY', 43), ('BIOLOGIC', 24), ('DENTAL', 14), ('RADIATION', 5), ('OPHTHALMOLOGY', 4), ('RADIOLOGY', 4)]
# Map medical services to 10 most common classes
map_services = {'MEDICINE': 'medicine', 'SURGERY': 'surgery', 'ORTHOPAEDICS': 'orthopaedics', 'NEUROLOGY': 'neurology', 'CARDIOTHORACIC': 'cardiothoracic', 'NEUROSURGERY': 'neurosurgery', 'OBSTETRICS/GYNECOLOGY': 'obs/gyn', 'PSYCHIATRY': 'psychiatry', 'UROLOGY': 'urology', 'PLASTIC': 'other', 'PODIATRY': 'other', 'OTOLARYNGOLOGY': 'other', 'UNKNOWN': 'other', 'OME': 'other', 'EMERGENCY': 'other', 'ANESTHESIOLOGY': 'other', 'BIOLOGIC': 'other', 'DENTAL': 'other', 'RADIATION': 'other', 'OPHTHALMOLOGY': 'other', 'RADIOLOGY': 'other'}
services = ['medicine', 'surgery', 'orthopaedics', 'neurology', 'cardiothoracic', 'neurosurgery', 'obs/gyn', 'psychiatry', 'urology', 'other']

mimic4_unfiltered_classes = [map_services[ex['text']] for ex in mimic4_unfiltered]
mimic4_filtered_classes = [map_services[ex['text']] for ex in mimic4_filtered]

# Move one value for each class in services at first positions of mimic4_unfiltered_classes and mimic4_filtered_classes
# This is done to ensure that the legend of the plot shows the correct colors for each class
for service in reversed(services):
    for i, c in enumerate(mimic4_unfiltered_classes):
        if c == service:
            mimic4_unfiltered_classes.insert(0, mimic4_unfiltered_classes.pop(i))
            mimic4_unfiltered_summaries.insert(0, mimic4_unfiltered_summaries.pop(i))
            mimic4_unfiltered_tsne.insert(0, mimic4_unfiltered_tsne.pop(i))
            break
    for i, c in enumerate(mimic4_filtered_classes):
        if c == service:
            mimic4_filtered_classes.insert(0, mimic4_filtered_classes.pop(i))
            mimic4_filtered_summaries.insert(0, mimic4_filtered_summaries.pop(i))
            mimic4_filtered_tsne.insert(0, mimic4_filtered_tsne.pop(i))
            break

In [12]:
# 3. Prediction results

# Prediction metrics for LED-large model
# * Trained on 100000 steps
# * max_source_length 16384 and max_target_length 512 

# Use custom rouge function to obtain rouge 3/4 which are not available in huggingface
def get_rouge_score(gold, pred):
    rouge_scores = ['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL']
    scorer = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True)
    scores = scorer.score(gold, pred)
    return {k: scores[k].fmeasure * 100 for k in rouge_scores}

# Metrics obtained with eval_summarization.py on test predictions
# TODO This is the filtered one
mimic4_unfiltered_pred_path = '/home/s_hegs02/mimic-iv-note-di/models/led-large-16384/mimic-iv-note-di-embeddings/processed-200k-steps/test_generations.pkl'
mimic4_filtered_pred_path = '/home/s_hegs02/mimic-iv-note-di/models/led-large-16384/mimic-iv-note-di-embeddings/unprocessed-200k-steps/test_generations.pkl'
mimic4_unfiltered_pred = pickle.load(open(mimic4_unfiltered_pred_path, 'rb'))
mimic4_filtered_pred = pickle.load(open(mimic4_filtered_pred_path, 'rb'))

mimic4_unfiltered_metrics = [get_rouge_score(gold, pred) for gold, pred in zip(mimic4_unfiltered_summaries, mimic4_unfiltered_pred)]
mimic4_filtered_metrics = [get_rouge_score(gold, pred) for gold, pred in zip(mimic4_filtered_summaries, mimic4_filtered_pred)]

# mimic4_unfiltered_metrics_path = '/home/s/s_hegs02/scratch/mimic-iv-avs_reproduced/models/output-train-led-base/test_generations_metrics.pkl'
# mimic4_filtered_metrics_path = '/home/s/s_hegs02/scratch/mimic-iv-avs/models/output-train-led-base/test_generations_metrics.pkl'
# mimic4_unfiltered_metrics = pd.read_pickle(mimic4_unfiltered_metrics_path)
# mimic4_filtered_metrics = pd.read_pickle(mimic4_filtered_metrics_path)
# 
metric = 'rouge1'
mimic4_unfiltered_classes = [int(m[metric]) for m in mimic4_unfiltered_metrics]
mimic4_filtered_classes = [int(m[metric]) for m in mimic4_filtered_metrics]

In [13]:
# Print counts of all avlues in mimic4_unfiltered_classes and mimic4_filtered_classes
print(Counter(mimic4_unfiltered_classes))
print(Counter(mimic4_filtered_classes))

Counter({28: 458, 29: 448, 30: 430, 26: 429, 27: 427, 25: 425, 24: 408, 22: 380, 32: 380, 23: 378, 20: 348, 31: 339, 18: 338, 21: 338, 19: 328, 16: 299, 17: 297, 33: 297, 15: 287, 14: 281, 34: 270, 0: 249, 13: 243, 35: 228, 12: 224, 11: 190, 36: 168, 10: 149, 37: 137, 9: 119, 38: 104, 2: 82, 8: 70, 40: 66, 39: 58, 7: 55, 6: 44, 5: 41, 41: 38, 4: 38, 42: 28, 3: 23, 43: 18, 1: 14, 44: 12, 45: 8, 48: 3, 47: 3, 46: 2, 50: 1})
Counter({27: 474, 25: 466, 29: 455, 26: 453, 28: 423, 20: 422, 22: 414, 23: 413, 30: 408, 31: 404, 24: 389, 21: 378, 18: 375, 32: 372, 19: 366, 33: 336, 17: 334, 34: 325, 16: 315, 0: 278, 15: 254, 35: 230, 14: 216, 36: 194, 37: 179, 13: 170, 12: 147, 38: 117, 39: 81, 11: 77, 10: 72, 40: 69, 2: 57, 9: 56, 1: 41, 41: 37, 42: 36, 8: 31, 43: 24, 7: 20, 4: 16, 44: 15, 5: 15, 3: 14, 6: 14, 45: 8, 46: 4, 47: 3, 49: 1, 50: 1, 48: 1})


In [15]:
# Create scatter plot
def create_labeled_hover_sp(summaries, embeddings, classes):
    hover_labels = [re.sub("(.{64})", "\\1<br>", ex, 0, re.DOTALL) for ex in summaries]
    
    # 2. Medical services: Discrete color scale
    # fig = px.scatter(x=np.array(embeddings)[:,0], y=np.array(embeddings)[:,1], hover_name=hover_labels, color=classes, width=900, height=600)
    # fig.update_layout(legend_title_text='Medical services')
    
    # 3. Prediction results: Continuous color scale
    fig = px.scatter(x=np.array(embeddings)[:,0], y=np.array(embeddings)[:,1], hover_name=hover_labels, color=classes, width=900, height=600, color_continuous_scale='viridis')
    # Name continuous color scale
    fig.update_layout(coloraxis_colorbar=dict(title='ROUGE-1'))
    fig.update_layout(coloraxis=dict(cmin=0, cmax=50))
    
    # Set all margins to zero
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    
    # Change legend and axis font size to 20
    fig.update_layout(legend=dict(font=dict(size=20)), font=dict(size=20))
    
    # Axis ticks at 0, 0.2, 0.4, 0.6, 0.8, 1
    fig.update_xaxes(tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1])
    fig.update_yaxes(tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1])
    return fig

# create_labeled_hover_sp(mimic4_unfiltered_summaries, mimic4_unfiltered_tsne, mimic4_unfiltered_classes).show() 
# create_labeled_hover_sp(mimic4_filtered_summaries, mimic4_filtered_tsne, mimic4_filtered_classes).show()

# Store as pdf
# create_labeled_hover_sp(mimic4_unfiltered_summaries, mimic4_unfiltered_tsne, mimic4_unfiltered_classes).write_image("/home/s_hegs02/patient_summaries_with_llms/mimic4_emb_unfiltered_services.pdf")
# create_labeled_hover_sp(mimic4_filtered_summaries, mimic4_filtered_tsne, mimic4_filtered_classes).write_image("/home/s_hegs02/patient_summaries_with_llms/mimic4_emb_filtered_services.pdf")

# create_labeled_hover_sp(mimic4_unfiltered_summaries, mimic4_unfiltered_tsne, mimic4_unfiltered_classes).write_image("/home/s_hegs02/patient_summaries_with_llms/mimic4_emb_unfiltered_embeddings.pdf")
# create_labeled_hover_sp(mimic4_filtered_summaries, mimic4_filtered_tsne, mimic4_filtered_classes).write_image("/home/s_hegs02/patient_summaries_with_llms/mimic4_emb_filtered_embeddings.pdf")