In [None]:
from HSI_utils import *
import torch
from vit import *
from mae import *
from sklearn import metrics

In [None]:
seed = 23
torch.manual_seed(seed)
np.random.seed(seed)

## Creating the Model

In [None]:
# Training parameters
PATCH_SIZE = 10
EPOCHS = 4000
loss = 'mse'
# loss = 'weighted_l2'
# loss = 'mae'

# Optional PCA layers reduction
PCA = 0
dim_enc = 8
dim_dec = 8

# Use "TRAIN" to train a model. Otherwise, the notebook will only evaluate.
# Use "LOAD" to load a presaved model. Otherwise, the notebook will use a newly createdm model.
TRAIN = True
LOAD = False

In [None]:
ruta = os.path.join("data")
df_files = create_df_from_files_in_path(ruta, verbose = True)

# List with the index of the images to use in the training. 
# Every image will be trained on a separated model.
list_imgs = [3]

# for id_img in df_files.index:
for id_img in list_imgs:
  name = df_files.Filename[id_img]
  print(f'>> Processing image {name}...')

  # Prepare dataset (HSI)
  layers = df_files.Layers[id_img]
  height = df_files.Height[id_img]
  width = df_files.Width[id_img]
  x_n = torch.zeros(1, layers, height, width)

  anomaly_map, hs_image = load_HSI_from_idx(id_img, df_files, verbose = True)
  hs_image = norm_img(hs_image)
  if PCA > 0:
    hs_image = reduce_HSI_dim_with_PCA(hs_image, PCA)
    hs_image = norm_img(hs_image)

  print_RGB_HSI(hs_image, img_name = name)
  print(f'Image under test with dimmensions {np.shape(hs_image)}')

  x_i = torch.from_numpy(np.float32(hs_image.copy()))
  x_i = torch.permute(x_i, (2, 0, 1))
  x_n[0,:,:] = x_i

  if TRAIN == True or LOAD == True:
    # Create the ViT encoder
    v = ViT(
        image_size = height,
        patch_size = PATCH_SIZE,
        channels = layers,
        num_classes = 1,
        dim = dim_enc,
        depth = 4,
        heads = 8,
        mlp_dim = 32,
        emb_dropout = 0,
        pool = 'mean',
        verbose = 1
    )

    # Create the Masked Autoencoder model
    mae = MAE(
        encoder = v,
        masking_ratio = 0.7,   
        decoder_dim = dim_dec,     
        decoder_depth = 4,       # anywhere from 1 to 8
        decoder_heads = 8,
        decoder_dim_head = 32,
        lr = 5e-4,
        verbose = 1
    )
    print(f'>> Model for {name} created.')

  if LOAD == True:
    # load previous weights for the model
    mae.load_state_dict(torch.load(name + '_TEST.pth'))
    df_history = pd.read_csv('./training_losses_TEST_' + name + '.csv')
    mae.loss_history = df_history.MSE.values.tolist()

  if TRAIN == True:
    # train the model
    mae.training_loop(x_n, EPOCHS, autosave = 0, loss_fcn = loss)

    # save the training history
    history = (mae.loss_history)
    df_history = pd.DataFrame(history, columns=['MSE'])
    df_history.to_csv('./training_losses_TEST_' + name + '.csv')

    # save your improved vision transformer
    torch.save(mae.state_dict(), './' + name + '_TEST.pth')

    # Plot the training history
    history = (mae.loss_history)
    plt.plot(history)
    plt.title(name + ' TRAINING LOSS HISTORY')
    plt.xlabel('EPOCH')
    plt.ylabel('RL')
    plt.show()

## Evaluation of the model

In [None]:
# auxiliar functions
def reconstruct_patch(patch_vector):
    rec_patch = torch.zeros((PATCH_SIZE, PATCH_SIZE, layers))
    for patch_pixel in range(PATCH_SIZE**2):
        id_pixel = 0
        for row in range(PATCH_SIZE):
            for column in range(PATCH_SIZE):
                rec_patch[row, column, :] = patch_vector[id_pixel*layers:(id_pixel+1)*layers]
                id_pixel += 1
    return rec_patch

def iterative_gaussian_blur(image, n_iters = 5, var = 1):
  for i in range(n_iters):
    image = cv2.GaussianBlur(image, (3,3), var)
  return (image) 

In [None]:
ruta = os.path.join("data")
df_files = create_df_from_files_in_path(ruta, verbose = True)

#### Reconstruction of the HSI

In [None]:
# Test loop

list_names_imgs = []
list_original_imgs = []
list_reconstructed_imgs = [] 
list_anomaly_maps = []

list_imgs = [3]

# for id_img in df_files.index:
for id_img in list_imgs:
  name = df_files.Filename[id_img]
  print(f'>> Processing image {name}...')

  # Prepare dataset (HSI)
  layers = df_files.Layers[id_img]
  height = df_files.Height[id_img]
  width = df_files.Width[id_img]
  x_n = torch.zeros(1, layers, height, width)
  anomaly_map, hs_image = load_HSI_from_idx(id_img, df_files, verbose = True)
  hs_image = norm_img(hs_image)
  if PCA > 0:
    hs_image = reduce_HSI_dim_with_PCA(hs_image, PCA)
    hs_image = norm_img(hs_image)
  x_i = torch.from_numpy(np.float32(hs_image.copy()))
  x_i = torch.permute(x_i, (2, 0, 1))
  x_n[0,:,:] = x_i

  # Create the ViT encoder
  v = ViT(
      image_size = height,
      patch_size = PATCH_SIZE,
      channels = layers,
      num_classes = 1,
      dim = dim_enc,
      depth = 4,
      heads = 8,
      mlp_dim = 32,
      emb_dropout = 0,
      pool = 'mean',
      verbose = 1
  )

  # Create the Masked Autoencoder model
  mae = MAE(
      encoder = v,
      masking_ratio = 0.7,   
      decoder_dim = dim_dec,     
      decoder_depth = 4,       # anywhere from 1 to 8
      decoder_heads = 8,
      decoder_dim_head = 32,
      lr = 5e-4,
      verbose = 1
  )

  print(f'>> Model for {name} created.')

  # Load previous weights for the model
  mae.load_state_dict(torch.load(name + '_TEST.pth'))
  df_history = pd.read_csv('./training_losses_TEST_' + name + '.csv')
  mae.loss_history = df_history.MSE.values.tolist()

  # Reconstruct the image
  reconstructed_img = mae.reconstruct_image(x_n)

  # Reorganization of the reconstructed image
  x_hat = torch.zeros((height, width, layers))
  rec_img = reconstructed_img[0,:]
  id_patch = 0
  for row in range(int(height/PATCH_SIZE)):
      for column in range(int(width/PATCH_SIZE)):
          patch_vector = rec_img[id_patch, :]
          patch = reconstruct_patch(patch_vector)
          x_hat[row*PATCH_SIZE:(row+1)*PATCH_SIZE, column*PATCH_SIZE:(column+1)*PATCH_SIZE, :] = patch
          id_patch += 1
  img = x_hat.detach().numpy()
  img = norm_img(img)

  print(f'Image {name} reconstructed. Storing the data...')

  # Store the images
  list_names_imgs.append(name)
  list_original_imgs.append(hs_image)
  list_anomaly_maps.append(anomaly_map)
  list_reconstructed_imgs.append(img)

#### Anomaly Detection and metrics

In [None]:
ERROR_MAP = 'mse'
# ERROR_MAP = 'mae'
# ERROR_MAP = 'SAD'
smooth_iters = 2

# Test loop
for idx, img in enumerate(list_reconstructed_imgs):
  name = list_names_imgs[idx]
  hs_image = list_original_imgs[idx]
  anomaly_map = list_anomaly_maps[idx]

  print_RGB_HSI(hs_image, img_name = name)
  print_RGB_HSI(img, img_name = 'RECONSTRUCTED ' + name)

  if ERROR_MAP == 'mse':
    error_map = np.mean(np.abs(((hs_image - img)**2)), 2)
  elif ERROR_MAP == 'mae':
    error_map = np.mean(np.abs(((hs_image - img))), 2)
  elif ERROR_MAP == 'SAD':
    error_map = SAD(hs_image, img)

  # Plot the smoothed error and the results from the 2D-CFAR
  fig, ax = plt.subplots(2, 2, figsize=(9,9))

  plt.subplot(221)
  plt.imshow(anomaly_map, cmap=plt.cm.gray)
  plt.title(f'Reference Anomaly map')
  plt.axis('off')

  plt.subplot(222)
  plt.imshow(error_map, cmap=plt.cm.gray, vmin=0, vmax=1)
  plt.title(f'Error map')
  plt.axis('off')

  # Gaussian blur
  error_map_gauss = iterative_gaussian_blur(error_map, n_iters = smooth_iters, var = 0.8)
  error_map_gauss = norm_img(error_map_gauss)

  plt.subplot(223)
  plt.imshow(error_map_gauss, cmap=plt.cm.gray, vmin=0, vmax=1)
  plt.title(f'Smoothed Error map')
  plt.axis('off')

  # Detection threshold: CFAR
  anomaly_detection_map = CFAR_2D(error_map_gauss, thr_type = 'higher', filter_dims = 5, gap_pixels = 1, thr_factor = 1.75)
  anomaly_detection_map = norm_img(anomaly_detection_map)

  plt.subplot(224)
  plt.imshow(anomaly_detection_map, cmap=plt.cm.gray, vmin=0, vmax=1)
  plt.title(f'Anomaly Detection map')
  plt.axis('off')

  # Calculate pd and pfa by thresholding the CFAR
  AD_CFAR = filter_image_over_threshold(anomaly_detection_map, 0)
  pd_CFAR = compute_pd(anomaly_map, AD_CFAR)
  pfa_CFAR = compute_pfa(anomaly_map, AD_CFAR)

  # Calculate and plot ROC
  pd_list_error, pl_list_error, pfa_list_error = compute_ROC(anomaly_map.copy(), error_map_gauss.copy())
  AUC = metrics.auc(pfa_list_error, pd_list_error)

  fig = plt.figure(figsize=(7,4))
  # plt.semilogx(pfa_list_error, pl_list_error, color='k', label = 'ROC localization')
  plt.semilogx(pfa_list_error, pd_list_error, color='r', label = 'ROC pd/pfa')
  plt.scatter(pfa_CFAR, pd_CFAR, color = 'r', label = '2D-CFAR threshold')
  plt.ylabel('Probability of Detection')
  plt.xlabel('Probability of False Alarm')
  plt.title(f'ROC curve with AUC = {AUC:.4f}')
  plt.grid()
  plt.legend(loc='lower right')
  plt.show()
