<a href="https://colab.research.google.com/github/spate472/RecreatingRetinaUNet/blob/main/Attempt1_DicomFilesModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install pyhealth pydicom scikit-image matplotlib
!pip install pandas==2.2.2
!pip install torch torchvision

In [None]:
import pydicom
import pandas
import numpy
import pyhealth
import skimage

print(pydicom.__version__)
print(pandas.__version__)
print(numpy.__version__)
print(pyhealth.__version__)
print(skimage.__version__)

In [None]:
import os
import pydicom


def get_dcm_file_paths(folder_path):
  """Gets the full paths of all .dcm files within a folder and its subfolders.

  Args:
    folder_path: The path to the folder.

  Returns:
    A list of the full paths of the .dcm files.
  """
  file_paths = []
  for item in os.listdir(folder_path):
    item_path = os.path.join(folder_path, item)
    if os.path.isfile(item_path) and item.lower().endswith(".dcm"):
      file_paths.append(item_path)
    elif os.path.isdir(item_path):
      file_paths.extend(get_dcm_file_paths(item_path))  # Recursive call
  return file_paths

# Get the path to the folder, assuming it's in your Drive
folder_path = '/content/drive/MyDrive/DLH_Project/RetinaUNetDataset'

file_paths = get_dcm_file_paths(folder_path)

file_paths = file_paths[0:100]  # Extract the first 1000 elements

# Print the subset of file paths with index
for index, file_path in enumerate(file_paths):
  print(f"Index: {index}, File Path: {file_path}")

In [None]:
import numpy as np
from skimage.transform import resize
import pydicom

def process_dicom_image(file_path, target_shape=(128, 128)):
    """Processes a DICOM image file.

    Args:
        file_path (str): The path to the DICOM file.
        target_shape (tuple, optional): The desired shape of the output image. Defaults to (128, 128).

    Returns:
        numpy.ndarray: The processed image as a flattened array.
                         Returns None if the file does not contain pixel data.
    """
    ds = pydicom.dcmread(file_path)

    # Check if the DICOM file contains pixel data
    if not hasattr(ds, 'PixelData'):
        print(f"Warning: DICOM file '{file_path}' does not contain pixel data. Skipping.")
        return None

    img = ds.pixel_array.astype(np.float32)

    # Normalize
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-5)

    # Resize
    img_resized = resize(img, target_shape)

    # Flatten
    return img_resized.flatten()

In [None]:
from pyhealth.data import Patient, Visit, Event

def build_patient_from_dicom(file_path, patient_id):
    vector = process_dicom_image(file_path)

    # Check if vector is None before proceeding
    if vector is None:
        print(f"Skipping file {file_path} due to missing or invalid pixel data.")
        return None  # or handle it differently

    # Build patient & visit
    patient = Patient(patient_id)
    visit = Visit(visit_id=f"{patient_id}-1", patient_id=patient_id, timestamp=0)

    # Create an Event object and pass it to add_event()
    event = Event(visit_id=visit.visit_id, patient_id=patient.patient_id, event_name="image", event_value=vector.tolist())
    visit.add_event(event)
    patient.add_visit(visit)
    print(patient)
    return patient

In [None]:
from pyhealth.datasets import SampleEHRDataset

patients = []

for idx, file_path in enumerate(file_paths):
    print(f"Index: {idx}, Processing {file_path}")
    pid = f"patient{idx}"
    patient = build_patient_from_dicom(file_path, pid)
    # Only append valid patient data to the list
    if patient is not None:
        # Convert Patient object to dictionary using to_dict() method
        patients.append(patient)

# Initialize the SampleEHRDataset with the 'patients' list
dataset = SampleEHRDataset(dataset_name="dicom_dataset", samples=patients)

In [None]:
def dummy_label_fn(patient):
    return 1 if "positive" in patient.patient_id else 0

task = dataset.get_patient_prediction_task(
    task_name="dummy_binary_classification",
    label_fn=dummy_label_fn,
    feature_keys=["image"],
)


In [None]:
from pyhealth.models import RNN  # or MLP for static features
from pyhealth.trainer import Trainer

model = RNN(dataset=dataset, feature_keys=["image"], label_key="label")

trainer = Trainer(model=model, dataset=dataset, task=task)
trainer.train()
