In [None]:
import os

import torch
from torch import nn

os.chdir("../..")

DATA_DIR = os.environ["DATA_DIR"]
OUTPUT_DIR = os.environ["OUTPUT_DIR"]

gpus = ["0"]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpus)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def count_trainable_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def get_model_size(model: nn.Module) -> float:
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

In [None]:
sample_slide = os.path.join(DATA_DIR, "tiles/output/660060-1.svs")
tile_paths = [
    os.path.join(sample_slide, fname)
    for fname in os.listdir(sample_slide)
    if fname.endswith(".png")
]
num_tiles = len(tile_paths)

In [None]:
model_stats = {"UNI": None, "gigapath": None, "prism": None}

In [None]:
# UNI
import time
from scripts.uni_embed import load_model, run_inference


model, transform = load_model(device)
start = time.perf_counter()
inf = run_inference(tile_paths, model, transform, 128, device)
elapsed = time.perf_counter() - start

In [None]:
model_stats["UNI"] = {
    "architecture": "ViT large, patch size 16",
    "params": count_trainable_params(model),
    "model_size": f"{get_model_size(model):.2f}MB",
    "runtime": f"{(elapsed / (num_tiles / 1000)):.4f} sec/k tiles",
    "embed_dim": inf["tile_embeds"].shape[-1],
}

In [None]:
# PRISM
from scripts.prism_embed import load_model, run_inference

model, transform = load_model()
start = time.perf_counter()
inf = run_inference(tile_paths, model, transform, 128, device)
elapsed = time.perf_counter() - start

In [None]:
model_stats["prism"] = {
    "architecture": "ViT huge, patch size 14",
    "params": count_trainable_params(model),
    "model_size": f"{get_model_size(model):.2f}MB",
    "runtime": f"{(elapsed / (num_tiles / 1000)):.4f} sec/k tiles",
    "embed_dim": inf["tile_embeds"].shape[-1],
}

In [None]:
# prov-gigapath
from scripts.gigapath_embed import load_tile_encoder, run_inference

model, transform = load_tile_encoder()
start = time.perf_counter()
inf = run_inference(tile_paths, model, transform, 128, device)
elapsed = time.perf_counter() - start

In [None]:
model_stats["gigapath"] = {
    "architecture": "ViT giant, patch size 14",
    "params": count_trainable_params(model),
    "model_size": f"{get_model_size(model):.2f}MB",
    "runtime": f"{(elapsed / (num_tiles / 1000)):.4f} sec/k tiles",
    "embed_dim": inf["tile_embeds"].shape[-1],
}

In [None]:
import pandas as pd

df = pd.DataFrame(model_stats)
df