<a href="https://colab.research.google.com/github/qinliuliuqin/iSegFormer/blob/main/notebooks/colab_test_isegformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Clone iSegFormer and install dependencies.
# It may took 16m to run this cell.
!git clone https://github.com/qinliuliuqin/iSegFormer.git

import os
os.chdir('/content/iSegFormer')
!pip install -r ./requirements.txt

In [6]:
# Download OAI-ZIB test set
URL_ISEGFORMER = "https://github.com/qinliuliuqin/iSegFormer/releases/download/v0.1"
DATA_FOLDER = "./datasets" 

!mkdir -p {DATA_FOLDER}

for dataset in ['OAI-ZIB-test']:
  dataset_url = f"{URL_ISEGFORMER}/{dataset}.zip"
  dataset_path = f"{DATA_FOLDER}/{dataset}.zip"
  !wget -q -O {dataset_path} {dataset_url}
  !unzip -q {dataset_path} -d {DATA_FOLDER}
  !rm {dataset_path}

# Download weights 
WEIGHTS_FOLDER = "./weights"
!mkdir -p {WEIGHTS_FOLDER}

MODEL_NAME_SWINB = "imagenet21k_pretrain_oaizib_finetune_swin_base_epoch_54"
MODEL_NAME_HR32 = "oai_pretrain_oaizib_finetune_hr32_epoch_54"

WEIGHTS_URL_SWINB = f"{URL_ISEGFORMER}/{MODEL_NAME_SWINB}.pth"
WEIGHTS_URL_HR32 = f"{URL_ISEGFORMER}/{MODEL_NAME_HR32}.pth"
!wget -q -P {WEIGHTS_FOLDER} {WEIGHTS_URL_SWINB}
!wget -q -P {WEIGHTS_FOLDER} {WEIGHTS_URL_HR32}

In [18]:
%matplotlib inline
import matplotlib.pyplot as plt

import sys
import torch
import numpy as np
from google.colab import drive

sys.path.insert(0, './')

from isegm.utils import vis, exp
from isegm.inference import utils
from isegm.inference.evaluation import evaluate_dataset, evaluate_sample
from isegm.inference.predictors import get_predictor

device = torch.device('cuda:0')
cfg = exp.load_config_file('./config.yml', return_edict=True)

EVAL_MAX_CLICKS = 20
MODEL_THRESH = 0.49
brs_mode = 'NoBRS'
TARGET_IOU = 0.9

model_path_swinb = utils.find_checkpoint("./weights", f"{MODEL_NAME_SWINB}.pth")
model_path_hr32 = utils.find_checkpoint("./weights", f"{MODEL_NAME_HR32}.pth")
model_swinb = utils.load_is_model(model_path_swinb, device)
model_hr32 = utils.load_is_model(model_path_hr32, device)

DATASET = 'OAIZIB'
dataset = utils.get_dataset(DATASET, cfg)

for model in [model_swinb, model_hr32]:
  predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)

  all_ious, elapsed_time = evaluate_dataset(dataset, predictor, pred_thr=MODEL_THRESH, 
                                            max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)
  mean_spc, mean_spi = utils.get_time_metrics(all_ious, elapsed_time)
  noc_list, over_max_list = utils.compute_noc_metric(all_ious,
                                                    iou_thrs=[0.8, 0.85, 0.9],
                                                    max_clicks=EVAL_MAX_CLICKS)

  header, table_row = utils.get_results_table(noc_list, over_max_list, brs_mode, DATASET,
                                              mean_spc, elapsed_time, EVAL_MAX_CLICKS)
  print(header)
  print(table_row)

  0%|          | 0/150 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------------
|  BRS Type   |  Dataset  | NoC@80% | NoC@85% | NoC@90% |>=20@85% |>=20@90% | SPC,s |  Time   |
-----------------------------------------------------------------------------------------------
|    NoBRS    |  OAIZIB   |  7.25   |  11.65  |  17.03  |   68    |   115   | 0.180 | 0:07:40 |


  0%|          | 0/150 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------------
|  BRS Type   |  Dataset  | NoC@80% | NoC@85% | NoC@90% |>=20@85% |>=20@90% | SPC,s |  Time   |
-----------------------------------------------------------------------------------------------
|    NoBRS    |  OAIZIB   |  7.93   |  12.81  |  19.36  |   79    |   142   | 0.147 | 0:07:06 |
