<img src="figures/hiti.png" alt="HITILab" width="150"/>

<!-- Author: Theo Dapamede, MD, PhD -->
<!-- Github: theodapamede -->

# CXR: Data Preprocessing

By going through this lecture and notebook, you should be able to:

1. Understand the basics of working with DICOM files
2. Open and display a DICOM image using Python
3. Perform standard DICOM image preprocessing techniques
4. Understand 2 different normalization techniques

# 0. Load Libraries and Prepare Environment

In [None]:
import os
import glob
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tqdm

# 1. Basics of working with DICOM Files

## 1.1. Opening a DICOM File

We will be using the Pydicom library. 

For more information regarding Pydicom:
[Pydicom Userguide](https://pydicom.github.io/pydicom/stable/old/pydicom_user_guide.html)

In [None]:
import pydicom

In [None]:
# Sample a dicom file for example purposes
dicom_file = "/fsx/embed/summer-school-24/Theo_session/dicoms/cxr_sample/1.2.826.0.1.3680043.8.498.12497993392562429345008417030462807206.dcm"

In [None]:
loaded_dicom = pydicom.dcmread(dicom_file)

In [None]:
# DICOM Tags --> in parentheses (xxxx, xxxx)
# Followed by DICOM Tag name; Value Representation; Value
loaded_dicom

In [None]:
# Get information using Tag --> treat as hexadeximal number by using 0x in front
print(loaded_dicom[0x0010, 0x0020])

In [None]:
# Getting Tag Keyword using .keyword
print(loaded_dicom[0x0010, 0x0020].keyword)

In [None]:
# Getting information using Tag Name
print(loaded_dicom["PatientID"])

In [None]:
# Getting Tag Name using .tag
print(loaded_dicom["PatientID"].tag)

In [None]:
# Getting Tag Name using .name
print(loaded_dicom["PatientID"].name)

In [None]:
# Getting Tag Value
print(loaded_dicom["PatientID"].value)

In [None]:
# Another way to get value
print(loaded_dicom.PatientID)

### ☢️ ***Exercise 1***

Print the following values:

1. **Exposure**

In [None]:
# Code your solution here

2. **Manufacturer**

In [None]:
# Code your solution here

3. **SOP Instance UID**

In [None]:
# Code your solution here

### ☢️ ***Exercise 2***
Find one example of a DICOM tag that contains non-alphanumeric characters (", ', -, ...) and get the value using it's keyword.

In [None]:
# Code your solution here

### 1.2. DICOM Image

In a DICOM file, the image is stored in the **Pixel Data [7fe0, 0010]** tag.

In [None]:
print(loaded_dicom[0x7FE0,0x0010])

In [None]:
print(loaded_dicom[0x7FE0,0x0010].keyword)

**PixelData** is often not immediately useful as data may be stored in a variety of different ways:

- The pixel values may be signed or unsigned integers, or floats
- There may be multiple image frames
- There may be multiple planes per frame (i.e. RGB) and the order of the pixels may be different
- The image data may be encoded using one of the available compression standards (1.2.840.10008.1.2.4.50 JPEG Baseline, 1.2.840.10008.1.2.5 RLE Lossless, etc). Encoded image data will also be encapsulated and each encapsulated image frame may be broken up into one or more fragments.

*Note: May cause slow response when running:*
```loaded_dicom.PixelData```

See image below for output example

![loaded_dicom.PixelData](figures/fig_dicom_pixel_data.png "loaded_dicom.PixelData")

Because of the complexity in interpreting the pixel data, pydicom provides an easy way to get it in a convenient form: **.pixel_array**

In [None]:
loaded_dicom.pixel_array

In [None]:
image = loaded_dicom.pixel_array

In [None]:
print(image.shape)  # (Height, Width)

In [None]:
image.min()

In [None]:
image.max()

### 1.3. View a DICOM Image

In [None]:
plt.figure(dpi=150)
plt.imshow(image, 'gray')
plt.show()

# 2. DICOM Image Preprocessing

## 2.1. Modality Specific Units

Modality-specific units are standardized measurement units used in different imaging modalities to quantify and interpret the image data.
For example, the modality-specific units for CT is Hounsfield Units (HU) (water = 0 HU).

In plain radiography, the units are in **optical density** for radiographic films and **pixel values** for digital radiography.

The range of pixel values in an image is determined by the **Bit Depth** of the imaging system (see Table).
If the minimum is 0, the maximum value is calculate as $2^B - 1$, where $B$ is the bit depth. For for example, a 12-bit system will have a maximum pixel value of $2^{12}-1=4095$.
What this means is that a 12-bit system will have 4095 shades of gray. The higher the bit depth, the higher the constrast resolution and dynamic range.


| System Bit Depth | Minimum Value | Maximum Value |
| ---------------- | ------------- | ------------- |
| 12-bit | 0 | 4,095 |
| 14-bit | 0 | 16,333 |
| 16-bit | 0 | 65,535 |

### ☢️ ***Excercise 3***

What is the bit depth of your loaded dicom image?

In [None]:
# Code your solution here

# 2.1. Presentation State

### 2.1.1. Applying Modality Transforms
The raw Pixel Data in a DICOM file may not be in the modality units. Therefore, we first need to apply a modality transformation to standardize the units.

There are 2 ways of transforming the values:
1. Using the Rescale Intercept and Rescale Slope
We apply the Rescale Slope and Intercept using the following equation:

$$\text{Output Value} = m \cdot \text{Stored Value} + b$$

where *m* is the Rescale Slope and *b* is the Rescale Intercept

2. Using the Modality LUT
This method uses a Look Up Table which will specify what a pixel value will be transformed into.

In a DICOM file, only one of the above method is available.

In [None]:
from pydicom.pixel_data_handlers import apply_modality_lut, apply_voi_lut

In [None]:
rescale_slope = loaded_dicom.RescaleSlope
print(rescale_slope)

In [None]:
rescale_intercept = loaded_dicom.RescaleIntercept
print(rescale_intercept)

### Try using manual `modality_transform` function

In [None]:
def modality_transform(img, dcm):
    return dcm.RescaleSlope * img + dcm.RescaleIntercept

In [None]:
output_image = modality_transform(image, loaded_dicom)

In [None]:
np.all(output_image==image)

### CT Scan Example

In [None]:
example_ct = glob.glob('/fsx/embed/summer-school-24/Theo_session/dicoms/ct_scan_sample/*.dcm')

In [None]:
other_shape = []
axial = {}
axial_lut = {}
axial_full = {}
for i, dicom_file in tqdm.tqdm(enumerate(example_ct)):
    dcm_ = pydicom.dcmread(dicom_file)
    # print(dcm_.InstanceNumber)
    # img_lut = apply_voi_lut(apply_modality_lut(dcm_.pixel_array, dcm_), dcm_)
    img_lut = modality_transform(dcm_.pixel_array, dcm_)
    img_full = apply_voi_lut(apply_modality_lut(dcm_.pixel_array, dcm_), dcm_)
    img_ = dcm_.pixel_array
    if img_.shape[0] == 512:
        # image_stack[dcm_.InstanceNumber] = img_
        # image_stack_lut[dcm_.InstanceNumber] = img_lut
        if dcm_.SeriesDescription == 'AX BRAIN THIN':
            axial[dcm_.InstanceNumber] = img_
            axial_lut[dcm_.InstanceNumber] = img_lut
            axial_full[dcm_.InstanceNumber] = img_full
    else:
        other_shape.append(img_)
        
axial_sorted = dict(sorted(axial.items()))
axial_stack = np.array(list(axial_sorted.values()))

axial_lut_sorted = dict(sorted(axial_lut.items()))
axial_stack_lut = np.array(list(axial_lut_sorted.values()))

axial_full_sorted = dict(sorted(axial_full.items()))
axial_stack_lut_full = np.array(list(axial_full_sorted.values()))

In [None]:
for i in range(0, axial_stack.shape[0], 100):
    
    print(f"Original Image Min Value = {axial_stack[i].min()}")
    print(f"Original Image Max Value = {axial_stack[i].max()}")

    print(f"Modality Transformed Min Value = {axial_stack_lut[i].min()}")
    print(f"Modality Transformed Max Value = {axial_stack_lut[i].max()}")

    fig, axs = plt.subplots(1, 2, dpi=150)
    axs[0].imshow(axial_stack[i], 'gray')
    axs[1].imshow(axial_stack_lut[i], 'gray')
    # axs[2].imshow(axial_stack_lut[i] - axial_stack[i], 'gray')
    axs[0].set_title("Original")
    axs[1].set_title("Modality Transformed")
    # axs[2].set_title("Modality Transformed - Original")
    for i, ax in enumerate(axs):
        ax.axis('off')
    plt.show()

### 2.1.2. Applying the VOI LUT

![VOI LUT](./figures/voi_lut.png)

In [None]:
loaded_dicom.WindowCenter

In [None]:
loaded_dicom.WindowWidth

In [None]:
loaded_dicom.VOILUTSequence[0]

In [None]:
for i in range(len(loaded_dicom.VOILUTSequence)):
    print(loaded_dicom.VOILUTSequence[i].LUTExplanation)

In [None]:
plt.plot(loaded_dicom.VOILUTSequence[0].LUTData)
plt.show()

### **Creating a general function `process_dicom_image`**

In [None]:
def process_dicom_image(dicom):
    return apply_voi_lut(apply_modality_lut(dicom.pixel_array, dicom), dicom)

In [None]:
for i in range(0, axial_stack.shape[0], 100):
    fig, axs = plt.subplots(1, 3, dpi=150)
    axs[0].imshow(axial_stack[i], 'gray')
    axs[1].imshow(axial_stack_lut[i], 'gray')
    axs[2].imshow(axial_stack_lut_full[i], 'gray')
    axs[0].set_title("Raw")
    axs[1].set_title("VOI LUT")
    axs[2].set_title("Modality LUT + VOI LUT")
    for i, ax in enumerate(axs):
        ax.axis('off')
    plt.show()

### **Apply on CXR Image**

In [None]:
fig, axs = plt.subplots(1, 2, dpi=300, sharey=True)
axs[0].imshow(image, 'gray')
axs[1].imshow(process_dicom_image(loaded_dicom), 'gray')

axs[0].set_title("Original Image")
axs[1].set_title("Processed Image")

plt.show()

In [None]:
loaded_dicom.BitsAllocated

In [None]:
loaded_dicom.BitsStored

### Experimenting with Histogram Equalization

In [None]:
img = np.array((image - image.min()) / (np.ptp(image)) * 255, dtype=np.uint8)
he_img = cv2.equalizeHist(img)

In [None]:
fig, axs = plt.subplots(1, 2, dpi=300, sharey=True)
axs[0].imshow(he_img, 'gray')
axs[1].imshow(process_dicom_image(loaded_dicom), 'gray')

axs[0].set_title("HE Image")
axs[1].set_title("VOI LUT Image")

plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].hist(he_img.flatten())
axs[1].hist(process_dicom_image(loaded_dicom).flatten())

axs[0].set_title("HE Image")
axs[1].set_title("VOI LUT Image")

plt.show()

# 3.1. Saving to PNG

In [None]:
cxr_image = process_dicom_image(loaded_dicom)

In [None]:
# Convert pixel array to PNG as a 16-bit greyscale
image_to_save = cxr_image.astype(np.double)

# Rescale grey scale between 0-65535
image_to_save = (np.maximum(image_to_save, 0) / image_to_save.max()) * 65535.0

# Convert to uint16
image_to_save = np.uint16(image_to_save)

output_png_path = f"./output/{loaded_dicom.SOPInstanceUID}.png"

if not os.path.exists("./output/"):
    os.makedirs("./output/")

image = Image.fromarray(image_to_save.astype(np.uint16))
image.save(output_png_path)

# 3.1. Common Normalization Techniques

1. Min-max normalization
2. Standardization

## 3.1.1 Min-Max Normalization

The most common Min-Max Normalization technique is transforming the pixel distribution between 0 and 1.

The steps are as follows:

1. Calculate the minimum and maximum pixel values of the image
2. Subtract the image with it's minimum value
3. Divide the results with the range of pixel values, i.e. Maximum minus Minimum (max - min)

In [None]:
img_max = cxr_image.max()
img_min = cxr_image.min()

In [None]:
normalized_img = (cxr_image - img_min) / (img_max - img_min)

In [None]:
normalized_img.min()

In [None]:
normalized_img.max()

In [None]:
print(f"Cropped Image: (min={cxr_image.min():.2f}, max={cxr_image.max():.2f})")
print(f"Normalized Image: (min={normalized_img.min():.2f}, max={normalized_img.max():.2f})")

fig, axs = plt.subplots(1, 2, dpi=300, constrained_layout=True, sharey=True)
axs[0].imshow(cxr_image, 'gray')
axs[1].imshow(normalized_img, 'gray')
axs[0].set_title("Processed Image")
axs[1].set_title("Normalized Image")
plt.show()

# 4.6.2. Standardization

The most common standardization technique is transforming the pixel distribution to a mean of zero and a standard deviation of 1 (or unit variance).

The steps are as follows:

1. Calculate the mean and standard deviation of the image
2. Subtract the image with it's mean
3. Divide the results with the standard deviation

In [None]:
img_mean = cxr_image.mean()
img_std = cxr_image.std()

In [None]:
standard_img = (cxr_image - img_mean) / img_std

In [None]:
standard_img.mean()

In [None]:
standard_img.std()

In [None]:
print(f"Processed Image: (mean={cxr_image.mean():.2f}, std={cxr_image.std():.2f})")
print(f"Standardized Image: (mean={standard_img.mean():.2f}, std={standard_img.std():.2f})")

fig, axs = plt.subplots(1, 2, dpi=300, constrained_layout=True, sharey=True)
axs[0].imshow(cxr_image, 'gray')
axs[1].imshow(standard_img, 'gray')
axs[0].set_title("Processed Image")
axs[1].set_title("Standardized Image")
plt.show()

# End