In [1]:
import os
os.environ["OMP_NUM_THREADS"]="1"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import json
import numpy as np
import torch
import torch.nn.functional as F

from guide_train_net import Trainer, setup, main
from detectron2.engine import (
    DefaultTrainer,
    default_argument_parser,
    default_setup,
    launch,
    hooks,
)

from detectron2.checkpoint import DetectionCheckpointer
import wandb
import matplotlib.pyplot as plt

from utils.cvppp_evaluation import read_gt_mask
from torchvision.io import read_image
from detectron2.utils.visualizer import Visualizer, VisImage, ColorMode
import h5py

from PIL import Image

def filter_pred_mask(pred, score_threshold=0.0):
    pred_ins = pred[0]['instances'].to('cpu') # [0] here is problematic?
    pred_scores = pred_ins.scores
    pred_boxes = pred_ins.pred_boxes.tensor
    pred_masks = pred_ins.pred_masks

    # confidence score filtering
    filtered_indices = pred_scores >= score_threshold
    filtered_boxes = pred_boxes[filtered_indices]
    filtered_scores = pred_scores[filtered_indices]
    filtered_masks = pred_masks[filtered_indices] # np (w, h)

    # combine multi-channel masks to single channel
    single_channel_mask = np.zeros((filtered_masks.shape[1], filtered_masks.shape[2]), dtype=np.uint8)
    for i in range(filtered_masks.shape[0]):
        single_channel_mask[filtered_masks[i].numpy() > 0] = i+1

    return single_channel_mask

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# for test
# config_files = ['./configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_nohpe.yaml',
#                 './configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_nodpe.yaml',
#                 './configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_noauxsup.yaml',]
# config_files = ['./configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre.yaml',]
# config_files = ['./configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_baseline_cocopre.yaml',]
config_files = ['./configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_onlyhpe.yaml',
                './configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_onlydpe.yaml',
                './configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_onlyauxsup.yaml',]
test_set = ('A1_coco',)

for config_file in config_files:
    device = 'cuda'
    args = default_argument_parser().parse_args([])
    args.eval_only = True
    args.config_file = config_file

    print("Command Line Args:", args)
    cfg = setup(args)

    # change cfg.test_datasets to the real test set
    cfg.defrost()
    cfg.DATASETS.TEST = test_set
    cfg.freeze()

    # build test loader
    if len(cfg.DATASETS.TEST) == 1:
        test_loader = Trainer.build_test_loader(cfg, dataset_name=cfg.DATASETS.TEST[0]) # test loader
    else:
        test_loader = Trainer.build_test_loader(cfg, dataset_name=cfg.DATASETS.TEST) # test loader

    # build and load model
    model = Trainer.build_model(cfg) # when building, already to cfg.MODEL.DEVICE (default='cuda')

    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(os.path.join(cfg.OUTPUT_DIR, "model_best_sbd.pth"), resume=False)
    # Note: resume = True will load the latest checkpoint, not best
    model.eval()

    # submission
    save_dir = './submission'
    save_name = os.path.join(save_dir, args.config_file.split('/')[-1].replace('.yaml','.h5'))
    save_dict = {'A1':{}, 'A2':{}, 'A3':{}, 'A4':{}, 'A5':{}}

    for batch in test_loader:
        file_name = batch[0]['file_name']
        img_domain = file_name.split('/')[-2] # A1 to A5
        img_name = file_name.split('/')[-1].replace('_rgb.png','')
        label_filename = img_name + '_label.png'

        # batch[0]['image'] = batch[0]['image'].to(device)  
        with torch.no_grad():
            outputs = model(batch)
        pred_mask = filter_pred_mask(outputs, score_threshold=0.85)
        save_dict[img_domain][img_name] = {'label_filename': label_filename, 'label': pred_mask}

    with h5py.File(save_name, 'w') as hf:
        groups = []
        for group_name in ['A1', 'A2', 'A3', 'A4', 'A5']:
            group = hf.create_group(group_name)
            for plant_name in save_dict[group_name].keys():
                if len(save_dict[group_name].keys()) > 0:
                    plant_group = group.create_group(plant_name)
                    plant_group.create_dataset('label', data = save_dict[group_name][plant_name]['label'])
                    plant_group.create_dataset('label_filename', data = save_dict[group_name][plant_name]['label_filename'])

Command Line Args: Namespace(config_file='./configs/guide_exp/cvppp_a1_90/cvppp_a1_90_r50_guide_cocopre_onlyhpe.yaml', resume=False, eval_only=True, num_gpus=1, num_machines=1, machine_rank=0, dist_url='tcp://127.0.0.1:56601', opts=[])
[32m[08/04 14:02:40 detectron2]: [0mRank of current process: 0. World size: 1
[32m[08/04 14:02:45 detectron2]: [0mEnvironment info:
-------------------------------  -------------------------------------------------------------------------------------------------
sys.platform                     linux
Python                           3.9.16 (main, May 15 2023, 23:46:34) [GCC 11.2.0]
numpy                            1.24.3
detectron2                       0.6 @/remote/rds/users/fchen2/codes/detectron2/detectron2
Compiler                         GCC 10.4
CUDA compiler                    CUDA 12.0
detectron2 arch flags            8.0
DETECTRON2_ENV_MODULE            <not set>
PyTorch                          2.1.0.dev20230618 @/home/fchen2/RDS/anaconda3/