# Imports

In [None]:
import os
import queue
from concurrent.futures import ThreadPoolExecutor, as_completed
from google.cloud import storage
from google.colab import auth 
import h5py 
import numpy as np 
import time #

from extract_vision_features import process_single_wsi, load_config

# Variables 

In [None]:
# --- Configuration (UPDATE THESE VALUES) ---
GCP_PROJECT_ID = "your-gcp-project-id" # <--- IMPORTANT: REPLACE WITH YOUR GCP PROJECT ID
GCS_BUCKET_NAME = "histo-bench"
GCS_WSI_FOLDER = "TCGA-LGG/wsi/"
GCS_COORDINATES_FOLDER = "TCGA-LGG/coordinates/" 

LOCAL_WSI_DOWNLOAD_DIR = "/content/wsi_temp_downloads/"
LOCAL_COORDINATES_DIR = "/content/coordinates/"

CONFIG_PATH = ""

BATCH_SIZE_WSI_LISTING = 5 
MAX_DOWNLOAD_WORKERS = 3 # Number of concurrent downloads (adjust based on network)

# Authentication 

In [None]:
# --- 0. Colab Authentication and Project Setup ---
print("--- Authenticating Google Colab and Setting GCP Project ---")
try:
    auth.authenticate_user()
    print("Colab authenticated.")
    os.environ['GCLOUD_PROJECT'] = GCP_PROJECT_ID
    !gcloud config set project {GCP_PROJECT_ID}
except Exception as e:
    print(f"Authentication or project setup failed: {e}")
    exit()

In [None]:
# --- 1. Download Patch Coordinates ---
print(f"\n--- Downloading Patch Coordinate Files from {GCS_BUCKET_NAME}/{GCS_COORDINATES_FOLDER} ---")
os.makedirs(LOCAL_COORDINATES_DIR, exist_ok=True)
!gcloud storage cp gs://{GCS_BUCKET_NAME}/{GCS_COORDINATES_FOLDER}*.h5 {LOCAL_COORDINATES_DIR}
print(f"Coordinate files downloaded to: {LOCAL_COORDINATES_DIR}")

# --- Helper Functions ---
def get_wsi_batches_from_gcs(bucket_name: str, gcs_folder_path: str, batch_size: int = 5):
    """
    Lists .svs files from a specified GCS folder path and yields them in batches.
    """
    # Ensure gcs_folder_path ends with a slash if it's not empty, for proper prefix matching
    if gcs_folder_path and not gcs_folder_path.endswith('/'):
        gcs_folder_path += '/'

    print(f"Listing files in bucket: {bucket_name}, folder: {gcs_folder_path}")

    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)

    all_wsi_blobs = [
        blob for blob in bucket.list_blobs(prefix=gcs_folder_path)
        if blob.name.endswith('.svs')
    ]

    if not all_wsi_blobs:
        print(f"No .svs files found in gs://{bucket_name}/{gcs_folder_path}. Exiting.")
        return # Exit if no files are found

    print(f"Found {len(all_wsi_blobs)} WSI files.")
    # Sort blobs by name for consistent batching across runs
    all_wsi_blobs.sort(key=lambda blob: blob.name)

    current_batch = []
    for blob in all_wsi_blobs:
        full_gcs_path = f"gs://{bucket_name}/{blob.name}"
        current_batch.append(full_gcs_path)
        if len(current_batch) == batch_size:
            yield current_batch
            current_batch = []
    # Yield any remaining files in the last batch
    if current_batch:
        yield current_batch

def download_single_svs_file(bucket_name: str, gcs_file_path: str, local_download_dir: str, download_queue: queue.Queue):
    """
    Downloads a single .svs file from GCS and puts its local path into a queue.
    This function is designed to be run in a separate thread.
    """
    try:
        os.makedirs(local_download_dir, exist_ok=True)
        storage_client = storage.Client()
        bucket = storage_client.bucket(bucket_name)

        blob_name = gcs_file_path.replace(f"gs://{bucket_name}/", "")
        blob = bucket.blob(blob_name)
        destination_file_name = os.path.join(local_download_dir, os.path.basename(blob.name))

        print(f"[DOWNLOAD] Starting download of {blob.name} to {destination_file_name}")
        blob.download_to_filename(destination_file_name)
        print(f"[DOWNLOAD] Finished download of {blob.name}")

        download_queue.put(destination_file_name)
    except Exception as e:
        print(f"[ERROR] Error downloading {gcs_file_path}: {e}")
        download_queue.put(None) # Signal failure for this specific download

def get_coordinate_file_path(wsi_filename: str, local_coordinates_dir: str):
    """
    Determines the expected local path of the .h5 coordinate file for a given WSI filename.
    Assumes coordinate files have the same base name as WSI but with .h5 extension.
    E.g., TCGA-XX-YYYY.svs -> TCGA-XX-YYYY.h5
    """
    base_name = os.path.splitext(os.path.basename(wsi_filename))[0] # Get filename without .svs from full path
    coord_filename = f"{base_name}.h5"
    full_coord_path = os.path.join(local_coordinates_dir, coord_filename)
    return full_coord_path

# --- Main processing loop structure ---
def main_batch_processor(
    bucket_name: str,
    gcs_wsi_folder: str,
    local_wsi_download_dir: str,
    batch_size_wsi_listing: int = 5,
    max_download_workers: int = 3
):
    """
    Main function to orchestrate concurrent downloading and processing of WSI files.
    """
    os.makedirs(local_wsi_download_dir, exist_ok=True)
    download_queue = queue.Queue() # Queue to hold paths of downloaded files

    print("\n--- Starting Concurrent WSI Processing ---")

    # Get total number of files to process for completion check
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    total_files_to_process = len([
        blob for blob in bucket.list_blobs(prefix=gcs_wsi_folder)
        if blob.name.endswith('.svs')
    ])
    print(f"Total WSI files identified for processing: {total_files_to_process}")

    # Use a ThreadPoolExecutor for parallel downloads
    with ThreadPoolExecutor(max_workers=max_download_workers) as download_executor:
        download_futures = [] # To keep track of submitted download tasks

        # Submit all download tasks concurrently in the background
        for batch_gcs_paths in get_wsi_batches_from_gcs(bucket_name, gcs_wsi_folder, batch_size_wsi_listing):
            for gcs_path in batch_gcs_paths:
                future = download_executor.submit(
                    download_single_svs_file,
                    bucket_name,
                    gcs_path,
                    local_wsi_download_dir,
                    download_queue
                )
                download_futures.append(future)

        processed_count = 0
        # Main thread loop: continuously try to get files from the queue and process them
        while processed_count < total_files_to_process:
            try:
                # Get a downloaded file path from the queue (with a timeout to prevent infinite blocking)
                # If queue is empty and downloads are still running, it will wait
                local_wsi_file_path = download_queue.get(timeout=100) # Increased timeout
                if local_wsi_file_path is None: # Handle potential download errors
                    print("[MAIN_LOOP] Skipping processing due to a previous download error.")
                    processed_count += 1
                    continue

                print(f"\n[MAIN_LOOP] Processing local file: {local_wsi_file_path}")

                wsi_filename = os.path.basename(local_wsi_file_path)                
                config = load_config(CONFIG_PATH)
                process_single_wsi(wsi_filename, config)
                
                # Remove the local WSI data after processing
                if os.path.exists(local_wsi_file_path):
                    os.remove(local_wsi_file_path)
                    print(f"[CLEANUP] Removed local WSI file: {local_wsi_file_path}")
                else:
                    print(f"[CLEANUP] WSI file not found for cleanup: {local_wsi_file_path}")

                processed_count += 1
            except queue.Empty:
                # Check if all downloads are truly finished and queue is empty, then break
                if all(f.done() for f in download_futures) and download_queue.empty():
                    print("[MAIN_LOOP] All downloads complete and queue is empty. Exiting processing loop.")
                    break
                print("[MAIN_LOOP] Queue is temporarily empty, waiting for more downloads or completion...")
                time.sleep(1) # Wait a bit before retrying to avoid busy-waiting
            except Exception as e:
                print(f"[MAIN_LOOP] An unexpected error occurred during processing: {e}")
                processed_count += 1 # Ensure loop progresses even on errors

    print("\n--- All WSI files processed (or attempted). ---")


# EXECUTION

In [None]:
if __name__ == '__main__':
    os.makedirs(LOCAL_WSI_DOWNLOAD_DIR, exist_ok=True)
    os.makedirs(LOCAL_COORDINATES_DIR, exist_ok=True)

    main_batch_processor(
        bucket_name=GCS_BUCKET_NAME,
        gcs_wsi_folder=GCS_WSI_FOLDER,
        local_wsi_download_dir=LOCAL_WSI_DOWNLOAD_DIR,
        batch_size_wsi_listing=BATCH_SIZE_WSI_LISTING,
        max_download_workers=MAX_DOWNLOAD_WORKERS
    )


# Preprocess Manifest and Labels

In [71]:
import pandas as pd
path ="/Users/bakhtierzhon.pashshoev/Downloads/gdc_download_20250711_164811.924176/f3a1bc62-9552-4553-b318-7d9c21d21ce7/nationwidechildrens.org_clinical_patient_lgg.txt"
patient_data = pd.read_csv(path, sep="\t")

# Read and select the necessary columns
patient_data = patient_data.iloc[2:]
selected_columns = ["bcr_patient_barcode", "histologic_diagnosis", "tumor_grade"]
patient_data = patient_data[selected_columns].dropna()
patient_data.rename(columns={"bcr_patient_barcode": "barcode"}, inplace=True)

# Read the manifest file
manifest = pd.read_csv("/Users/bakhtierzhon.pashshoev/Desktop/Thesis/TCGA-LGG/gdc_manifest.txt", sep="\t")
manifest["barcode"] = manifest["filename"].apply(lambda x: "-".join(x.split("-")[:3]))

# Combine the manifest and patient data
combined = manifest.merge(patient_data, on="barcode", how="left")
combined = combined[["barcode", "filename", "histologic_diagnosis", "tumor_grade", "size"]]
combined = combined.dropna()

# Keep the largest file for each barcode
combined = combined.sort_values(by=["barcode", "size"], ascending=[True, False])
combined = combined.drop_duplicates(subset="barcode", keep="first")

# Add slide_id column
combined = combined.assign(slide_id=combined["filename"].apply(lambda x: ".".join(x.split(".")[:-1])))

# Shuffle the dataframe
combined = combined.sample(frac=1)

# Filter manifest to only include slides from combined dataframe
filtered_manifest = manifest[manifest["filename"].isin(combined["filename"])]
filtered_manifest = filtered_manifest.drop(columns=["barcode"])

# Encode the labels for training metadata
training_metadata = combined[["slide_id", "histologic_diagnosis", "tumor_grade", "barcode"]]
unique_labels = training_metadata['histologic_diagnosis'].unique()
label_to_index = {label: i for i, label in enumerate(unique_labels)}
training_metadata = training_metadata.assign(label=training_metadata['histologic_diagnosis'].map(label_to_index))

# Save the training metadata
training_metadata.to_csv("training_metadata.csv", index=False)

# Save the filtered manifest for downstream tasks
filtered_manifest.to_csv("filtered_manifest.txt", index=False, sep="\t")

In [53]:
import sys
sys.path.append("/Users/bakhtierzhon.pashshoev/Cursor/histo-bench")
from scripts.models.vision.resnet import ResNetEncoder

encoder = ResNetEncoder(device="mps")


In [None]:
encoder.get_summary

{'input_shape': (1, 3, 224, 224),
 'output_shape': (1, 2048, 1, 1),
 'total_parameters': 23508032}