In [None]:
import os
hf_token = os.getenv("HF_TOKEN")
assert hf_token, "HF_TOKEN is not set"

In [12]:
# Authenticate to access Google Cloud resources
from google.colab import auth
auth.authenticate_user()


In [2]:
import timm
import torch

tile_encoder = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
tile_encoder = tile_encoder.cuda().eval()

print("✅ Tile Encoder loaded.")
print("🧮 Total parameters:", sum(p.numel() for p in tile_encoder.parameters()))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/829 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

✅ Tile Encoder loaded.
🧮 Total parameters: 1134953984


In [36]:
import os
import subprocess
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pandas as pd
import numpy as np
import timm


tile_transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# --------------------------------------------
# 2. Define Dataset
# --------------------------------------------
class TileEncodingDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        x, y = Path(img_path).stem.split("_")
        x = int(x.replace("x", ""))
        y = int(y.replace("y", ""))
        return {"img": img, "coords": torch.tensor([x, y], dtype=torch.float)}

# --------------------------------------------
# 3. Define helper function for one WSI
# --------------------------------------------
def process_one_slide(split, group, subtype, slide_id, gcs_bucket="bracs-dataset-bucket"):
    print(f"🔍 Processing WSI: {slide_id} under {group}/{subtype}")

    gcs_slide_path = f"gs://{gcs_bucket}/Tiles/{split}/{group}/{subtype}/{slide_id}/"
    local_slide_dir = Path("/content/tiles") / split / group / subtype / slide_id

    # ⚡ 1. Download output/ + dataset.csv from GCS
    print(f"🔽 Downloading {gcs_slide_path}dataset.csv to {local_slide_dir}")
    subprocess.run([
        "gsutil", "cp",
        f"{gcs_slide_path}dataset.csv",
        str(local_slide_dir / "dataset.csv")
    ], check=True)

    print(f"🔽 Downloading {gcs_slide_path}output/ to {local_slide_dir}")
    subprocess.run([
        "gsutil", "-m", "cp", "-r",
        f"{gcs_slide_path}output",
        str(local_slide_dir)
    ], check=True)


    tile_dir = local_slide_dir / "output"
    dataset_csv_path = local_slide_dir / "dataset.csv"

    if not tile_dir.exists() or not dataset_csv_path.exists():
        print(f"⚠️ Missing tiles or dataset.csv for {slide_id}. Skipping.")
        return

    tile_paths = sorted([str(p) for p in tile_dir.glob("*.png")])
    if len(tile_paths) == 0:
        print(f"⚠️ No tiles found for {slide_id}. Skipping.")
        return

    print(f"📈 Found {len(tile_paths)} tiles.")

    # ⚡ 2. Encode tiles
    dataset = TileEncodingDataset(tile_paths, transform=tile_transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)

    all_embeddings = []
    all_coords = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"🚀 Encoding {slide_id}"):
            imgs = batch['img'].cuda()
            coords = batch['coords']
            embeds = tile_encoder(imgs)
            all_embeddings.append(embeds.cpu())
            all_coords.append(coords)

    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_coords = torch.cat(all_coords, dim=0)

    print(f"✅ Finished encoding {slide_id}: {all_embeddings.shape}")

    # ⚡ 3. Save locally
    output_save_dir = Path("/content/embeddings") / split / group / subtype / slide_id
    output_save_dir.mkdir(parents=True, exist_ok=True)

    torch.save({
        "embeddings": all_embeddings,
        "coords": all_coords
    }, output_save_dir / "tile_embeddings.pt")

    print(f"💾 Saved embeddings to {output_save_dir/'tile_embeddings.pt'}")

    # ⚡ 4. Upload embeddings to GCS
    gcs_target = f"gs://{gcs_bucket}/Embeddings/{split}/{group}/{subtype}/{slide_id}"
    print(f"🚀 Uploading to {gcs_target}")
    subprocess.run(["gsutil", "-m", "cp", "-r", str(output_save_dir), gcs_target], check=True)

    print(f"✅ Upload complete for {slide_id}")

# --------------------------------------------
# 4. Process All Slides Under a Split
# --------------------------------------------
def process_all_slides(split="train", gcs_bucket="bracs-dataset-bucket"):
    base_gcs_path = f"gs://{gcs_bucket}/Tiles/{split}/"

    # 1. List groups (Group_AT, Group_BT, etc.)
    groups = subprocess.check_output(["gsutil", "ls", base_gcs_path]).decode("utf-8").strip().split("\n")
    groups = [Path(g).name for g in groups if g.strip() != ""]
    print(f"📂 Found groups: {groups}")

    for group in groups:
        print(f"📂 Checking group: {group}")
        group_gcs_path = f"{base_gcs_path}{group}/"

        # 2. List subtypes
        subtypes = subprocess.check_output(["gsutil", "ls", group_gcs_path]).decode("utf-8").strip().split("\n")
        subtypes = [Path(s).name for s in subtypes if s.strip() != ""]

        for subtype in subtypes:
            print(f"📂 Checking subtype: {subtype}")
            subtype_gcs_path = f"{group_gcs_path}{subtype}/"

            # 3. List WSIs
            wsis = subprocess.check_output(["gsutil", "ls", subtype_gcs_path]).decode("utf-8").strip().split("\n")
            wsis = [Path(w).name for w in wsis if w.strip() != ""]

            for wsi in wsis:
                slide_id = Path(wsi).name
                try:
                    process_one_slide(
                        split=split,
                        group=group,
                        subtype=subtype,
                        slide_id=slide_id,
                        gcs_bucket=gcs_bucket
                    )
                except Exception as e:
                    print(f"❌ Error processing {slide_id}: {e}")
                    continue


In [43]:
process_all_slides(split="train", gcs_bucket="bracs-dataset-bucket")


📂 Found groups: ['Group_AT']
📂 Checking group: Group_AT
📂 Checking subtype: Type_ADH
🔍 Processing WSI: BRACS_1379 under Group_AT/Type_ADH
🔽 Downloading gs://bracs-dataset-bucket/Tiles/train/Group_AT/Type_ADH/BRACS_1379/dataset.csv to /content/tiles/train/Group_AT/Type_ADH/BRACS_1379
🔽 Downloading gs://bracs-dataset-bucket/Tiles/train/Group_AT/Type_ADH/BRACS_1379/output/ to /content/tiles/train/Group_AT/Type_ADH/BRACS_1379
📈 Found 619 tiles.


🚀 Encoding BRACS_1379: 100%|██████████| 5/5 [00:16<00:00,  3.28s/it]


✅ Finished encoding BRACS_1379: torch.Size([619, 1536])
💾 Saved embeddings to /content/embeddings/train/Group_AT/Type_ADH/BRACS_1379/tile_embeddings.pt
🚀 Uploading to gs://bracs-dataset-bucket/Embeddings/train/Group_AT/Type_ADH/BRACS_1379
✅ Upload complete for BRACS_1379
🔍 Processing WSI: BRACS_1892 under Group_AT/Type_ADH
🔽 Downloading gs://bracs-dataset-bucket/Tiles/train/Group_AT/Type_ADH/BRACS_1892/dataset.csv to /content/tiles/train/Group_AT/Type_ADH/BRACS_1892
🔽 Downloading gs://bracs-dataset-bucket/Tiles/train/Group_AT/Type_ADH/BRACS_1892/output/ to /content/tiles/train/Group_AT/Type_ADH/BRACS_1892
📈 Found 937 tiles.


🚀 Encoding BRACS_1892: 100%|██████████| 8/8 [00:24<00:00,  3.07s/it]


✅ Finished encoding BRACS_1892: torch.Size([937, 1536])
💾 Saved embeddings to /content/embeddings/train/Group_AT/Type_ADH/BRACS_1892/tile_embeddings.pt
🚀 Uploading to gs://bracs-dataset-bucket/Embeddings/train/Group_AT/Type_ADH/BRACS_1892
✅ Upload complete for BRACS_1892


In [42]:
import torch

# Load
data = torch.load("/content/embeddings/train/Group_AT/Type_ADH/BRACS_1892/tile_embeddings.pt")

# Extract embeddings and coordinates
tile_embeds = data['embeddings']
tile_coords = data['coords']

print("✅ Tile embeddings shape:", tile_embeds.shape)
print("✅ Tile coords shape:", tile_coords.shape)

✅ Tile embeddings shape: torch.Size([937, 1536])
✅ Tile coords shape: torch.Size([937, 2])
