Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 67 additions & 8 deletions scripts/prepare_USGS.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
and to updating labels (overwriting UBFAI from detection/crops). Use
--no-generate-detection-crops or --no-update-labels to skip either step.

Stage 0 parallelizes split_raster across images when multiple CPUs are available
(set SLURM_CPUS_PER_TASK in your sbatch script, or PREPARE_USGS_CROP_WORKERS).

Why new Label Studio annotations might not appear in classification:
- Stage 0 (generate detection crops): runs by default; only regenerates a crop
CSV when that image's annotation CSV is newer (or the crop is missing).
Expand All @@ -42,6 +45,7 @@
import os
import random
import shutil
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import pandas as pd
Expand All @@ -64,6 +68,37 @@
PATCH_OVERLAP = 0


def _detection_crop_parallel_workers() -> int:
"""Workers for per-image split_raster in Stage 0.

Prefer SLURM allocation, then explicit override, then host CPU count.
"""
for key in ("PREPARE_USGS_CROP_WORKERS", "SLURM_CPUS_PER_TASK"):
raw = os.environ.get(key)
if raw:
return max(1, int(raw))
return max(1, os.cpu_count() or 1)


def _detection_crop_worker(task: tuple) -> str:
"""Run split_raster for one image (pickled args; must stay top-level for spawn)."""
root_dir, save_dir, image_path, records, patch_size, patch_overlap = task
import pandas as pd
from src import data_processing

annotation_df = pd.DataFrame.from_records(records)
data_processing.process_image(
image_path=image_path,
annotation_df=annotation_df,
root_dir=root_dir,
save_dir=save_dir,
patch_size=patch_size,
patch_overlap=patch_overlap,
allow_empty=True,
)
return image_path


def parse_args():
parser = argparse.ArgumentParser(
description="Prepare USGS detection data for training"
Expand Down Expand Up @@ -310,14 +345,38 @@ def generate_detection_crops():
continue

combined_refresh = combined[combined["image_path"].isin(images_to_refresh)]
data_processing.preprocess_images(
combined_refresh,
root_dir=root_dir,
save_dir=save_dir,
patch_size=PATCH_SIZE,
patch_overlap=PATCH_OVERLAP,
allow_empty=True,
)
n_workers = min(_detection_crop_parallel_workers(), len(images_to_refresh))
if n_workers <= 1:
data_processing.preprocess_images(
combined_refresh,
root_dir=root_dir,
save_dir=save_dir,
patch_size=PATCH_SIZE,
patch_overlap=PATCH_OVERLAP,
allow_empty=True,
)
else:
tasks = []
for image_path in images_to_refresh:
ann = combined_refresh[combined_refresh["image_path"] == image_path]
tasks.append(
(
root_dir,
save_dir,
image_path,
ann.to_dict("records"),
PATCH_SIZE,
PATCH_OVERLAP,
)
)
print(
f" {flight_name}: parallel split_raster "
f"({len(tasks)} images, {n_workers} workers)"
)
with ProcessPoolExecutor(max_workers=n_workers) as pool:
futures = [pool.submit(_detection_crop_worker, t) for t in tasks]
for fut in as_completed(futures):
fut.result()
print(
f" {flight_name}: refreshed {len(images_to_refresh)} images (of "
f"{combined['image_path'].nunique()} total) -> {save_dir}"
Expand Down
17 changes: 17 additions & 0 deletions submit_prepare_annotations.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
# One SLURM job: request many CPUs; prepare_USGS.py parallelizes split_raster per image.
# Run from repo root on the cluster: sbatch submit_prepare_annotations.sh

#SBATCH --job-name=prep_ann
#SBATCH --account=ewhite
#SBATCH --partition=hpg-b200
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --mem=64GB
#SBATCH --time=24:00:00
#SBATCH --output=/home/b.weinstein/logs/prep_ann_%j.out
#SBATCH --error=/home/b.weinstein/logs/prep_ann_%j.err

cd "${SLURM_SUBMIT_DIR}"
uv run python scripts/prepare_USGS.py
Loading