<div>
<img src="https://github.com/fsemerar/SlicerTomoSAM/raw/main/TomoSAM/Resources/Media/tomosam_logo.png" width="300"/>
</div>

This notebook helps with the generation of the image embeddings for all the slices of your tiff stack along the three Cartesian directions. You can create the embeddings by running this notebook either locally or on Colab. A GPU is recommended for this step to speed up the process; in Colab, make sure to select `Runtime`→`Change runtime type` and set the `Hardware accelerator` to GPU. Locally, you will first need to create the conda environment, as shown in the README.

In [1]:
from google.colab import output
output.enable_custom_widget_manager()
!pip install https://github.com/facebookresearch/segment-anything/archive/6fdee8f2727f4506cfbbe553e23b895e27956588.zip

Collecting https://github.com/facebookresearch/segment-anything/archive/6fdee8f2727f4506cfbbe553e23b895e27956588.zip
  Downloading https://github.com/facebookresearch/segment-anything/archive/6fdee8f2727f4506cfbbe553e23b895e27956588.zip
[2K     [32m\[0m [32m19.2 MB[0m [31m29.8 MB/s[0m [33m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=7687172a456e93c8dd4c5259cce4d3306d3443cc224256b730b1ebf29177f8be
  Stored in directory: /tmp/pip-ephem-wheel-cache-vfq0h1p4/wheels/f7/85/24/6c615ef5d04ea1f87f1a717cd18e472ef44962e9ad28d07b69
Successfully built segment_anything
Installing collected packages: segment_anything
Successfully installed segment_anything-1.0


In [2]:
# Download weights for SAM
![ ! -f "sam_vit_h_4b8939.pth" ] && wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2025-02-18 08:14:13--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.226.210.15, 13.226.210.78, 13.226.210.111, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.226.210.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2564550879 (2.4G) [binary/octet-stream]
Saving to: ‘sam_vit_h_4b8939.pth’


2025-02-18 08:14:32 (129 MB/s) - ‘sam_vit_h_4b8939.pth’ saved [2564550879/2564550879]



In [3]:
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
import torch
import sys, os
import pickle
def create_embeddings(img_input_filepath, output_filepath, sam_checkpoint_path):

    check, img = cv2.imreadmulti(img_input_filepath)
    img = np.array(img)
    if not check:
        raise Exception("Image file not found.")
    elif img.ndim > 3 or img.ndim < 2:
        raise Exception("Unsupported image type.")
    elif img.ndim == 2:
        img = img[:, :, np.newaxis]

    print(f"Image dimensions: {img.shape}")

    sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint_path)
    if torch.cuda.is_available():
        sam.to(device="cuda")
    predictor = SamPredictor(sam)

    embeddings = [[], [], []]
    slice_direction = ['x', 'y', 'z']
    for i, d in enumerate(slice_direction):
        print(f"\nSlicing along {d} direction")
        for k in range(img.shape[i]):
            if i == 0:
                img_slice = img[k]
            elif i == 1:
                img_slice = img[:, k]
            else:
                img_slice = img[:, :, k]
            sys.stdout.write(f"\rCreating embedding for {k + 1}/{img.shape[i]} image")
            predictor.reset_image()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            predictor.set_image(np.repeat(img_slice[:, :, np.newaxis], 3, axis=2))
            embeddings[i].append({'original_size': predictor.original_size,
                                  'input_size': predictor.input_size,
                                  'features': predictor.features.to('cpu')})

    with open(output_filepath + ".pkl", 'wb') as f:
        pickle.dump(embeddings, f)
        print(f"\nSaved {output_filepath}.pkl")

In [None]:
from google.colab import files
img_filename = list(files.upload().keys())[0]

In [None]:
# for local use
img_filename = ""

In [None]:
create_embeddings(img_filename, os.path.splitext(img_filename)[0], "sam_vit_h_4b8939.pth")

In [None]:
# Download from Colab
files.download(img_filename + ".pkl")