# Notebook for Generating CXR Embeddings from ELIXR Models

This notebook contains the deidentified code we used to generate embeddings using the [ELIXR](https://arxiv.org/pdf/2308.01317) [1] models from chest radiographs.

Since the unique identifiers (UIDs) and the way they map to patients and DICOM files are institution dependent and we risk identifying subjects from our own clinical center, we provide only the code necessary to save the embeddings.

Please find details regarding the models and weights from the [HuggingFace model card](https://huggingface.co/google/cxr-foundation) and the original authors' [GitHub repository](https://github.com/Google-Health/cxr-foundation).

[1] Xu, Shawn, et al. "Elixr: Towards a general purpose x-ray artificial intelligence system through alignment of large language models and radiology vision encoders." arXiv preprint arXiv:2308.01317 (2023).

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pydicom import dcmread

from huggingface_hub.utils import HfFolder
from huggingface_hub import notebook_login
import io
import png
import pickle
import tensorflow as tf
import tensorflow_text as tf_text

from skimage.io import imread
from skimage.transform import resize
from skimage.exposure import rescale_intensity, equalize_adapthist

import warnings
warnings.simplefilter(action = 'ignore', category = FutureWarning)
warnings.simplefilter(action = 'ignore', category = Warning)

In [None]:
# Code from ELIXR authors
# https://github.com/Google-Health/cxr-foundation/blob/master/notebooks/quick_start_with_hugging_face.ipynb
# Helper function for processing image data
def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
    """Creates a tf.train.Example from a NumPy array."""
    # Convert the image to float32 and shift the minimum value to zero
    image = image_array.astype(np.float32)
    image -= image.min()

    if image_array.dtype == np.uint8:
        # For uint8 images, no rescaling is needed
        pixel_array = image.astype(np.uint8)
        bitdepth = 8
    else:
        # For other data types, scale image to use the full 16-bit range
        max_val = image.max()
        if max_val > 0:
            image *= 65535 / max_val  # Scale to 16-bit range
        pixel_array = image.astype(np.uint16)
        bitdepth = 16

    # Ensure the array is 2-D (grayscale image)
    if pixel_array.ndim != 2:
        raise ValueError(f'Array must be 2-D. Actual dimensions: {pixel_array.ndim}')

    # Encode the array as a PNG image
    output = io.BytesIO()
    png.Writer(
        width=pixel_array.shape[1],
        height=pixel_array.shape[0],
        greyscale=True,
        bitdepth=bitdepth
    ).write(output, pixel_array.tolist())
    png_bytes = output.getvalue()

    # Create a tf.train.Example and assign the features
    example = tf.train.Example()
    features = example.features.feature
    features['image/encoded'].bytes_list.value.append(png_bytes)
    features['image/format'].bytes_list.value.append(b'png')

    return example

# Code adapted from from ELIXR authors
# https://github.com/Google-Health/cxr-foundation/blob/master/notebooks/quick_start_with_hugging_face.ipynb
def get_elixr_models(hf_local_dir) :

    # Download the model repository files
    from huggingface_hub import snapshot_download
    snapshot_download(repo_id="google/cxr-foundation",local_dir = hf_local_dir,
                    allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'])

    if 'elixrc_model' not in locals():
        elixrc_model = tf.saved_model.load(os.path.join(hf_local_dir, 'elixr-c-v2-pooled'))

    if 'qformer_model' not in locals():
        qformer_model = tf.saved_model.load(os.path.join(hf_local_dir, 'pax-elixr-b-text'))

    return elixrc_model, qformer_model

# Code adapted from from ELIXR authors
# https://github.com/Google-Health/cxr-foundation/blob/master/notebooks/quick_start_with_hugging_face.ipynb
def embed_cxr(img, elixrc_model, qformer_model) :

    serialized_img_tf_example = png_to_tfexample(np.array(img)).SerializeToString()

    elixrc_infer = elixrc_model.signatures['serving_default']
    elixrc_output = elixrc_infer(input_example=tf.constant([serialized_img_tf_example]))
    elixrc_embedding = elixrc_output['feature_maps_0'].numpy()

    # Step 2 - Invoke QFormer with Elixr-C embeddings
    # Initialize text inputs with zeros
    qformer_input = {
        'image_feature': elixrc_embedding.tolist(),
        'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
        'paddings':np.zeros((1, 1, 128), dtype=np.float32).tolist(),
    }

    qformer_output = qformer_model.signatures['serving_default'](**qformer_input)
    elixrb_embeddings = qformer_output['all_contrastive_img_emb'].numpy()

    return elixrc_embedding, elixrb_embeddings

def center_crop(img, out_dims = (224, 224)):

    y,x = img.shape
    y_start = y // 2 - (out_dims[0] // 2)    
    x_start = x // 2 - (out_dims[1] // 2)
    
    return_val = img[y_start:y_start+out_dims[1],x_start:x_start+out_dims[0]]

    assert return_val.shape == out_dims

    return return_val

def center_crop_to_short_edge(img) :

    if img.shape[0] == img.shape[1] :

        return img
    
    short_edge = min(img.shape)

    return center_crop(img, out_dims = (short_edge, short_edge))

def get_pixels_from_dcm(src, crop_to_short_edge = True, clahe_clip_limit = None, contrast_stretching = None, resize_dim = 1024) :

    img = dcmread(src).pixel_array
    
    if crop_to_short_edge :
        img = center_crop_to_short_edge(img)
    if clahe_clip_limit :
        img = equalize_adapthist(img, clip_limit = clahe_clip_limit)
    if isinstance(contrast_stretching, tuple)  :
        plow, phigh = np.percentile(img, (contrast_stretching[0], contrast_stretching[1]))
        img = rescale_intensity(img, in_range=(plow, phigh))
    if resize_dim :
        img = resize(img, (resize_dim, resize_dim))
    if not clahe_clip_limit :
        img = (img - img.min()) / (img.max() - img.min())

    return img

In [None]:
# TO-DO: Define paths for data loading and saving
COMBINED_TABLE_PATH = None
COMBINED_DATA_FOLDER_PATH = None
HF_LOCAL_DIR = None
EMBEDDINGS_FOLDER_PATH = None

# TO-DO: get list or array of file paths
path_list = None
assert os.path.isfile(path_list[0])

embeddings_by_setting_path = os.path.join(EMBEDDINGS_FOLDER_PATH, f'cxr_embeddings')
elixrc_cxr_embeddings_parent = os.path.join(embeddings_by_setting_path, 'elixrc_embeddings')
elixrb_cxr_embeddings_parent = os.path.join(embeddings_by_setting_path, 'elixrb_embeddings')

for path in [embeddings_by_setting_path, elixrb_cxr_embeddings_parent, elixrc_cxr_embeddings_parent] :
    if not os.path.isdir(path) :
        os.mkdir(path)
        print(f'{path} created.')

In [None]:
elixrc_model, qformer_model = get_elixr_models(hf_local_dir = HF_LOCAL_DIR)

In [None]:
elixrc_cxr_embeddings = []
elixrb_cxr_embeddings = []

for i,src in enumerate(path_list) :

    filename = f"{src.split('.dcm')[0].split('/')[-1]}.pkl"
    elixrc_cxr_embedding_path = os.path.join(elixrc_cxr_embeddings_parent, filename)
    elixrb_cxr_embedding_path = os.path.join(elixrb_cxr_embeddings_parent, filename)

    if not (os.path.isfile(elixrc_cxr_embedding_path) and os.path.isfile(elixrb_cxr_embedding_path)) :

        img = get_pixels_from_dcm(src, crop_to_short_edge = True, clahe_clip_limit = 0.2, resize_dim = 1284)
        current_elixrc_embedding, current_elixrb_embedding = embed_cxr(img, elixrc_model, qformer_model)

        with open(elixrc_cxr_embedding_path, 'wb') as handle:
            pickle.dump(current_elixrc_embedding, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open(elixrb_cxr_embedding_path, 'wb') as handle:
            pickle.dump(current_elixrb_embedding, handle, protocol=pickle.HIGHEST_PROTOCOL)

    if (i % 250 == 0) or (i == len(path_list) - 1):
        print(f'{i:>6} / {len(path_list)} complete.')