In [38]:
from HSI_utils import *
from sklearn import metrics
# plot_learning_curves(os.path.join('models','training_losses.csv'))

In [41]:

PRINT_HYPERCUBES = False
PRINT_FAKE_COLORS = False
PRINT_ANOMALY_DET = False
PRINT_ROC = False

dir_results = os.path.join('models','full_HSI_AD')
data_path = os.path.join("data")
df_files = create_df_from_files_in_path(data_path, verbose = False)

## LOOP OVER ALL THE HSIs
for training_data_index in df_files.index:
    anomaly_map, hs_image = load_HSI_from_idx(training_data_index, df_files, verbose = False)
    height, width, layers = np.shape(hs_image)

    # Only train when the image has 100x100 dims. TODO rescale images with 150x150
    if height == 100:

        # Load data
        save_path = os.path.join('models','full_HSI_AD',f'reconstructed_img_{df_files.Filename[training_data_index]}.npy')
        img_hat = np.load(save_path)

        # Plot hypercubesç
        if PRINT_HYPERCUBES:
            scatter_hypercube_from_hyperspectral_image(hs_image)
            scatter_hypercube_from_hyperspectral_image(img_hat)

        # Plot fake-color images
        hs_image = norm_img(hs_image)
        img_hat = norm_img(img_hat)

        if PRINT_FAKE_COLORS:
            print_RGB_HSI(hs_image, img_name = f'{df_files.Filename[training_data_index]}')
            print_RGB_HSI(img_hat, img_name = f'reconstructed {df_files.Filename[training_data_index]}')

        ###########################################
        ## ANOMALY DETECTION BASED ON THE RECONSTRUCTION ERROR
        ERROR = 'MSE'

        if ERROR == 'MSE':
            diff_img = np.mean(((hs_image - img_hat)**2), -1) # MSE
        elif ERROR == 'SAD':
            diff_img = SAD(hs_image, img_hat) # SAD

        error_map = norm_img(diff_img)

        # Gaussian blur
        error_map = norm_img(cv2.GaussianBlur(error_map, (3,3), 0.8))
        
        # histogram, bin_edges = np.histogram(error_map, bins=256, range=(0, 1))
        
        # Anomaly detection: CFAR
        anomaly_detection_map = CFAR_2D(error_map, thr_type = 'higher', filter_dims = 5, gap_pixels = 1, thr_factor = 1.75)
        
        # Anomaly detection: global thresholding
        # anomaly_detection_map = filter_image_under_threshold(anomaly_detection_map, 0.05)
        # anomaly_detection_map = filter_image_over_threshold(anomaly_detection_map, 0)

        if PRINT_ANOMALY_DET:
            fig, ax = plt.subplots(2, 2, figsize=(9,9))

            # Plots anomaly maps
            plt.subplot(221)
            plt.imshow(anomaly_map, cmap=plt.cm.gray)
            plt.title(f'Reference Anomaly map')
            plt.axis('off')

            # plt.subplot(222)
            # plt.plot(bin_edges[0:-1], histogram)
            # plt.title(f'Error histogram')
            # # plt.axis('off')

            plt.subplot(222)
            plt.imshow(error_map, cmap=plt.cm.gray, vmin=0, vmax=1)
            plt.title(f'Error map')
            plt.axis('off')

            plt.subplot(223)
            plt.imshow(error_map, cmap=plt.cm.gray)
            plt.title(f'Smoothed Error map')
            plt.axis('off')

            plt.subplot(224)
            plt.imshow(anomaly_detection_map, cmap=plt.cm.gray, vmin=0, vmax=1)
            plt.title(f'Anomaly Detection map')
            plt.axis('off')

            # SHOW RESULTS
            # print_RGB_HSI_with_mask(hs_image, anomaly_detection_map, cmap = plt.cm.hot, alpha=0.8, title = 'Anomaly Detection (red) over fake color HSI')
            # print_RGB_HSI_with_mask(hs_image, error_map, cmap = plt.cm.hot, alpha=0.8, title = 'Anomaly Detection (red) over fake color HSI')

        # Calculate pd and pfa by thresholding the CFAR
        AD_CFAR = filter_image_over_threshold(anomaly_detection_map, 0)
        pd_CFAR = compute_pd(anomaly_map, AD_CFAR)
        pfa_CFAR = compute_pfa(anomaly_map, AD_CFAR)

        # Calculate ROC and AUC
        pd_list_error, pl_list_error, pfa_list_error = compute_ROC(anomaly_map.copy(), error_map.copy())
        AUC = metrics.auc(pfa_list_error, pd_list_error)

        if PRINT_ROC:
            fig = plt.figure(figsize=(7,4))
            # plt.semilogx(pfa_list_error, pl_list_error, color='k', label = 'ROC localization')
            plt.semilogx(pfa_list_error, pd_list_error, color='r', label = 'ROC pd/pfa')
            plt.scatter(pfa_CFAR, pd_CFAR, color = 'r', label = '2D-CFAR threshold')
            plt.ylabel('Probability of Detection')
            plt.xlabel('Probability of False Alarm')
            plt.title(f'ROC curve with AUC = {AUC:.4f}')
            plt.grid()
            plt.legend(loc='lower right')

        print(f'>> IMAGEN {df_files.Filename[training_data_index]}:')
        print(f'pd = {pd_CFAR:.3f}, pfa = {pfa_CFAR:.3f}, AUC = {AUC:.3f}')

>> IMAGEN abu-airport-1:
pd = 0.201, pfa = 0.026, AUC = 0.731
>> IMAGEN abu-airport-2:
pd = 0.517, pfa = 0.095, AUC = 0.915
>> IMAGEN abu-airport-3:
pd = 0.365, pfa = 0.035, AUC = 0.706
>> IMAGEN abu-airport-4:
pd = 0.300, pfa = 0.005, AUC = 0.985
>> IMAGEN abu-beach-2:
pd = 0.000, pfa = 0.018, AUC = 0.449
>> IMAGEN abu-beach-3:
pd = 0.818, pfa = 0.004, AUC = 0.931
>> IMAGEN abu-urban-1:
pd = 0.687, pfa = 0.008, AUC = 0.923
>> IMAGEN abu-urban-2:
pd = 0.000, pfa = 0.000, AUC = 0.000
>> IMAGEN abu-urban-3:
pd = 0.808, pfa = 0.025, AUC = 0.737
>> IMAGEN abu-urban-4:
pd = 0.000, pfa = 0.000, AUC = 0.000
>> IMAGEN abu-urban-5:
pd = 0.556, pfa = 0.063, AUC = 0.852
