In [1]:
import torch
from torch.utils.data import DataLoader
import os
import pandas as pd
from work.utils.dataset import PandasDataset, MultiColorSpaceTransform
from work.utils.models import EfficientNetMultiColor as EfficientNet
import albumentations
from tqdm import tqdm

In [2]:
output_dimensions = 5
data_dir = '../../../dataset'
images_dir = os.path.join(data_dir, 'tiles')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
df_test = pd.read_csv(f"../../data/test.csv")
backbone_model = 'efficientnet-b0'
pretrained_model = {
    backbone_model: '/home/woshington/Projects/Doutorado/work/efficientnet-b0-08094119.pth'
}

In [3]:
print("Cuda", device)

Cuda cuda


In [4]:
val_transform =albumentations.Compose([
    MultiColorSpaceTransform()
])

In [5]:
dataloader = DataLoader(
    PandasDataset(images_dir, df_test, transforms=val_transform),
    batch_size=1,
    shuffle=False,
)

In [6]:
model = EfficientNet(
    backbone=backbone_model,
    output_dimensions=output_dimensions,
    pre_trained_model=pretrained_model
)
model.load_state_dict(
    torch.load(
        "../models/with-noise-mult-color.pth",
        weights_only=True
    )
)

Loaded pretrained weights for efficientnet-b0


<All keys matched successfully>

In [7]:
def compute_global_channel_importance(model, dataloader, target_class, device="cuda"):
    model.eval()
    model.to(device)

    # Initialize storage for channel gradients across the dataset
    all_channel_gradients = []

    for images, labels, _ in tqdm(dataloader):  # Assuming dataloader yields (image, label)
        images = images.to(device)
        images.requires_grad = True

        # Forward pass
        outputs = model(images)
        target_scores = outputs[:, target_class]

        # Backward pass (compute gradients per image in batch)
        gradients = []
        for score in target_scores:
            model.zero_grad()
            if images.grad is not None:
                images.grad.zero_()
            score.backward(retain_graph=True)
            gradients.append(images.grad.data.mean(dim=(2, 3)))  # [batch, C]

        # Stack gradients for the batch
        batch_gradients = torch.stack(gradients).mean(dim=0)  # [batch, C] -> [C]
        all_channel_gradients.append(batch_gradients)

    # Aggregate across all batches
    global_gradients = torch.stack(all_channel_gradients).mean(dim=0)  # [C]

    # Get top 3 channels globally
    top3_values, top3_indices = torch.topk(global_gradients.abs(), k=3)

    return {
        "global_gradients": global_gradients.cpu().numpy(),
        "top3_indices": top3_indices.cpu().numpy(),
        "top3_values": top3_values.cpu().numpy()
    }

In [8]:
result = compute_global_channel_importance(model, dataloader, target_class=1)

print("Global Top 3 Channels (Indices):", result["top3_indices"])
print("Global Top 3 Channel Gradients:", result["top3_values"])

100%|██████████| 1592/1592 [38:36<00:00,  1.46s/it]


Global Top 3 Channels (Indices): [[14  6  4]]
Global Top 3 Channel Gradients: [[3.0710787e-08 2.4052094e-08 2.0998312e-08]]


In [12]:
print(result["global_gradients"])

[[-3.04177417e-09 -1.20104602e-08 -1.05012452e-08 -8.68790195e-09
   2.09983124e-08  1.42345735e-08  2.40520937e-08 -1.38739304e-08
  -1.26334001e-08 -1.08983880e-08  6.57189325e-09 -1.82706614e-08
  -1.76451598e-09  4.47961401e-09  3.07107868e-08  8.45942605e-09
  -4.39729986e-09 -3.14057980e-09]]


In [1]:
import matplotlib.pyplot as plt
import numpy as np

channel_labels = [
    'RGB_R', 'RGB_G', 'RGB_B',      # img (RGB)
    'XYZ_X', 'XYZ_Y', 'XYZ_Z',      # image_xyz
    'HED_H', 'HED_E', 'HED_D',      # image_hed
    'LAB_L', 'LAB_A', 'LAB_B',      # image_lab
    'CIELUV_L', 'CIELUV_U', 'CIELUV_V', # image_luv
    'HSV_H', 'HSV_S', 'HSV_V'       # image_hsv
]

# Converter lista de arrays para um array 1D de médias por canal
arr_2d = np.vstack(result["global_gradients"])  # Se cada item for array 1D
grads = arr_2d.mean(axis=0)

plt.figure(figsize=(12, 6))
bars = plt.bar(channel_labels, np.abs(grads), color='skyblue')
plt.xticks(rotation=45, ha='right')
plt.ylabel('Channel Importance (|Mean Gradient|)')
plt.title('Color Channel Contribution to Target Class Activation')
plt.tight_layout()
plt.savefig('channel_importance.png')
plt.show()

NameError: name 'result' is not defined