In [None]:
import numpy as np
from BurgerDataTest3 import BurgerData
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix ,accuracy_score , precision_score, recall_score


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def reassemble_patches(data, model):
    # Function to iterate over the dataset, apply the model's prediction on each patch,
    # and collect both the predicted anomaly scores and the true labels
    ytrue = []
    ypred = []


    for i in range(len(data)):
        patch, label = data[i]
        patch = patch.to(device)
        img_lvl_anom_score, pxl_lvl_anom_score = model.predict(patch[None, ...])
        ytrue.append(label)
        ypred.append(img_lvl_anom_score) 

    
    return ypred , ytrue

Category=1

if Category == 1:
    #white
    mean=[0.5815, 0.5940, 0.5015]	
    std=[0.2716, 0.2812, 0.2710]
    json= '/home/shn/PatchCore/white_coords.json'
    folder="/home/shn/data/white/ground_truth"

elif Category == 2:
    #white with edges
    mean=[0.6384, 0.6557, 0.5500]
    std=[0.2846, 0.2897, 0.2772]
    json= '/home/shn/PatchCore/white_with_edges_coords.json'
    folder="/home/shn/data/white_with_edges/ground_truth"
else:
    print("No mean and std defined")

trans = torchvision.transforms.Normalize(mean=mean, std = std)

data = BurgerData(imgSize=224, stride=112, image_folder=folder, json_file=json ,  transform=trans) 
model = torch.load('white_new')
model.to(device)

print("Started reassembling patches")
ypred , ytrue= reassemble_patches(data, model)
print("Starting metrcics...")

def calculate_roc_auc(ytrue, ypred):
    roc_auc = roc_auc_score(ytrue, ypred)
    return roc_auc

roc_auc = calculate_roc_auc(ytrue, ypred)
print("ROC AUC:", roc_auc)

# Generate and plot the ROC curve
fpr, tpr, thresholds = roc_curve(ytrue, ypred)
plt.plot(fpr, tpr, label='ROC Curve')
plt.plot([0, 1], [0, 1], 'k--', label='Random Guess')  # y=x line
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()


optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print(f'Optimal Threshold: {optimal_threshold}')

y_pred_binarized = [1 if score >= optimal_threshold else 0 for score in ypred]

# Compute accuracy, precision, and recall
accuracy = accuracy_score(ytrue, y_pred_binarized)
precision = precision_score(ytrue, y_pred_binarized)
recall = recall_score(ytrue, y_pred_binarized)

# Print the results
print(f'Accuracy: {accuracy}')
print(f'Precision: {precision}')
print(f'Recall: {recall}')

conf_matrix = confusion_matrix(ytrue, y_pred_binarized)
print(f'Confusion Matrix:\n{conf_matrix}')
