In [1]:
!pip install rasterio

import pandas as pd
import os
import shutil
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import output
output.enable_custom_widget_manager()
import warnings
warnings.filterwarnings("ignore")
import logging
logging.getLogger('rasterio._env').setLevel(logging.ERROR)

from google.colab import drive
drive.mount('/content/drive')

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m69.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1.2-py2.py3-none-any.whl (11 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1.2 cligj-0.7.2 rasterio-1.4.3
Mounted at /content/driv

In [3]:
data_folder_path = '/content/drive/MyDrive/img-label-correction-SAM/data/raw/'
output_folder_path = '/content/drive/MyDrive/img-label-correction-SAM/output/'
output_metadata_folder_path = output_folder_path+'metadata/'
manifest_path = output_metadata_folder_path+"rts_auto_segmentation_manifest.csv"

input_folder = output_folder_path+'images/'
output_folder = output_folder_path+'approved_images/'

df = pd.read_csv(manifest_path)
is_approved = df['approval_status'] == 'Approved'
uids_with_approved = df.loc[is_approved, 'uid'].unique()
all_uids = df['uid'].unique()
uids = [uid for uid in all_uids if uid not in uids_with_approved]

uid_index = [0]  # update index when next is clicked
output = widgets.Output()

def show_uid_images(uid):
  output.clear_output(wait=True)
  uid_df = df[df['uid'] == uid].sort_values(by='iou', ascending=False)

  with output:
    display(widgets.HTML(f"<h4>UID: {uid}</h4>"))
    widgets_list = []

    # --------- BASE TIFF DISPLAY (first row, no mask) ---------
    first_row = uid_df.iloc[0]
    uid=first_row['uid']
    level=first_row['level']
    complexity=first_row['complexity']
    level_path = 'Level ' + str(level) + '/' + str(complexity)+'/'
    base_filepath = data_folder_path+ level_path+ uid+'.tif'
    iou=first_row['iou']
    with rasterio.open(base_filepath) as src:
        base_array = src.read()
    base_rgb = np.transpose(base_array, (1, 2, 0)).astype(np.uint8)[:, :, :3]

    fig_base, ax_base = plt.subplots(figsize=(3, 3))
    ax_base.imshow(base_rgb)
    ax_base.set_title(f"Base Image", fontsize=9)
    ax_base.axis('off')
    plt.close(fig_base)

    display(fig_base)

    # --------- Masked Png DISPLAY WITH SAVE BUTTONS ---------
    for i, row in uid_df.iterrows():

      filename = row['output_fileid']
      filename_short = filename[filename.index("sam"):] if "sam" in filename else filename
      iou=row['iou']
      filename_display= str(iou)+' '+filename_short
      png_filepath = input_folder+filename+'.png'
      tif_filepath = input_folder+filename+'.tif'
      img = Image.open(png_filepath)

      fig, ax = plt.subplots(figsize=(3, 3))
      ax.imshow(img)
      ax.set_title(filename_display, fontsize=8)
      ax.axis('off')
      plt.close(fig)

      save_btn = widgets.Button(description='Save', layout=widgets.Layout(width='80px'), button_style='success')

      def save_callback(b, filepath=tif_filepath, filename=filename):
        dst_path = output_folder+filename+'.tif'
        shutil.copy(filepath, dst_path)
        print(f"Saved: {filename}")
        df = pd.read_csv(manifest_path)
        df.loc[df['output_fileid'] == filename, 'approval_status'] = 'Approved'
        df.to_csv(manifest_path, index=False)

      save_btn.on_click(lambda b, f=tif_filepath, fn=filename: save_callback(b, f, fn))

      out_box = widgets.Output()
      with out_box:
        display(fig)
        centered_btn = widgets.HBox([save_btn], layout=widgets.Layout(justify_content='center'))
        vbox = widgets.VBox([out_box, centered_btn])
        widgets_list.append(vbox)
    rows = [widgets.HBox(widgets_list[i:i+4]) for i in range(0, len(widgets_list), 4)]
    display(widgets.VBox(rows))

# Next button
def on_next_uid_clicked(b):
  uid_index[0] += 1
  if uid_index[0] >= len(uids):
    output.clear_output()
    with output:
      print("No more UIDs.")
  else:
    show_uid_images(uids[uid_index[0]])

next_btn = widgets.Button(description='Next UID', icon='forward', button_style='primary')
next_btn.on_click(on_next_uid_clicked)

# Display UI
display(next_btn)
display(output)

# Show first UID
if uids:
  show_uid_images(uids[uid_index[0]])
else:
  with output:
    print("No UIDs to display.")
#show_uid_images(uids[uid_index[0]])

Button(button_style='primary', description='Next UID', icon='forward', style=ButtonStyle())

Output()

Saved: b88297fa-643c-5a51-b74c-03a8ce4ea9f6_sam2_small_point_prompt_mask_1
