# TorchXrayVision

## Installation

In [None]:
%load_ext autoreload
%autoreload 2
!pip install numpy matplotlib torch torchvision torchxrayvision 

In [None]:
import numpy as np
import skimage
import torch
import torchvision
import matplotlib.pyplot as plt
import torchxrayvision as xrv

In [None]:
model = xrv.baseline_models.chestx_det.PSPNet()

In [None]:
model

## Single Image Test

In [None]:
import skimage.io
import torch
import torchxrayvision as xrv
import torchvision.transforms as transforms

# Load the grayscale image
img = skimage.io.imread("path/to/image")

# Normalize the image to [-1024, 1024] range
img = xrv.datasets.normalize(img, 255)

# Since the image is already grayscale, simply add the channel dimension
img = img[None, ...]

# Apply the transformations
transform = transforms.Compose([
    xrv.datasets.XRayCenterCrop(),
    xrv.datasets.XRayResizer(512),
])

img = transform(img)
img = torch.from_numpy(img)

print(img.shape)  # Verify the shape of the tensor


In [None]:
with torch.no_grad():
    pred = model(img)

In [None]:
plt.figure(figsize = (26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')
for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')
plt.tight_layout()

In [None]:
pred = 1 / (1 + np.exp(-pred))  # sigmoid
pred[pred < 0.5] = 0
pred[pred > 0.5] = 1

In [None]:
plt.figure(figsize = (26,5))
plt.subplot(1, len(model.targets) + 1, 1)
plt.imshow(img[0], cmap='gray')
for i in range(len(model.targets)):
    plt.subplot(1, len(model.targets) + 1, i+2)
    plt.imshow(pred[0, i])
    plt.title(model.targets[i])
    plt.axis('off')
plt.tight_layout()

## Full Chest-XRay Datset

The following code generates a dataset with images + segmentation masks for the Chest Xray Dataset. The current label mapping is ideal for visually representing each class however, it is not ideal for training SPADEGAN. Either manually change label mapping to be 0-13 or use code after the next cell.

In [None]:
import os
import numpy as np
import skimage.io
import torch
import torchxrayvision as xrv
import torchvision.transforms as transforms
from skimage import img_as_ubyte  # Import this to convert images to 8-bit
from skimage.exposure import rescale_intensity  # Import to rescale intensity

# Set the dataset path and output directory
data_dir = "dataset"  # Replace with your dataset path
output_image_dir = "output/path/images"  # Directory to save the original images
output_mask_dir = "output/path/masks"  # Directory to save the combined masks

# Create the output directories if they don't exist
os.makedirs(output_image_dir, exist_ok=True)
os.makedirs(output_mask_dir, exist_ok=True)

# Load the pre-trained PSPNet model from TorchXRayVision
model = xrv.baseline_models.chestx_det.PSPNet()
model.eval()  # Set the model to evaluation mode

# Define the necessary transforms
transform = transforms.Compose([
    xrv.datasets.XRayCenterCrop(),
    xrv.datasets.XRayResizer(512),
])

# Define label mapping for the combined mask
labels = {
    "Left Clavicle": 20,
    "Right Clavicle": 40,
    "Left Scapula": 60,
    "Right Scapula": 80,
    "Left Lung": 100,
    "Right Lung": 120,
    "Left Hilus Pulmonis": 140,
    "Right Hilus Pulmonis": 160,
    "Heart": 180,
    "Aorta": 200,
    "Facies Diaphragmatica": 220,
    "Mediastinum": 240,
    "Weasand": 255,
    "Spine": 30  # Ensure distinct from others
}

# Process images from all subdirectories
for root, dirs, files in os.walk(data_dir):
    for file in files:
        if file.endswith((".png", ".jpg", ".jpeg")):
            # Load and preprocess the image
            img_path = os.path.join(root, file)
            img = skimage.io.imread(img_path)
            img = xrv.datasets.normalize(img, 255)  # Normalize to [-1024, 1024]

            if len(img.shape) == 2:  # If the image is grayscale (H, W)
                img = img[None, ...]  # Add a channel dimension (1, H, W)

            img = transform(img)
            img_tensor = torch.from_numpy(img).unsqueeze(0)  # Add batch dimension

            # Generate segmentation masks for each anatomical structure
            with torch.no_grad():
                pred = model(img_tensor)

            # Apply sigmoid and thresholding
            pred = 1 / (1 + np.exp(-pred))  # Sigmoid function
            pred[pred < 0.5] = 0
            pred[pred > 0.5] = 1

            # Combine masks into one with distinct grayscale values
            combined_mask = np.zeros(pred.shape[2:], dtype=np.uint8)
            for i, label in enumerate(labels):
                combined_mask[pred[0, i] > 0.5] = labels[label]

            # Rescale the original image to [0, 1] before converting to 8-bit
            img_rescaled = rescale_intensity(img.squeeze(), in_range=(-1024, 1024), out_range=(0, 1))
            img_8bit = img_as_ubyte(img_rescaled)
            combined_mask_8bit = img_as_ubyte(combined_mask)

            # Create subdirectory in output directory corresponding to original image's subdirectory
            rel_path = os.path.relpath(root, data_dir)
            output_image_subdir = os.path.join(output_image_dir, rel_path)
            output_mask_subdir = os.path.join(output_mask_dir, rel_path)
            os.makedirs(output_image_subdir, exist_ok=True)
            os.makedirs(output_mask_subdir, exist_ok=True)

            # Save the original image and the combined mask
            output_image_path = os.path.join(output_image_subdir, file)
            output_mask_path = os.path.join(output_mask_subdir, f"mask_{os.path.splitext(file)[0]}.png")

            skimage.io.imsave(output_image_path, img_8bit)
            skimage.io.imsave(output_mask_path, combined_mask_8bit)

            print(f"Saved image: {output_image_path}")
            print(f"Saved combined segmentation mask: {output_mask_path}")


The following code remaps the labels to be suitable for training SPADEGAN.

In [None]:
import os
import numpy as np
from skimage.io import imread, imsave

# Define the remapping dictionary
remap_dict = {
    0: 0,
    20: 1,
    30: 2,
    40: 3,
    60: 4,
    80: 5,
    100: 6,
    120: 7,
    140: 8,
    160: 9,
    180: 10,
    220: 11,
    240: 12,
    255: 13
}

# Function to remap mask values
def remap_mask(mask, remap_dict):
    remapped_mask = np.copy(mask)
    for old_value, new_value in remap_dict.items():
        remapped_mask[mask == old_value] = new_value
    return remapped_mask

# Directory containing the masks
mask_dir = "path/to/masks"
output_mask_dir = "path/to/masks_remapped"
os.makedirs(output_mask_dir, exist_ok=True)

# Process each mask
for filename in os.listdir(mask_dir):
    if filename.endswith(".png"):  # Assuming masks are in PNG format
        mask_path = os.path.join(mask_dir, filename)
        mask = imread(mask_path)

        # Remap the mask values
        remapped_mask = remap_mask(mask, remap_dict)

        # Save the remapped mask
        output_path = os.path.join(output_mask_dir, filename)
        imsave(output_path, remapped_mask)
        print(f"Saved remapped mask: {output_path}")
