In [None]:
import os
os.chdir(f'./Transformer-Explainability')

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from tqdm import tqdm

In [None]:
import pandas as pd
from torch.utils.data import Dataset
import torch
import PIL
import numpy as np
from torchvision import datasets, transforms
import os
import ast
import random

In [None]:
from baselines.ViT.ViT_Seq import sleep_inter_epoch as ViT_SLEEP_Inter
from baselines.ViT.Seq_explanation_generator import LRP

# 1. Load model with LRP to device

In [None]:
# initialize InterEpoch ViT pretrained with vit_SLEEP
model_path = "/tf/data_AIoT1/ViT_models/Seq-Full-15-relprop-final-relu-2023-12-18.pth"
model = ViT_SLEEP_Inter(pretrained=True, checkpoint_path = model_path).cuda()
model.eval()
attribution_generator = LRP(model)

In [None]:
model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 2. Argument setting

In [None]:
img_dir = "/tf/data_AIoT1/psg_image"
test_labels = "/tf/data_AIoT1/psg_image/labels/test_1209.txt"
workers = 4
batch_size = 128
num_seq = 15
save_path = "/tf/data_AIoT1/psg_image_codes/confusion_matrix/"

In [None]:
# read test_labels
df_file_list = pd.read_csv(test_labels, sep="\t", header=None)
files = df_file_list[0].tolist()

In [None]:
# create patient_dict
patient_dict = {} # number of vectors that is from the same patient
prev_patient = None

for file in files:
    patient = file[0:16]
    if patient == prev_patient:
        patient_dict[patient] += 1
    else:
        patient_dict[patient] = 1
    prev_patient = patient

# 3. Functions
- generate attribution

In [None]:
def generate_attributions(image, num_seq):
    attributions = []

    for idx in range(num_seq):
        transformer_attribution = attribution_generator.generate_LRP(image.unsqueeze(0).cuda(), method="transformer_attribution", index=None, seq_idx=idx).detach()
        transformer_attribution = transformer_attribution.squeeze()
        # print("transformer_attribution", transformer_attribution)
        # print("transformer_attribution_sum", transformer_attribution.sum())
        # mask mine
        transformer_attribution[idx] = 0
        # print("transformer_attribution_sum_after_masking", transformer_attribution.sum())
        
        # mask mine
        transformer_attribution[idx] = 0
        
        # second smallest
        values, indices = torch.topk(transformer_attribution, k=2, largest=False)
        second_smallest_value = values[1].item()
        
        # normalize across each token
        # min_value to be 0 and max_value to be 1
        transformer_attribution = (transformer_attribution - second_smallest_value) / (transformer_attribution.max() - second_smallest_value)
        
        transformer_attribution = transformer_attribution.data.cpu().numpy()
        # print("after_normalization", transformer_attribution)
        # print("transformer_attribution_sum_after_normalization", transformer_attribution.sum())
        # set mine attention as 1
        transformer_attribution[idx] = 1
        attributions.append(transformer_attribution)

    # Normalize
    # attributions = (attributions - np.min(attributions)) / (np.max(attributions) - np.min(attributions))
    # print(attributions)
    return attributions

In [None]:
def sliding_windows_relprop(patient_vector, num_seq, batch_size, length, model, device):
    """
    [INPUT]
    patient_vector : extracted feature vectors per image and concatenated from one patient
    num_seq : sliding window's kernel size
    batch_size : batch size to be processed
    length : length of the total epoch of one patient
    model : trained model
    
    [OUTPUT]
    final_pred = final predictions with softmax for each sliding window
    final_attributions = Patient's attributions for each epoch.
    final_agg = final aggregated softmax value - one value per epoch
    final_lbl = final prediction label for each sliding window
    """
    final_attributions = [[] for _ in range(length)]
    final_pred = [[] for _ in range(length)]
    final_agg = [torch.tensor([0, 0, 0, 0, 0], dtype=torch.float32, device=device) for _ in range(length)]
    final_lbl = [[] for _ in range(length)]
    current_seq = []

    for i in tqdm(range(length - num_seq + 1)):
        current_seq.append(patient_vector[i:i+num_seq])
        current_lbl = lbl[i:i+num_seq]
        
        # iterate for labels
        for j in range(i, i+num_seq):
            final_lbl[j].append(current_lbl)
            
        if (i+1) % batch_size == 0:
            batch_seq = torch.stack(current_seq).to(device)
            current_seq = []
            output = model(batch_seq)

            for k in range(batch_size):  # batch
                attributions = generate_attributions(batch_seq[k], num_seq)
                temp_pred = []
                for j in range(num_seq):
                    # append the output of the predictions(?)
                    # append the attributions
                    # append the current predictions and total aggregated prediction
                    final_agg[k + j + i - (batch_size-1)].add_(torch.softmax(output[j][k], dim=0))
                    final_attributions[k + i + j - (batch_size-1)].append(attributions[j])
                    soft_val = torch.softmax(output[j][k], dim=0)
                    _, predict = soft_val.max(dim=0)
                    temp_pred.append(predict.item())
                for p in range(k + i - (batch_size-1), k + i - (batch_size-1) + num_seq):
                    final_pred[p].append(temp_pred)
                
    
    # for remainder
    if len(current_seq) != 0:
        batch_seq = torch.stack(current_seq).to(device)
        output_2 = model(batch_seq)

        new_i = batch_size * ((length - num_seq + 1)//batch_size)

        for k in range(output_2[0].shape[0]):  # batch
            attributions_2 = generate_attributions(batch_seq[k], num_seq)
            temp_pred = []
            for j in range(num_seq):  # sequence length
                final_agg[k + j + new_i].add_(torch.softmax(output_2[j][k], dim=0))
                final_attributions[k + j + new_i].append(attributions_2[j])
                # get prediction by sequence
                soft_val = torch.softmax(output_2[j][k], dim=0)
                _, predict = soft_val.max(dim=0)
                temp_pred.append(predict.item())
            for p in range(k + new_i, k + new_i + num_seq):
                final_pred[p].append(temp_pred)
    
    return final_pred, final_attributions, final_agg, final_lbl

In [None]:
for patient, length in patient_dict.items():
    hospital = '-'.join(patient.split('-')[:-1])
    patient_folder = os.path.join(img_dir, hospital, "vectors_v2", patient) # change vectors_v2 here
    if not os.path.exists:
        print(patient_folder)
    temp = []
    lbl = []
    patient_epoch = df_file_list[df_file_list[0].str.startswith(patient)]
    sorted_epoch = patient_epoch.sort_values(by=0)
    # sorted_epoch = sorted(os.listdir(patient_folder))
    # print(sorted_epoch)
    for index, row in sorted_epoch.iterrows():
        # print(f"Loading {i}th vector")
        file_name = row[0][:-4] + "_" + str(row[1]) + ".npy"
        vector_np = torch.from_numpy(np.load(os.path.join(patient_folder, file_name)).astype(np.float32))
        temp.append(vector_np)
        lbl.append(row[1])
        # print(len(temp))
    patient_vector = torch.stack(temp)

    final_pred, final_attributions, final_agg, final_lbl = sliding_windows_relprop(patient_vector=patient_vector,
                                                                        num_seq=num_seq,
                                                                        batch_size = batch_size,
                                                                        length = length, model = model,
                                                                        device=device)
    # do for just one patient
    break

## Change the middle labeling

In [None]:
from matplotlib.patheffects import withStroke
import matplotlib.patches as patches

In [None]:
def visualization_for_v2(final_pred, final_attributions, final_agg, final_lbl):
    
    # get aggregated predictions for current epoch
    labels_map = {0:'W', 1:'N1', 2:'N2', 3:'N3', 4:'R'}

    for epoch in range(len(final_pred)):
        pred = final_pred[epoch] # predictions vector
        attributions = final_attributions[epoch] # attributions vector
        agg = final_agg[epoch] # aggregated prediction for this epoch(1)
        lbl = final_lbl[epoch] # aggregated label (1)

        rolled_attributions = [[0 for _ in range(num_seq * 2 - 1)] for _ in range(len(pred))]

        for i in range(len(attributions)):
            for j in range(i, i+num_seq):
                rolled_attributions[i][j] = attributions[i][j-i]
                
        rolled_predictions = [[None for _ in range(num_seq * 2 - 1)] for _ in range(len(attributions))]

        for i in range(len(pred)):
            for j in range(i, i+num_seq):
                # make the target sequence as 0
                if epoch < 14:
                    rolled_attributions[i][epoch] = 0
                else:
                    rolled_attributions[i][14] = 0
                rolled_predictions[i][j] = labels_map[pred[i][j-i]]
                
        new_lbl = lbl[0][:-1] + lbl[-1]
                
        # Create a figure and axis object
        fig, ax = plt.subplots(figsize=(10, 8))

        # Plot the rolled_attributions
        cmap = plt.cm.get_cmap('hot_r')
        im = ax.imshow(rolled_attributions, cmap=cmap, aspect='auto')
        ax.set_aspect('equal')

        # Add colorbar
        plt.colorbar(im, shrink=0.5)

        # Calculate predictions and set title
        prob = torch.softmax(agg, dim=0)
        _, pred = prob.max(dim=0)
        pred_label = labels_map[pred.item()]
        ax.set_title(f'Epoch {epoch} | Prediction: {pred_label}', fontsize = 10)

        # Set labels
        ax.set_ylabel('Sliding Windows')
        ax.set_xlabel('GT Labels')

        # Set ticks
        ax.set_xticks(np.arange(len(new_lbl)))
        ax.set_yticks(np.arange(num_seq))

        labels = []
        for l in new_lbl:
            # print(l)
            labels.append(labels_map[l])

        ax.set_xticklabels(labels)


        # Put predictions labels on the cell
        for i in range(len(rolled_predictions)):
            for j, txt in enumerate(rolled_predictions[i]):
                # ax.text(j, i, txt, color="white", ha="center", va="center")
                if epoch < 14 and j == epoch:
                    ax.text(j, i, txt, color="black", ha="center", va="center")
                elif epoch >= 14 and j == 14:
                    ax.text(j, i, txt, color="black", ha="center", va="center")
                else:
                    ax.text(j, i, txt, color="white", ha="center", va="center")
        
        # Add borders around each cell
        for i in range(len(rolled_attributions)):
            for j in range(len(rolled_attributions[i])):
                if epoch < 14 and j == epoch:
                    rect = patches.Rectangle((j - 0.5, i - 0.5), 1, 1, linewidth=1, edgecolor='black', facecolor='none')
                    ax.add_patch(rect)
                elif epoch >= 14 and j == 14:
                    rect = patches.Rectangle((j - 0.5, i - 0.5), 1, 1, linewidth=1, edgecolor='black', facecolor='none')
                    ax.add_patch(rect)
                else:
                    continue
            
        # Show the plot
        plt.show()

In [None]:
visualization_for_v2(final_pred, final_attributions, final_agg, final_lbl)

## Do for the second patient

In [None]:
patient_dict

In [None]:
def generate_inter_epoch_visualization(patient, length):
    hospital = '-'.join(patient.split('-')[:-1])
    patient_folder = os.path.join(img_dir, hospital, "vectors_v2", patient) # change vectors_v2 here
    if not os.path.exists:
        print(patient_folder)
    temp = []
    lbl = []
    patient_epoch = df_file_list[df_file_list[0].str.startswith(patient)]
    sorted_epoch = patient_epoch.sort_values(by=0)
    # sorted_epoch = sorted(os.listdir(patient_folder))
    # print(sorted_epoch)
    for index, row in sorted_epoch.iterrows():
        # print(f"Loading {i}th vector")
        vector_np = torch.from_numpy(np.load(os.path.join(patient_folder, row[0])).astype(np.float32))
        temp.append(vector_np)
        lbl.append(row[1])
        # print(len(temp))
    patient_vector = torch.stack(temp)

    final_pred, final_attributions, final_agg, final_lbl = sliding_windows_relprop(patient_vector=patient_vector,
                                                                        num_seq=num_seq,
                                                                        batch_size = batch_size,
                                                                        length = length, model = model,
                                                                        device=device)
    
    return final_pred, final_attributions, final_agg, final_lbl

In [None]:
random.seed(31)
patient, length = random.choice(list(patient_dict.items()))
print(patient, length)

In [None]:
random_pred, random_attributions, random_agg, random_lbl = generate_inter_epoch_visualization(patient, length)

In [None]:
visualization_for_v2(random_pred, random_attributions, random_agg, random_lbl)