In [21]:
import numpy as np
import pandas as pd
from pathlib import Path
import ast

import re
import matplotlib.pyplot as plt
from skimage import io

import matplotlib.patches as patches

import cv2
from skimage.transform import resize
from skimage.util import img_as_ubyte
from skimage.io import imread

EMBED_ROOT = Path("/vol/biomedic3/data/EMBED")
VINDR_ROOT = Path("/vol/biomedic3/data/VinDR-Mammo")

In [None]:
# Need instance number for our specific casewise aggregation methodologies
dicom = pd.read_csv(
    EMBED_ROOT / "tables/EMBED_OpenData_metadata.csv", low_memory=False
)

vindr_findings = pd.read_csv(
    VINDR_ROOT / "finding_annotations.csv", low_memory=False
)

In [None]:
dicom["image_path"] = (
    dicom["empi_anon"].astype("str")
    + "/"
    + dicom["anon_dicom_path"].str.split("/").str[-1].str.split(".dcm").str[0]
    + ".png"
)

In [None]:
# XCCL shouldn't be converted to CC so manually editing it
dicom.loc[
    (dicom["SeriesDescription"] == "RXCCL") | (dicom["SeriesDescription"] == "LXCCL"),
    "ViewPosition",
] = "XCCL"

# Getting all rows with "ViewPosition" == Nan (but for which SeriesDescription is also not nan, as these are the ones subject to the data entry error)
view_nan = dicom.loc[
    (dicom.ViewPosition.isna()) & (dicom.SeriesDescription.isna() == False)
]

# Drop these rows from
dicom_no_nans = dicom[~dicom.index.isin(view_nan.index)]

view_nan["ViewPosition"] = view_nan["SeriesDescription"].apply(
    lambda x: "CC" if "CC" in x else ("MLO" if "MLO" in x else None)
)

dicom = pd.concat([dicom_no_nans, view_nan], axis=0, ignore_index=True)

In [None]:
print(len(dicom))
# Remove any duplicated images
dicom = dicom.drop_duplicates(subset="anon_dicom_path")
# Remove spot compressed and magnified images
dicom = dicom[dicom.spot_mag.isna()]
# Remove invalid views
dicom = dicom[dicom.ViewPosition.isin(["CC", "MLO"])]
# Remove images from male clients
dicom = dicom[dicom.PatientSex == "F"]
print(len(dicom))

In [None]:
# Remove any unnecessary fields from the DICOM imagewise dataframe (this may need to be updated in the future if other fields are deemed relevant)
dicom = dicom[
    [
        "empi_anon",
        "acc_anon",
        "image_path",
        "FinalImageType",
        "ImageLateralityFinal",
        "ViewPosition",
        "Manufacturer",
        "ManufacturerModelName",
         'ROI_coords',
        'num_roi',
        'PatientOrientation',
        "Rows",
        "Columns"
    ]
]

vindr_findings = vindr_findings[
    [
        'study_id',
        'image_id',
        'height',
        'width',
        'breast_birads',
        'breast_density',
        'finding_categories',
        'finding_birads',
        'xmin',
        'ymin',
        'xmax',
        'ymax'
    ]
]

In [None]:
# Combine 'xmin', 'ymin', 'xmax', 'ymax' 
vindr_findings['bbox'] = vindr_findings.apply(lambda row: [row['ymin'], row['xmin'], row['ymax'], row['xmax']], axis=1)

# Drop the individual columns 'xmin', 'ymin', 'xmax', 'ymax'
vindr_findings.drop(['xmin', 'ymin', 'xmax', 'ymax'], axis=1, inplace=True)

# Remove findings with a birads recorded as None
filtered_vindr = vindr_findings[(vindr_findings['finding_birads'].notna()) | (vindr_findings['finding_categories'].apply(lambda x: x == "['No Finding']"))]

In [None]:
# Find duplicate rows
duplicate_rows = filtered_vindr[filtered_vindr.duplicated(subset=['study_id', 'image_id'], keep=False)]

# Function to aggregate the relevant columns
def aggregate_columns(group):
    agg_dict = {}
    for col in ['finding_categories', 'finding_birads']:
        agg_dict[col] = group[col].apply(lambda x: [x] if not isinstance(x, list) else x).sum()
    # For 'bbox', ensure it is a list of lists
    agg_dict['bbox'] = group['bbox'].apply(lambda x: [x] if not isinstance(x, list) else x).tolist()
    # Include all other columns, taking the first value 
    for col in group.columns:
        if col not in agg_dict and col not in ['study_id', 'image_id']:
            agg_dict[col] = group[col].iloc[0]
    agg_dict['image_path'] = group['study_id'].iloc[0] + '/' + group['image_id'].iloc[0] + '.png'
    return pd.Series(agg_dict)

# Group duplicate rows by 'study_id' and 'image_id' and apply the aggregation function
collapsed_duplicates = duplicate_rows.groupby(['study_id', 'image_id']).apply(aggregate_columns).reset_index(drop=True)

# Find non-duplicate rows
non_duplicate_rows = filtered_vindr[~filtered_vindr.duplicated(subset=['study_id', 'image_id'], keep=False)].copy()

# Add 'image_path' to non-duplicate rows 
non_duplicate_rows.loc[:, 'image_path'] = non_duplicate_rows.apply(lambda row: row['study_id'] + '/' + row['image_id'] + '.png', axis=1)

# Remove 'study_id' and 'image_id' columns 
non_duplicate_rows = non_duplicate_rows.drop(columns=['study_id', 'image_id'])

# Combine collapsed duplicates with non-duplicate rows
vindr_final = pd.concat([collapsed_duplicates, non_duplicate_rows], ignore_index=True)

In [None]:
# Conversion dictionary to standardised naming of various fields in clincial metadata

# Human reader BIRADS density assessment
dens_conversion = {1.0: "A", 2.0: "B", 3.0: "C", 4.0: "D"}
# Load in the clinical metadata
mag = pd.read_csv(EMBED_ROOT / "tables/EMBED_OpenData_clinical.csv", low_memory=False)
print(len(mag))
# Remove cases from cases a valid BIRADS density assessment
mag = mag[mag.tissueden.isin([1.0, 2.0, 3.0, 4.0])]
mag.replace({"tissueden": dens_conversion}, inplace=True)
# Keep important study metadata tags to join up with final aggregated dataframe at end of script
mag = mag[["empi_anon", "tissueden", "study_date_anon", "acc_anon"]].drop_duplicates(
    subset="acc_anon"
)
print(len(mag))
# Convert to pandas datetime object
mag["study_date_anon"] = pd.to_datetime(mag["study_date_anon"], errors="coerce")


In [None]:
# Only consider studies which have a valid link between the DICOM and clinical metadata
print(len(dicom))
df = mag.merge(dicom, on="acc_anon")
print(len(df))

In [None]:
vindr_final['finding_categories']

In [None]:
def scale_and_flip_bounding_box(orig_coords, orig_height, orig_width, new_height, new_width, flip_horizontal = False, flip_vertical = False):
    """
    Transform bounding box coords to fit on the rescaled DICOM image. 

    Bounding box outputs are returned in form [y1, x1, y2, x2], where [y1, x1] is at the top left
    (ie y1 < y2, and x1 < x2) of image.

    """
    height_decrease = new_height / orig_height
    width_decrease = new_width / orig_width

    scale_factor = min(height_decrease, width_decrease)

    # Rescale the box to be relative to full size image.
    coords = (orig_coords * scale_factor).astype("int").tolist()
    y1, x1, y2, x2 = coords
    # Reflect bbox co-ords based on horizontal or vertical flipping from original patient orientation
    # (Remember indexing starts from 0, so subtract 1 from geometric lengths).
    # Assume output transformation co-ords identical to original, and apply each separately
    y1_new, x1_new, y2_new, x2_new = coords
    # Single reflection will yield (y1 > y2) & (x1 > x2), scale by bbox width to get top-left and bottom-right coords
    bbox_width = abs(x2 - x1)
    if flip_horizontal:
        x1_new = new_width - 1 - x1 - bbox_width
        x2_new = new_width - 1 - x2 + bbox_width
    if flip_vertical:
        y1_new = new_height - 1 - y1 - bbox_width
        y2_new = new_height - 1 - y2 + bbox_width

    return y1_new, x1_new, y2_new, x2_new

def convert_str_bbox_to_numpy(coords, embed=True):
    if embed:
        return np.array(ast.literal_eval(coords.replace('(', '[').replace(')', ']').replace(' ','')))
    else:
        return np.array(ast.literal_eval(coords.replace(' ','')))

In [None]:
new_height = 1024
new_width = 768


# Take a small subset
head = df.loc[df.num_roi > 0].head(30)

f, ax = plt.subplots(5,6,figsize=(20,20))
ax = ax.ravel()
# Iterate over several rows
for p, (_, row_df) in enumerate(head.iterrows()):
    
    # Get original height and width from the csv
    orig_height = row_df["Rows"]
    orig_width = row_df["Columns"]
    
    # Convert the string bbox to an array
    bboxes = convert_str_bbox_to_numpy(row_df['ROI_coords'])
    img_path = row_df.image_path

    # Plot image
    ax[p].imshow(io.imread('/vol/biomedic3/data/EMBED/images/png/1024x768/' + img_path), cmap='grey')
    
    # Iterate over boxes and plot over image
    for i in range(bboxes.shape[0]):
        orig_bbox_coordinates = bboxes[i]
        
        # Extract processed bbox coords, format is [y1, x1, y2, x2]
        bbox_coords = scale_and_flip_bounding_box(orig_bbox_coordinates, orig_height, orig_width, new_height, new_width)
        
        # Convert to bbox format compatible with matplotlib
        y0, x0 = bbox_coords[0], bbox_coords[1]
        width = bbox_coords[2] - bbox_coords[0]
        h = bbox_coords[3] - bbox_coords[1]
        
        # Plot
        rect = patches.Rectangle((x0, y0), width, h, linewidth=1, edgecolor='r', facecolor='none')
        ax[p].add_patch(rect)
        ax[p].axis('off')

In [None]:
new_height = 1024
new_width = 768


# Take a small subset
head = vindr_final.loc[vindr_final.finding_categories != "['No Finding']"].head(30)

f, ax = plt.subplots(5,6,figsize=(20,20))
ax = ax.ravel()
# Iterate over several rows
for p, (_, row_df) in enumerate(head.iterrows()):
    
    # Get original height and width from the csv
    orig_height = row_df["height"]
    orig_width = row_df["width"]
    
    # Convert the string bbox to an array
    bboxes = convert_str_bbox_to_numpy(str(row_df['bbox']), embed=False)
    img_path = row_df.image_path

    # Rescale image
    image = io.imread('/vol/biomedic3/data/VinDR-Mammo/pngs/' + img_path)
    image = resize(image, (1024, 768), preserve_range=True)

    # Plot image
    ax[p].imshow(image, cmap='grey') # ax[p].imshow(io.imread('/vol/biomedic3/data/VinDR-Mammo/pngs/' + img_path), cmap='grey')
    
    # Iterate over boxes and plot over image
    for i in range(bboxes.shape[0]):
        orig_bbox_coordinates = bboxes[i]
        
        # Extract processed bbox coords, format is [y1, x1, y2, x2]
        bbox_coords = scale_and_flip_bounding_box(orig_bbox_coordinates, orig_height, orig_width, new_height, new_width)
        
        # Convert to bbox format compatible with matplotlib
        y0, x0 = bbox_coords[0], bbox_coords[1]
        width = bbox_coords[2] - bbox_coords[0]
        h = bbox_coords[3] - bbox_coords[1]
        
        # Plot
        rect = patches.Rectangle((x0, y0), width, h, linewidth=1, edgecolor='r', facecolor='none')
        ax[p].add_patch(rect)
        ax[p].axis('off')