# Notebook project about testing SAM to segment Cardiac MRI scans

## Google Colab configuration

In [None]:
!git clone https://github.com/Silvano315/Med-Physics.git

In [None]:
# change working directory

import os 

os.chdir("Med-Physics")
os.getcwd()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Copy your Kaggle API to /root/.config/kaggle and /root/.kaggle/kaggle.json

os.makedirs('/root/.kaggle', exist_ok = True)

!cp /content/drive/MyDrive/Kaggle_api/kaggle.json /root/.config/kaggle.json
!cp /content/drive/MyDrive/Kaggle_api/kaggle.json /root/.kaggle/kaggle.json

In [None]:
# Install requirements

!pip install segment_anything

## Import Libraries

In [None]:
# If you're running this repository LOCALLY, RUN this cell:

import os
os.chdir("..")
os.getcwd()

'/Users/silvanoquarto/Desktop/LAVORO/MEDICAL_PHYSICS/Med-Physics'

In [None]:
# Import libraries 

import kaggle
import urllib.request
from pathlib import Path
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry, SamPredictor
import h5py
import pandas as pd
import re
from pathlib import Path
from typing import List, Dict

## Setup and Configuration

In [9]:
# Set up device to use

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cpu


In [10]:
# Create directories

data_dir = Path('data')
data_dir.mkdir(exist_ok=True)

## Download dataset

In [None]:
def download_kaggle_dataset(dataset_name : str = None, kaggle_url : str = None):
    """
    Download ACDC dataset from Kaggle using Kaggle API.
    Requires:
    1. Kaggle account
    2. API token (kaggle.json) in ~/.kaggle/
    3. kaggle package installed: pip install kaggle
    """

    data_dir = Path('data')
    data_dir.mkdir(exist_ok=True)

    print(f"Downloading {dataset_name} dataset from Kaggle...")
    try:
        kaggle.api.authenticate()
        kaggle.api.dataset_download_files(
            kaggle_url,
            path=data_dir,
            unzip=True
        )
        print("Dataset downloaded and extracted successfully!")
        
    except Exception as e:
        print(f"Error downloading dataset: {e}")
        print("\nPlease ensure you have:")
        print("1. Created a Kaggle account")
        print("2. Generated an API token from https://www.kaggle.com/settings")
        print("3. Placed kaggle.json in ~/.kaggle/")
        print("4. Set appropriate permissions: chmod 600 ~/.kaggle/kaggle.json")
        raise

In [None]:
# Download ACDC dataset from Kaggle using API key

download_kaggle_dataset(dataset_name="ACDC", kaggle_url='anhoangvo/acdc-dataset')

Downloading ACDC dataset from Kaggle...
Dataset URL: https://www.kaggle.com/datasets/anhoangvo/acdc-dataset
Dataset downloaded and extracted successfully!


## Cardiac MRI Segmentation with SAM - Exploratory Analysis

In [13]:
# Set up SAM Model

def setup_sam():
    """Initialize and load SAM model."""

    sam_checkpoint = "sam_vit_b_01ec64.pth"
    checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"

    if not os.path.exists(sam_checkpoint):
        print("Downloading SAM checkpoint...")
        urllib.request.urlretrieve(checkpoint_url, sam_checkpoint)

    model_type = "vit_b"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=DEVICE)

    return sam

In [16]:
sam = setup_sam()
predictor = SamPredictor(sam)

Downloading SAM checkpoint...


  state_dict = torch.load(f)


## Data Loading & Preprocessing

In [None]:
def list_h5_files(data_dir: Path, subset: str = 'training'):
    """
    List all H5 files in the specified directory.
    
    Args:
        data_dir: Base directory containing the dataset
        subset: 'training' or 'testing'
        
    Returns:
        List of paths to H5 files
    """
    if 'training' in subset:
        pattern = f"**/*_{subset}/*.h5"
    else:
        pattern = f"**/*_{subset}_*/*.h5"
    
    return list(data_dir.glob(pattern))

In [None]:
training_volumes_files = list_h5_files(data_dir, 'training_volumes')
training_slices_files = list_h5_files(data_dir, 'training_slices')
testing_volumes_files = list_h5_files(data_dir, 'testing')

print(f"Found {len(training_slices_files)} training slises \
and {len(training_volumes_files)} training volumes files \
and {len(testing_volumes_files)} testing volumes files")

In [None]:
def create_dataset_info(file_lists: Dict[str, List[Path]]) -> pd.DataFrame:
    """
    Create a DataFrame with information about the dataset files.
    
    Args:
        file_lists: Dictionary with keys 'training_volumes', 'training_slices', 'testing_volumes'
                   and corresponding lists of Path objects
    
    Returns:
        DataFrame with columns: patient_id, frame, slice (if applicable), type, path
    """
    all_data = []
    
    pattern = r'patient(\d+)_frame(\d+)(?:_slice_(\d+))?'
    
    for data_type, files in file_lists.items():
        for file_path in files:
            match = re.search(pattern, file_path.name)
            if match:
                patient_id = match.group(1)
                frame = match.group(2)
                slice_num = match.group(3)
                
                data_entry = {
                    'patient_id': int(patient_id),
                    'frame': int(frame),
                    'slice': int(slice_num) if slice_num else None,
                    'type': data_type,
                    'path': str(file_path)
                }
                all_data.append(data_entry)
    
    df = pd.DataFrame(all_data)
    df = df.sort_values(['patient_id', 'frame', 'slice'])
    
    return df

In [None]:
file_lists = {
        'training_volumes': training_volumes_files,
        'training_slices': training_slices_files,
        'testing_volumes': testing_volumes_files
    }

dataset_df = create_dataset_info(file_lists)
print("\nDataset Overview:")
dataset_df

In [None]:
print("Dataset Statistics:")
print(f"Total files: {len(dataset_df)}")
print("\nFiles per type:")
print(dataset_df['type'].value_counts())
print("\nUnique patients:", dataset_df['patient_id'].nunique())

In [None]:
def load_h5_data(file_path: str):
    """
    Load data from H5 file.
    
    Args:
        file_path: Path to H5 file
        
    Returns:
        Dictionary containing image and mask data
    """
    with h5py.File(file_path, 'r') as f:
        print(f"Available keys in {Path(file_path).name}:", list(f.keys()))
        
        data = {}
        if 'image' in f:
            data['image'] = f['image'][:]
        if 'label' in f:
            data['label'] = f['label'][:]
        if 'scribble' in f:
            data['scribble'] = f['scribble'][:]
            
        return data

In [None]:
if training_volumes_files:
        sample_data = load_h5_data(training_volumes_files[0])

sample_data['image'], sample_data['label'], sample_data['scribble']