In [1]:
from google.colab import drive
drive.mount('/content/drive')

!pip install rasterio

%cd "/content/drive/MyDrive/img-label-correction-SAM/SAM/segment-anything"
!pip install .

import torch
import rasterio
import numpy as np
import pandas as pd
from datetime import datetime, timezone
import os
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import shape, MultiPolygon
from PIL import Image
from rasterio.warp import transform
from rasterio.features import rasterize, shapes
from segment_anything import sam_model_registry, SamPredictor

import logging
logging.getLogger('rasterio._env').setLevel(logging.ERROR)

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

Mounted at /content/drive
Collecting rasterio
  Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m112.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4

In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import shutil
import rasterio
import geopandas as gpd
from shapely.geometry import shape, MultiPolygon
from rasterio.features import shapes
from segment_anything import sam_model_registry, SamPredictor
from datetime import datetime, timezone

# --- Configuration ---
# (Your paths remain the same)
data_folder_path = '/content/drive/MyDrive/img-label-correction-SAM/src/Arav_App_v1/data/input/'
bad_geojson_path = data_folder_path+'SAM_Test_RTS_20250909.geojson'
output_folder_path = '/content/drive/MyDrive/img-label-correction-SAM/src/Arav_App_v1/data/output/'
output_images_folder_path = output_folder_path+'images/'
output_metadata_folder_path = output_folder_path+'metadata/'
geojson_output_path = output_metadata_folder_path + "predicted_polygons_sam_box_prompt.geojson"
manifest_path = output_metadata_folder_path+"rts_auto_segmentation_manifest.csv"
missing_uid_list_path = output_metadata_folder_path+'missing_uid_list.txt'
sam_folder_path = '/content/drive/MyDrive/img-label-correction-SAM/SAM/'

# --- Setup (SAM Model, Device etc.) ---
# This setup assumes you have a GPU available and have defined 'device'
# For example: device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"

models = [
    {
        "basemodel": "sam",
        "name": "sam_huge",
        "config": "vit_h",
        "checkpoint": sam_folder_path+"sam_vit_h_4b8939.pth"
    }
]

# --- MODIFICATION: ACCUMULATE, DON'T APPEND ---
# We will collect all verified records here and write the manifest once at the end.
all_verified_manifest_records = []
all_polygon_records = []

# --- Main Processing Loop ---
for model_info in models:
    basemodel = model_info["basemodel"]
    model_name = model_info["name"]
    config = model_info["config"]
    checkpoint = model_info["checkpoint"]

    sam = sam_model_registry[config](checkpoint=checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    print(f"======= Loaded model: {model_info['name']}")

    level_list=['1','2','3']
    complexity_list=['A','B','C']
    bad_jsondf = gpd.read_file(bad_geojson_path)

    for level in level_list:
        for complexity in complexity_list:
            level_path = f'Level {level}/{complexity}{level}'
            tif_path = os.path.join(data_folder_path, level_path)
            if not os.path.exists(tif_path): continue

            tif_list= [f[:-4] for f in os.listdir(tif_path) if f.endswith('.tif')]
            print(f"--- Processing folder: {level_path} ---")

            for uid in tif_list:
                print(f"Processing {uid}")
                geotiff_path = os.path.join(tif_path, uid + '.tif')
                bad_json_uid_record = bad_jsondf[bad_jsondf['UID'] == uid]

                if bad_json_uid_record.empty:
                    with open(missing_uid_list_path, "a") as f:
                        f.write(f"{uid} {level_path}\n")
                    continue

                with rasterio.open(geotiff_path) as geotiff_reader:
                    allbands_array = geotiff_reader.read()
                    tiff_profile = geotiff_reader.profile.copy()
                    tiff_image = np.transpose(allbands_array, (1, 2, 0))
                    tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]

                    bad_polygon_geom = bad_json_uid_record.iloc[0].geometry
                    delineation_date = bad_json_uid_record.iloc[0].get('BaseMapDate', '') # Safer way to get data

                    minx, miny, maxx, maxy = bad_polygon_geom.bounds
                    top_left = geotiff_reader.index(minx, maxy)
                    bottom_right = geotiff_reader.index(maxx, miny)
                    pad = 0
                    input_box = np.array([
                        max(0, top_left[1] - pad), max(0, top_left[0] - pad),
                        min(tiff_image_rgb.shape[1], bottom_right[1] + pad),
                        min(tiff_image_rgb.shape[0], bottom_right[0] + pad)
                    ])

                    predictor.set_image(tiff_image_rgb)
                    masks, scores, logits = predictor.predict(box=input_box[None, :], multimask_output=True)

                    for i, mask in enumerate(masks):
                        output_fileid = f'{uid}_{i}'
                        png_path = os.path.join(output_images_folder_path, output_fileid + '.png')
                        tif_path_out = os.path.join(output_images_folder_path, output_fileid + '.tif')

                        try:
                            # --- 1. Create and Save PNG ---
                            mask_uint8 = (mask > 0).astype(np.uint8)
                            rgba_image = np.dstack((tiff_image_rgb, np.full(tiff_image_rgb.shape[:2], 255, dtype=np.uint8)))
                            overlay = np.zeros_like(rgba_image)
                            overlay[:, :, 0] = mask_uint8 * 255
                            overlay[:, :, 3] = mask_uint8 * int(255 * 0.3)
                            composited = Image.alpha_composite(Image.fromarray(rgba_image), Image.fromarray(overlay))
                            composited.save(png_path)

                            # --- 2. Create and Save 4-Band TIFF ---
                            rgb_chw = np.transpose(tiff_image_rgb, (2, 0, 1))
                            mask_chw = mask[None, :, :]
                            output_array = np.vstack([rgb_chw, mask_chw])
                            tiff_profile.update(count=4, dtype=rasterio.uint8, driver='GTiff')
                            tiff_profile.pop("nodata", None)
                            with rasterio.open(tif_path_out, 'w', **tiff_profile) as dst:
                                dst.write(output_array.astype(rasterio.uint8))

                            # --- 3. MODIFICATION: VERIFY FILES EXIST ---
                            if os.path.exists(png_path) and os.path.exists(tif_path_out):
                                # --- 4. MODIFICATION: IF VERIFIED, ADD RECORD TO LISTS ---

                                # Add to manifest list
                                manifest_record = {
                                    'uid': uid,
                                    'method': 'Auto_Generated', 'approval_status': 'Pending',
                                    'base_tiff': uid + '.tif', 'delineation_date': delineation_date,
                                    'worker': '', 'mask id': i,
                                    'output_fileid': output_fileid,
                                    'output_filename': output_fileid + '.tif',
                                    'output_folder_location': output_images_folder_path,
                                    'assigned_time_utc': None, 'completed_time_utc': None,
                                    'notes':''
                                }
                                all_verified_manifest_records.append(manifest_record)

                                # Add to polygon list for GeoJSON
                                mask_bool = mask.astype(bool)
                                mask_polygons = [shape(geom) for geom, value in shapes(mask_bool.astype(np.uint8), mask=mask_bool, transform=geotiff_reader.transform) if value == 1]
                                if mask_polygons:
                                    combined_geom = MultiPolygon(mask_polygons) if len(mask_polygons) > 1 else mask_polygons[0]
                                    polygon_record = {
                                        "geometry": combined_geom, "UID": uid,
                                        "output_fileid": output_fileid, "basemodel": basemodel,
                                        "model_name": model_name, "prompt": 'box', "maskid": i
                                    }
                                    all_polygon_records.append(polygon_record)
                            else:
                                print(f"    ❌ Verification failed for {output_fileid}. Files not found after saving. Skipping manifest entry.")

                        except Exception as e:
                            print(f"    ❌ An error occurred processing mask {i} for {uid}: {e}. Skipping manifest entry.")

# --- MODIFICATION: WRITE MANIFEST AND GEOJSON ONCE AT THE END ---
if all_verified_manifest_records:
    print("\n--- All UIDs processed. Writing final manifest file... ---")
    final_df = pd.DataFrame(all_verified_manifest_records)
    # Overwrite the old manifest completely to ensure a clean file
    final_df.to_csv(manifest_path, index=False)
    print(f"✅ Successfully created manifest with {len(final_df)} records at: {manifest_path}")
else:
    print("⚠️ No records were successfully processed and verified. Manifest file was not created.")

if all_polygon_records:
    print("\n--- Writing final GeoJSON file... ---")
    gdf = gpd.GeoDataFrame(all_polygon_records, crs=tiff_profile['crs'])
    gdf.to_file(geojson_output_path, driver="GeoJSON")
    print(f"✅ SAM output polygons saved to: {geojson_output_path}")
else:
    print("⚠️ No mask polygons were generated to save.")

print("\nProcessing Completed")


--- Processing folder: Level 1/A1 ---
Processing eee7c4c4-d3c1-5dcf-9ab7-a0d9e64f3958
Processing c1fc9daf-0d41-59cf-811c-8ba864120d3d
Processing f7522b25-37a3-59e3-980a-fa855fcdf520
Processing e8e1c95a-2b30-5d1b-aaab-6c0baa7d3c11
Processing e59250bd-0a30-5ab2-b96f-85109552723b
Processing d8858bfb-89e8-5a7d-b64c-85517919fd26
Processing ef526c89-fda1-536a-9869-0516e2882327
Processing c2a75e01-8f08-5829-9b6a-0c3c2f4d2967
Processing be2f8d5b-11a7-5e90-b01e-80cdf65b1712
Processing e461289f-ec38-5bc2-846b-7237dd2b1015


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing cd2e956e-2eb7-5c56-87c8-28e9f1bf6759
Processing d68e3cd9-f88e-5f09-bc86-e75d549af64c


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing d0a82842-ea84-57ff-9b74-8d344a7816fc
Processing f64c6b97-a6bb-5884-81ff-8b1443666aee
Processing d710212e-5efd-5f2c-beee-7814ba06dc9a
Processing ea276a88-b7f8-5ca0-bf4a-f4adfa0b4b88


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing c63ba86d-2019-52ad-9ecf-99fdfd3df4e3
Processing d40e2d7e-bbb7-5bc0-8aee-3ffd68064b2f
Processing d557366d-965b-5229-b296-faf3546d4d1f
Processing be34e400-0f70-51cd-a585-6b07c9182d77
Processing fd19e7a9-e608-5785-b7f5-c8ea2dfcb0cf
Processing e5cc5251-6a71-5c2b-a82d-8d45d655810a
Processing e5f525be-e4e1-5633-a1bb-9de5259ec6db
Processing e7f21b11-1f9d-52d2-9bbb-3a727632b78f
Processing f4b4d191-d1d4-55b2-a7be-1c52ce902c0a
Processing cfb5245d-7f2f-5f02-a3e7-d77f07ffbbf8
Processing eeb9ce75-def7-59af-83e5-e8cc00d24669
Processing f22a971c-24d0-5085-b794-3f6365a85a41


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing d9977814-1106-5689-9bbe-1c5c73405988
Processing f9229c27-3eca-5940-a445-827cc71d82a6


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing c83f9928-8336-5c84-879d-6a0ff43ed8ef
Processing e4309004-216a-5a95-9327-620d4a15a4aa


  tiff_image_rgb = tiff_image.astype(np.uint8)[:, :, :3]


Processing c5e16665-bccf-5262-a9c7-e8ed6296b43f
Processing d8613848-69dc-53a0-a3b2-d4c0403dcdf1
Processing e78c3135-76f0-58f7-8983-186c5c5d18e5
Processing df0ba7b7-a8cd-5659-9321-5d36c9e07a2a
Processing fe3cc38f-906f-5efc-807b-5e58bd54d9fa
Processing dcea6013-6c6f-5ae9-b2d0-ea5e4626d82b
Processing fe7c47c4-0900-5a83-b664-c6742b50e4cf
--- Processing folder: Level 1/B1 ---
--- Processing folder: Level 1/C1 ---
--- Processing folder: Level 2/A2 ---
--- Processing folder: Level 2/B2 ---
--- Processing folder: Level 2/C2 ---
--- Processing folder: Level 3/A3 ---
--- Processing folder: Level 3/B3 ---
--- Processing folder: Level 3/C3 ---

--- All UIDs processed. Writing final manifest file... ---
✅ Successfully created manifest with 117 records at: /content/drive/MyDrive/img-label-correction-SAM/src/Arav_App_v1/data/output/metadata/rts_auto_segmentation_manifest.csv

--- Writing final GeoJSON file... ---
✅ SAM output polygons saved to: /content/drive/MyDrive/img-label-correction-SAM/src/Arav