In [1]:
import os
import glob
import numpy as np
import pandas as pd
import torch
import SimpleITK as sitk
from tqdm import tqdm
from monai.networks.nets import BasicUNet
from monai.losses import DiceCELoss
from monai.data import Dataset, DataLoader
from monai.transforms import (
    Compose, EnsureChannelFirstd, ScaleIntensityd, ToTensord, ResizeWithPadOrCropd
)
from torch.nn.functional import one_hot
from torch.optim import Adam

# ==== USER PATHS ====
KTRANS_DIR = r"C:\Users\anude\Downloads\ProstateXKtrains-train-fixed"
DICOM_ROOT = r"C:\Users\anude\OneDrive\Desktop\project"
FINDINGS_CSV = r"C:\Users\anude\OneDrive\Desktop\project\ProstateX-Findings-Train100.csv"
IMAGES_CSV = r"C:\Users\anude\OneDrive\Desktop\project\ProstateX-Images-Train100.csv"

# ==== READ CSVs ====
findings_df = pd.read_csv(FINDINGS_CSV)
images_df = pd.read_csv(IMAGES_CSV)
meta_df = pd.merge(images_df, findings_df, on=["ProxID", "fid"])

In [3]:
# ==== READ CSVs ====
print("Loading metadata CSVs...")
findings_df = pd.read_csv(FINDINGS_CSV)
images_df = pd.read_csv(IMAGES_CSV)
meta_df = pd.merge(images_df, findings_df, on=["ProxID", "fid"])
print("Metadata loaded and merged.")


Loading metadata CSVs...
Metadata loaded and merged.


In [5]:
# ==== UTILS ====
def load_dicom_volume(proxid, series_description="ADC"):
    patient_dir = os.path.join(DICOM_ROOT, proxid)
    found_dicom_files = []
    for root, dirs, files in os.walk(patient_dir):
        if series_description.lower() in root.lower():
            dicom_files = sorted(glob.glob(os.path.join(root, "*.dcm")))
            if dicom_files:
                found_dicom_files = dicom_files
                break
    if not found_dicom_files:
        raise RuntimeError(f"No DICOM series found for {proxid} matching '{series_description}'")
    reader = sitk.ImageSeriesReader()
    reader.SetFileNames(found_dicom_files)
    return reader.Execute()

def load_ktrans(proxid):
    ktrans_path = os.path.join(KTRANS_DIR, proxid, f"{proxid}-Ktrans.mhd")
    if not os.path.exists(ktrans_path):
        raise FileNotFoundError(f"Ktrans not found at {ktrans_path}")
    return sitk.ReadImage(ktrans_path)

def generate_label_mask(image, ijk_list):
    label = np.zeros(sitk.GetArrayFromImage(image).shape, dtype=np.uint8)
    for i, j, k in ijk_list:
        try:
            label[int(k), int(j), int(i)] = 1
        except:
            continue
    label_image = sitk.GetImageFromArray(label)
    label_image.CopyInformation(image)
    return label_image

def parse_ijk(ijk_str):
    try:
        parts = ijk_str.strip().replace(",", " ").split()
        if len(parts) != 3:
            raise ValueError("Invalid IJK format")
        i, j, k = map(int, parts)
        return i, j, k
    except Exception as e:
        print(f"\u26a0\ufe0f  Skipping bad ijk string: {ijk_str} — Reason: {e}")
        return None


In [None]:

# ==== CREATE OUTPUT FOLDERS ====
os.makedirs("images", exist_ok=True)
os.makedirs("labels", exist_ok=True)

# ==== BUILD DATA ====
data = []
print("\n\ud83d\udce6 Building dataset...\n")
for proxid in tqdm(meta_df["ProxID"].unique()):
    try:
        print(f"\n\ud83d\udd0d Processing: {proxid}")
        adc_image = load_dicom_volume(proxid, series_description="ADC")
        print("\u2705 ADC loaded")
        ktrans_image = load_ktrans(proxid)
        print("\u2705 Ktrans loaded")
        ktrans_resampled = sitk.Resample(ktrans_image, adc_image)
        print("\u2705 Ktrans resampled")
        lesions = meta_df[meta_df["ProxID"] == proxid]
        ijk_coords = [parse_ijk(ijk) for ijk in lesions["ijk"].dropna().tolist()]
        ijk_coords = [coord for coord in ijk_coords if coord is not None]
        if not ijk_coords:
            raise ValueError("No lesion ijk coordinates found")
        print(f"\u2705 Lesions found: {len(ijk_coords)}")
        label_image = generate_label_mask(adc_image, ijk_coords)
        adc_path = f"images/{proxid}_adc.nii.gz"
        ktrans_path = f"images/{proxid}_ktrans.nii.gz"
        label_path = f"labels/{proxid}_label.nii.gz"
        sitk.WriteImage(adc_image, adc_path)
        sitk.WriteImage(ktrans_resampled, ktrans_path)
        sitk.WriteImage(label_image, label_path)
        data.append({"image": [adc_path, ktrans_path], "label": label_path})
        print(f"\u2705 Added {proxid} to training set.")
    except Exception as e:
        print(f"❌ Failed on {proxid}: {str(e)}")

print(f"\n\u2705 Total usable training samples: {len(data)}\n")
if len(data) == 0:
    raise ValueError("\u274c No usable training data was found. Please check preprocessing steps.")
