In [1]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix
from model import CustomEnsembleModel
from kf_data import CustomImageCSVModule_kf


  r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count


In [None]:
# Carregar hiperparâmetros
def load_hyperparameters(file_path='config.yaml'):
    import yaml
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

# Carregar modelo final
hyperparams = load_hyperparameters()
final_model_path = f"{hyperparams['PROJECT']}/{hyperparams['WANDB_RUN_NAME']}.ckpt"

In [None]:
print(f"Carregando modelo final de: {final_model_path}")
model = CustomEnsembleModel.load_from_checkpoint(final_model_path)
model.eval()

In [None]:
# Configurar o DataLoader de Teste
data_module = CustomImageCSVModule_kf(
    train_dir=hyperparams['TRAIN_DIR'],
    test_dir=hyperparams['TEST_DIR'],
    shape=hyperparams['SHAPE'],
    batch_size=hyperparams['BATCH_SIZE'],
    num_workers=hyperparams['NUM_WORKERS'],
    n_splits=hyperparams['K_FOLDS'],
    fold_idx=0
)
data_module.setup(stage='test')
test_loader = data_module.test_dataloader()

In [None]:
# Inicializar PyTorch Lightning Trainer
trainer = pl.Trainer(accelerator='gpu' if torch.cuda.is_available() else 'cpu')

# Avaliação do modelo
results = trainer.test(model, datamodule=data_module)

In [None]:
# Obter métricas
accuracy = results[0]['test/accuracy']
precision = results[0]['test/precision']
recall = results[0]['test/recall']
f1 = results[0]['test/f1_score']
conf_matrix_value = results[0]['test/confusion_matrix']

# Exibir resultados
print(f"Acurácia: {accuracy:.4f}")
print(f"Precisão: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")

In [None]:
# Exibir matriz de confusão
def plot_confusion_matrix(cm):
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predito')
    plt.ylabel('Real')
    plt.title('Matriz de Confusão')
    plt.show()

plot_confusion_matrix(conf_matrix_value)