In [1]:
import random
from typing import Sequence, Tuple

import numpy as np
import torch
from sklearn.metrics import mean_absolute_error, roc_auc_score
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.data import IDAOData, val_transforms 

In [2]:
def fake_model_energy(images: np.ndarray) -> torch.Tensor:
    """Fake model that returns a random number in [1, 30] range."""
    
    return torch.rand((images.shape[0],), device=images.device) * 29 + 1

def fake_model_classes(images: np.ndarray) -> torch.Tensor:
    """Fake model that returns either 0 or 1 with 50% probability."""
    
    return torch.randint(2, (images.shape[0],), device=images.device)

device = torch.device('cuda:0')

In [3]:
@torch.no_grad()
def get_mae(dataloader, model) -> Tuple[float, float]:
    all_energies = []
    all_energies_pred = []

    for images, classes, energies in tqdm(dataloader):
        energies_pred = model(images.to(device))
        all_energies.append(energies)
        all_energies_pred.append(energies_pred.cpu())

    all_energies = torch.cat(all_energies).numpy()
    all_energies_pred = torch.cat(all_energies_pred).numpy()
    
    return mean_absolute_error(all_energies, all_energies_pred)

@torch.no_grad()
def get_rocauc(dataloader, model) -> Tuple[float, float]:
    all_classes = []
    all_classes_pred = []

    for images, classes, energies in tqdm(dataloader):
        classes_pred = model(images.to(device))
        all_classes.append(classes)
        all_classes_pred.append(classes_pred.cpu())

    all_classes = torch.cat(all_classes).numpy()
    all_classes_pred = torch.cat(all_classes_pred).numpy()
    
    return roc_auc_score(all_classes, all_classes_pred)

# Evaluation

## Global evaluation

In [4]:
train_data = IDAOData('data/train', val_transforms())
dataloader = DataLoader(train_data, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)

MAE = get_mae(dataloader, fake_model_energy)
ROCAUC = get_rocauc(dataloader, fake_model_classes)
print(F'{MAE=:.2f}, {ROCAUC=:.3f}, final score {1000*(ROCAUC-MAE):.2f}')

  0%|          | 0/1507 [00:00<?, ?it/s]

  0%|          | 0/1507 [00:00<?, ?it/s]

MAE=11.42, ROCAUC=0.498, final score -10918.78


## Evaluation by class/energy

Here only MAE evaluation is possible

In [5]:
import pandas as pd

def filter_dataset(dataset, energies: Sequence[int], classes: Sequence[str]):
    
    def belongs(label: Tuple[str, int]) -> bool:
        return label[0] in classes and label[1] in energies
    
    filtered_idx = [idx for idx, label in enumerate(dataset.classes) if belongs(label)]
    
    dataset.classes = [label for idx, label in enumerate(dataset.classes) if idx in filtered_idx]
    dataset.image_files = [img for idx, img in enumerate(dataset.image_files) if idx in filtered_idx]    

### Energy

In [12]:
results = []
for energy in [1, 3, 6, 10, 20, 30]:
    data = IDAOData('data/train', val_transforms())
    filter_dataset(data, [energy], ['ER', 'NR'])
    dataloader = DataLoader(data, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
    results.append(get_mae(dataloader, fake_model_energy))

  0%|          | 0/246 [00:00<?, ?it/s]

  0%|          | 0/255 [00:00<?, ?it/s]

  0%|          | 0/253 [00:00<?, ?it/s]

  0%|          | 0/254 [00:00<?, ?it/s]

  0%|          | 0/249 [00:00<?, ?it/s]

  0%|          | 0/253 [00:00<?, ?it/s]

In [13]:
pd.DataFrame({'Energy': [1, 3, 6, 10, 20, 30], 'MAE': results})

Unnamed: 0,Energy,MAE
0,1,14.528497
1,3,12.730168
2,6,10.477481
3,10,8.34336
4,20,7.910609
5,30,14.549813


### Classes

In [9]:
results = []
for iclass in ['ER', 'NR']:
    data = IDAOData('data/train', val_transforms())
    filter_dataset(data, [1, 3, 6, 10, 20, 30], [iclass])
    dataloader = DataLoader(data, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
    results.append(get_mae(dataloader, fake_model_energy))

  0%|          | 0/762 [00:00<?, ?it/s]

  0%|          | 0/746 [00:00<?, ?it/s]

In [11]:
pd.DataFrame({'Class': ['ER', 'NR'], 'MAE': results})

Unnamed: 0,Class,MAE
0,ER,11.859465
1,NR,10.939141
