In [None]:
import numpy as np
from BurgerDataTest3 import BurgerData
import torch
import torchvision
import os
import matplotlib.pyplot as plt
from PIL import Image

def reassemble_patches(data, model):
    # Function to reassemble patch predictions into a full-image anomaly map
    
    height, width = -1, -1
    pxl_anom_whole_image = []
    ground_truth_image = None

    for i in range(len(data)):
        patch, _ = data[i]
        file, (x, y) = data.index2position[i]

        if height == -1 or width == -1:
            imagefull = Image.open(os.path.join(data.image_folder, file))
            ground_truth_image = Image.open(os.path.join(data.image_folder, file[:-4] + '_groundtruth.png')).convert('L')
            width, height = imagefull.size
            pxl_anom_whole_image = np.zeros((height, width))

        img_lvl_anom_score, pxl_lvl_anom_score = model.predict(patch[None, ...])  
        pxl_anom_whole_image[y:y+224, x:x+224] = pxl_lvl_anom_score    
    
    return imagefull, ground_truth_image, pxl_anom_whole_image

def visualize_whole_images(whole_image, ground_truth_image, pxl_anom):
    # Function to visualize the full image, ground truth, anomaly map, and overlay of the anomaly map on the original image.
    
    image_np = np.array(whole_image)
    ground_truth_np = np.array(ground_truth_image)
    anomlay_map = np.array(pxl_anom)

    plt.figure(figsize=(15, 15))

    # Upper Left: Original Image
    plt.subplot(2, 2, 1)
    plt.imshow(image_np)
    plt.title('a. Original Image')
    plt.axis('off')

    # Upper Right: Ground Truth Image
    plt.subplot(2, 2, 2)
    plt.imshow(ground_truth_np, cmap='gray')
    plt.title('b. Ground Truth Image')
    plt.axis('off')

    # Bottom Left: Detection Alone
    plt.subplot(2, 2, 3)
    plt.imshow(anomlay_map, cmap='jet')
    plt.title('c. Anomaly Map (Detection Alone) PatchCore')
    plt.axis('off')

    # Bottom Right: Detection Overlay
    plt.subplot(2, 2, 4)
    plt.imshow(image_np, alpha=0.5)
    plt.imshow(anomlay_map, alpha=0.5, cmap='jet')
    plt.title('d. Detection Overlay PatchCore')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

w=2
if w == 1:
    #white
    mean=[0.5815, 0.5940, 0.5015]    
    std=[0.2716, 0.2812, 0.2710]
    json= '/home/shn/PatchCore/white_coords.json'
    folder="/home/shn/data/test/white"
elif w == 2:
    #white with edges
    mean=[0.6384, 0.6557, 0.5500]
    std=[0.2846, 0.2897, 0.2772]
    json= '/home/shn/PatchCore/white_with_edges_coords.json'
    folder="/home/shn/data/test/white_with_edges"
else:
    print("No mean and std defined")
trans = torchvision.transforms.Normalize(mean=mean, std = std)
data = BurgerData(imgSize=224, stride=112, image_folder=folder, json_file=json, transform=trans)
model = torch.load('white_with_edges448s')
model.to("cuda:0")

print("Starting reassembly of patches...")
whole_image, ground_truth_image, pxl_anom = reassemble_patches(data, model)
print("Reassembly of patches completed.")

print("Starting visualization...")
visualize_whole_images(whole_image, ground_truth_image, pxl_anom)
print("Visualization completed.")
