In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from utils import dataset_test, transform, bc_config, models
from captum.attr import LayerGradCam, LRP
import yaml
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def load_pretrained_model(args_dict, fold, magnification):
    
    
    encoder = args_dict["encoder"]["name"]
    version = args_dict["encoder"]["version"]
    dropout = args_dict["encoder"]["fc_dropout"]
    
    downstream_task_model = None
    if "resnet" == encoder:
        downstream_task_model = models.ResNet_Model(version=int(version), pretrained=True)
        num_ftrs = downstream_task_model.num_ftrs
        downstream_task_model.model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(num_ftrs, 1))
    
    downstream_task_model = downstream_task_model.to(device)
    
    # Load weights (You need to specify the path properly)
    model_path = os.path.join(args_dict["results"]["result_base_path"], f"weights_{fold}_{magnification}.pth")
    downstream_task_model.load_state_dict(torch.load(model_path))
    downstream_task_model.eval()
    
    return downstream_task_model

In [None]:
def load_config(config_path):
    """Load configuration information from a yaml file."""
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

In [None]:
def compute_gradcam(model, dataloader, target_layer):
    grad_cam = LayerGradCam(model, target_layer)
    attributions = []
    
    for inputs, _ in dataloader:
        inputs = inputs.to(device)
        attribution = grad_cam.attribute(inputs, target=1) # Assuming binary classification
        attributions.append(attribution)
    
    return attributions

In [None]:
if __name__ == "__main__":
    config = load_config("imagenet_run.yaml")
    
    for fold in list(config["computational_infra"]["fold_to_gpu_mapping"].keys()):
        model = load_pretrained_model(config, fold, '400X')
        
        # Load your data (maybe validation data)
        val_loader = dataset_test.get_breakhis_data_loader(
        dataset_path=os.path.join(data_path, fold, 'test_20'),
        transform=transform.resize_transform,
        pre_processing=[],
        image_type_list=['400X'],
        num_workers=2,
        is_test=True
    )
        
        # Compute GradCAM attributions
        target_layer = model.model.layer4  # You need to specify the target layer for GradCAM
        gradcam_attributions = compute_gradcam(model, val_loader, target_layer)
        