In [12]:
import os,yaml
from pathlib import Path
def load_config(config_path):
    """Load configuration information from a yaml file."""
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

def find_best_weights_for_magnification(root_dir, magnification, num_folds=5):
    """
    Iterate through each fold and find the best weight file based on validation accuracy.

    Args:
    - root_dir (str): Root directory containing fold sub-directories.
    - magnification (str): Desired magnification like "400X".
    - num_folds (int): Number of folds. Default is 5.

    Returns:
    - str: Path to the best weight file across all folds.
    """
    
    bc_config = load_config("run_imagenet.yaml")
    weight_files = []
    
    for fold in list(bc_config["computational_infra"]["fold_to_gpu_mapping"].keys()):
        fold_dir = Path(root_dir) / f"_{fold}_{magnification}_BreakHis_FT_60_resnet50_imagenet_"
        # Gather all weight files for this fold
        fold_weight_files = list(fold_dir.glob('*.pth'))
        weight_files.extend(fold_weight_files)
        
    # Sort the weight files based on validation accuracy which is the second value in the filename
    sorted_files = sorted(weight_files, key=lambda x: float(x.stem.split('_')[1]), reverse=True)
    print(sorted_files)
    
    # Return the path of the weight file with the highest validation accuracy
    return str(sorted_files[0]) if sorted_files else None

# Example usage
root_directory = "/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet"
magnification = "400X"
best_weight = find_best_weights_for_magnification(root_directory, magnification)
print('Best weight:', best_weight)
        
        
        



[PosixPath('/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_4_5_400X_BreakHis_FT_60_resnet50_imagenet_/_75_85.25798525798525_86.29455539952498_0.8421162962913513.pth'), PosixPath('/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_1_5_400X_BreakHis_FT_60_resnet50_imagenet_/_68_91.83168316831683_89.45802206500767_0.9190493822097778.pth'), PosixPath('/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_0_5_400X_BreakHis_FT_60_resnet50_imagenet_/_55_89.33717579250721_86.90818840905345_0.893511176109314.pth'), PosixPath('/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_2_5_400X_BreakHis_FT_60_resnet50_imagenet_/_37_96.07250755287009_95.06917631917632_0.9614846110343933.pth'), PosixPath('/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_3_5_400X_BreakHis_FT_60_resnet50_imagenet_/_9_86.70694864048339_88.22885510136258_0.8655245304107666.pth')]
Best weight: /home/student/Desk

In [None]:
def evaluate_test_set(self, test_loader):
    confusion_matrix_test = torch.zeros(len(bc_config.binary_label_list), len(bc_config.binary_label_list))
    self.model.eval()
    with torch.no_grad():
        for patient_id, magnification, item_dict, binary_label, multi_label in tqdm(test_loader):
            view = item_dict[magnification[0]]
            view = view.cuda(self.device, non_blocking=True)                
            target = binary_label.to(self.device)
            outputs = self.model(view)
            outputs = outputs.squeeze(1)
            target = target.type_as(outputs)
            
            # Since it's testing, no need for loss calculation
            predicted = (outputs > self.threshold).int()
            
            for targetx, predictedx in zip(target.view(-1), predicted.view(-1)):
                confusion_matrix_test[(targetx.long(), predictedx.long())] += 1

    # Extract metrics from the confusion matrix similar to your validation function
    weighted_f1, accuracy, _, classwise_precision, classwise_recall, classwise_f1 = self.get_metrics_from_confusion_matrix(confusion_matrix_test)
    print(f'{self.experiment_description}: Test classwise precision', classwise_precision)
    print(f'{self.experiment_description}: Test classwise recall', classwise_recall)
    print(f'{self.experiment_description}: Test classwise f1', classwise_f1)
    print(f'{self.experiment_description}: Test Weighted F1', weighted_f1)
    print(f'{self.experiment_description}: Test Accuracy', accuracy)
    print(confusion_matrix_test)
    return (weighted_f1, accuracy, classwise_precision, classwise_recall, classwise_f1)


Generating Explainations


In [52]:
import sys
import torch
from tqdm import tqdm
# Append custom paths to sys.path for importing custom modules
sys.path.append(os.path.dirname("/home/student/Desktop/31171109-donotdelete/xai-chan/utils"))
from utils import train_utils, dataset_test, transform, models
# Append custom paths to sys.path for importing custom modules
sys.path.append(os.path.dirname("/home/student/Desktop/31171109-donotdelete/xai-chan/utils"))

# Initialize device
device = torch.device('cuda:0')

In [49]:
# Load the data
test_loader = dataset_test.get_breakhis_data_loader(
        dataset_path="/home/student/Desktop/31171109-donotdelete/xai-chan/dataset/Fold_0_5/val_20",
        transform=transform.resize_transform,
        pre_processing=[],
        image_type_list=["400X"],
        num_workers=2,
        is_test=True
    )

452 445 422 391
391


ResNet

In [59]:
# Load model
version = 50
model = models.ResNet_Model(version=version).to(device)
weights_path = "/home/student/Desktop/31171109-donotdelete/xai-chan/result/imagenet/_Fold_2_5_400X_BreakHis_FT_60_resnet50_imagenet_/_37_96.07250755287009_95.06917631917632_0.9614846110343933.pth"  # TODO: Provide the model path
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()



ResNet_Model(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
       

In [36]:
import numpy as np
def get_metrics_from_confusion_matrix(confusion_matrix_epoch):
        epoch_classwise_precision_manual_cpu = np.array(confusion_matrix_epoch.diag().cpu()) / np.array(confusion_matrix_epoch.cpu()).sum(axis=0)
        epoch_classwise_precision_manual_cpu = np.nan_to_num(epoch_classwise_precision_manual_cpu, nan=0, neginf=0, posinf=0)
        epoch_classwise_recall_manual_cpu = np.array(confusion_matrix_epoch.diag().cpu()) / np.array(confusion_matrix_epoch.cpu()).sum(axis=1)
        epoch_classwise_recall_manual_cpu = np.nan_to_num(epoch_classwise_recall_manual_cpu, nan=0, neginf=0, posinf=0)
        epoch_classwise_f1_manual_cpu = 2 * (epoch_classwise_precision_manual_cpu * epoch_classwise_recall_manual_cpu) / (epoch_classwise_precision_manual_cpu + epoch_classwise_recall_manual_cpu)
        epoch_classwise_f1_manual_cpu = np.nan_to_num(epoch_classwise_f1_manual_cpu, nan=0, neginf=0, posinf=0)
        epoch_avg_f1_manual = np.sum(epoch_classwise_f1_manual_cpu * np.array(confusion_matrix_epoch.cpu()).sum(axis=1)) / np.array(confusion_matrix_epoch.cpu()).sum(axis=1).sum()
        epoch_acc_manual = 100 * np.sum(np.array(confusion_matrix_epoch.diag().cpu())) / np.sum(np.array(confusion_matrix_epoch.cpu()))
        return (
         epoch_avg_f1_manual, epoch_acc_manual, epoch_classwise_precision_manual_cpu, epoch_classwise_recall_manual_cpu, epoch_classwise_f1_manual_cpu)

confusion_matrix_test = torch.zeros(2, 2).to(device)

with torch.no_grad():
    for patient_id, magnification, item_dict, binary_label, multi_label in tqdm(test_loader):
        view = item_dict[magnification[0]]
        view = view.cuda(device, non_blocking=True)                
        target = binary_label.to(device)
        outputs = model(view)
        outputs = outputs.squeeze(1)
        target = target.type_as(outputs)
        
        # Since it's testing, no need for loss calculation
        predicted = (outputs > 0.2).int()
        
        for targetx, predictedx in zip(target.view(-1), predicted.view(-1)):
            confusion_matrix_test[(targetx.long(), predictedx.long())] += 1

# Extract metrics from the confusion matrix similar to your validation function

weighted_f1, accuracy, classwise_precision, classwise_recall, classwise_f1 = get_metrics_from_confusion_matrix(confusion_matrix_test.cpu())
# Display the metrics
print(f'Test classwise precision: {classwise_precision}')
print(f'Test classwise recall: {classwise_recall}')
print(f'Test classwise f1: {classwise_f1}')
print(f'Test Weighted F1: {weighted_f1}')
print(f'Test Accuracy: {accuracy}')
print('Confusion Matrix:')
print(confusion_matrix_test.cpu().numpy())


100%|██████████| 13/13 [00:04<00:00,  3.22it/s]

Test classwise precision: [0.9097744 0.996124 ]
Test classwise recall: [0.9918033  0.95539033]
Test classwise f1: [0.9490196 0.9753321]
Test Weighted F1: 0.9671220183372498
Test Accuracy: 96.67519181585678
Confusion Matrix:
[[121.   1.]
 [ 12. 257.]]





In [None]:
from PIL import Image
from torchvision.transforms import ToPILImage
from zennit.attribution import Gradient, SmoothGrad
from zennit.core import Stabilizer
from zennit.composites import EpsilonGammaBox, EpsilonPlusFlat
from zennit.composites import SpecialFirstLayerMapComposite, NameMapComposite
from zennit.rules import Epsilon, ZPlus, ZBox, Norm, Pass, Flat
from zennit.types import Convolution, Activation, AvgPool, Linear as AnyLinear
from zennit.types import BatchNorm, MaxPool
from zennit.torchvision import VGGCanonizer, ResNetCanonizer
import torch.nn as nn

# Define XAI composite
low, high = torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]])
composite = SpecialFirstLayerMapComposite(
    layer_map=[
        (nn.ReLU, Pass()),
        (nn.AvgPool2d, Norm()),
        (nn.Conv2d, ZPlus()),
        (nn.Linear, Epsilon(epsilon=1e-6)),
        (nn.BatchNorm2d, Pass()),
    ],
    first_map=[
        (AnyLinear, ZBox(low, high))
    ]
)



In [None]:
from zennit.image import imgify, imsave
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
from zennit.attribution import Gradient

# Use the ResNet-specific canonizer
canonizer = ResNetCanonizer()

# Create a composite, specifying the canonizers
composite = EpsilonPlusFlat(canonizers=[canonizer])

# Iterate over the test_loader again
with torch.no_grad():
    for patient_id, magnification, item_dict, binary_label, multi_label in tqdm(test_loader):

        view = item_dict[magnification[0]]
        view = view.cuda(device, non_blocking=True)                

        outputs = model(view)
        outputs = outputs.squeeze(1)
        

        
        print(view)
        # Since it's testing, no need for loss calculation
        predicted = (outputs > 0.2).int()
        print(predicted)
         # Iterate over each image and its corresponding prediction
        for i, (single_view, single_pred) in enumerate(zip(view, predicted)):
            
            with Gradient(model=model, composite=composite) as attributor:
                _, attribution = attributor(single_view.unsqueeze(0), single_pred.unsqueeze(0))

            relevance = attribution.sum(1).cpu()
            imsave(f"/home/student/Desktop/31171109-donotdelete/xai-chan/saved/explanation/{patient_id}_{magnification[0]}_{[i]}.png", relevance, symmetric=True, cmap='coldnhot')

In [None]:

from IPython.display import display

# Make sure your model is in evaluation mode
model.eval()

# Define canonizer and composite
canonizer = ResNetCanonizer()
composite = EpsilonPlusFlat(canonizers=[canonizer])

# Iterate over the test_loader again
with torch.no_grad():
    for patient_id, magnification, item_dict, binary_label, multi_label in tqdm(test_loader):

        view = item_dict[magnification[0]]
        view = view.cuda(device, non_blocking=True)                

        outputs = model(view)
        outputs = outputs.squeeze(1)
        

        
        print(view)
        # Since it's testing, no need for loss calculation
        predicted = (outputs > 0.2).int()
        print(predicted)
   
        
        
        # Get the XAI attribution
        with Gradient(model=model, composite=composite) as attributor:
            _, attribution = attributor(view, predicted)

        # Sum over the channels and visualize
        relevance = attribution.sum(1).cpu()
        imsave(f"/home/student/Desktop/31171109-donotdelete/xai-chan/saved/explanation/{patient_id}_{magnification[0]}.png", relevance, symmetric=True, cmap='coldnhot')


In [57]:
import os
from PIL import Image
from torchvision import transforms
import torch

from zennit.image import imgify, imsave  # For creating visualizations
from zennit.torchvision import ResNetCanonizer  # For ResNet-specific canonization
from zennit.composites import EpsilonPlusFlat  # For the composite function in LRP
from zennit.attribution import Gradient  # For attributing using gradients
from IPython.display import display



def compute_heatmap(img_tensor, model, target):
    canonizer = ResNetCanonizer()
    composite = EpsilonPlusFlat(canonizers=[canonizer])
    
    with Gradient(model=model, composite=composite) as attributor:
        output, attribution = attributor(img_tensor, target)
    
    # Sum over the channels
    relevance = attribution.sum(1).cpu()
    return relevance

# Function to generate heatmaps for all images in a directory
def generate_heatmaps(directory, model):
    model.eval()
    target = torch.tensor([[1.0]]).to(device)

    for filename in os.listdir(directory):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            print(f"Processing image: {filename}")

            img_path = os.path.join(directory, filename)
            img = Image.open(img_path)
            print(img)
            img_tensor = transform.resize_transform(img).unsqueeze(0).to(device)  # Add batch dimension and move to device

            relevance = compute_heatmap(img_tensor, model, target)
            
            # Save the heatmap
            heatmap_filename = f"{filename.split('.')[0]}_gradient.png"  # Removing the original extension and appending the method name
            heatmap_path = os.path.join(directory, heatmap_filename)
            imsave(heatmap_path, relevance, symmetric=True, cmap='coldnhot')

# Execute the function
generate_heatmaps("/home/student/Desktop/31171109-donotdelete/xai-chan/explanation/val_10/SOB_B_F_14-21998EF/400X", model)



Processing image: SOB_B_F-14-21998EF-400-020.png
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=700x460 at 0x7F76BBF61400>


TypeError: pic should be Tensor or ndarray. Got <class 'PIL.PngImagePlugin.PngImageFile'>.