In [None]:
# Ensemble Variant 1: S: Densenet201, W: ResNet152d (Batch_size = 16, With ShapCAM)
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import torch
import sys
sys.path.extend(["..", "../../backbone","../.."])
from ctran import CTranEncoder
from densenet201 import DenseNet201
from resnet152d import ResNet152d
from cam_generate import ShapCAMGenerator,generate_heatmaps, generate_heatmap
import torch.nn.functional as F
from torchvision import transforms
torch.cuda.empty_cache()
from multiprocessing import get_start_method, set_start_method, Pool

# Check and set the start method for multiprocessing
if get_start_method(allow_none=True) != 'spawn':
    set_start_method('spawn', force=True)

# Set device to GPU if available, else use CPU
device = torch.device("cuda")
print(f"Current device: {torch.cuda.get_device_name(torch.cuda.current_device())}" if torch.cuda.is_available() else "Current device: CPU")

In [None]:
# Define hyperparameters
batch_size = 16
num_epochs = 120
learning_rate = 0.000001
image_size = 320
num_workers = 6
num_layers = 12
embed_dim = 960
num_labels = 20
thresholds = [0.5] * num_labels
num_classes = 20
a = 0.4
b = 0.6
data_dir = '../../../mured-data/data'

In [None]:
# Define model and optimizer
backbone1 = ResNet152d(num_classes=num_classes, embed_dim = embed_dim)
model1 = CTranEncoder(num_classes=num_classes, image_size=image_size, embed_dim=embed_dim, num_layers=num_layers, num_heads=num_workers, backbone=backbone1)
model1.load_state_dict(torch.load('model1.pth', map_location=device))
model1.to(device)

# Define model and optimizer
backbone2 = DenseNet201(num_classes=num_classes, embed_dim = embed_dim)
model2 = CTranEncoder(num_classes=num_classes, image_size=image_size, embed_dim=embed_dim, num_layers=num_layers, num_heads=num_workers, backbone=backbone2)
model2.load_state_dict(torch.load('model2.pth', map_location=device))
model2.to(device)

In [None]:
# Define the list of image indices to process
test = [3]

# Define paths to the images
image_paths = [
    os.path.join(data_dir, 'images', str(test[i]) + '.png') for i in range(len(test))
]

# Load and preprocess the images
data_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

images = []
for image_path in image_paths:
    image = Image.open(image_path)
    preprocessed_image = data_transform(image)
    images.append(preprocessed_image)

# Create a ShapCAMGenerator instance for each model (assuming models are already loaded)
shap_cam_generator1 = ShapCAMGenerator(model1, device)
# shap_cam_generator2 = ShapCAMGenerator(model2, device)

# Define the number of processes (adjust as needed)
num_processes = min(6, len(images))

# Split the images into chunks
chunk_size = len(images) // num_processes
image_chunks = [images[i:i + chunk_size] for i in range(0, len(images), chunk_size)]

# Create a Pool of processes and map the function to each chunk
with Pool(processes=num_processes) as pool:
    heatmaps_chunks = pool.starmap(
        generate_heatmap, 
        [(image_chunk, model1, model2, shap_cam_generator1, device, a, b) for image_chunk in image_chunks]
    )

# Flatten the list of heatmaps chunks
heatmaps1 = [heatmap for heatmaps_chunk in heatmaps_chunks for heatmap in heatmaps_chunk[0]]
# heatmaps2 = [heatmap for heatmaps_chunk in heatmaps_chunks for heatmap in heatmaps_chunk[1]]

In [None]:
# Define the list of image indices to process
test = [3, 5]

# Define paths to the images
image_paths = [
    os.path.join(data_dir, 'images', str(test[i]) + '.png') for i in range(len(test))
]

# Load and preprocess the images
data_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

images = []
for image_path in image_paths:
    image = Image.open(image_path)
    preprocessed_image = data_transform(image)
    images.append(preprocessed_image)

# Create a ShapCAMGenerator instance for each model (assuming models are already loaded)
shap_cam_generator1 = ShapCAMGenerator(model1, device)
shap_cam_generator2 = ShapCAMGenerator(model2, device)

# Define the number of processes (adjust as needed)
num_processes = min(6, len(images))

# Split the images into chunks
chunk_size = len(images) // num_processes
image_chunks = [images[i:i + chunk_size] for i in range(0, len(images), chunk_size)]

# Create a Pool of processes and map the function to each chunk
with Pool(processes=num_processes) as pool:
    heatmaps_chunks = pool.starmap(
        generate_heatmaps, 
        [(image_chunk, model1, model2, shap_cam_generator1, shap_cam_generator2, device, a, b) for image_chunk in image_chunks]
    )

# Flatten the list of heatmaps chunks
heatmaps1 = [heatmap for heatmaps_chunk in heatmaps_chunks for heatmap in heatmaps_chunk[0]]
heatmaps2 = [heatmap for heatmaps_chunk in heatmaps_chunks for heatmap in heatmaps_chunk[1]]

In [None]:
# Visualize or save the generated heatmaps
for i, (image, heatmap1, heatmap2) in enumerate(zip(images, heatmaps1, heatmaps2)):
    # Convert the heatmap tensors to numpy arrays
    heatmap1_np = heatmap1.squeeze()
    heatmap2_np = heatmap2.squeeze()

    # Overlay the heatmaps on the original images
    heatmap1_overlay = heatmap1_np * 0.5 + image.squeeze().numpy().transpose(1, 2, 0)
    heatmap2_overlay = heatmap2_np * 0.5 + image.squeeze().numpy().transpose(1, 2, 0)

    # Plot the images with overlaid heatmaps
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(image.squeeze().numpy().transpose(1, 2, 0))
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(heatmap1_overlay, cmap="jet")
    plt.title("Shap-CAM Heatmap (Model 1)")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(heatmap2_overlay, cmap="jet")
    plt.title("Shap-CAM Heatmap (Model 2)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()  # Display the plot with overlaid heatmaps

    # Print the classification results for the images using the sigmoided outputs and weighted average
    with torch.no_grad():
        outputs1 = model1(image.unsqueeze(0).to(device))
        outputs2 = model2(image.unsqueeze(0).to(device))
        weighted_outputs = a * torch.sigmoid(outputs1) + b * torch.sigmoid(outputs2)
    print(f"Image {i + 1} - Weighted Average Classification Result: {weighted_outputs[0].cpu().numpy()}")

# Clean up
torch.cuda.empty_cache()