# LRP Medit (Resnet) Notebook
In this notebook we deploy our LRPEngine on the **Resnet-based wound classification model**. As comprehensibly described in one provided paper (see "literature/lrp_resnet"), the implementation of LRP on models with residual connections (as ResNet) is non-trivial.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torchvision.transforms import v2
import cv2
import json
import os
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from src.xai.lrp import LRPEngine
import medit_classifier.medit_resnet_model as medit_classifier

In [None]:
# constant variables in capital letters:
MEDIT_PATH = "../medit_classifier/"
OUTPUT_PATH = "../output/241128 lrp_heatmaps/"
STAGE1_OUTPUT = OUTPUT_PATH + "stage1/"
STAGE2_OUTPUT = OUTPUT_PATH + "stage2/"
STAGE3_OUTPUT = OUTPUT_PATH + "stage3/"
STAGE4_OUTPUT = OUTPUT_PATH + "stage4/"

MODEL_PATH = MEDIT_PATH + 'Resnet50_v2_learning rate_0.001_loss_function_CrossEntropyLoss_batch_size_16_dropout_False_20241027_023602_savedepoch_2.pth'
CONFIG_PATH = MEDIT_PATH + 'Resnet50_v2_learning rate_0.001_loss_function_CrossEntropyLoss_batch_size_16_dropout_False_20241027_023602.json'
PRED_PATH = MEDIT_PATH + 'saved/'

DATASET = MEDIT_PATH + 'Dataset_reduced/'
STAGE1_DATA = DATASET + '1/'
STAGE2_DATA = DATASET + '2/'
STAGE3_DATA = DATASET + '3/'
STAGE4_DATA = DATASET + '4/'

## 1. Model Initialisation

In [None]:
model = medit_classifier.initialise_model(CONFIG_PATH, MODEL_PATH)

## 2. LRP Workflow

In [None]:
def lrp_workflow(classifier_model, image_directory, output_directory=None, rel_filter_ratio=0.75, layers_to_inspect=[0, 10, 20], check_already_calculated=True):
    # initialise LRP engine:
    lrp_engine = LRPEngine(classifier_model, plot_output_dir=output_directory)
    
    # already calculated check can be overwritten by start_after and requires an output dict:
    if check_already_calculated and output_directory is not None:
        # use set to prevent duplicates:
        files_calculated = {title.split('Calc ')[1].split(' class')[0] for title in os.listdir(output_directory)}

    # iterate over files (tqdm yields a progress bar but requires the len of iterable (total) if it's a generator)
    for i, file in tqdm(enumerate(os.listdir(image_directory)), total=len(os.listdir(image_directory)), ncols=80, position=0, leave=True):
        file_title = file.split('.')[0]
        if check_already_calculated and output_directory is not None:
            if file_title in files_calculated:
                print(f"{file_title} already calculated. Skipping!")
                continue
            
        # load image:
        image_path = os.path.join(image_directory, file)
        try:
            _, img_tensor = medit_classifier.load_img_and_tensor(image_path)
        except BaseException as err:
            print(err)
            continue

        # calculate LRP:
        lrp_engine.input_batch = img_tensor.unsqueeze(0)
        lrp_engine.calculate_relevance_scores(rel_filter_ratio=rel_filter_ratio)

        # plot and save output results:
        for layer_ind in layers_to_inspect:
            lrp_engine.plot_relevance_scores(layer_ind, input_reference=file_title, hidden=True, plt_cmap='bwr')

### 2.1. Stage 1 Picture Analysis

In [None]:
lrp_workflow(model, STAGE1_DATA, STAGE1_OUTPUT)

### 2.2. Stage 2 Picture Analysis

In [None]:
lrp_workflow(model, STAGE2_DATA, STAGE2_OUTPUT)

### 2.3. Stage 3 Picture Analysis

In [None]:
lrp_workflow(model, STAGE3_DATA, STAGE3_OUTPUT)

### 2.3. Stage 4 Picture Analysis

In [None]:
lrp_workflow(model, STAGE4_DATA, STAGE4_OUTPUT)