In [None]:
import argparse
import datetime
import json

import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from transformers import (
    ViTImageProcessor, ViTForImageClassification,
    AutoImageProcessor, EfficientNetForImageClassification,
    ResNetForImageClassification
)

import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset,DistributedSamplerWrapper,TransformWrapper
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.losses import FocalLoss, compute_alpha_from_labels
from util.evaluation import InsertionMetric, DeletionMetric
from baselines.Attention import Attention_Map
from baselines.GradCAM import GradCAM
from baselines.RISE import RISE, RISEBatch
from huggingface_hub import hf_hub_download, login
from engine_finetune import evaluate_half3D, train_one_epoch, evaluate

import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image

from huggingface_hub import hf_hub_download
from zennit.image import imgify
from zennit.composites import LayerMapComposite
import zennit.rules as z_rules
from lxt.efficient import monkey_patch, monkey_patch_zennit

np.set_printoptions(threshold=np.inf)
np.random.seed(1)
torch.manual_seed(1)

  warn("'zennit' library is not available. Please install it to use for vision transformers.")


<torch._C.Generator at 0x14aae40c5bb0>

In [None]:
processor = None
patch_size = None
input_size = 224
nb_classes = 2
finetune = "RETFound_mae_natureOCT"
resume = "output_dir/Cataract_all_split-IRB2024_v4-all-RETFound_mae_natureOCT-OCT-bs16ep50lr5e-4optadamw-roc_auceval--/checkpoint-best.pth"

model = models.__dict__['RETFound_mae'](
        img_size=input_size,
        num_classes=nb_classes,
        drop_path_rate=0.2,
        global_pool=True,
        )
patch_size = 16
print(f"Downloading pre-trained weights from: {finetune}")
checkpoint_path = hf_hub_download(
    repo_id=f'YukunZhou/{finetune}',
    filename=f'{finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
    checkpoint_model = checkpoint['teacher']
else:
    checkpoint_model = checkpoint['model']
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
processor = None
#Load fine-tune model
checkpoint = torch.load(resume, map_location='cpu')
if 'model' in checkpoint:
    checkpoint_model = checkpoint['model']
else:
    checkpoint_model = checkpoint
model.load_state_dict(checkpoint_model, strict=False)
print("Resume checkpoint %s" % resume)

In [None]:
[name_list,feature]=get_feature(data_path,
                chkpt_dir,
                device,
                arch=arch)

In [23]:
#save the feature
df_feature = pd.DataFrame(feature)
df_imgname = pd.DataFrame(name_list)
df_visualization = pd.concat([df_imgname,df_feature], axis=1)
column_name_list = []

for i in range(1024):
    column_name_list.append("feature_{}".format(i))
df_visualization.columns = ["name"] + column_name_list
df_visualization.to_csv("Feature.csv",index=False)

# LXT sample provided

In [None]:
import torch
import itertools
from PIL import Image
from torchvision.models import vision_transformer

from zennit.image import imgify
from zennit.composites import LayerMapComposite
import zennit.rules as z_rules

from lxt.efficient import monkey_patch, monkey_patch_zennit

# Modify the Vision Transformer module to compute Layer-wise Relevance Propagation (LRP)
# in the backward pass. For ViTs, we utilize the LRP Gamma rule. It is implemented
# inside the 'zennit' library. To make it compatible with LXT, we also monkey patch it. That's it.
monkey_patch(vision_transformer, verbose=True)
monkey_patch_zennit(verbose=True)


def get_vit_imagenet(device="cuda"):
    """
    Load a pre-trained Vision Transformer (ViT) model with ImageNet weights.
    
    Parameters:
    device (str): Device to load the model on ('cuda' or 'cpu')
    
    Returns:
    tuple: (model, weights) - The ViT model and its pre-trained weights
    """
    weights =vision_transformer.ViT_B_16_Weights.IMAGENET1K_V1
    model = vision_transformer.vit_b_16(weights=weights)
    model.eval()
    model.to(device)
    
    # Deactivate gradients on parameters to save memory
    for param in model.parameters():
        param.requires_grad = False
        
    return model, weights

# Load the pre-trained ViT model
model, weights = get_vit_imagenet()

# Load and preprocess the input image
image = Image.open('docs/source/_static/cat_dog.jpg').convert('RGB')
input_tensor = weights.transforms()(image).unsqueeze(0).to("cuda")

# Store the generated heatmaps
heatmaps = []

# Experiment with different gamma values for Conv2d and Linear layers
# Gamma is a hyperparameter in LRP that controls how much positive vs. negative
# contributions are considered in the explanation
for conv_gamma, lin_gamma in itertools.product([0.1, 0.25, 100], [0, 0.01, 0.05, 0.1, 1]):
    input_tensor.grad = None  # Reset gradients
    print("Gamma Conv2d:", conv_gamma, "Gamma Linear:", lin_gamma)
    
    # Define rules for the Conv2d and Linear layers using 'zennit'
    # LayerMapComposite maps specific layer types to specific LRP rule implementations
    zennit_comp = LayerMapComposite([
        (torch.nn.Conv2d, z_rules.Gamma(conv_gamma)),
        (torch.nn.Linear, z_rules.Gamma(lin_gamma)),
    ])
    
    # Register the composite rules with the model
    zennit_comp.register(model)
    
    # Forward pass with gradient tracking enabled
    y = model(input_tensor.requires_grad_())
    
    # Get the top 5 predictions
    _, top5_classes = torch.topk(y, 5, dim=1)
    top5_classes = top5_classes.squeeze(0).tolist()
    
    # Get the class labels
    labels = weights.meta["categories"]
    top5_labels = [labels[class_idx] for class_idx in top5_classes]
    
    # Print the top 5 predictions and their labels
    for i, class_idx in enumerate(top5_classes):
        print(f'Top {i+1} predicted class: {class_idx}, label: {top5_labels[i]}')
    
    # Backward pass for the highest probability class
    # This initiates the LRP computation through the network
    y[0, top5_classes[0]].backward()
    
    # Remove the registered composite to prevent interference in future iterations
    zennit_comp.remove()
    
    # Calculate the relevance by computing Gradient * Input
    # This is the final step of LRP to get the pixel-wise explanation
    heatmap = (input_tensor * input_tensor.grad).sum(1)
    
    # Normalize relevance between [-1, 1] for plotting
    heatmap = heatmap / abs(heatmap).max()
    
    # Store the normalized heatmap
    heatmaps.append(heatmap[0].detach().cpu().numpy())

# Visualize all heatmaps in a grid (3Ã—5) and save to a file
# vmin and vmax control the color mapping range
imgify(heatmaps, vmin=-1, vmax=1, grid=(3, 5)).save('vit_heatmap.png')