### SAT latency for the SegNext model (sax1)

In [3]:
import sys
sys.path.insert(0, '../segnext')
from time import time
import torch

from isegm.inference import utils
from isegm.data.datasets import DavisDataset
from isegm.inference.predictor import BasePredictor

model = utils.load_is_model(
    checkpoint='../weights/vitb_sa1_cocolvis_epoch_90.pth', 
    device=torch.device('cuda:0'))
model.cpu_dist_maps = True
predictor = BasePredictor(model)

dataset_path = '../data/DAVIS345/'
dataset = DavisDataset(dataset_path)
sample = dataset.get_sample(0)
image = sample.image

start_time = time()
with torch.no_grad():
    predictor.set_image(image)
    preds = predictor.predict_sat(points_per_side=16)
end_time = time()

print(f'SAT Latency: {end_time - start_time} (s)')

SAT Latency: 13.298871278762817 (s)


### SAT latency for the SegNext model (sax2)

In [2]:
import sys
sys.path.insert(0, '../segnext')
from time import time
import torch

from isegm.inference import utils
from isegm.data.datasets import DavisDataset
from isegm.inference.predictor import BasePredictor

model = utils.load_is_model(
    checkpoint='../weights/vitb_sa2_cocolvis_hq44k_epoch_0.pth', 
    device=torch.device('cuda:0'))
model.cpu_dist_maps = True
predictor = BasePredictor(model)

dataset_path = '../data/DAVIS345/'
dataset = DavisDataset(dataset_path)
sample = dataset.get_sample(0)
image = sample.image

start_time = time()
with torch.no_grad():
    predictor.set_image(image)
    preds = predictor.predict_sat(points_per_side=16)
end_time = time()

print(f'SAT Latency: {end_time - start_time} (s)')

SAT Latency: 17.585328340530396 (s)


### SAT latency for the SAM model

In [3]:
# The user needs install SAM before running the code

import torch
from time import time
from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
import numpy as np
import sys
sys.path.insert(0, '../segnext')

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from isegm.data.datasets import  DavisDataset

dataset_path = '../data/DAVIS345/'
dataset = DavisDataset(dataset_path)
sample = dataset.get_sample(0)
image = sample.image

sam_path = '../weights/sam_vit_b_01ec64.pth'
sam = sam_model_registry["vit_b"](checkpoint=sam_path)
device = torch.device('cuda:0')
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=16,
    points_per_batch=1
)
start_time = time()
with torch.no_grad():
    masks = mask_generator.generate(image)
end_time = time()

print(f'SAT Latency: {end_time - start_time} (s)')

SAT Latency: 7.002699375152588 (s)


### SAT latency for HQ-SAM

In [4]:
# The user needs install SAM before running the code

import torch
from time import time
from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
import numpy as np
import sys
sys.path.insert(0, '../segnext')

from segment_anything_hq import sam_model_registry, SamAutomaticMaskGenerator
from isegm.data.datasets import  DavisDataset

dataset_path = '../data/DAVIS345/'
dataset = DavisDataset(dataset_path)
sample = dataset.get_sample(0)
image = sample.image

sam_path = '../weights/sam_hq_vit_b.pth'
sam = sam_model_registry["vit_b"](checkpoint=sam_path)
device = torch.device('cuda:0')
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=16,
    points_per_batch=1
)
start_time = time()
with torch.no_grad():
    masks = mask_generator.generate(image)
end_time = time()

print(f'SAT Latency: {end_time - start_time} (s)')

<All keys matched successfully>
SAT Latency: 8.309977293014526 (s)


### SAT latency for MobileSAM

In [5]:
# The user needs install SAM before running the code

import torch
from time import time
from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
import numpy as np
import sys
sys.path.insert(0, '../segnext')

from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator
from isegm.data.datasets import  DavisDataset

dataset_path = '../data/DAVIS345/'
dataset = DavisDataset(dataset_path)
sample = dataset.get_sample(0)
image = sample.image

sam_path = '../weights/mobile_sam.pt'
sam = sam_model_registry["vit_t"](checkpoint=sam_path)
device = torch.device('cuda:0')
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=16,
    points_per_batch=1
)
start_time = time()
with torch.no_grad():
    masks = mask_generator.generate(image)
end_time = time()

print(f'SAT Latency: {end_time - start_time} (s)')

SAT Latency: 6.589253902435303 (s)
