In [2]:
import sys
sys.path.append('..')
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
import time
from thop import profile
from sklearn.metrics import classification_report, accuracy_score
import numpy as np

In [3]:
data_transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
])

In [4]:
# Load datasets
dataset_ecgid_test = datasets.ImageFolder(
    root='../data/ecgid/test',
    transform=data_transform
)

dataset_ptb_test = datasets.ImageFolder(
    root='../data/ptb/test',
    transform=data_transform
)

dataloader_ecgid_test = DataLoader(
    dataset_ecgid_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=False
)

dataloader_ptb_test = DataLoader(
    dataset_ptb_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=False
)

datasets_to_test = {
    'ecgid': {'loader': dataloader_ecgid_test, 'classes': 89},
    # 'ptb': {'loader': dataloader_ptb_test, 'classes': 289}
}

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
model_path = '../models_para/deit_tiny_patch16_224_ecgid.pth'
model = timm.create_model('deit_tiny_patch16_224', pretrained=False, num_classes=89)
model.load_state_dict(torch.load(model_path), strict=False)
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for data, targets in dataloader_ecgid_test:
        # data, targets = data.to('cuda'), targets.to('cuda')
        logits = model.forward(data)
        preds = logits.argmax(dim=1)
        all_preds.append(preds.item())
        all_targets.append(targets.item())
    

report = classification_report(all_targets, all_preds)
cm = confusion_matrix(all_targets, all_preds)
acp = 0
for i, r in enumerate(cm):
    pred = r.argmax()
    if pred == i:
        acp += 1
acc = acp / 89

[[3 0 0 ... 0 0 0]
 [0 4 0 ... 0 0 0]
 [0 0 4 ... 0 0 0]
 ...
 [0 0 0 ... 3 0 0]
 [0 0 0 ... 0 4 0]
 [0 0 0 ... 0 0 4]]


In [5]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
def evaluate_model(model, dataloader, nb_classes):
    model.eval()
    all_preds = []
    all_targets = []
    
    test_time = 0
    with torch.no_grad():
        for data, targets in dataloader:
            # data, targets = data.to('cuda'), targets.to('cuda')
            start_time = time.time()
            logits = model.forward(data)
            preds = logits.argmax(dim=1)
            test_time += time.time() - start_time

            all_preds.append(preds.item())
            all_targets.append(targets.item())
    
    test_time = test_time / nb_classes

    report = classification_report(all_targets, all_preds, output_dict=True)

    # 计算ACC
    cm = confusion_matrix(all_targets, all_preds)
    acp = 0
    for i, r in enumerate(cm):
        pred = r.argmax()
        if pred == i:
            acp += 1
    acc = acp / nb_classes
    
    return acc, report, test_time

In [6]:
from torchvision import models as tv_models
from torch import nn

for dataset_name, info in datasets_to_test.items():
    print(f'--- Evaluating on {dataset_name} dataset ---')
    
    # Load model with correct number of classes
    nb_classes = info['classes']
    st_model_path = '../models_para/mobilenetv3_small_050_ecgid_baseline.pth'
    tea_model_path = f'../models_para/resnet34.a1_in1k_{dataset_name}_baseline.pth'
    st_model = timm.create_model('mobilenetv3_small_050', pretrained=False, num_classes=nb_classes)

    # st_model = tv_models.squeezenet1_0(pretrained=True)
    # st_model.classifier[1] = nn.Conv2d(512, 89, kernel_size=1)
    # st_model = tv_models.shufflenet_v2_x1_0(pretrained=False)
    # st_model.fc = nn.Linear(st_model.fc.in_features, 89)
    st_model.load_state_dict(torch.load(st_model_path), strict=False)

    tea_model = timm.create_model('resnet34.a1_in1k', pretrained=False, num_classes=nb_classes)
    tea_model.load_state_dict(torch.load(tea_model_path), strict=False)
    
    # Evaluate
    print('==== eva student ====')
    acc, report, test_time = evaluate_model(st_model, info['loader'], nb_classes=nb_classes)
    
    mp = report['macro avg']['precision']
    mr = report['macro avg']['recall']
    mf1 = report['macro avg']['f1-score']

    weighted_p = report['weighted avg']['precision']
    weighted_r = report['weighted avg']['recall']
    weighted_f1 = report['weighted avg']['f1-score']
    # Print results
    print(f'Accuracy (ACC): {acc:.4f}')
    print(report)
    print(f'Precision (P): {mp:.4f}')
    print(f'Recall (R): {mr:.4f}')
    print(f'F1-score (F1): {mf1:.4f}')
    print(f'Test time: {test_time:.2f}s')

    print('==== eva teacher ====')
    acc, report, test_time = evaluate_model(tea_model, info['loader'], nb_classes=nb_classes)
    
    mp = report['macro avg']['precision']
    mr = report['macro avg']['recall']
    mf1 = report['macro avg']['f1-score']

    weighted_p = report['weighted avg']['precision']
    weighted_r = report['weighted avg']['recall']
    weighted_f1 = report['weighted avg']['f1-score']
    # Print results
    print(f'Accuracy (ACC): {acc:.4f}')
    print(report)
    print(f'Precision (P): {mp:.4f}')
    print(f'Recall (R): {mr:.4f}')
    print(f'F1-score (F1): {mf1:.4f}')
    print(f'Test time: {test_time:.2f}s')

--- Evaluating on ecgid dataset ---
==== eva student ====
Accuracy (ACC): 0.9775
{'0': {'precision': 1.0, 'recall': 0.8, 'f1-score': 0.8888888888888888, 'support': 5.0}, '1': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 5.0}, '2': {'precision': 1.0, 'recall': 0.8, 'f1-score': 0.8888888888888888, 'support': 5.0}, '3': {'precision': 1.0, 'recall': 0.4, 'f1-score': 0.5714285714285714, 'support': 5.0}, '4': {'precision': 0.8, 'recall': 0.8, 'f1-score': 0.8, 'support': 5.0}, '5': {'precision': 0.75, 'recall': 0.6, 'f1-score': 0.6666666666666666, 'support': 5.0}, '6': {'precision': 0.7142857142857143, 'recall': 1.0, 'f1-score': 0.8333333333333334, 'support': 5.0}, '7': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 5.0}, '8': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 5.0}, '9': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 5.0}, '10': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 5.0}, '11': {'precision': 0.2

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
