In [None]:
import os
import shutil
import torch
import numpy as np
import SimpleITK as sitk
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

# ==============================================================================
# 1. CONFIGURATION & AUTOMATIC PATH DISCOVERY
# ==============================================================================
# Define the root directory where all the series folders are located.
SERIES_ROOT_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series'

# The OUTPUT_DIR you specified
OUTPUT_DIR = "/kaggle/working/resampled_nifti"

# --- AUTOMATICALLY FIND ALL SERIES DIRECTORIES ---
# The script will now scan the SERIES_ROOT_DIR and create the list of paths to process.
# This replaces the need for you to create a list manually.
print(f"Scanning for series directories in: {SERIES_ROOT_DIR}...")
batch_series_paths = sorted([
    os.path.join(SERIES_ROOT_DIR, name)
    for name in os.listdir(SERIES_ROOT_DIR)
    if os.path.isdir(os.path.join(SERIES_ROOT_DIR, name))
])

if not batch_series_paths:
    raise FileNotFoundError(f"No series directories found in {SERIES_ROOT_DIR}. Please check the path.")

print(f"Found {len(batch_series_paths)} unique series to process.")


# ==============================================================================
# 2. THE DATASET CLASS (No changes needed here)
# ==============================================================================

class DicomSeriesDataset(Dataset):
    """
    This Dataset class is already correctly designed to work with a list of
    directory paths. It checks for pre-processed NIfTI files to allow for
    restarting a failed run.
    """
    def __init__(self, series_paths, new_spacing=(1, 1, 1), transform=None, save_dir=None):
        self.series_paths = series_paths
        self.new_spacing = new_spacing
        self.transform = transform
        self.save_dir = save_dir
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)

    def __len__(self):
        return len(self.series_paths)

    def __getitem__(self, idx):
        series_path = self.series_paths[idx]
        series_name = os.path.basename(os.path.normpath(series_path))
        
        if self.save_dir:
            save_path = os.path.join(self.save_dir, f"{series_name}.nii.gz")
            if os.path.exists(save_path):
                resampled_img = sitk.ReadImage(save_path)
            else:
                resampled_img = self._load_and_resample_dicom(series_path, save_path)
        else:
            resampled_img = self._load_and_resample_dicom(series_path)

        arr = sitk.GetArrayFromImage(resampled_img).astype(np.float32)
        arr = np.expand_dims(arr, axis=0)
        tensor = torch.from_numpy(arr)

        if self.transform:
            tensor = self.transform(tensor)

        return tensor, series_name

    def _load_and_resample_dicom(self, series_path, save_path=None):
        try:
            reader = sitk.ImageSeriesReader()
            dicom_names = reader.GetGDCMSeriesFileNames(series_path)
            if not dicom_names:
                print(f"Warning: No DICOM files found in {series_path}")
                return sitk.Image(1, 1, 1, sitk.sitkFloat32)
            reader.SetFileNames(dicom_names)
            sitk_img = reader.Execute()

            if sitk_img.GetDimension() == 4:
                size = list(sitk_img.GetSize())
                size[3] = 0; index = [0, 0, 0, 0]
                sitk_img = sitk.Extract(sitk_img, size, index)

            sitk_img = sitk.Cast(sitk_img, sitk.sitkFloat32)
            original_spacing = sitk_img.GetSpacing()
            original_size = sitk_img.GetSize()
            new_size = [int(round(osz * ospc / nspc)) for osz, ospc, nspc in zip(original_size, original_spacing, self.new_spacing)]

            resample_filter = sitk.ResampleImageFilter()
            resample_filter.SetDefaultPixelValue(0)
            resample_filter.SetOutputSpacing(self.new_spacing)
            resample_filter.SetSize(new_size)
            resample_filter.SetOutputDirection(sitk_img.GetDirection())
            resample_filter.SetOutputOrigin(sitk_img.GetOrigin())
            resample_filter.SetInterpolator(sitk.sitkLinear)
            resampled_img = resample_filter.Execute(sitk_img)

            if save_path:
                sitk.WriteImage(resampled_img, save_path)
                
            return resampled_img
        except Exception as e:
            print(f"Error processing {series_path}: {e}")
            return sitk.Image(1, 1, 1, sitk.sitkFloat32)


# ==============================================================================
# 3. MAIN PROCESSING SCRIPT WITH CHUNKING LOGIC
# ==============================================================================

# --- CONFIGURATION ---
CHUNK_SIZE = 100
TOTAL_FILES = len(batch_series_paths)
NUM_CHUNKS = -(-TOTAL_FILES // CHUNK_SIZE) # Ceiling division

print(f"Total unique series to process: {TOTAL_FILES}")
print(f"Processing in {NUM_CHUNKS} chunks of up to {CHUNK_SIZE} series each.")
print("-" * 50)


# --- MAIN LOOP ---
for i in range(NUM_CHUNKS):
    start_index = i * CHUNK_SIZE
    end_index = min((i + 1) * CHUNK_SIZE, TOTAL_FILES)
    
    print(f"\n--- Processing Chunk {i+1}/{NUM_CHUNKS}: Series {start_index} to {end_index-1} ---")
    
    current_paths_chunk = batch_series_paths[start_index:end_index]
    
    dataset = DicomSeriesDataset(current_paths_chunk, new_spacing=(1, 1, 1), save_dir=OUTPUT_DIR)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
    
    for _, series_name in tqdm(dataloader, desc=f"Chunk {i+1}"):
        pass
        
    print(f"\n✅ Chunk {i+1} processed successfully.")
    print(f"All {len(current_paths_chunk)} NIfTI files are saved in: {OUTPUT_DIR}")
    
    zip_filename = f"chunk_{i+1}_series_{start_index}_to_{end_index-1}.zip"
    shutil.make_archive(zip_filename.replace('.zip', ''), 'zip', OUTPUT_DIR)
    
    print("\n--- ACTION REQUIRED ---")
    print(f"1. A zip file named '{zip_filename}' has been created in /kaggle/working/.")
    print("2. Download this zip file to your local machine now.")
    print("3. After the download is complete, return to this notebook.")

    if i < NUM_CHUNKS - 1:
        input("4. Press [Enter] here to DELETE the processed files and continue to the next chunk...")
        
        try:
            shutil.rmtree(OUTPUT_DIR)
            os.remove(zip_filename)
            print(f"🧹 Cleaned up {OUTPUT_DIR} and {zip_filename}. Ready for the next chunk.")
        except OSError as e:
            print(f"Error during cleanup: {e}. Please manually delete the folder.")
    else:
        print("\n🎉 All chunks have been processed!")

print("-" * 50)
print("Processing complete.")

Scanning for series directories in: /kaggle/input/rsna-intracranial-aneurysm-detection/series...
Found 4348 unique series to process.
Total unique series to process: 4348
Processing in 44 chunks of up to 100 series each.
--------------------------------------------------

--- Processing Chunk 1/44: Series 0 to 99 ---


Chunk 1:   0%|          | 0/100 [00:00<?, ?it/s]

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:9.92447e-05

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.697059

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000312229

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.00014289

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000746138

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000162543

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:0.000101863

ImageSeriesReader (0x404fc350): Non uniform sampling or missing slices detected,  maximum nonuniformity:9.4387e-05

