In [None]:
import torch
import cv2
import numpy as np
import torch.nn as nn
import pandas as pd
import os
from torchvision import models, transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus 
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import torch.nn.functional as F
from collections import defaultdict

In [None]:
MODEL_LIST = [
    "tox_safeshroom_best.pth",
    "tox_safeshroom_best_alpha2.0.pth",
    "tox_safeshroom_best_alpha5.0.pth",
    "tox_safeshroom_best_alpha10.0.pth",
    "tox_only_best_acc_bestsweep2.pth"
    
]

In [None]:
class SafeShroomMTL(nn.Module):
    def __init__(self, num_species):
        super(SafeShroomMTL, self).__init__()
        self.backbone = models.efficientnet_b3(weights='DEFAULT')
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.species_head = nn.Linear(in_features, num_species)
        self.tox_head = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1) 
        )
    def forward(self, x):
        f = self.backbone(x)
        return self.species_head(f), self.tox_head(f)

class SafeShroomToxOnly(nn.Module):
    def __init__(self, dropout_rate=0.7): 
        super(SafeShroomToxOnly, self).__init__()
        self.backbone = models.efficientnet_b3(weights='DEFAULT')
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.tox_head = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate), 
            nn.Linear(64, 1) 
        )
    def forward(self, x):
        f = self.backbone(x)
        return self.tox_head(f)

class ToxicityModelWrapper(nn.Module):
    def __init__(self, model, is_mtl):
        super(ToxicityModelWrapper, self).__init__()
        self.model = model
        self.is_mtl = is_mtl
    def forward(self, x):
        if self.is_mtl:
            _, tox_out = self.model(x)
        else:
            tox_out = self.model(x)
        return tox_out
    
def load_model(path, num_species,device):
    checkpoint = torch.load(path, map_location=device)
    try:
        model = SafeShroomMTL(num_species=num_species)
        model.load_state_dict(checkpoint)
        return model, True 
    except RuntimeError:
        pass 
    try:
        model = SafeShroomToxOnly()
        model.load_state_dict(checkpoint)
        return model, False 
    except RuntimeError as e:
        print(f"Could not load model {path}. Unknown architecture.")
        raise e

In [None]:
# For seen species from test dataset
IMAGE_FOLDER = '../image/yolo8x_300*300'
TRAIN_CSV = '../../final_train_80.csv'  
TEST_CSV = './gold_standard_100.csv' 
BASE_OUTPUT_DIR = './gradcam_comparisons'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 300

In [None]:
# For species that are not seen by model but similar to species in dataset
IMAGE_FOLDER = '../image/similar_unseen_image'
TRAIN_CSV = '../../final_train_80.csv'  
TEST_CSV = '../image/similar_unseen_image/similar_unseen_meta.csv' 
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMG_SIZE = 300
BASE_OUTPUT_DIR = './gradcam_comparisons_unseen_similar'


In [None]:

def evaluate_model(model_path, train_csv, test_csv, img_folder):
    
    # Output Setup
    model_name = os.path.basename(model_path).replace('.pth', '')
    output_dir = os.path.join(BASE_OUTPUT_DIR, model_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    report_path = os.path.join(output_dir, "accuracy_report_detailed.txt")
    print(f"\nProcessing Model: {model_name}")

    # Load Resources
    train_df = pd.read_csv(train_csv)
    num_species = len(train_df['class_id'].unique())

    try:
        model, is_mtl = load_model(model_path, num_species,DEVICE)
        model = model.to(DEVICE)
        model.eval()
    except Exception:
        return

    # GradCAM Setup
    wrapped_model = ToxicityModelWrapper(model, is_mtl)
    target_layers = [model.backbone.features[-1]]
    cam = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers)

    stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    preprocess = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(*stats)
    ])

    # Processing Loop
    df = pd.read_csv(test_csv)
    
    correct_count = 0
    total_count = 0
    
    # Storage for detailed list
    species_details = defaultdict(list)
    
    # Storage for summary stats
    species_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'type': '?'})

    print(f" Generating GradCAMs")

    for idx, row in df.iterrows():
        species_name = row.get('species', row.get('text_label', f"Class {row.get('class_id', '?')}"))
        filename = os.path.basename(row['image_path'])
        full_image_path = os.path.join(img_folder, filename)
        
        if not os.path.exists(full_image_path):
            continue
        try:
            img = Image.open(full_image_path).convert('RGB')
        except:
            continue

        # Inference
        img_resized = img.resize((IMG_SIZE, IMG_SIZE))
        input_tensor = preprocess(img_resized).unsqueeze(0).to(DEVICE)
        rgb_img = np.float32(img_resized) / 255.0

        if is_mtl:
            _, tox_logits = model(input_tensor)
        else:
            tox_logits = model(input_tensor)
            
        tox_prob = torch.sigmoid(tox_logits).item()
        
        # Prediction Logic
        # > 0.5 = POISON, <= 0.5 = SAFE
        pred_label_str = "POISON" if tox_prob > 0.5 else "SAFE"
        true_label_str = "POISON" if int(row['poisonous']) == 1 else "SAFE"
        
        is_correct = (pred_label_str == true_label_str)
        icon = "✅" if is_correct else "❌"

        # Update Summary Stats
        total_count += 1
        species_stats[species_name]['total'] += 1
        species_stats[species_name]['type'] = true_label_str
        if is_correct:
            correct_count += 1
            species_stats[species_name]['correct'] += 1

        # Store Details for the Report
        species_details[species_name].append({
            'filename': filename,
            'true': true_label_str,
            'pred': pred_label_str,
            'prob': tox_prob * 100, 
            'icon': icon
        })

        # Generate GradCAM 
        targets = [ClassifierOutputTarget(0)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
        vis = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
        vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
        
        # Annotate Image
        color = (0, 255, 0) if is_correct else (0, 0, 255)
        cv2.putText(vis, f"Sp: {species_name}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        cv2.putText(vis, f"True: {true_label_str}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        cv2.putText(vis, f"Pred: {pred_label_str} ({tox_prob*100:.1f}%)", (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)

        # Save Image
        safe_sp = str(species_name).replace(" ", "_").replace("/", "-")
        save_name = f"{output_dir}/{safe_sp}_{idx}_{true_label_str}_vs_{pred_label_str}.jpg"
        cv2.imwrite(save_name, vis)

    # Generate FINAL Report
    overall_acc = (correct_count / total_count * 100) if total_count > 0 else 0
    
    print(f" Saving Detailed Report to {report_path}")

    with open(report_path, "w") as f:
        f.write(f"Model: {model_name}\n")
        f.write(f"Evaluating on: {test_csv}\n")
        f.write("="*60 + "\n")
        f.write(f"OVERALL ACCURACY: {overall_acc:.2f}% ({correct_count}/{total_count})\n")
        f.write("="*60 + "\n\n")

        # SUMMARY TABLE 
        f.write(f"{'SPECIES (TYPE)':<45} | {'ACCURACY':<10} | {'CORRECT/TOTAL'}\n")
        f.write("-" * 75 + "\n")
        for sp, data in sorted(species_stats.items()):
            sp_acc = (data['correct'] / data['total'] * 100) if data['total'] > 0 else 0
            merged_name = f"{sp} ({data['type']})" 
            f.write(f"{merged_name:<45} | {sp_acc:<9.1f}% | {data['correct']}/{data['total']}\n")
        
        f.write("\n\n")

        #  Detail Breakdown
        # Iterate through species
        for sp in sorted(species_details.keys()):
            f.write("="*50 + "\n")
            f.write(f"SPECIES GROUP: {sp}\n")
            f.write("="*50 + "\n")
            
            # Get list of images for this species
            items = species_details[sp]
            
            # Print each image row
            for i, item in enumerate(items, 1):
                line = (f"  [{i}] {item['filename']}: "
                        f"True={item['true']} | "
                        f"Pred={item['pred']} ({item['prob']:.1f}%) {item['icon']}")
                
                f.write(line + "\n")
                print(line)
            f.write("\n") 

    print(f"\nDone. Check {report_path}")



In [None]:
if __name__ == '__main__':
    print(f"Starting Comparison for {len(MODEL_LIST)} models...")
    for model_path in MODEL_LIST:
        evaluate_model(model_path, TRAIN_CSV, TEST_CSV, IMAGE_FOLDER)
