In [1]:
import torch
from torch.utils.data import DataLoader
import os
import pandas as pd
from work.utils.dataset import PandasDataset, RGB2Fusion
from work.utils.models import EfficientNet
from work.utils.metrics import evaluation
from sklearn.metrics import confusion_matrix
import albumentations

In [2]:
backbone_model = 'efficientnet-b0'
pretrained_model = {
    backbone_model: '../efficientnet-b0-08094119.pth'
}
data_dir = '../../dataset'
images_dir = os.path.join(data_dir, 'tiles')

df_test = pd.read_csv(f"../data/test.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dimensions = 5

In [3]:
print("Cuda", device)

Cuda cuda


In [4]:
transform =albumentations.Compose([
    RGB2Fusion(mode="max", space_colors=["rgb", "xyz", "lab"]),
])

In [5]:
dataloader = DataLoader(
    PandasDataset(images_dir, df_test, transforms=transform),
    batch_size=2,
    shuffle=False,
)

In [6]:
model = EfficientNet(
    backbone=backbone_model,
    output_dimensions=output_dimensions,
    pre_trained_model=pretrained_model
)
model.to(device)
model.load_state_dict(
    torch.load(
        "models/with-noise-fusion-max-3-images.pth",
        weights_only=True
    )
)

response_0 = evaluation(model, dataloader, device)
print("without treatment", response_0[0])
cm_0 = confusion_matrix(response_0[1][1], response_0[1][0])

Loaded pretrained weights for efficientnet-b0


100%|██████████| 796/796 [32:55<00:00,  2.48s/it]


without treatment {'val_acc': {'mean': 61.47292720079422, 'std': 1.2406626209459328, 'ci_5': 59.42211151123047, 'ci_95': 63.44221234321594}, 'val_kappa': {'mean': 0.8498734462910676, 'std': 0.010031731820573283, 'ci_5': 0.8325826474902217, 'ci_95': 0.8659759518665118}, 'val_f1': {'mean': 0.5562127445340157, 'std': 0.013066023371671984, 'ci_5': 0.5356912225484848, 'ci_95': 0.5779006004333496}, 'val_recall': {'mean': 0.5563964654803276, 'std': 0.013326440267590106, 'ci_5': 0.5352728098630906, 'ci_95': 0.578430563211441}, 'val_precision': {'mean': 0.5592746970057487, 'std': 0.01294799330635752, 'ci_5': 0.5378402233123779, 'ci_95': 0.5807172626256942}}
