In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import wandb
from sklearn.metrics import recall_score, precision_score, confusion_matrix, accuracy_score
from tqdm import tqdm

In [None]:
IMG_SIZE = 300 
BATCH_SIZE = 32
EPOCHS = 60
LEARNING_RATE = 0.0003011514889996196
MOMENTUM = 0.9
DROPOUT = 0.7
POS_WEIGHT = 50.0

IMAGE_FOLDER = './data/image/yolo8x_300*300'
TRAIN_CSV = 'final_train_80.csv'  
VAL_CSV   = 'final_val_10.csv'      
PUBTEST_CSV = './data/DanishFungi2024-Mini-pubtest.csv'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Dataset
class SafeShroomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.unique_classes = sorted(self.df['class_id'].unique())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.unique_classes)}
        self.num_classes = len(self.unique_classes)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_name = os.path.basename(row['image_path'])
        img_path = os.path.join(self.root_dir, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE))
        
        if self.transform:
            image = self.transform(image)
            
        cls_idx = self.class_to_idx.get(row['class_id'], 0)
        
        return image, cls_idx, float(row['poisonous'])

# Model
class SafeShroomToxOnly(nn.Module):
    def __init__(self):
        super(SafeShroomToxOnly, self).__init__()
        # Load EfficientNet B3
        self.backbone = models.efficientnet_b3(weights='DEFAULT')
        
        # Get input features
        in_features = self.backbone.classifier[1].in_features
        
        # Remove original classifier
        self.backbone.classifier = nn.Identity()
        
        # Toxicity Head
        self.tox_head = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(64, 1) 
        )

    def forward(self, x):
        f = self.backbone(x)
        return self.tox_head(f)


In [None]:

# Transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), 
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Train
def train():
    wandb.init(project="SafeShroom-toxic-head-only", name="ADAMW_ToxOnly_Experiment-bestsweep")
    
    #Load data
    train_data = SafeShroomDataset(TRAIN_CSV, IMAGE_FOLDER, train_transform)
    val_data = SafeShroomDataset(VAL_CSV, IMAGE_FOLDER, val_transform)      
    
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    
    print(f"Stats: Train={len(train_data)} | Val={len(val_data)}")
    
    # Setup Model
    model = SafeShroomToxOnly().to(DEVICE)
    
    # Weighted Loss (Crucial for imbalance)
    pos_weight = torch.tensor([POS_WEIGHT]).to(DEVICE) 
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
    
    # Single Loss Function
    criterion_tox = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 
    
    best_tox_acc = 0
    best_tox_recall = 0
    
    # Train Loop
    for epoch in range(EPOCHS):
        model.train()
        loop = tqdm(train_loader, desc=f"Ep {epoch+1}")
        
        running_loss = 0.0
        
        for imgs, _, tox_lbl in loop:
            imgs, tox_lbl = imgs.to(DEVICE), tox_lbl.to(DEVICE).float().unsqueeze(1)
            
            optimizer.zero_grad()
            
            # Forward pass (Single Output)
            tox_out = model(imgs)
            
            # Loss Calculation
            loss = criterion_tox(tox_out, tox_lbl)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            wandb.log({"train_loss": loss.item()})
            loop.set_postfix(loss=loss.item())

        # Validation
        model.eval()
        val_loss = 0
        all_tox_preds, all_tox_lbls = [], []
        
        with torch.no_grad():
            for imgs, _, tox_lbl in val_loader:
                imgs, tox_lbl = imgs.to(DEVICE), tox_lbl.to(DEVICE)
                
                tox_out = model(imgs)
                
                # Val Loss
                l_tox = criterion_tox(tox_out, tox_lbl.float().unsqueeze(1))
                val_loss += l_tox.item()
                
                # Metrics
                tox_probs = torch.sigmoid(tox_out)
                preds = (tox_probs > 0.5).float().cpu().numpy()
                all_tox_preds.extend(preds)
                all_tox_lbls.extend(tox_lbl.cpu().numpy())

        # Calculate Metrics
        avg_val_loss = val_loss / len(val_loader)
        tox_acc = accuracy_score(all_tox_lbls, all_tox_preds) * 100
        tox_recall = recall_score(all_tox_lbls, all_tox_preds, pos_label=1, zero_division=0)
        
        cm = confusion_matrix(all_tox_lbls, all_tox_preds, labels=[0, 1])
        fn = cm[1, 0]
        actual_positives = cm[1, :].sum()
        fnr = 100 * fn / (actual_positives + 1e-6)

        print(f"Ep {epoch+1} | Val Loss: {avg_val_loss:.4f} | Recall: {tox_recall:.4f} | Acc: {tox_acc:.2f}% | FNR: {fnr:.2f}%")
        
        scheduler.step(avg_val_loss)
        
        wandb.log({
            "epoch": epoch+1,
            "val_loss": avg_val_loss,
            "toxicity_recall": tox_recall,
            "toxicity_fnr": fnr,
            "toxic_accuracy": tox_acc,
            "lr": optimizer.param_groups[0]['lr']
        })
        
        # Save Best Models
        # Best Recall (Safety Priority)
        if tox_recall > best_tox_recall:
            best_tox_recall = tox_recall
            torch.save(model.state_dict(), "tox_only_best_recall_bestsweep1.pth")
            print(f"   Saved Best Recall Model")
            
        # Best Overall Accuracy
        if tox_acc > best_tox_acc:
            best_tox_acc = tox_acc
            torch.save(model.state_dict(), "tox_only_best_acc_bestsweep1.pth")
            print(f" Saved Best Accuracy Model")
    torch.save(model.state_dict(), "final_epoch_bestsweep1.pth")

if __name__ == "__main__":
    train()
    wandb.finish()

### Inference

In [None]:

REPORT_FILE = 'toxic_only_report.txt'

# List of models to test
MODEL_PATHS = [
    './toxic_only/tox_only_best_recall_bestsweep1.pth',
    './toxic_only/tox_only_best_acc_bestsweep1.pth',
    './toxic_only/final_epoch_bestsweep1.pth'
]

# Test Transform
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def evaluate_model(model_path, test_loader):
    model = SafeShroomToxOnly().to(DEVICE)
    
    try:
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    except FileNotFoundError:
        print(f"Warning: {model_path} not found.")
        return None

    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, tox_lbls in tqdm(test_loader, desc=f"Testing {os.path.basename(model_path)}"):
            imgs = imgs.to(DEVICE)
            logits = model(imgs)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float().cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(tox_lbls.numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds).flatten()

    acc = accuracy_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds, pos_label=1, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1])
    
    tn, fp, fn, tp = cm.ravel()
    actual_positives = fn + tp
    fnr = (fn / actual_positives) * 100 if actual_positives > 0 else 0.0

    return {
        "model_name": model_path,
        "acc": acc * 100,
        "recall": recall,
        "fnr": fnr,
        "cm": cm
    }

def run_evaluation():
    # Load Test Data 
    test_data = SafeShroomDataset(PUBTEST_CSV, IMAGE_FOLDER, test_transform)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Inference Report
    with open(REPORT_FILE, "w") as f:
        f.write("SafeShroom Inference Report\n")
        f.write("========================================\n\n")

    for model_path in MODEL_PATHS:
        result = evaluate_model(model_path, test_loader)
        
        if result:
            # Output
            print("\n" + "="*30)
            print(f"Model: {result['model_name']}")
            print(f"Toxicity Accuracy:  {result['acc']:.2f}%")
            print(f"Toxicity Recall:    {result['recall']:.4f}")
            print(f"FNR (Death Rate):   {result['fnr']:.2f}%")
            print("Confusion Matrix:")
            print(result['cm'])
            print("="*30 + "\n")

            with open(REPORT_FILE, "a") as f:
                f.write(f"Model: {result['model_name']}\n")
                f.write(f"Toxicity Accuracy:  {result['acc']:.2f}%\n")
                f.write(f"Toxicity Recall:    {result['recall']:.4f}\n")
                f.write(f"FNR (Death Rate):   {result['fnr']:.2f}%\n")
                f.write("Confusion Matrix:\n")
                f.write(str(result['cm']) + "\n")
                f.write("-" * 40 + "\n\n")

    print(f"Inference complete. Results saved to {REPORT_FILE}")

if __name__ == "__main__":
    run_evaluation()