# DermaCon-IN: CBM Model Inferencing

This notebook cell provides a full inference pipeline for a trained Concept Bottleneck Model (CBM) with a Swin Transformer backbone. The model uses:

- A Concept Bottleneck layer of 96 dimensions, capturing interpretable skin lesion descriptors and anatomical body parts.

- A final classification head predicting among 9 main diagnostic classes.

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
import torch.nn as nn
from torchvision import transforms
import timm

# ===============================
# Configuration
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "./checkpoints/CBM_MC_best_model.pth"
image_path = "./DATASET/IMG_0618.jpg"  # Replace with the image you want to test
csv_path = "./METADATA/test_split.csv"
mean = [0.53749797, 0.45875554, 0.40382471]
std = [0.21629889, 0.20366619, 0.20136241]

# ===============================
# Step 1: Load label mapping
# ===============================
df = pd.read_csv(csv_path)
label_to_idx = {label: idx for idx, label in enumerate(sorted(df['Main_class'].dropna().unique()))}
idx_to_label = {v: k for k, v in label_to_idx.items()}
concept_columns = [col for col in df.columns if 'Body_part' in col or 'Descriptor' in col]
concept_dim = len(concept_columns)
num_classes = len(label_to_idx)

# ===============================
# Step 2: Define transforms
# ===============================
val_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# ===============================
# Step 3: Load and preprocess image
# ===============================
img = Image.open(image_path).convert('RGB')
img_tensor = val_transform(img).unsqueeze(0).to(device)  # Add batch dim

# ===============================
# Step 4: Define models
# ===============================
class ConceptPredictor(nn.Module):
    def __init__(self, concept_dim: int = 100, pretrained: bool = False, model_name: str = "swin_base_patch4_window12_384", dropout_p: float = 0.3):
        super().__init__()
        self.backbone = timm.create_model(
            model_name, pretrained=pretrained, num_classes=0, global_pool="avg", img_size=512
        )
        in_dim = self.backbone.num_features
        self.fc = nn.Sequential(
            nn.Dropout(dropout_p),
            nn.Linear(in_dim, concept_dim, bias=True),
        )
        nn.init.trunc_normal_(self.fc[-1].weight, std=0.02)
        nn.init.zeros_(self.fc[-1].bias)

    def forward(self, x: torch.Tensor):
        features = self.backbone(x)
        concept_logits = self.fc(features)
        return torch.sigmoid(concept_logits), concept_logits


class LabelPredictor(nn.Module):
    def __init__(self, concept_dim: int = 96, num_classes: int = 8):
        super().__init__()
        self.fc = nn.Linear(concept_dim, num_classes)
        nn.init.trunc_normal_(self.fc.weight, std=0.02)
        nn.init.zeros_(self.fc.bias)

    def forward(self, concept_logits: torch.Tensor):
        return self.fc(concept_logits)


# ===============================
# Step 5: Load models & checkpoint
# ===============================
concept_model = ConceptPredictor(concept_dim=concept_dim).to(device)
label_model = LabelPredictor(concept_dim=concept_dim, num_classes=num_classes).to(device)

checkpoint = torch.load(checkpoint_path, map_location=device)
concept_model.load_state_dict(checkpoint["concept_model"])
label_model.load_state_dict(checkpoint["label_model"])
concept_model.eval()
label_model.eval()

# ===============================
# Step 6: Inference
# ===============================
with torch.no_grad():
    pred_concepts_prob, pred_concepts_logits = concept_model(img_tensor)  # (1, 96)
    pred_label_logits = label_model(pred_concepts_logits)  # (1, num_classes)
    pred_label_idx = torch.argmax(pred_label_logits, dim=1).item()
    pred_label_name = idx_to_label[pred_label_idx]

# ===============================
# Step 7: Display Results
# ===============================
print(f"✅ Predicted Class: {pred_label_name} (index={pred_label_idx})\n")

# Print Top-5 positive and negative concepts
concept_vals = pred_concepts_logits.cpu().numpy().flatten()
concept_list = list(zip(concept_columns, concept_vals))

top_pos = sorted(concept_list, key=lambda x: x[1], reverse=True)[:4]
top_neg = sorted(concept_list, key=lambda x: x[1])[:4]

print("🔵 Top 4 Positive Concepts:")
for name, val in top_pos:
    if val<0:
        continue
    print(f"  {name:50s} +{val:.3f}")

print("\n🔴 Top 4 Negative Concepts:")
for name, val in top_neg:
    print(f"  {name:50s} {val:.3f}")


  checkpoint = torch.load(checkpoint_path, map_location=device)


✅ Predicted Class: Keratanisation Disorders (index=2)

🔵 Top 4 Positive Concepts:
  Descriptor_Hyperkeratotic plaques                  +2.364
  Descriptor_Fissure                                 +2.312
  Body_part_Lower Extremities Soles (Plantar Region) +1.949
  Descriptor_Scale                                   +0.931

🔴 Top 4 Negative Concepts:
  Body_part_Calves (Sural Region)                    -22.963
  Body_part_Perianal Area                            -14.085
  Descriptor_Lichenification                         -11.711
  Body_part_Upper Extremities Shoulders              -11.575
