In [15]:
"""
Run the q2l inference code. Use their map method to calculate map 
"""

'\nRun the q2l inference code. Use their map method to calculate map \n'

In [79]:
import argparse
import os, sys
import random
import datetime
import time
from typing import List
import json
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed

from tqdm import tqdm
from sklearn.metrics import average_precision_score

In [80]:
import sys
sys.path.append('/home/ksmehrab/FishDatasetTrack/Identification/Query2label/fish_q2l')

In [81]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
server = 'arc'

In [82]:
import numpy as np

def voc_ap(rec, prec, true_num):
    mrec = np.concatenate(([0.], rec, [1.]))
    mpre = np.concatenate(([0.], prec, [0.]))
    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
    i = np.where(mrec[1:] != mrec[:-1])[0]
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def voc_mAP(imagessetfilelist, num, return_each=False):
    if isinstance(imagessetfilelist, str):
        imagessetfilelist = [imagessetfilelist]
    lines = []
    for imagessetfile in imagessetfilelist:
        with open(imagessetfile, 'r') as f:
            lines.extend(f.readlines())
    
    seg = np.array([x.strip().split(' ') for x in lines]).astype(float)
    gt_label = seg[:,num:].astype(np.int32)
    num_target = np.sum(gt_label, axis=1, keepdims = True)


    sample_num = len(gt_label)
    class_num = num
    tp = np.zeros(sample_num)
    fp = np.zeros(sample_num)
    aps = []

    for class_id in range(class_num):
        confidence = seg[:,class_id]
        sorted_ind = np.argsort(-confidence)
        sorted_scores = np.sort(-confidence)
        sorted_label = [gt_label[x][class_id] for x in sorted_ind]

        for i in range(sample_num):
            tp[i] = (sorted_label[i]>0)
            fp[i] = (sorted_label[i]<=0)
        true_num = 0
        true_num = sum(tp)
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        rec = tp / float(true_num)
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = voc_ap(rec, prec, true_num)
        aps += [ap]

    np.set_printoptions(precision=6, suppress=True)
    aps = np.array(aps) * 100
    mAP = np.mean(aps)
    if return_each:
        return mAP, aps
    return mAP



In [83]:
@torch.no_grad()
def validate(val_loader, model, amp, num_class, output_path):
    # switch to evaluate mode
    model.eval()
    saved_data = []
    all_predicted = []
    all_targets = []
    with torch.no_grad():
        for i, (images, target) in enumerate(tqdm(val_loader)):
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast(enabled=amp):
                output = model(images)
                # breakpoint()
                output_sm = nn.functional.sigmoid(output)
                
            # Flip the pelvic fin trait
            output_sm = output_sm.detach().cpu()
            target = target.detach().cpu()
            
            output_sm[:, 1] = 1 - output_sm[:, 1]
            target[:, 1] = 1 - target[:, 1]
            
            all_predicted.append(output_sm.numpy())
            all_targets.append(target.numpy())
            
            # save some data
            _item = torch.cat((output_sm, target), 1)
            saved_data.append(_item)
        
        all_predicted = np.vstack(all_predicted)
        all_targets = np.vstack(all_targets)

        # calculate mAP
        saved_data = torch.cat(saved_data, 0).numpy()
        saved_name = 'saved_data_tmp.txt'
        np.savetxt(os.path.join(output_path, saved_name), saved_data)

        print("Calculating mAP:")
        filenamelist = ['saved_data_tmp.txt']
        metric_func = voc_mAP                
        mAP, aps = metric_func([os.path.join(output_path, _filename) for _filename in filenamelist], num_class, return_each=True)
       
        ### SKLEARN MAP

        # Calculate Average Precision for each class
        sk_aps = []
        for i in range(all_targets.shape[1]):
            ap = average_precision_score(all_targets[:, i], all_predicted[:, i])
            sk_aps.append(ap)

        # Compute mAP
        sk_mAP = np.mean(sk_aps)
        
    return aps, mAP, sk_aps, sk_mAP

In [84]:
amp = True
num_class = 4
### CHANGE OUTPUT PATH HERE
output = '/home/ksmehrab/FishDatasetTrack/Identification/TraitIDBasic/Outputs/results_swinb_wbce_basic'

In [85]:
# START LOADING MODEL AND DATA LOADER 
# Get model
from model import get_custom_model

#### CHANGE MODEL AND CHECKPOINT_PATH
MODEL = 'swin_b'
N_CLASSES = num_class


model = get_custom_model(
    model_name=MODEL,
    num_classes=N_CLASSES,
    pretrained=False
)

model = model.to(device)

In [86]:
# Get checkpoint
checkpoint_path = '/home/ksmehrab/FishDatasetTrack/Identification/TraitIDBasic/Outputs/results_swinb_wbce_basic/ckpt_9229_S9229_tid_swinb_basic_fishair_processed_swin_b.t7'
ckpt_t = torch.load(checkpoint_path)
model.load_state_dict(ckpt_t['net'])
epoch = ckpt_t['epoch']

In [87]:
# Get test loader
from pathlib import Path
if server == 'pda':
#     train_file = Path('/data/DatasetTrackFinalData/Identification/trait_identification_train.csv')
#     val_file = Path('/data/DatasetTrackFinalData/Identification/trait_identification_val.csv')
    test_file = Path('/data/DatasetTrackFinalData/Identification/trait_identification_test_inspecies.csv')
    lv_sp_normal_test_file = Path('/data/DatasetTrackFinalData/Identification/trait_identification_test_leavespecies.csv')
    lv_sp_difficult_test_file = None
    img_dir = Path('/data/BGRemovedCropped/all')
elif server == 'arc':
#     train_file = Path('/projects/ml4science/FishDatasetTrack/DatasetTrackFinalData/Identification/trait_identification_train.csv')
#     val_file = Path('/projects/ml4science/FishDatasetTrack/DatasetTrackFinalData/Identification/trait_identification_val.csv')
    test_file = Path('/projects/ml4science/FishDatasetTrack/DatasetTrackFinalData/Identification/trait_identification_test_inspecies.csv')
    lv_sp_normal_test_file = Path('/projects/ml4science/FishDatasetTrack/DatasetTrackFinalData/Identification/trait_identification_test_leavespecies.csv')
    lv_sp_difficult_test_file = Path('/projects/ml4science/FishDatasetTrack/DatasetTrackFinalData/Segmentation/annotations_mlic.csv')
    img_dir = Path('/projects/ml4science/FishAIR/BGRemovedCropped/all')

In [88]:
from data_setup import get_transform, get_dataset_and_dataloader

mean = torch.tensor([0.9353, 0.9175, 0.8923])
std = torch.tensor([0.1535, 0.1933, 0.2464])
transform = get_transform(224, mean, std, 'squarepad_augment_normalize')
test_transform = get_transform(224, mean, std, 'squarepad_no_augment_normalize')

BATCH_SIZE=256
num_workers = 8

test_dataset, test_loader = get_dataset_and_dataloader(
    data_file=test_file,
    img_dir=img_dir,
    transform=test_transform,
    batch_size=BATCH_SIZE,
    num_workers=num_workers
)

if lv_sp_normal_test_file:
    lv_sp_normal_dataset, lv_sp_normal_loader = get_dataset_and_dataloader(
        data_file=lv_sp_normal_test_file,
        img_dir=img_dir,
        transform=test_transform,
        batch_size=BATCH_SIZE,
        num_workers=num_workers
    )

if lv_sp_difficult_test_file:
    lv_sp_dif_dataset, lv_sp_dif_loader = get_dataset_and_dataloader(
        data_file=lv_sp_difficult_test_file,
        img_dir=img_dir,
        transform=test_transform,
        batch_size=BATCH_SIZE,
        num_workers=num_workers
    )


In [90]:
q2l_aps, q2l_map, sk_aps, sk_map = validate(lv_sp_normal_loader, model, amp, num_class, output)

100%|██████████| 7/7 [01:47<00:00, 15.40s/it] 

Calculating mAP:





In [92]:
print(f'q2l_aps:  {q2l_aps.tolist()}\nq2l_map: {q2l_map}')

q2l_aps:  [82.97406277947556, 39.74128396051576, 82.04788472937936, 44.67251751695641]
q2l_map:62.358937246581775


In [94]:
print(f'sk_aps:  {sk_aps}\nsk_map: {sk_map}')

sk_aps:  [0.8203605169115833, 0.39597167150629464, 0.8148718101808212, 0.42998019812404914]
sk_map: 0.6152960491806871
