In [1]:

import os
import shutil
import argparse
import traceback
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional

from concurrent.futures import ThreadPoolExecutor, as_completed

import fiftyone as fo

from tator_tools.download_datasets import DatasetDownloader
from tator_tools.fiftyone_clustering import FiftyOneDatasetViewer

import tator

import cv2
import numpy as np
import pandas as pd

import torch
from ultralytics import YOLO
from ultralytics import RTDETR

# Custom DataDownloader (for getting select frames from media)

In [2]:
class DataDownloader:
    def __init__(self, api_token: str, project_id: int, media_ids: List[int], 
                 frame_ids_dict: Dict[int, List[int]], output_dir: str, 
                 max_workers: int = 10, max_retries: int = 10):
        """
        Initialize the DataDownloader with multiple media IDs and their corresponding frames.

        :param api_token: Tator API token for authentication
        :param project_id: Project ID in Tator
        :param media_ids: List of media IDs to process
        :param frame_ids_dict: Dictionary mapping media IDs to their frame IDs to download
        :param output_dir: Output directory for downloaded frames
        :param max_workers: Maximum number of concurrent download threads
        :param max_retries: Maximum number of retries for failed downloads
        """
        self.project_id = project_id
        self.media_ids = media_ids
        self.frames_dict = frame_ids_dict
        self.output_dir = output_dir
        self.max_workers = max_workers
        self.max_retries = max_retries
        
        # Create a single API instance for all operations
        self.api = self._authenticate(api_token)
        
        # Set up directories
        self._setup_directories()
        
        # Cache for media objects
        self.media_cache = {}
        
        # Output data
        self.output_data = None

    @staticmethod
    def _authenticate(api_token: str):
        """
        Authenticate with the Tator API.

        :param api_token: API token for authentication
        :return: Authenticated API instance
        """
        try:
            api = tator.get_api(host='https://cloud.tator.io', token=api_token)
            return api
        except Exception as e:
            raise Exception(f"ERROR: Could not authenticate with provided API Token\n{e}")

    def _setup_directories(self):
        """
        Create necessary directories for frame storage.
        """
        os.makedirs(f"{self.output_dir}/frames", exist_ok=True)

    def _get_media(self, media_id: int):
        """
        Get media object with caching to avoid redundant API calls.
        
        :param media_id: Media ID to retrieve
        :return: Media object
        """
        if media_id not in self.media_cache:
            self.media_cache[media_id] = self.api.get_media(id=int(media_id))
        return self.media_cache[media_id]

    def download_frame(self, params: tuple) -> Tuple[int, int, Optional[str]]:
        """
        Download a single frame for a given media with retry logic.

        :param params: Tuple containing (media_id, frame_id)
        :return: Tuple of (media_id, frame_id, frame_path or None if failed)
        """
        media_id, frame_id = params
        media = self._get_media(media_id)
        
        # Use absolute path for frame_path
        frame_path = os.path.abspath(f"{self.output_dir}/frames/{str(media_id)}_{str(frame_id)}.jpg")
        
        # Use absolute path for lock_path
        lock_path = f"{frame_path}.lock"
        
        # Rest of the method remains the same as before
        if os.path.exists(frame_path):
            return media_id, frame_id, frame_path
            
        if os.path.exists(lock_path):
            if os.path.getmtime(lock_path) < time.time() - 300:
                try:
                    os.remove(lock_path)
                except:
                    pass
            else:
                for _ in range(60):
                    time.sleep(1)
                    if os.path.exists(frame_path):
                        return media_id, frame_id, frame_path
                    if not os.path.exists(lock_path):
                        break
                
        try:
            with open(lock_path, 'w') as f:
                f.write(str(os.getpid()))
        except:
            time.sleep(1)
            if os.path.exists(frame_path):
                return media_id, frame_id, frame_path
        
        for attempt in range(self.max_retries):
            try:
                temp = self.api.get_frame(
                    id=media.id,
                    tile=f"{media.width}x{media.height}",
                    force_scale="1024x768",  # TODO remove hardcoding
                    frames=[int(frame_id)]
                )
                shutil.move(temp, frame_path)
                
                try:
                    os.remove(lock_path)
                except:
                    pass
                    
                return media_id, frame_id, frame_path
                
            except Exception as e:
                error_msg = f"Error downloading frame {frame_id} for media {media_id}: {e}"
                if attempt < self.max_retries - 1:
                    print(f"{error_msg}, retrying...")
                    time.sleep(2 ** attempt)
                else:
                    print(f"{error_msg}, giving up.")
        
        try:
            os.remove(lock_path)
        except:
            pass
            
        return media_id, frame_id, None

    def download_data(self) -> Dict[int, List[str]]:
        """
        Download frames for all media IDs using a single thread pool.

        :return: Dictionary mapping media IDs to lists of frame paths
        """
        # Prepare all download tasks
        all_tasks = []
        for media_id in self.media_ids:
            frames = self.frames_dict[media_id]
            for frame_id in frames:
                all_tasks.append((media_id, frame_id))
        
        results_dict = {media_id: [] for media_id in self.media_ids}
        
        # Use a single thread pool for all downloads
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = {
                executor.submit(self.download_frame, task): task 
                for task in all_tasks
            }
            
            with tqdm(total=len(all_tasks), desc="Downloading frames") as pbar:
                for future in as_completed(futures):
                    media_id, frame_id, frame_path = future.result()
                    if frame_path:  # If download was successful
                        results_dict[media_id].append(frame_path)
                    pbar.update(1)
        
        self.output_data = results_dict


In [3]:
api_token = os.getenv("TATOR_TOKEN")
project_id = 155

output_dir="../Data/NCICS/Madeline_Data/"

# Read in CSV file

In [21]:
# Read and preprocess the data
df = pd.read_csv("../Data/NCICS/MadelineIs_All_Annotations_20241113.csv")

# Only drop rows where TatorMediaID OR TatorFrame have NA values
df = df.dropna(subset=['TatorMediaID', 'TatorFrame'])

# Convert the columns to integers after removing NA values
df['TatorMediaID'] = df['TatorMediaID'].astype(int)
df['TatorFrame'] = df['TatorFrame'].astype(int)

# Create dictionary mapping media IDs to their frame lists
media_ids = df['TatorMediaID'].unique().tolist()
frame_ids_dict = {media_id: df[df['TatorMediaID'] == media_id]['TatorFrame'].tolist() for media_id in media_ids}

# Download frames from CSV

In [None]:
# Initialize downloader with multiple media IDs
downloader = DataDownloader(
    api_token=api_token,
    project_id=project_id,
    media_ids=media_ids,
    frame_ids_dict=frame_ids_dict,
    output_dir=output_dir,
    max_workers=10,
    max_retries=10,
)

# Download all frames for all media IDs
downloader.download_data()

In [25]:
frame_paths_dict = downloader.output_data

In [26]:
# Create a new dataframe with the paths
output_df = []
for media_id, group in df.groupby('TatorMediaID'):
    # Get the frame paths for this media ID
    media_frame_paths = frame_paths_dict[media_id]
    
    # Create a mapping of frame number to path
    frame_to_path = {
        int(path.split('_')[-1].replace('.jpg', '')): path 
        for path in media_frame_paths
    }
    
    # Add paths to the group
    group = group.copy()
    group['Image_Path'] = group['TatorFrame'].map(frame_to_path)
    output_df.append(group)

# Combine all groups back into a single dataframe
final_df = pd.concat(output_df, ignore_index=True)

In [None]:
final_df[['TatorMediaID', 'TatorFrame', 'Image_Path', 'Sclass', 'Ssubclass', 'Sgroup']].head(3)

In [28]:
# Create and output updated dataframe with the paths
final_df.to_csv("../Data/NCICS/Madeline_Data/MadelineIs_Modified.csv", index=False)

# Download Unlabeled AUV Data

In [4]:
# Search string comes from Tator's Data Metadata Export utility
search_string = "eyJtZXRob2QiOiJBTkQiLCJvcGVyYXRpb25zIjpbeyJhdHRyaWJ1dGUiOiJNaXNzaW9uTmFtZSIsIm9wZXJhdGlvbiI6Imljb250YWlucyIsImludmVyc2UiOmZhbHNlLCJ2YWx1ZSI6Ik1hZGVsaW5lIn0seyJtZXRob2QiOiJPUiIsIm9wZXJhdGlvbnMiOlt7ImF0dHJpYnV0ZSI6IiR0eXBlIiwib3BlcmF0aW9uIjoiZXEiLCJpbnZlcnNlIjpmYWxzZSwidmFsdWUiOjMzMX1dfV19"

# Demo for downloading labeled data
frac = 0.01

dataset_name = "Unlabeled_AUV_Data"
output_dir = "../Data/NCICS/"

In [5]:
# Create a downloader for the labeled data
downloader = DatasetDownloader(api_token,
                               project_id=project_id,
                               search_string=search_string,
                               frac=frac,
                               output_dir=output_dir,
                               dataset_name=dataset_name,
                               label_field="",
                               download_width=1024)

NOTE: Authentication successful for jordan.pierce
NOTE: Search string saved to e:\tator-tools\Data\NCICS\Unlabeled_AUV_Data\search_string.txt


In [None]:
# Download the labeled data
downloader.download_data()

NOTE: Querying Tator for labeled data


In [8]:
df = downloader.as_dataframe()  # .as_dict()

In [None]:
df.to_csv("../Data/NCICS/Unlabeled_AUV_Data/Unlabeled_AUV_Data.csv", index=False)

# Clustering

In [None]:
# Get the labeled data, subset
labeled_df = pd.read_csv("../Data/NCICS/Madeline_Data/MadelineIs_Modified.csv")
labeled_df = labeled_df[['Path', 'Sclass', 'Ssubclass', 'Sgroup']]

labeled_df.head(3)

In [None]:
# Get the unlabeled data, subset, conform
unlabeled_df = pd.read_csv("../Data/NCICS/Madeline_Data/Unlabeled_AUV_Data.csv")
unlabeled_df = unlabeled_df[['image_path']]
unlabeled_df['Path'] = unlabeled_df['image_path']
unlabeled_df['Sclass'] = "Unknown"
unlabeled_df['Ssubclass'] = "Unknown"

# Drop the image_path column
unlabeled_df.drop(columns=['image_path'], inplace=True)

unlabeled_df.head(3)

In [None]:
# Combine the labeled and unlabeled data
combined_df = pd.concat([labeled_df, unlabeled_df], ignore_index=True)
combined_df.head(3)

In [None]:
# Calculate custom embeddings
model_weights = "E:\\tator-tools\\Data\\Runs\\2024-06-26_20-47-34_detect_yolov10m\\weights\\best.pt"
# Load the model
model = YOLO(model_weights)

try:
    imgsz = self.loaded_model.__dict__['overrides']['imgsz']
except:
    imgsz = 640
    
# Get the device
device='cuda' if torch.cuda.is_available() else 'cpu'
print(f"NOTE: Using device {device}")

In [None]:
# Get the embeddings
embeddings = model.embed(combined_df['Path'].tolist(), imgsz=imgsz, stream=True, device=device)

embeddings_list = []

# Use the length of combined_df as the total for tqdm
total_items = len(combined_df)
for e in tqdm(embeddings, total=total_items, desc="Processing embeddings"):
    embeddings_list.append(e.cpu().numpy())
    torch.cuda.empty_cache()
    
embeddings = np.array(embeddings_list)
embeddings.shape

In [12]:
# Clear cuda cache
torch.cuda.empty_cache()

In [34]:
# Initialize the viewer with the path to the directory containing images
viewer = FiftyOneDatasetViewer(dataframe=combined_df,
                               image_path_column='Path',
                               feature_columns=['Sclass', 'Ssubclass'],
                               nickname='MadelineIs',
                               custom_embeddings=embeddings)  # Pass the embeddings, or None

In [None]:
# Process the dataset to create the FiftyOne dataset and generate the UMAP visualization
viewer.process_dataset()

In [None]:
# Launch the FiftyOne app
try:
    session = fo.launch_app(viewer.dataset)
except:
    # Weird behavior in notebook
    session = fo.launch_app(viewer.dataset)