In [1]:
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

%load_ext autoreload
%autoreload all
    
from utils import read_images
from datasets import DominanceDataset
from vit import SiameseViT

torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7f3a14669810>

In [2]:
model = SiameseViT()
model.load_state_dict(torch.load('weights/dominance.pth'))
model.eval()

images = read_images(ids="test")
test_dataset = DominanceDataset(images, split="test")
#test_loader = DataLoader(test_dataset, batch_size=8)

Loaded test images from cache.


In [5]:
def get_pred(left_imgs, right_imgs):
    preds = []
    for l in left_imgs:
        for r in right_imgs:
            pred = torch.sigmoid(model(l, r))[0, 1]
            preds.append(pred)
    preds = np.array(preds)
    return preds.mean(), preds

In [13]:
# eval representative frame sampling 

labels = []
partial_labels = []
final_preds = []
partial_preds = []
for i in range(len(test_dataset)):
    study_id, data = test_dataset.index[i]

    imgs_left=[]
    for img_id, frame_n in zip(data["l"], data["l_f"]):
        img = images[img_id][frame_n]
        img = test_dataset.transform(img).unsqueeze(0)
        imgs_left.append(img)

    imgs_right=[]
    for img_id, frame_n in zip(data["r"], data["r_f"]):
        img = images[img_id][frame_n]
        img = test_dataset.transform(img).unsqueeze(0)
        imgs_right.append(img)

    pred, all_preds = get_pred(imgs_left, imgs_right)
    final_preds.append(1 if pred > 0.5 else 0)
    partial_preds.extend([1 if p > 0.5 else 0 for p in all_preds])
    labels.append(data["label"])
    partial_labels.extend([data["label"]]*len(all_preds))


In [14]:
report = classification_report(labels, final_preds)
matrix = confusion_matrix(labels, final_preds)
print("Classification Report:")
print(report)
print("Confusion Matrix:")
print(matrix)

Classification Report:
              precision    recall  f1-score   support

         0.0       0.50      0.67      0.57         3
         1.0       0.93      0.88      0.90        16

    accuracy                           0.84        19
   macro avg       0.72      0.77      0.74        19
weighted avg       0.86      0.84      0.85        19

Confusion Matrix:
[[ 2  1]
 [ 2 14]]


In [15]:
report = classification_report(partial_labels, partial_preds)
matrix = confusion_matrix(partial_labels, partial_preds)
print("Classification Report:")
print(report)
print("Confusion Matrix:")
print(matrix)

Classification Report:
              precision    recall  f1-score   support

         0.0       0.27      0.89      0.42        28
         1.0       0.99      0.76      0.86       280

    accuracy                           0.78       308
   macro avg       0.63      0.83      0.64       308
weighted avg       0.92      0.78      0.82       308

Confusion Matrix:
[[ 25   3]
 [ 66 214]]
