In [None]:
import torch
import torch.nn.functional as F
from custom_dataset import METEORDataLayer
import transformer_models
from custom_utils import generate_dict, ModelConfig

EXPERIMENT_PATH = 'experiments/att_back/...'

In [None]:
config_dict = generate_dict(os.path.join(EXPERIMENT_PATH, 'log_dist.txt'))
args = ModelConfig(**config_dict)

model = transformer_models.VisionTransformer_v3(args=args, img_dim=args.enc_layers,
                                             patch_dim=args.patch_dim,
                                             out_dim=args.numclass,
                                             embedding_dim=args.embedding_dim,
                                             num_heads=args.num_heads,
                                             num_layers=args.num_layers,
                                             hidden_dim=args.hidden_dim,
                                             dropout_rate=args.dropout_rate,
                                             attn_dropout_rate=args.attn_dropout_rate,
                                             num_channels=args.dim_feature,
                                             positional_encoding_type=args.positional_encoding_type,
                                             with_motion=args.use_flow
                                                )


In [None]:
def grad_cam_multilabel(model, rgb_extractor, flow_extractor, input_rgb, input_flow, target_classes=4):
    # Set model to evaluation mode and enable gradient calculation
    model.eval()
    model.zero_grad()
    input_rgb.requires_grad_()
    input_flow.requires_grad_()

    # Perform forward pass and get the last self-attention layer output
    logits, _ = model(input_rgb, input_flow)
    x_last_attention = model.x_last_attention

    # Create an empty heatmap
    combined_heatmap = torch.zeros_like(x_last_attention[:, -1, :, :])
    heatmaps = []
    for target_class in target_classes:
        # Compute gradients for the target class
        logits[:, target_class].backward(retain_graph=True)

        # Get gradients of the last self-attention layer output
        gradients = input_rgb.grad

        # Calculate Grad-CAM weights
        weights = torch.mean(gradients, dim=(2, 3))

        # Create a heatmap using the Grad-CAM weights and the last self-attention layer output
        heatmap = torch.zeros_like(x_last_attention[:, -1, :, :])
        for i, weight in enumerate(weights):
            heatmap += weight * x_last_attention[:, -1, :, :]

        # Normalize the heatmap and add it to the combined heatmap
        heatmap = F.relu(heatmap)
        heatmap /= torch.max(heatmap)
        combined_heatmap += heatmap

    # Normalize the combined heatmap
    combined_heatmap /= torch.max(combined_heatmap)

    return combined_heatmap.detach().cpu().numpy()
