In [None]:
import json
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import seaborn as sns
import numpy as np
import sys
sys.path.append("./code/")

from data_processing import load_data
from model import RNNBinPacking
from utils import count_parameters

In [None]:
checkpoint_path = 'best_model_checkpoint.pth'
customize_dataset = []



filter_string = None    # Overall
# filter_string = 'PC2'   # PC2
# filter_string = 'PC3'   # PC3
# filter_string = 'PC4'   # PC4



In [None]:
if filter_string == 'PC2':
    print("TESTING ON PC2")
elif filter_string == 'PC3':
    print("TESTING ON PC3")
elif filter_string == 'PC4':
    print("TESTING ON PC4")
elif filter_string == None:
    print("TESTING ON OVERALL")
else:
    raise NotImplementedError

with open('args_info.json', 'r', encoding='utf-8') as f:
    args_dict = json.load(f)
data_set = args_dict['dataset']
batch_size = args_dict['batch_size']
seed = args_dict['seed']
hidden_size = args_dict['dim']
nhead = args_dict['head']
num_transformer_layers = args_dict['transformer_layers']
num_rnn_layers = args_dict['rnn_layers']
num_fc_neurons = hidden_size
d_ff = 4 * hidden_size
device = torch.device("cuda:0" if args_dict['gpu'] and torch.cuda.is_available() else "cpu")

torch.manual_seed(seed)
original_dir = os.getcwd()
os.chdir("..")
if customize_dataset != []:
    train_loader, val_loader, test_loader, pos_weight = load_data(customize_dataset, batch_size, filter_string=filter_string)
else:
    train_loader, val_loader, test_loader, pos_weight = load_data(data_set, batch_size, test_batch_size=32, filter_string=filter_string)

# Create the model and load saved weights
os.chdir(original_dir)
model = RNNBinPacking(hidden_size, nhead, num_transformer_layers, num_rnn_layers, num_fc_neurons, d_ff).to(device)
model.load_state_dict(torch.load(f"./models/{checkpoint_path}")['model_state_dict'])

print(count_parameters(model))

# Testing
model.eval()
all_outputs = []
all_labels = []
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0

with torch.no_grad(): 
    for batch_data, batch_labels in test_loader:
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        outputs = model(batch_data).sigmoid()
        all_outputs.extend(outputs.cpu().numpy())
        all_labels.extend(batch_labels.cpu().numpy())

        predicted = (outputs > 0.5).float()
        true_positives += ((predicted == 1) & (batch_labels == 1)).sum().item()
        true_negatives += ((predicted == 0) & (batch_labels == 0)).sum().item()
        false_positives += ((predicted == 1) & (batch_labels == 0)).sum().item()
        false_negatives += ((predicted == 0) & (batch_labels == 1)).sum().item()

pr = (true_positives + false_negatives) / (true_positives + false_negatives + true_negatives + false_positives)
nr = (true_negatives + false_positives) / (true_positives + false_negatives + true_negatives + false_positives)
tpr = true_positives / (true_positives + false_negatives)
fpr = false_positives / (true_negatives + false_positives)
tnr = true_negatives / (true_negatives + false_positives)
fnr = false_negatives / (true_positives + false_negatives)
print("Positive Rate:", pr)
print("Negative Rate:", nr)
print("True Positive Rate (TPR):", tpr)
print("False Positive Rate (FPR):", fpr)
print("True Negative Rate (TNR):", tnr)
print("False Negative Rate (FNR):", fnr)

#################### Metrics ###################
labels = ["TPR", "FPR", "TNR", "FNR"]
values = [tpr, fpr, tnr, fnr]
plt.bar(labels, values, color=['blue', 'green', 'red', 'purple'])
plt.title("Classification Metrics")
plt.xlabel("Metrics")
plt.ylabel("Values")
plt.tight_layout()
plt.show()
#################### Metrics ###################

################ Confusion Matrix ################
confusion_matrix = np.array([
    [true_negatives, false_positives],
    [false_negatives, true_positives]
])
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Predicted 0", "Predicted 1"],
            yticklabels=["Actual 0", "Actual 1"])
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.title('Confusion Matrix')
plt.show()
################ Confusion Matrix ################

##################### ROC ####################
fpr, tpr, _ = roc_curve(all_labels, all_outputs)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(7, 5))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True)
# plt.savefig(f"{plot_dir}/roc_curve.png")
plt.show()
##################### ROC ####################

predicted_labels = np.array(all_outputs) > 0.5
correct_predictions = (predicted_labels == np.array(all_labels)).sum()
total_predictions = len(all_labels)
accuracy = correct_predictions / total_predictions
print(f"Test Accuracy: {accuracy * 100:.2f}%")