diff --git a/scripts/prepare_USGS.py b/scripts/prepare_USGS.py index b0809dd..f8b9c57 100644 --- a/scripts/prepare_USGS.py +++ b/scripts/prepare_USGS.py @@ -19,6 +19,9 @@ and to updating labels (overwriting UBFAI from detection/crops). Use --no-generate-detection-crops or --no-update-labels to skip either step. +Stage 0 parallelizes split_raster across images when multiple CPUs are available +(set SLURM_CPUS_PER_TASK in your sbatch script, or PREPARE_USGS_CROP_WORKERS). + Why new Label Studio annotations might not appear in classification: - Stage 0 (generate detection crops): runs by default; only regenerates a crop CSV when that image's annotation CSV is newer (or the crop is missing). @@ -42,6 +45,7 @@ import os import random import shutil +from concurrent.futures import ProcessPoolExecutor, as_completed import numpy as np import pandas as pd @@ -64,6 +68,37 @@ PATCH_OVERLAP = 0 +def _detection_crop_parallel_workers() -> int: + """Workers for per-image split_raster in Stage 0. + + Prefer SLURM allocation, then explicit override, then host CPU count. + """ + for key in ("PREPARE_USGS_CROP_WORKERS", "SLURM_CPUS_PER_TASK"): + raw = os.environ.get(key) + if raw: + return max(1, int(raw)) + return max(1, os.cpu_count() or 1) + + +def _detection_crop_worker(task: tuple) -> str: + """Run split_raster for one image (pickled args; must stay top-level for spawn).""" + root_dir, save_dir, image_path, records, patch_size, patch_overlap = task + import pandas as pd + from src import data_processing + + annotation_df = pd.DataFrame.from_records(records) + data_processing.process_image( + image_path=image_path, + annotation_df=annotation_df, + root_dir=root_dir, + save_dir=save_dir, + patch_size=patch_size, + patch_overlap=patch_overlap, + allow_empty=True, + ) + return image_path + + def parse_args(): parser = argparse.ArgumentParser( description="Prepare USGS detection data for training" @@ -310,14 +345,38 @@ def generate_detection_crops(): continue combined_refresh = combined[combined["image_path"].isin(images_to_refresh)] - data_processing.preprocess_images( - combined_refresh, - root_dir=root_dir, - save_dir=save_dir, - patch_size=PATCH_SIZE, - patch_overlap=PATCH_OVERLAP, - allow_empty=True, - ) + n_workers = min(_detection_crop_parallel_workers(), len(images_to_refresh)) + if n_workers <= 1: + data_processing.preprocess_images( + combined_refresh, + root_dir=root_dir, + save_dir=save_dir, + patch_size=PATCH_SIZE, + patch_overlap=PATCH_OVERLAP, + allow_empty=True, + ) + else: + tasks = [] + for image_path in images_to_refresh: + ann = combined_refresh[combined_refresh["image_path"] == image_path] + tasks.append( + ( + root_dir, + save_dir, + image_path, + ann.to_dict("records"), + PATCH_SIZE, + PATCH_OVERLAP, + ) + ) + print( + f" {flight_name}: parallel split_raster " + f"({len(tasks)} images, {n_workers} workers)" + ) + with ProcessPoolExecutor(max_workers=n_workers) as pool: + futures = [pool.submit(_detection_crop_worker, t) for t in tasks] + for fut in as_completed(futures): + fut.result() print( f" {flight_name}: refreshed {len(images_to_refresh)} images (of " f"{combined['image_path'].nunique()} total) -> {save_dir}" diff --git a/submit_prepare_annotations.sh b/submit_prepare_annotations.sh new file mode 100755 index 0000000..4be0d66 --- /dev/null +++ b/submit_prepare_annotations.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# One SLURM job: request many CPUs; prepare_USGS.py parallelizes split_raster per image. +# Run from repo root on the cluster: sbatch submit_prepare_annotations.sh + +#SBATCH --job-name=prep_ann +#SBATCH --account=ewhite +#SBATCH --partition=hpg-b200 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64GB +#SBATCH --time=24:00:00 +#SBATCH --output=/home/b.weinstein/logs/prep_ann_%j.out +#SBATCH --error=/home/b.weinstein/logs/prep_ann_%j.err + +cd "${SLURM_SUBMIT_DIR}" +uv run python scripts/prepare_USGS.py