In [36]:
import torch
import torch.nn.functional as F
import numpy as np
import pydicom
import matplotlib.pyplot as plt
from torchvision import transforms, models
from PIL import Image

In [37]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
def generate_saliency_map(model, image, target_class=1):
    model.eval()
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229]),  # For grayscale image, using single channel mean and std
    ])
    
    # Convert the image to a tensor and set requires_grad to True
    image_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    image_tensor.requires_grad_()

    # Simulate a batch: create the target label tensor for class 0 or 1
    target_tensor = torch.tensor([target_class]).unsqueeze(0)  # Make it a batch of size 1

    # Forward pass using the shared_step function
    loss, logits, preds, _ = model.shared_step((image_tensor, target_tensor))  # Pass the image and target as a batch
    
    # Get the class score for the target class
    class_score = logits[0, target_class]
    
    # Zero all previous gradients
    model.zero_grad()
    
    # Backward pass to compute gradients with respect to the image
    class_score.backward()
    
    # Get the gradient of the image
    saliency, _ = torch.max(image_tensor.grad.data.abs(), dim=1)
    
    # Convert the saliency to numpy for visualization
    saliency = saliency.squeeze().cpu().numpy()
    
    return saliency

In [39]:
def load_dicom_image(dicom_path):
    # Read DICOM file
    dicom_data = pydicom.dcmread(dicom_path)
    
    # Convert the pixel data to numpy array
    image_array = dicom_data.pixel_array
    
    # If the image is grayscale (1 channel), convert to 3 channels (RGB)
    if len(image_array.shape) == 2:  # Check if it's grayscale (height, width)
        image_array = np.stack([image_array] * 3, axis=-1)  # Create 3 channels (RGB)
    
    # Convert to PIL Image
    image = Image.fromarray(image_array)
    return image

In [40]:
def plot_attention_maps(image, saliency_class0, saliency_class1):
    fig, ax = plt.subplots(1, 2, figsize=(15, 7))

    # Plot the saliency maps for class 0 and class 1
    ax[0].imshow(image, cmap='gray')
    ax[0].imshow(saliency_class0, cmap='hot', alpha=0.5)
    ax[0].axis('off')
    ax[0].set_title("Attention for Class 0 (No Finding)")

    ax[1].imshow(image, cmap='gray')
    ax[1].imshow(saliency_class1, cmap='hot', alpha=0.5)
    ax[1].axis('off')
    ax[1].set_title("Attention for Class 1 (Abnormal)")

    plt.show()

In [41]:
import torch
from torchvision import transforms
from PIL import Image
import yaml
from train_cls import load_config
from methods.cls_model import FinetuneClassifier
from datasets.cls_dataset import RSNAImageClsDataset  # Using RSNA dataset now
from datasets.data_module import DataModule
from datasets.transforms import DataTransforms
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

# Load config for RSNA dataset
config = load_config('../configs/rsna.yaml')

finetuned_rsna = FinetuneClassifier(config)

Loading configuration from: ../configs/rsna.yaml


In [42]:
checkpoint_path = '/Users/sandradening/Documents/Dokumente_Sandra/Master_Studium/3_Semester/VLM_Seminar/Code/data/ckpts/FinetuneCLS/rsna/2025_01_17_14_27_25/epoch=27-step=8147.ckpt'


finetuned_resnet50 = FinetuneClassifier(config)
finetuned_resnet50.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['state_dict'])

first = "/Users/sandradening/Documents/Dokumente_Sandra/Master_Studium/3_Semester/VLM_Seminar/Code/datasets/rsna/stage_2_train_images/f2698fda-0477-435f-b297-f1b284a731aa.dcm"

second = "/Users/sandradening/Documents/Dokumente_Sandra/Master_Studium/3_Semester/VLM_Seminar/Code/datasets/rsna/stage_2_train_images/b76dd4b8-7b51-4cb6-8fd7-0b7365ef3e1e.dcm"

third = "/Users/sandradening/Documents/Dokumente_Sandra/Master_Studium/3_Semester/VLM_Seminar/Code/datasets/rsna/stage_2_train_images/5d8dbcf9-0d68-4aec-8638-b0a9f45d71d6.dcm"

fourth = "/Users/sandradening/Documents/Dokumente_Sandra/Master_Studium/3_Semester/VLM_Seminar/Code/datasets/rsna/stage_2_train_images/f6be6dc3-9539-46c0-a1f5-b10919ff81cd.dcm"
dicom_paths = [first, second, third, fourth]  # List your DICOM file paths


for dicom_path in dicom_paths:
    # Load and preprocess DICOM image
    image = load_dicom_image(dicom_path)

    # Generate saliency maps for class 0 and class 1
    saliency_class0 = generate_saliency_map(finetuned_resnet50, image, target_class=0)
    saliency_class1 = generate_saliency_map(finetuned_resnet50, image, target_class=1)

    # Plot the attention maps for both classes side by side
    plot_attention_maps(image, saliency_class0, saliency_class1)

ValueError: Expected input batch_size (1) to match target batch_size (0).

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, models
from PIL import Image

def generate_saliency_map(model, image, target_class=1):
    """
    Generates the saliency map for a given class (0 or 1)
    """
    model.eval()
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Convert the image to a tensor and set requires_grad to True
    image_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension
    image_tensor.requires_grad_()

    # Forward pass
    output = model(image_tensor)
    
    # Get the class score for the target class
    class_score = output[0, target_class]
    
    # Zero all previous gradients
    model.zero_grad()
    
    # Backward pass to compute gradients with respect to the image
    class_score.backward()
    
    # Get the gradient of the image
    saliency, _ = torch.max(image_tensor.grad.data.abs(), dim=1)
    
    # Convert the saliency to numpy for visualization
    saliency = saliency.squeeze().cpu().numpy()
    
    return saliency

def plot_attention_maps(image, saliency_class0, saliency_class1):
    """
    Plot the attention maps (saliency maps) for both class 0 and class 1 side by side.
    """
    fig, ax = plt.subplots(1, 2, figsize=(15, 7))

    # Plot the saliency maps for class 0 and class 1
    ax[0].imshow(image)
    ax[0].imshow(saliency_class0, cmap='hot', alpha=0.5)
    ax[0].axis('off')
    ax[0].set_title("Attention for Class 0 (No Finding)")

    ax[1].imshow(image)
    ax[1].imshow(saliency_class1, cmap='hot', alpha=0.5)
    ax[1].axis('off')
    ax[1].set_title("Attention for Class 1 (Abnormal)")

    plt.show()

In [None]:
# Load the test image
image_path = 'path_to_your_image.jpg'
image = Image.open(image_path)

# Load the pre-trained ResNet50 model (or your fine-tuned model)
model = models.resnet50(pretrained=True)

# Generate saliency maps for class 0 and class 1
saliency_class0 = generate_saliency_map(model, image, target_class=0)
saliency_class1 = generate_saliency_map(model, image, target_class=1)

# Plot the attention maps for both classes side by side
plot_attention_maps(image, saliency_class0, saliency_class1)

In [None]:
for image_path in test_image_paths:
    image = Image.open(image_path)
    saliency_class0 = generate_saliency_map(finetuned_rsna, image, target_class=0)
    saliency_class1 = generate_saliency_map(finetuned_rsna, image, target_class=1)
    plot_attention_maps(image, saliency_class0, saliency_class1)