In [1]:
import timm
import torch
from sklearn import metrics
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd

from test_dataset import Test_Dataset

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
config = {"batch_size": 41}

In [3]:
def expand_prediction(arr):
    arr_reshaped = arr.reshape(-1, 1)
    return np.clip(np.concatenate((1.0 - arr_reshaped, arr_reshaped), axis=1), 0.0, 1.0)

In [4]:
def validate(df, root_dir, mode):
    model = timm.create_model("tf_efficientnet_b4_ns", pretrained=False, num_classes=1)
    if mode == "random_erase":
        model.load_state_dict(torch.load("weights/random_erase_tf_efficientnet_b4_ns.h5"))
    elif mode == "face_cutout":
        model.load_state_dict(torch.load("weights/face_cutout_tf_efficientnet_b4_ns.h5"))
    model.to(device)

    data = Test_Dataset(df, root_dir)
    data_loader = DataLoader(
        data,
        batch_size=config["batch_size"],
        num_workers=8,
        shuffle=True,
        drop_last=True,
    )

    predictions = []
    targets = []

    with torch.no_grad():
        for batch in tqdm(data_loader):
            batch_images = batch["image"].to(device)
            batch_labels = batch["label"].to(device)

            out = model(batch_images)

            batch_targets = (batch_labels.view(-1, 1).cpu() >= 0.5) * 1
            batch_preds = torch.sigmoid(out).cpu()

            targets.append(batch_targets)
            predictions.append(batch_preds)

        targets = np.vstack((targets)).ravel()
        predictions = np.vstack((predictions)).ravel()

        auc = metrics.roc_auc_score(targets, predictions)
        mAP = metrics.average_precision_score(targets, predictions)
        log_loss = metrics.log_loss(targets, expand_prediction(predictions))

        print('')
        print(f"AUC : {auc}")
        print(f"mAP : {mAP}")
        print(f"LogLoss : {log_loss}")

In [5]:
df = pd.read_csv("metadata.csv")

print("Face Cutout --------------")
validate(df, root_dir="data/val_images", mode="face_cutout")
print()
print("Random Erase -------------")
validate(df, root_dir="data/val_images", mode="random_erase")

Face Cutout --------------
100%|██████████| 40/40 [00:05<00:00,  8.04it/s]

AUC : 0.9659584773104254
mAP : 0.9927165331395946
LogLoss : 0.2119294066442086

Random Erase -------------
100%|██████████| 40/40 [00:05<00:00,  8.10it/s]
AUC : 0.9448474498375792
mAP : 0.988081434519885
LogLoss : 0.24963293145048693

