In [None]:
# !git clone https://github.com/hila-chefer/Transformer-Explainability.git

import os
os.chdir(f'./Transformer-Explainability')

# !pip install einops

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
from tqdm import tqdm

In [None]:
import pandas as pd
from torch.utils.data import Dataset
import torch
import PIL
import numpy as np
from torchvision import datasets, transforms
import os

In [None]:
from baselines.ViT.ViT_LRP import sleep_base_patch16_224 as vit_SLEEP
from baselines.ViT.ViT_explanation_generator import LRP

In [None]:
from torch.utils.data import DataLoader

In [None]:
# initialize ViT pretrained with vit_SLEEP
model_path = "/tf/data_AIoT1/ViT_models/ViT-full-2023-12-09.pth"
model = vit_SLEEP(pretrained=True, checkpoint_path = model_path).cuda()
model.eval()
attribution_generator = LRP(model)

In [None]:
test_labels = "/tf/data_AIoT1/psg_image/labels/test_1209.txt"
os.path.exists(test_labels)

In [None]:
img_path = "/tf/data_AIoT1/psg_image/full_test_1116/"

In [None]:
# read test_labels
df_file_list = pd.read_csv(test_labels, sep="\t", header=None)
files = df_file_list[0].tolist()

In [None]:
# create patient_dict
patient_dict = {} # number of vectors that is from the same patient
prev_patient = None

for file in files:
    patient = file[0:16]
    if patient == prev_patient:
        patient_dict[patient] += 1
    else:
        patient_dict[patient] = 1
    prev_patient = patient

In [None]:
df_file_list

In [None]:
import pandas as pd
from torch.utils.data import Dataset
import torch
import PIL
import numpy as np
from torchvision import datasets, transforms
import os

class IntraEpochDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_dir = img_dir
        df = pd.read_csv(annotations_file, sep="\t", header=None)
        self.labels = dict(zip(df[0], df[1])) # file_name will be the key labels will be the value
        self.image_filenames = list(self.labels.keys())
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        image_path = os.path.join(self.img_dir, image_name)

        image = torch.from_numpy(np.load(image_path).astype(np.float32))
        label = int(self.labels[image_name])

        return image, label

In [None]:
test_dataset = IntraEpochDataset(test_labels, img_path)
len(test_dataset)

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [None]:
images, labels = next(iter(test_dataloader))

In [None]:
images.shape

In [None]:
ex_img = images[0]

In [None]:
ex_lbl = labels[0].item()

In [None]:
labels_map = {0:'Wake', 1:'N1', 2:'N2', 3:'N3', 4:'REM'}

In [None]:
plt.imshow(ex_img.permute(1, 2, 0))
plt.title(labels_map[ex_lbl])
plt.show()

In [None]:
def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    # print("transformer_attribution before reshaping", transformer_attribution)
    # print("sum up the attribution", transformer_attribution.sum())
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    # interpolate -> upsampling
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    # print("after normalization", transformer_attribution)
    # print("max and min", transformer_attribution.max(), transformer_attribution.min())
#     if use_thresholding:
#         transformer_attribution = transformer_attribution * 255
#         transformer_attribution = transformer_attribution.astype(np.uint8)
#         ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
#         transformer_attribution[transformer_attribution == 255] = 1

    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def print_top_classes(predictions, **kwargs):    
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    # print("softmax values", prob)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 4
    class_names = ['Wake', 'N1', 'N2', 'N3', 'REM']

    
    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, class_names[cls_idx])
        output_string += ' ' * (max_str_len - len(class_names[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)

In [None]:
def get_predictions(output):
    labels_map = {0:'Wake', 1:'N1', 2:'N2', 3:'N3', 4:'REM'}
    
    prob = torch.softmax(output, dim=1)
    _, pred = prob.max(dim=1)
    pred = labels_map[pred.item()]
    
    return pred

In [None]:
# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

## Bigger Image

In [None]:
# show original image and ground truth label
plt.imshow(image.permute(1,2,0))
plt.title(f"Ground Truth: {labels_map[ex_lbl]}")
plt.axis("off")

plt.show()

# create output and predictions
output = model(image.unsqueeze(0).cuda())
print_top_classes(output)
pred = get_predictions(output)
im = generate_visualization(image)
plt.imshow(im)

# show prediction and explanations
plt.title(f"Prediction: {pred}")
plt.axis("off")

plt.show()

In [None]:
def show_inidivdual(img_path, labels):
    image = torch.from_numpy(np.load(img_path).astype(np.float32))
    ex_lbl = labels[ex_path.split('/')[-1]]
    
    # show original image and ground truth label
    plt.imshow(image.permute(1,2,0))
    plt.title(f"Ground Truth: {labels_map[ex_lbl]}")
    plt.axis("off")

    plt.show()

    # create output and predictions
    output = model(image.unsqueeze(0).cuda())
    print_top_classes(output)
    pred = get_predictions(output)
    im = generate_visualization(image)
    plt.imshow(im)

    # show prediction and explanations
    plt.title(f"Prediction: {pred}")
    plt.axis("off")

    plt.show()

In [None]:
show_inidivdual('/tf/data_AIoT1/psg_image/full_test_1116/A2020-NX-01-0179_0279.npy', labels)

In [None]:
class IntraEpochDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.img_dir = img_dir
        df = pd.read_csv(annotations_file, sep="\t", header=None)
        self.labels = dict(zip(df[0], df[1])) # file_name will be the key labels will be the value

        self.image_paths = []
        for path in os.listdir(self.img_dir):
            self.image_paths.append(path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_name = self.image_paths[idx]
        image_path = os.path.join(self.img_dir, image_name)

        image = torch.from_numpy(np.load(image_path).astype(np.float32))
        label = int(self.labels[image_name])

        return image, label

In [None]:
test_dataset = IntraEpochDataset(test_label_path, test_image_path)

In [None]:
print('Total test images: ', len(test_dataset))

In [None]:
from torch.utils.data import DataLoader
test_dataloader = DataLoader(test_dataset, batch_size=80, shuffle=True, num_workers=8)

In [None]:
images, labels = next(iter(test_dataloader))
images[0].shape
print(labels)

In [None]:
labels_map = {0:'Wake', 1:'N1', 2:'N2', 3:'N3', 4:'REM'}

In [None]:
sample_idx = torch.randint(len(images), size=(1,)).item()
img, label = images[sample_idx], labels[sample_idx].item()
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.permute((1, 2, 0)))
plt.show()

In [None]:
torch.manual_seed(42)

In [None]:
def visualize_image_heatmap(images, labels):
    # image = torch.from_numpy(np.load(ex_path).astype(np.float32)/255.0)
    for image, label in zip(images, labels):
        # show original image and ground truth label
        plt.imshow(image.permute(1, 2, 0))
        label = label.tolist()
        plt.title(f"Ground Truth: {labels_map[label]}")
        plt.axis("off")

        plt.show()

        # create output and predictions
        output = model(image.unsqueeze(0).cuda())
        # print_top_classes(output)
        pred = get_predictions(output)
        im = generate_visualization(image)
        plt.imshow(im)

        # show prediction and explanations
        plt.title(f"Prediction: {pred}")
        plt.axis("off")

        plt.show()

In [None]:
visualize_image_heatmap(images, labels)

In [None]:
images, labels = next(iter(test_dataloader))

In [None]:
visualize_image_heatmap(images, labels)

# Visualize only correct guesses & aggregate
- see the events of that epoch 
- event labels

In [None]:
use_thresholding = True

In [None]:
class_heatmaps = {class_index: None for class_index in range(5)}
class_num = {class_index: None for class_index in range(5)}

In [None]:
def transformer_attribution(original_image, class_index=None):
    # generate mask
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    
    if use_thresholding:
        transformer_attribution = transformer_attribution * 255
        transformer_attribution = transformer_attribution.astype(np.uint8)
        ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        transformer_attribution[transformer_attribution == 255] = 1
        
    return transformer_attribution

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

num_samples = 10000

y_pred = []
y_true = []

for inputs, labels in tqdm(test_dataloader):
    inputs = inputs.cuda()
    labels = labels.data.cpu().numpy()

    # Feed Network and get predictions
    output = model(inputs)
    output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
    y_pred.extend(output)  # Save Prediction
    y_true.extend(labels)  # Save Truth

    # Calculate heatmap for each correct prediction in this batch
    correct_mask = (output == labels)
    # print(output, labels)
    # print(correct_mask)
    correct_indices = correct_mask.nonzero()[0]
    # print(correct_indices)
    for idx in correct_indices:
        # print(idx)
        i = idx.item()  # Extract the index as a scalar
        heatmap = transformer_attribution(inputs[i])
        lbl = labels[i]
        
        # save until 10000 labels for each class
        if class_heatmaps[lbl] is None:
            class_heatmaps[lbl] = heatmap / num_samples
            class_num[lbl] = 1
        elif class_num[lbl] < num_samples:
            class_heatmaps[lbl] += (heatmap / num_samples)
            class_num[lbl] += 1
        else:
            continue

## confusion matrix

In [None]:
# constant for classes
classes = ('Wake', 'N1', 'N2', 'N3', 'REM')

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1).reshape((5, 1)), index = [i for i in classes],
                    columns = [i for i in classes])
plt.figure(figsize = (5,4))
sn.heatmap(df_cm, annot=True, cmap="Blues")

## f1 score

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred, target_names=classes))

In [None]:
print(class_num)