In [None]:
import torch # type: ignore
from torchvision import models, transforms # type: ignore

import numpy as np # type: ignore

import os # type: ignore

import cv2 # type: ignore

import matplotlib.pyplot as plt # type: ignore

from code_files.utils import ( # type: ignore
    set_seed,
    worker_init_fn
) 
from code_files.preprocessing import ( # type: ignore
    get_middle_age_dataset, 
    to_memory
) 
from code_files.models import ( # type: ignore
    mixedresnetnetwork,
)
from code_files.grad_cam import ( # type: ignore
    GradCAM
) 

In [None]:
# Set seed and initialize device

torch.__version__

set_seed(44)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Selected device: {device}')

torch.cuda.is_available()

In [None]:
# Load and transform your data

test_transform = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                    ])

test_data_dir = './data/your_data'

np_images = np.load(test_data_dir + '/np_images.npy')
features = np.load(test_data_dir + '/features.npy')
print(f"Shape of np_images: {np_images.shape}")
print(f"Shape of features: {features.shape}")

test_dataset = get_middle_age_dataset(test_data_dir, transform=test_transform)

test_dataset = to_memory(test_dataset, device)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn) # on CPU-only machines, set num_workers=0

In [None]:
# Load pre-trained model with default weights

resnet50 = models.resnet50(weights='DEFAULT')
model = mixedresnetnetwork(model=resnet50, embeddings=resnet50.fc.in_features)
print(model)

SAVE_END_MODEL=True

if SAVE_END_MODEL:
    # for this to work, your notebook must be saved in the same folder as '3BTRON.pt' and the 'scripts' folder.
	# depending on your machine, comment out the line you don't need
    ## if using a gpu
    model.load_state_dict(torch.load('./3BTRON.pt')) 
    ## if running on a CPU-only machine
    model.load_state_dict(torch.load('./3BTRON.pt', map_location=torch.device('cpu')))

model = model.to(device)
model.eval()

In [None]:
# Unfreeze the target layer for Grad-CAM

for name, param in model.image_features_.named_parameters():
    if "layer4" in name:  # Unfreeze the target layer for Grad-CAM
        param.requires_grad = True

In [None]:
# Iterate over the model's named parameters

for name, param in model.image_features_.named_parameters():
    if "layer4" in name:
        print(f"{name} - requires_grad: {param.requires_grad}")

In [None]:
# Set pre-determined optimal thresholds

green_threshold = 0.25
amber_threshold = 0.75
binary_threshold = 0.5

In [None]:
# Generate the Grad-CAM heatmap (this is the final convolutional layer in the model)

target_layer = model.image_features_.layer4[2].conv3
grad_cam = GradCAM(model, target_layer)

In [None]:
# Optionally, visualize/save the Grad-CAM for each image

output_dir = './grad_cam_outputs_all'
original_dir = os.path.join(output_dir, 'original_images')
if not os.path.exists(original_dir):
    os.makedirs(original_dir)
cam_dir = os.path.join(output_dir, 'cam_images')
if not os.path.exists(cam_dir):
    os.makedirs(cam_dir)
    
categories = ['certain_positive', 'certain_negative', 'uncertain_positive', 'uncertain_negative']

# Create the category directories if they don't exist

for category in categories:
    category_dir = os.path.join(output_dir, category)
    if not os.path.exists(category_dir):
        os.makedirs(category_dir)

In [None]:
# Iterate again to generate Grad-CAM images

for idx, (data) in enumerate(test_loader):
    data = data.to(device)  # Send data to the GPU if available

    ## Make predictions

    outputs = model(data)  # Model is on the correct device

    probs = torch.sigmoid(outputs[:, 1]).squeeze()

    ## Convert to numpy for easier manipulation

    probs = probs.detach().cpu().numpy() # Get the probability for the current image

    ## Loop over each image in the batch

    for i in range(data.size(0)):
        image = data[i:i+1]  # Take one sample
        print(f"Original image shape: {image.shape}")
        prob = probs[i]

        ## Assuming you have the flattened image with 5 extra features

        image_feature_length = 224 * 224 * 3  # Length of the flattened image
        image_tensor = image[:, :, :image_feature_length]  # Extract image part (flattened)
        image_tensor = image_tensor.view(3, 224, 224)  # Reshape into (3, H, W)
        print(f"Reshaped image_tensor shape: {image_tensor.shape}")

        ## Generate Grad-CAM heatmap

        cam = grad_cam.generate_cam(image, device)

        ## Check if cam is None

        if cam is None:
            print(f"Warning: grad_cam.generate_cam returned None for image {idx * 16 + i}")
            continue  # Skip this iteration if no valid cam was generated

        ## Convert image to NumPy and check dtype and shape

        image_numpy = image_tensor.cpu().numpy().transpose(1, 2, 0)  # Convert to numpy and adjust channels
        print(f"image_numpy shape: {image_numpy.shape}, dtype: {image_numpy.dtype}")

        ## Normalize image (ensure it's float32 for visualization)

        image_numpy = (image_numpy - image_numpy.min()) / (image_numpy.max() - image_numpy.min())  # Normalize image

        ## Determine the category based on the thresholds

        if prob > amber_threshold:
            category = 'certain_positive'
        elif prob <= green_threshold:
            category = 'certain_negative'
        elif green_threshold < prob <= binary_threshold:
            category = 'uncertain_negative'
        elif binary_threshold < prob <= amber_threshold:
            category = 'uncertain_positive'

        ## Prepare filename with true label and predicted probability

        image_filename = f'grad_cam_image_{idx * 16 + i}_prob_{prob:.2f}.png'

        ## Save the Grad-CAM image to the corresponding category folder

        category_dir = os.path.join(output_dir, category)

        ## Debug: Check the full image path

        print(f"Saving image to: {category_dir}/{image_filename}")
        
        ## Generate Grad-CAM visualization

        cam_image = grad_cam.visualize_cam(image_numpy, cam, save_as=os.path.join(category_dir, image_filename))

        ## Check if cam_image is None

        if cam_image is None:
            print(f"Warning: visualize_cam returned None for image {idx * 16 + i}")
            continue  # Skip this iteration if no valid cam_image was generated

        print(f"Grad-CAM image saved: {os.path.join(category_dir, image_filename)}")

        save_path_original = os.path.join(original_dir, image_filename)
        save_path_cam = os.path.join(cam_dir, image_filename)

        fig2, ax2 = plt.subplots(figsize=(4, 4))
        ax2.imshow(image_numpy)
        ax2.axis('off')
        fig2.savefig(save_path_original, dpi=300, bbox_inches='tight')
        plt.close(fig2)

        print(f"Figure saved as {save_path_original}")

        if cam is not None:
            cam = cam.detach().cpu().numpy()
            cam = cv2.resize(cam, (image_numpy.shape[1], image_numpy.shape[0]))
            cam = (cam - cam.min()) / (cam.max() - cam.min())

            fig3 = plt.figure(figsize=(4,4))
            plt.imshow(cam, cmap='jet')
            plt.axis('off')
            plt.savefig(save_path_cam, dpi=300, bbox_inches='tight')
            plt.close(fig3)
            print(f"Figure saved as {save_path_cam}")
        else:
            print(f"Warning: CAM is None for image {idx * 16 + i}")

print("Grad-CAM generation complete for all test images.")