In [4]:
import os
import argparse
import random
import time
import datetime

import torch
import numpy as np
import pandas as pd
import peft
import wandb
from flair import FLAIRModel

from data.dataset import build_dataset_single
from model.flair_distill import FLAIRMultiLayer
from process.finetune import train_one_epoch, evaluate
from utils.eval import save_model, load_model
from utils.logger import MetricLogger
from utils.eval_single import compute_metrics, compute_classwise_metrics, print_result, print_result_whole


def get_args_parser():
    parser = argparse.ArgumentParser('Multi Eye CLIP', add_help=False)
    parser.add_argument('--modality', default='fundus', type=str, help='modality for training backbone model')
    parser.add_argument('--device_id', default='4', type=str, help='select device id')
    parser.add_argument('--device', default='cuda', type=str, help='device: cuda or cpu')
    
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
                    help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
                    help='epochs to warmup LR')
    parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
    parser.add_argument('--alpha_distill', type=float, default=1)
    parser.add_argument('--beta_distill', type=float, default=10)
    parser.add_argument('--temperature', type=float, default=0.4)
    
    parser.add_argument('--data_path', default='multieye/assemble', type=str,help='dataset path')
    parser.add_argument('--concept_path', default='concepts', type=str, help='concept path')
    parser.add_argument('--checkpoint_path', default='checkpoint/fundus_checkpoint', help='oct checkpoint')

    # Augmentation parameters
    parser.add_argument('--input_size', default=512, type=int,
                    help='images input size')
    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')
    
    # Mixup parameters
    parser.add_argument('--mixup', type=float, default=0,
                    help='mixup alpha, mixup enabled if > 0.')
    parser.add_argument('--cutmix', type=float, default=0,
                        help='cutmix alpha, cutmix enabled if > 0.')
    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup_prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup_mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
    parser.add_argument('--smoothing', type=float, default=0.1,
                    help='Label smoothing (default: 0.1)')

    parser.add_argument('--batch_size', default=64, type=int, help='batch size')
    parser.add_argument('--num_samples', default=36000, type=int, help='number of the sampled training data')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--n_classes', default=9, type=int, help='number of the classification types')
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--accum_iter', default=1, type=int)
    
    parser.add_argument('--print_freq', default=100, type=int, help='batch size')
    parser.add_argument('--eval', action='store_true', default=False, help='Perform evaluation only')
    
    return parser

In [7]:
args = get_args_parser()
args, unknown = args.parse_known_args()

device = torch.device(args.device, int(args.device_id))

# fix the seed for reproducibility
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.benchmark = True

test_dataset = build_dataset_single('test', args=args, mod='fundus')

data_loader_test = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=64, 
    pin_memory=args.pin_mem, 
    shuffle=False, num_workers=args.num_workers,
    drop_last=False)

concept_feat_path = os.path.join(args.concept_path, 'concepts_raw.npy')
model = FLAIRMultiLayer(args, device, concept_feat_path)

model_state_dict = load_model(args=args, if_best=True, device=device, checkpoint=args.checkpoint_path)
model.load_state_dict(model_state_dict)
model = model.to(device)
model.eval()

criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion))

Pretrained weights: IMAGENET1K_V1
load model weight from: ./flair/modeling/flair_pretrained_weights/flair_resnet.pth
criterion = CrossEntropyLoss()


In [8]:
metric_logger = MetricLogger(delimiter="  ")
header = 'Test:'

prediction_decode_list = []
prediction_prob_list = []
true_label_decode_list = []
img_name_list = []
nn = []

model.eval()
                
with torch.no_grad():
    for samples, targets, image_names in metric_logger.log_every(data_loader_test, 200, header):
        images = samples.to(device)
        target = targets.to(device)
        image_name = image_names

        with torch.cuda.amp.autocast():
            outputs, concept_sim = model.forward_concept(images)

        prediction_prob = torch.softmax(outputs, dim=1)
        prediction_decode = torch.argmax(prediction_prob, dim=1)
        prediction_score, _ = torch.max(prediction_prob, dim=1)
        concept_sim = torch.sigmoid(concept_sim)
        true_label_decode = target
        
        prediction_decode_list.extend(prediction_decode.cpu().detach().numpy().tolist())
        true_label_decode_list.extend(true_label_decode.cpu().detach().numpy().tolist())
        prediction_prob_list.extend(prediction_prob.cpu().detach().numpy().tolist())

true_label_decode_list = np.array(true_label_decode_list)
prediction_decode_list = np.array(prediction_decode_list)
prediction_prob_list = np.array(prediction_prob_list)

from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, cohen_kappa_score, roc_auc_score, average_precision_score, confusion_matrix
from sklearn.preprocessing import label_binarize
from utils.eval_single import compute_sen_spe

results = compute_metrics(true_label_decode_list, prediction_decode_list, prediction_prob_list)
class_wise_results = compute_classwise_metrics(true_label_decode_list, prediction_decode_list, prediction_prob_list)

alldiseases = ['NOR', 'AMD', 'CSC', 'DR', 'GLC', 'MEM', 'MYO', 'RVO', 'WAMD']

print_result_whole(class_wise_results, results, alldiseases)

Test:  [  0/182]  eta: 0:09:49    time: 3.2397  data: 1.2991  max mem: 0
Test:  [181/182]  eta: 0:00:00    time: 0.3463  data: 0.0013  max mem: 0
Test: Total time: 0:00:56 (0.3081 s / it)
Class	Pre	Rec	F1_PR	Sen	Spe	F1_SS	AUC	AP
NOR	0.8183	0.8819	0.8489	0.8819	0.7063	0.7844	0.8683	0.8821
AMD	0.4627	0.6691	0.5471	0.6691	0.9906	0.7987	0.9711	0.6512
CSC	0.4800	0.4000	0.4364	0.4000	0.9989	0.5712	0.9928	0.4350
DR	0.7778	0.6714	0.7207	0.6714	0.9017	0.7697	0.8587	0.8241
GLC	0.6791	0.6594	0.6691	0.6594	0.9962	0.7936	0.9324	0.6924
MEM	0.7500	0.3571	0.4839	0.3571	0.9996	0.5263	0.9557	0.3804
MYO	0.5412	0.9109	0.6790	0.9109	0.9932	0.9503	0.9811	0.7533
RVO	0.8200	0.8367	0.8283	0.8367	0.9984	0.9105	0.9755	0.8572
WAMD	0.7143	0.2484	0.3687	0.2484	0.9986	0.3979	0.8122	0.3169
Average	Pre	Rec	F1_PR	Sen	Spe	F1_SS	AUC	MAP	Acc	Kappa
	0.6715	0.6261	0.6202	0.6261	0.9537	0.7225	0.9275	0.6436	0.7933	0.5965
