In [1]:
import os
import importlib.util
import torch.distributed as dist
import torch
from data.builder import build_dataset
from models.detectors.zid_rcnn import ZidRCNN
from scripts import dist_util
from models.utils.data_container import collate
from functools import partial
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

In [2]:
def build_detector(model_cfg):
    model_cfg_ = model_cfg.copy()

    model_type = model_cfg_.pop('type') 
    assert model_type == 'ZidRCNN', f'{model_type} is not implemented yet.'
    return ZidRCNN(**model_cfg_)
    
def get_config_from_file(filename, mode):
    spec = importlib.util.spec_from_file_location(mode, filename)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    # Create a dictionary from module attributes
    config_dict = {key: getattr(module, key) for key in dir(module) if not key.startswith('__')}
    return config_dict

In [7]:
def test_batch_processing(batch):
    batch['obj_id'] = batch['id'][0]
    batch.pop('id')
    for k, v in batch.items():
        if k == 'img_metas':
            batch[k] = batch[k][0].data
        if k == 'img':
            batch[k] = [batch[k][0].data[0].to(device, non_blocking=True)]
        
        elif k in ['rgb', 'mask', 'traj']:
            batch[k] = batch[k].data.to(device, non_blocking=True)
    return batch

In [3]:
cfg = get_config_from_file('configs/test_conf.py', 'detection')
cfg.get('model')['train_cfg'] = None

In [4]:
dataset = build_dataset(cfg.get('data')['test'])

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [5]:
model = build_detector(cfg.get('model'))
model_path = '/home/minhnh/project_drive/CV/FewshotObjectDetection/outputs/VoxDet_p2_1/iter_56251.pth'
model.load_state_dict(
    torch.load(model_path, map_location="cpu")['state_dict']
)
model.CLASSES = dataset.CLASSES
device = torch.device('cuda:4')
model.to(device)
model.eval()
print('Loaded model')

load model from: torchvision://resnet50
Loaded model


In [6]:
data = torch.utils.data.DataLoader(dataset,
                                    batch_size=1,
                                    num_workers=2,
                                    shuffle=False,
                                    pin_memory=False,
                                    collate_fn=partial(collate, samples_per_gpu=1))

In [None]:
results = []
torch.multiprocessing.set_sharing_strategy('file_system')
for batch in tqdm(data):
    batch = test_batch_processing(batch)
    with torch.no_grad():
        output = model(**batch, return_loss=False, rescale=True)
    results.extend(output)
    # break

 65%|██████▍   | 5876/9109 [08:45<04:43, 11.39it/s]

In [None]:
dataset.evaluate(results, jsonfile_prefix='results')