## Data Pipeline

### Steps Overview

Dataset Organization: Directory structure

```plaintext
data/raw/
├── Calc-Test_P_00038_LEFT_CC/
│   └── .../1-1.dcm         (full mammogram image)
├── Calc-Test_P_00038_LEFT_CC_1/
│   └── .../1-1.dcm, 1-2.dcm (ROI mask images)
```

1. Scan DICOM Files
- Recursively scan all .dcm files from data/raw/. Separate into full mammogram images and ROI masks.
- Input: data/raw/
- Output: List of file paths (full_mammo, roi_masks)

2. Extract Metadata from File Paths
- Parse abnormality type (Calc/Mass), patient ID, laterality (LEFT/RIGHT), and view (CC/MLO) from DICOM folder names.
- Input: file paths
- Output: structured dictionary per DICOM

3. Pair Images and Masks
- Match each full mammogram image to its corresponding ROI mask images based on naming convention.
- Input: lists from Step 1
- Output: paired image-mask metadata

4. Consolidate Clinical Metadata CSVs
- Read clinical metadata (calc_case_description_\*.csv, mass_case_description_\*.csv) from data/metadata/ and merge into a single DataFrame.
- Input: data/metadata/ CSV files
- Output: consolidated clinical metadata

5. Merge Image–Mask Metadata and Clinical Metadata
- Combine extracted folder metadata and consolidated clinical CSV metadata using keys like patient_id, view, laterality, and abnormality_id.
- Export the merged, fully consolidated metadata DataFrame to data/processed/cbis_ddsm_metadata_full.csv.
- Input: outputs of Step 3 and Step 4
- Output: cbis_ddsm_metadata_full.csv

6. Build TensorFlow Dataset
- Create a tf.data.Dataset that loads image-mask-label triples ready for deep learning model training. Includes preprocessing like resizing and normalization.
- Input: final metadata CSV
- Output: TensorFlow-ready dataset (train_ds, val_ds)

### Scan DICOM Files

- Recursively scan all .dcm files from `base_dir`.
- Separate into full mammogram images and ROI mask images.

- Args:
  - base_dir (str): Root directory containing DICOM folders (e.g., "data/raw/").

- Returns:
  - dict: {
    - "full_mammo": [list of full mammogram DICOM paths],
    - "roi_masks": [list of ROI mask DICOM paths] }

In [1]:
import os
from pathlib import Path
from typing import Dict, List

def scan_dicom_files(base_dir: str) -> Dict[str, List[str]]:
    base_path = Path(base_dir)
    dicom_files = {
        "full_mammo": [],
        "roi_masks": [],
    }

    if not base_path.exists():
        raise ValueError(f"Provided base directory does not exist: {base_dir}")

    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.lower().endswith(".dcm"):
                file_path = Path(root) / file
                lower_root = str(root).lower()

                if "full mammogram images" in lower_root:
                    dicom_files["full_mammo"].append(str(file_path))
                elif "roi mask images" in lower_root:
                    dicom_files["roi_masks"].append(str(file_path))

    print(f"[INFO] Found {len(dicom_files['full_mammo'])} full mammogram DICOMs.")
    print(f"[INFO] Found {len(dicom_files['roi_masks'])} ROI mask DICOMs.")

    return dicom_files


In [None]:
base_dir = "../data/raw/"
dicom_paths = scan_dicom_files(base_dir)

# Access full mammograms and ROI masks
full_mammo_files = dicom_paths["full_mammo"]
roi_mask_files = dicom_paths["roi_masks"]

# Preview first few rows
print("\nExample full mammogram paths:")
for path in full_mammo_files[:3]:
    print(path)

print("\nExample ROI mask paths:")
for path in roi_mask_files[:3]:
    print(path)

[INFO] Found 3103 full mammogram DICOMs.
[INFO] Found 7026 ROI mask DICOMs.

Example full mammogram paths:
../data/raw/Mass-Test_P_00066_LEFT_CC/10-04-2016-DDSM-NA-12982/1.000000-full mammogram images-25433/1-1.dcm
../data/raw/Calc-Test_P_02176_RIGHT_MLO/08-29-2017-DDSM-NA-33174/1.000000-full mammogram images-82696/1-1.dcm
../data/raw/Calc-Training_P_00418_LEFT_CC/08-07-2016-DDSM-NA-95820/1.000000-full mammogram images-43865/1-1.dcm

Example ROI mask paths:
../data/raw/Calc-Training_P_01205_RIGHT_MLO_1/09-06-2017-DDSM-NA-99604/1.000000-ROI mask images-30743/1-1.dcm
../data/raw/Calc-Training_P_01205_RIGHT_MLO_1/09-06-2017-DDSM-NA-99604/1.000000-ROI mask images-30743/1-2.dcm
../data/raw/Calc-Training_P_00778_LEFT_CC_1/09-06-2017-DDSM-NA-38576/1.000000-ROI mask images-63875/1-1.dcm


### Extract Metadata from File Paths

- Extracts abnormality type, patient ID, laterality, and view from CBIS-DDSM DICOM file path.

- Args:
  - path (str or Path): Full path to a DICOM file.

- Returns:
  - dict: { 
      - "abnormality_type": "Calc" or "Mass",
      - "patient_id": "00038",
      - "laterality": "LEFT" or "RIGHT",
      - "view": "CC" or "MLO",
      - "path": full file path (str) }
  - If parsing fails, returns None for the fields.

In [3]:
import re
from pathlib import Path
from typing import Union, Dict

def extract_metadata_from_path(path: Union[str, Path]) -> Dict[str, Union[str, None]]:
    """
    Extract abnormality type, patient ID, laterality, view from any DICOM path
    by searching path parts for the folder name.
    """
    path = Path(path)

    # Try to find the folder matching "Calc-..." or "Mass-..." with pattern
    for part in path.parts:
        pattern = r"^(Calc|Mass)-(Test|Training)_P_(\d+)_([A-Z]+)_(CC|MLO)"
        match = re.match(pattern, part)
        if match:
            abnormality_type, dataset_split, patient_id, laterality, view = match.groups()
            return {
                "abnormality_type": abnormality_type,
                "patient_id": patient_id,
                "laterality": laterality,
                "view": view,
                "path": str(path)
            }
    
    # If no match found
    return {
        "abnormality_type": None,
        "patient_id": None,
        "laterality": None,
        "view": None,
        "path": str(path)
    }

In [16]:
dicom_path = "../data/raw/Calc-Test_P_00038_LEFT_CC/08-29-2017-DDSM-NA-96009/1.000000-full mammogram images-63992/1-1.dcm"

metadata = extract_metadata_from_path(dicom_path)

print("Extracted Metadata:")
for key, value in metadata.items():
    print(f"{key}: {value}")

Extracted Metadata:
abnormality_type: Calc
patient_id: 00038
laterality: LEFT
view: CC
path: ../data/raw/Calc-Test_P_00038_LEFT_CC/08-29-2017-DDSM-NA-96009/1.000000-full mammogram images-63992/1-1.dcm


### Pair Images and Masks

- Pairs full mammogram images with corresponding ROI mask images based on naming convention.

- Args:
  - full_mammo_paths (List[str]): List of full mammogram DICOM paths.
  - roi_mask_paths (List[str]): List of ROI mask DICOM paths.

- Returns:
  - List[Dict[str, str]]: List of dictionaries with keys:
    - abnormality_type
    - patient_id
    - laterality
    - view
    - image_path
    - mask_path

In [5]:
from pathlib import Path
from collections import defaultdict
import re

def get_base_key(path: str) -> str:
    path = Path(path)

    # Search for the correct part
    for part in path.parts:
        pattern = r"^(Calc|Mass)-(Test|Training)_P_(\d+)_([A-Z]+)_(CC|MLO)"
        match = re.match(pattern, part)
        if match:
            abnormality_type, dataset_split, patient_id, laterality, view = match.groups()
            return f"{abnormality_type}-Test_P_{patient_id}_{laterality}_{view}"
    
    # If not found, fallback
    return None


def pair_images_and_masks(full_mammo_paths, roi_mask_paths):
    grouped_masks = defaultdict(list)
    
    for mask_path in roi_mask_paths:
        base_key = get_base_key(mask_path)
        grouped_masks[base_key].append(str(mask_path))

    paired_records = []

    for image_path in full_mammo_paths:
        metadata = extract_metadata_from_path(image_path)
        base_key = get_base_key(image_path)
        mask_list = grouped_masks.get(base_key, [])

        record = metadata.copy()
        record["image_path"] = str(image_path)
        record["mask_paths"] = mask_list  # <== LIST of mask paths!
        paired_records.append(record)

    return paired_records

# Extracts a base key like 'Calc-Test_P_00038_LEFT_CC' or 'Mass-Test_P_00123_RIGHT_MLO'
# from a DICOM file path, ignoring mask/image folder suffixes.
# Args:
#     path (str): Full DICOM file path.
# Returns:
#     str: Standardized base key for pairing (patient_id + view + laterality).

In [6]:
# from scan_dicom_files import scan_dicom_files

dicom_files = scan_dicom_files("../data/raw/")

full_mammo_paths = dicom_files["full_mammo"]
roi_mask_paths = dicom_files["roi_masks"]

paired_metadata = pair_images_and_masks(full_mammo_paths, roi_mask_paths)

# Convert to DataFrame and preview
import pandas as pd
df = pd.DataFrame(paired_metadata)
print(df.head())

# Optional: Save to CSV
df.to_csv("../data/processed/cbis_ddsm_metadata_paired.csv", index=False)

[INFO] Found 3103 full mammogram DICOMs.
[INFO] Found 7026 ROI mask DICOMs.
  abnormality_type patient_id laterality view  \
0             Mass      00066       LEFT   CC   
1             Calc      02176      RIGHT  MLO   
2             Calc      00418       LEFT   CC   
3             Mass      01307      RIGHT  MLO   
4             Mass      00488       LEFT   CC   

                                                path  \
0  ../data/raw/Mass-Test_P_00066_LEFT_CC/10-04-20...   
1  ../data/raw/Calc-Test_P_02176_RIGHT_MLO/08-29-...   
2  ../data/raw/Calc-Training_P_00418_LEFT_CC/08-0...   
3  ../data/raw/Mass-Test_P_01307_RIGHT_MLO/10-04-...   
4  ../data/raw/Mass-Training_P_00488_LEFT_CC/07-2...   

                                          image_path  \
0  ../data/raw/Mass-Test_P_00066_LEFT_CC/10-04-20...   
1  ../data/raw/Calc-Test_P_02176_RIGHT_MLO/08-29-...   
2  ../data/raw/Calc-Training_P_00418_LEFT_CC/08-0...   
3  ../data/raw/Mass-Test_P_01307_RIGHT_MLO/10-04-...   
4  ../data/r

### Consolidate Clinical Metadata CSVs

The four separately provided CBIS-DDSM metadata CSV files contain critical clinical information—such as BI-RADS category, pathology, assessment, and lesion subtlety—that are not embedded in the DICOM files or their folder names.

- data/metadata/calc_case_description_test_set.csv
- data/metadata/calc_case_description_train_set.csv
- data/metadata/mass_case_description_test_set.csv
- data/metadata/mass_case_description_train_set.csv

Each rows contains the following columns:

- Patient ID: `P_00038` — Unique patient identifier.
- Breast Side: `LEFT` — Left breast.
- Image View: `CC` or `MLO` — Standard cranio-caudal or mediolateral oblique view used in mammography.
- Breast Density: `2`
- Abnormality ID: `1`
- Abnormality Type: `calcification` — Specifically dealing with microcalcifications.
- Calcification Type: `PUNCTATE-PLEOMORPHIC` — Mixed types, suggesting variable morphology.
- Calcification Distribution: `CLUSTERED` — Clustered microcalcifications, often suspicious.
- Assessment: `4` — BI-RADS 4, suspicious abnormality; biopsy usually recommended.
- Pathology: `BENIGN` — Biopsy/pathology confirmed the finding as benign.
- Subtlety: `2` — Fairly subtle (1 = very subtle, 5 = very obvious).
- Image Files:
  - Original Image Path: Full mammogram DICOM.
    Calc-Test_P_00038_LEFT_CC/1.3.6.1.4.1.9590.100.1.2.85935434310203356712688695661986996009/1.3.6.1.4.1.9590.100.1.2.374115997511889073021386151921807063992/000000.dcm
  - Cropped Image Path: Focused region where calcifications are.
    Calc-Test_P_00038_LEFT_CC_1/1.3.6.1.4.1.9590.100.1.2.161465562211359959230647609981488894942/1.3.6.1.4.1.9590.100.1.2.419081637812053404913157930753972718515/000001.dcm
  - ROI Mask Path: Binary mask of the calcifications (region of interest).
    Calc-Test_P_00038_LEFT_CC_1/1.3.6.1.4.1.9590.100.1.2.161465562211359959230647609981488894942/1.3.6.1.4.1.9590.100.1.2.419081637812053404913157930753972718515/000000.dcm

The following script reads and merges the metadata CSV files, cleans weird paths, adds missing fields where needed and exports everything into a properly formatted metadata_master.csv.

In [7]:
import pandas as pd
import csv
from pathlib import Path

# Input CSV files
input_files = [
    '../data/metadata/calc_case_description_test_set.csv',
    '../data/metadata/calc_case_description_train_set.csv',
    '../data/metadata/mass_case_description_test_set.csv',
    '../data/metadata/mass_case_description_train_set.csv'
]

# Read CSVs by properly handling malformed newlines and all the data cleanly kept.
dfs = []
for file in input_files:
    df = pd.read_csv(
        file,
        engine="python",        # Use Python engine to properly handle malformed newlines
        quoting=csv.QUOTE_MINIMAL, # Respect quotes
        skip_blank_lines=True   # Optional: Skip totally blank lines
    )
    dfs.append(df)

# Concatenate all together
metadata = pd.concat(dfs, ignore_index=True)

# Rename columns
metadata = metadata.rename(columns={
    'patient_id': 'patient_id',
    'left or right breast': 'side',
    'image view': 'view',
    'abnormality id': 'abnormality_id',
    'abnormality type': 'abnormality_type',
    'calc type': 'calc_type',
    'calc distribution': 'distribution',
    'mass shape': 'mass_shape',
    'mass margins': 'mass_margins',
    'breast density': 'breast_density',
    'assessment': 'assessment',
    'pathology': 'pathology',
    'subtlety': 'subtlety',
    'image file path': 'full_mammo_path',
    'cropped image file path': 'cropped_roi_path',
    'ROI mask file path': 'roi_mask_path'
})

# Add missing columns if needed
for col in ['calc_type', 'distribution', 'mass_shape', 'mass_margins']:
    if col not in metadata.columns:
        metadata[col] = pd.NA

# Normalize file paths
def fix_path(path):
    if pd.isna(path):
        return None
    # Remove a newline character embedded inside the "cropped image file path" field.
    path = path.strip().replace('\\', '/').replace('\"', '')
    parts = Path(path).parts
    if len(parts) < 4:
        return path
    parent_folder = parts[0]
    subfolder = parts[1]
    file_name = parts[-1]
    return f'raw/{parent_folder}/{subfolder}/{file_name}'

# Apply path fixing
for col in ['full_mammo_path', 'cropped_roi_path', 'roi_mask_path']:
    metadata[col] = metadata[col].apply(fix_path)

# Select final columns
final_cols = [
    'patient_id', 'breast_density', 'side', 'view', 'abnormality_id',
    'abnormality_type', 'calc_type', 'distribution', 'mass_shape', 'mass_margins',
    'assessment', 'pathology', 'subtlety',
    'full_mammo_path', 'cropped_roi_path', 'roi_mask_path'
]
metadata = metadata[final_cols]

# Save to CSV
output_path = '../data/metadata/metadata_master.csv'
metadata.to_csv(output_path, index=False)
print(f'Master metadata CSV created: {output_path}')

Master metadata CSV created: ../data/metadata/metadata_master.csv


### Merge Image–Mask Metadata and Clinical Metadata

- Merge paired image–mask metadata with clinical metadata based on patient_id, view, laterality, and abnormality_id.

- Args:
  - paired_metadata_path (str): Path to the paired image-mask metadata CSV.
  - clinical_metadata_path (str): Path to the consolidated clinical metadata CSV.
  - output_path (str): Path to save the merged master metadata CSV.

- Returns:
  - pd.DataFrame: Final merged metadata DataFrame.

In [None]:
import pandas as pd

def merge_metadata(
    paired_metadata_path: str,
    clinical_metadata_path: str,
    output_path: str
) -> pd.DataFrame:

    # Load CSVs
    paired_df = pd.read_csv(paired_metadata_path)
    clinical_df = pd.read_csv(clinical_metadata_path)

    # Normalize column names if needed
    paired_df = paired_df.rename(columns={
        'side': 'laterality'  # If needed (depending on how it was named)
    })

    # Remove 'P_' prefix from clinical metadata patient IDs
    clinical_df['patient_id'] = clinical_df['patient_id'].astype(str).str.replace('P_', '').str.zfill(5)

    # Paired metadata: ensure patient_id is 5-digit zero-padded string
    paired_df['patient_id'] = paired_df['patient_id'].astype(str).str.zfill(5)

    # Ensure consistent casing for joining keys
    paired_df['laterality'] = paired_df['laterality'].astype(str).str.upper()
    paired_df['view'] = paired_df['view'].astype(str).str.upper()
    clinical_df['side'] = clinical_df['side'].astype(str).str.upper()
    clinical_df['view'] = clinical_df['view'].astype(str).str.upper()

    # Merge: LEFT JOIN, keep all paired image/mask metadata
    merged = pd.merge(
        paired_df,
        clinical_df,
        how="left",
        left_on=["patient_id", "laterality", "view"],
        right_on=["patient_id", "side", "view"]
    )

    # Drop duplicate columns (like "side" from clinical metadata)
    if "side" in merged.columns:
        merged = merged.drop(columns=["side"])

    # Save merged DataFrame
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    merged.to_csv(output_path, index=False)

    print(f"[INFO] Master merged metadata saved at: {output_path}")
    print(f"[INFO] Merged metadata shape: {merged.shape}")

    return merged

In [9]:
paired_metadata_path = "../data/processed/cbis_ddsm_metadata_paired.csv"
clinical_metadata_path = "../data/metadata/metadata_master.csv"
output_path = "../data/processed/cbis_ddsm_metadata_full.csv"
merged_metadata = merge_metadata(paired_metadata_path, clinical_metadata_path, output_path)

# Preview
print(merged_metadata.head())

[INFO] Master merged metadata saved at: ../data/processed/cbis_ddsm_metadata_full.csv
[INFO] Merged metadata shape: (3751, 21)
  abnormality_type_x patient_id laterality view  \
0               Mass      00066       LEFT   CC   
1               Calc      02176      RIGHT  MLO   
2               Calc      00418       LEFT   CC   
3               Mass      01307      RIGHT  MLO   
4               Mass      00488       LEFT   CC   

                                                path  \
0  ../data/raw/Mass-Test_P_00066_LEFT_CC/10-04-20...   
1  ../data/raw/Calc-Test_P_02176_RIGHT_MLO/08-29-...   
2  ../data/raw/Calc-Training_P_00418_LEFT_CC/08-0...   
3  ../data/raw/Mass-Test_P_01307_RIGHT_MLO/10-04-...   
4  ../data/raw/Mass-Training_P_00488_LEFT_CC/07-2...   

                                          image_path  \
0  ../data/raw/Mass-Test_P_00066_LEFT_CC/10-04-20...   
1  ../data/raw/Calc-Test_P_02176_RIGHT_MLO/08-29-...   
2  ../data/raw/Calc-Training_P_00418_LEFT_CC/08-0...   
3  ..

### Build TensorFlow Dataset

In [10]:
# pip uninstall tensorflow -y
# python3 -m venv .venv
# source .venv/bin/activate
# pip install tensorflow==2.15.0
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# pip install pandas pydicom

In [11]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [12]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

2025-04-27 21:30:27.080448: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-27 21:30:27.080482: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-27 21:30:27.081650: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


TensorFlow version: 2.15.0
Num GPUs Available: 2
GPU Devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]


In [13]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA device count: 2
Current device: 0
Device name: NVIDIA RTX 6000 Ada Generation


In [15]:
import pandas as pd
import tensorflow as tf
import numpy as np
import pydicom
from pathlib import Path
import ast

# Configuration
IMG_SIZE = (512, 512)  # Resize target size
AUTOTUNE = tf.data.AUTOTUNE

# --- Helper Functions ---

def load_dicom_image(path_tensor):
    """
    Load and normalize a DICOM image from a byte string path.
    """
    path = path_tensor.decode('utf-8')  # Decode byte string to UTF-8
    try:
        ds = pydicom.dcmread(path)
        img = ds.pixel_array.astype(np.float32)
        img -= np.min(img)
        img /= (np.max(img) + 1e-6)  # normalize to [0,1]
    except Exception as e:
        print(f"Failed to load DICOM file: {path} with error: {e}")
        img = np.zeros((512, 512), dtype=np.float32)  # or some fallback
    return img

def load_and_preprocess_image(image_path: tf.Tensor) -> tf.Tensor:
    """
    Load and preprocess a single full mammogram image.
    """
    img = tf.numpy_function(load_dicom_image, [image_path], tf.float32)
    img.set_shape([None, None])  # 2D
    img = tf.expand_dims(img, axis=-1)  # Make it [H, W, 1]
    img.set_shape([None, None, 1])
    img = tf.image.resize(img, IMG_SIZE)
    return img

def load_and_preprocess_mask(mask_paths: tf.Tensor) -> tf.Tensor:
    """
    Load and preprocess multiple ROI masks and combine into a single mask tensor.
    """
    def load_single_mask(path):
        mask = tf.numpy_function(load_dicom_image, [path], tf.float32)
        mask.set_shape([None, None])
        mask = tf.expand_dims(mask, axis=-1)
        mask.set_shape([None, None, 1])
        mask = tf.image.resize(mask, IMG_SIZE)
        return mask

    masks = tf.map_fn(
        load_single_mask,
        mask_paths,
        fn_output_signature=tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 1), dtype=tf.float32)
    )

    # Combine multiple masks into a single one
    combined_mask = tf.reduce_max(masks, axis=0)
    return combined_mask

def parse_record(record):
    """
    Parse a dictionary record into (image, mask) tensors.
    """
    image_path = record['image_path']
    mask_paths = record['mask_paths']

    img = load_and_preprocess_image(image_path)
    mask = load_and_preprocess_mask(mask_paths)

    return img, mask

# --- Main Dataset Builder ---

def build_tf_dataset(
    metadata_csv: str,
    batch_size: int = 8,
    shuffle: bool = True
) -> tf.data.Dataset:
    """
    Build tf.data.Dataset from cbis_ddsm_metadata_full.csv
    """
    # Load metadata
    df = pd.read_csv(metadata_csv)

    # Parse stringified list of mask_paths
    df['mask_paths'] = df['mask_paths'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])

    # Convert dataframe to list of dicts
    records = df[['image_path', 'mask_paths']].to_dict(orient='records')

    # Build tf.data.Dataset
    ds = tf.data.Dataset.from_generator(
        lambda: (r for r in records),
        output_signature={
            "image_path": tf.TensorSpec(shape=(), dtype=tf.string),
            "mask_paths": tf.TensorSpec(shape=(None,), dtype=tf.string),
        }
    )

    ds = ds.map(lambda r: parse_record(r), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(buffer_size=len(records))

    ds = ds.batch(batch_size).prefetch(AUTOTUNE)

    return ds

# Build dataset
train_ds = build_tf_dataset(
    metadata_csv="../data/processed/cbis_ddsm_metadata_full.csv",
    batch_size=8
)

# Preview one batch
for images, masks in train_ds.take(1):
    print(f"Images batch shape: {images.shape}")  # (8, 512, 512, 1)
    print(f"Masks batch shape: {masks.shape}")    # (8, 512, 512, 1)

Images batch shape: (8, 512, 512, 1)
Masks batch shape: (8, 512, 512, 1)


The resulting train_ds is a complete TensorFlow training dataset — images and masks, normalized, resized, shuffled, batched, ready for model training.