In [31]:
import os
import sys
import pickle

import numpy as np
import torch
from torch.multiprocessing import set_start_method
from torch.utils.data import DataLoader, DistributedSampler

# 3DETR codebase specific imports
from datasets import build_dataset_base
from engine import evaluate, train_one_epoch, evaluate_incremental
from models import build_model
from optimizer import build_optimizer
from criterion import build_criterion
from utils.dist import init_distributed, is_distributed, is_primary, get_rank, barrier
from utils.misc import my_worker_init_fn
from utils.io import save_checkpoint, resume_if_possible, resume_if_possible_SDCoT
from utils.logger import Logger
from utils.ap_calculator import APCalculator, get_ap_config_dict, parse_predictions_SDCoT, get_ap_config_dict

In [32]:
def test_model(args, model, model_no_ddp, criterion, dataset_config, dataloaders):
    if args.test_ckpt is None or not os.path.isfile(args.test_ckpt):
        f"Please specify a test checkpoint using --test_ckpt. Found invalid value {args.test_ckpt}"
        sys.exit(1)

    sd = torch.load(args.test_ckpt, map_location=torch.device("cpu"))
    model_no_ddp.load_state_dict(sd["model"])
    logger = Logger()
    criterion = None  # do not compute loss for speed-up; Comment out to see test loss
    epoch = -1
    curr_iter = 0
    ap_calculator = evaluate(
        args,
        epoch,
        model,
        criterion,
        dataset_config,
        dataloaders["test"],
        logger,
        curr_iter,
    )
    metrics = ap_calculator.compute_metrics()
    metric_str = ap_calculator.metrics_to_str(metrics)
    if is_primary():
        print("==" * 10)
        print(f"Test model; Metrics {metric_str}")
        print("==" * 10)

In [33]:
class TempArgs:
    def __init__(self) -> None:
        self.dataset_name = 'scannet'
        self.num_base_class = 9
        self.num_novel_class = 9
        self.dataset_root_dir = None
        self.meta_data_dir = None
        self.use_color = False
        self.seed = 42
        # self.checkpoint_dir = 'ckpts_scannet/debug_test_notebook'
        # self.checkpoint_name = 'checkpoint_best_6480.pth'
        self.test_ckpt = 'ckpts_scannet/debug_test_notebook/checkpoint_best_6480.pth'
        self.enc_dim = 256
        self.dec_dim = 512
        self.nqueries = 256
        self.mlp_dropout = 0.3
        self.model_name = '3detr'
        self.preenc_npoints = 2048
        self.enc_type = 'masked'
        self.enc_nhead = 4
        self.enc_ffn_dim = 128
        self.enc_dropout = 0.1
        self.enc_activation = 'relu'
        self.enc_nlayers = 3

        # define for the decoder
        self.dec_nhead = 4
        self.dec_ffn_dim = 256
        self.dec_dropout = 0.1
        self.dec_nlayers = 8
        self.dec_dim = 256

        # criterion
        self.matcher_cls_cost = 1
        self.matcher_giou_cost = 2
        self.batchsize_per_gpu = 16
        self.dataset_num_workers = 0
        self.log_every = 10

_args = TempArgs()
print(_args.matcher_cls_cost)

1


In [34]:
torch.cuda.set_device(0)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

datasets, dataset_config = build_dataset_base(_args)
# model, _ = build_model(_args, dataset_config)
# model = model.cuda()
# model_no_ddp = model

ap_calculator = APCalculator(dataset_config=dataset_config,
        ap_iou_thresh=[0.25, 0.5],
        class2type_map=dataset_config.class2type,
        exact_eval=True)

# resume_if_possible(
#     checkpoint_dir=_args.checkpoint_dir, model_no_ddp=model, optimizer=None, checkpoint_name=_args.checkpoint_name
# )

dataloaders = {}
split = "test"
sampler = torch.utils.data.SequentialSampler(datasets[split])

dataloaders[split] = DataLoader(
    datasets[split],
    sampler=sampler,
    batch_size=_args.batchsize_per_gpu,
    num_workers=_args.dataset_num_workers,
    worker_init_fn=my_worker_init_fn,
)
dataloaders[split + "_sampler"] = sampler

criterion = None  # faster evaluation

kept 1199 scans out of 1201
kept 312 scans out of 312


In [None]:
# Set class threshold

In [35]:
test_model(_args, model, model_no_ddp, criterion, dataset_config, dataloaders)

Evaluate ; Batch [0/20];  Iter time 4.91; Mem 5217.60MB
Evaluate ; Batch [10/20];  Iter time 4.90; Mem 5225.33MB
Test model; Metrics mAP0.25, mAP0.50: 65.04, 45.69
AR0.25, AR0.50: 77.99, 58.35
-----
IOU Thresh=0.25
bathtub Average Precision: 89.65
bed Average Precision: 79.93
bookshelf Average Precision: 51.00
cabinet Average Precision: 47.01
chair Average Precision: 89.00
counter Average Precision: 57.35
curtain Average Precision: 51.37
desk Average Precision: 71.77
door Average Precision: 48.30
bathtub Recall: 90.32
bed Recall: 88.89
bookshelf Recall: 70.13
cabinet Recall: 69.09
chair Recall: 91.23
counter Recall: 80.77
curtain Recall: 64.18
desk Recall: 85.83
door Recall: 61.46
-----
IOU Thresh=0.5
bathtub Average Precision: 75.88
bed Average Precision: 73.45
bookshelf Average Precision: 43.81
cabinet Average Precision: 18.03
chair Average Precision: 74.53
counter Average Precision: 20.94
curtain Average Precision: 29.17
desk Average Precision: 46.48
door Average Precision: 28.96
ba