In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import wandb
from sklearn.metrics import recall_score, precision_score, confusion_matrix, top_k_accuracy_score
from tqdm import tqdm


In [None]:

IMG_SIZE = 300 
BATCH_SIZE = 64
EPOCHS = 60
LEARNING_RATE = 0.01
MOMENTUM = 0.9
ALPHA = 1.0

IMAGE_FOLDER = './data/image/yolo8x_300*300'
TRAIN_CSV = 'final_train_80.csv'  
VAL_CSV   = 'final_val_10.csv'    
PUBTEST_CSV = './data/DanishFungi2024-Mini-pubtest.csv'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
# Dataset
class SafeShroomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.unique_classes = sorted(self.df['class_id'].unique())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.unique_classes)}
        self.num_classes = len(self.unique_classes)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.root_dir, os.path.basename(row['image_path']))
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE))
        
        if self.transform:
            image = self.transform(image)
            
        cls_idx = self.class_to_idx.get(row['class_id'], 0)
        return image, cls_idx, float(row['poisonous'])

class SafeShroomMTL(nn.Module):
    def __init__(self, num_species):
        super(SafeShroomMTL, self).__init__()
        self.backbone = models.efficientnet_b3(weights='DEFAULT')
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.species_head = nn.Linear(in_features, num_species)
        self.tox_head = nn.Sequential(
            nn.Linear(in_features, 64), nn.ReLU(), nn.Linear(64, 1) 
        )

    def forward(self, x):
        f = self.backbone(x)
        return self.species_head(f), self.tox_head(f)

class SafeShroomToxOnly(nn.Module):
    def __init__(self):
        super(SafeShroomToxOnly, self).__init__()
        self.backbone = models.efficientnet_b3(weights='DEFAULT')
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        self.tox_head = nn.Sequential(
            nn.Linear(in_features, 64), nn.ReLU(), nn.Linear(64, 1) 
        )

    def forward(self, x):
        f = self.backbone(x)
        return self.tox_head(f)

# Load Model
def load_model(path, num_species):
    checkpoint = torch.load(path, map_location=DEVICE)
    
    # Try MTL
    try:
        model = SafeShroomMTL(num_species)
        model.load_state_dict(checkpoint)
        return model, True # True = Is MTL
    except:
        pass
        
    # Try Tox Only
    try:
        model = SafeShroomToxOnly()
        model.load_state_dict(checkpoint)
        return model, False # False = Not MTL
    except:
        print(f"Error loading {path}")
        return None, None


In [None]:

# Transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)), 
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])




# Train
def train():
    wandb.init(project="SafeShroom", name="SGD_MTL_Train")
    
    # load data
    train_data = SafeShroomDataset(TRAIN_CSV, IMAGE_FOLDER, train_transform)
    val_data = SafeShroomDataset(VAL_CSV, IMAGE_FOLDER, val_transform)  
    
    
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
    
    
    print(f"Stats: Train={len(train_data)} | Val={len(val_data)}")
    
    # Set model
    pos_weight = torch.tensor([15.0]).to(DEVICE)
    model = SafeShroomMTL(train_data.num_classes).to(DEVICE)
    
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
    
    criterion_sp = nn.CrossEntropyLoss()
    criterion_tox = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 
    
    best_acc = 0
    best_fnr = float('inf')
    best_tox = 0
    
    # Train Loop
    for epoch in range(EPOCHS):
        model.train()
        loop = tqdm(train_loader, desc=f"Ep {epoch+1}")
        
        for imgs, sp_lbl, tox_lbl in loop:
            imgs, sp_lbl, tox_lbl = imgs.to(DEVICE), sp_lbl.to(DEVICE), tox_lbl.to(DEVICE).float().unsqueeze(1)
            
            optimizer.zero_grad()
            sp_out, tox_out = model(imgs)
            
            loss_sp = criterion_sp(sp_out, sp_lbl)
            loss_tox = criterion_tox(tox_out, tox_lbl)
            total_loss = loss_sp + (ALPHA * loss_tox)
            
            total_loss.backward()
            optimizer.step()
            
            wandb.log({"train_loss": total_loss.item()})
            loop.set_postfix(loss=total_loss.item())

        # Validation
        model.eval()
        val_loss = 0
        sp_correct = 0
        total = 0
        all_tox_preds, all_tox_lbls = [], []
        
        with torch.no_grad():
            for imgs, sp_lbl, tox_lbl in val_loader:
                imgs, sp_lbl, tox_lbl = imgs.to(DEVICE), sp_lbl.to(DEVICE), tox_lbl.to(DEVICE)
                
                sp_out, tox_out = model(imgs)
                
                # Loss for Scheduler
                l_sp = criterion_sp(sp_out, sp_lbl)
                l_tox = criterion_tox(tox_out, tox_lbl.float().unsqueeze(1))
                val_loss += (l_sp + ALPHA * l_tox).item()
                
                # Species Metrics
                _, sp_pred = torch.max(sp_out, 1)
                sp_correct += (sp_pred == sp_lbl).sum().item()
                total += sp_lbl.size(0)
                
                # Toxicity Metrics
                tox_probs = torch.sigmoid(tox_out)
                preds = (tox_probs > 0.5).float().cpu().numpy()
                all_tox_preds.extend(preds)
                all_tox_lbls.extend(tox_lbl.cpu().numpy())

        # Calculate Metrics
        sp_acc = 100 * sp_correct / total
        avg_val_loss = val_loss / len(val_loader)
        tox_acc = 100 * (np.array(all_tox_preds) == np.array(all_tox_lbls)).sum() / len(all_tox_lbls)
        tox_recall = recall_score(all_tox_lbls, all_tox_preds, pos_label=1, zero_division=0)
        
        cm = confusion_matrix(all_tox_lbls, all_tox_preds, labels=[0, 1])
        fn = cm[1, 0]
        actual_positives = cm[1, :].sum()
        fnr = 100 * fn / (actual_positives + 1e-6)

        print(f"Ep {epoch+1} | Sp Acc: {sp_acc:.2f}% | Tox Recall: {tox_recall:.4f} | FNR: {fnr:.2f}%")
        
        scheduler.step(avg_val_loss)
        
        wandb.log({
            "epoch": epoch+1,
            "val_loss": avg_val_loss,
            "species_accuracy": sp_acc,
            "toxicity_recall": tox_recall,
            "toxicity_fnr": fnr,
            "toxic_accuracy": tox_acc,
            "lr": optimizer.param_groups[0]['lr']
        })
        
        # Save Models
        if sp_acc > best_acc:
            best_acc = sp_acc
            torch.save(model.state_dict(), "sp_safeshroom_best.pth")
            print(f"   Saved Best Species Model")
        if tox_acc > best_tox:
            best_tox = tox_acc
            torch.save(model.state_dict(), "tox_safeshroom_best.pth")
            print(f"   Saved Best Toxic Model")
        
        if fnr < best_fnr:
            best_fnr = fnr
            torch.save(model.state_dict(), "fnr_safeshroom_best.pth")


if __name__ == "__main__":
    train()

### Inference

In [None]:

MODEL_LIST = [
    "./tox_safeshroom_best.pth",
    "./tox_safeshroom_best_alpha2.0.pth",
    "./tox_safeshroom_best_alpha5.0.pth",
    "./tox_safeshroom_best_alpha10.0.pth"
    
]
REPORT_FILE = "final_pubtest_report.txt"

# Inference
def run_evaluation():
    # Setup Transforms
    test_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dummy = SafeShroomDataset(TRAIN_CSV, IMAGE_FOLDER, transform=None)
    
    # Load test data
    test_data = SafeShroomDataset(PUBTEST_CSV, IMAGE_FOLDER, test_transform, class_to_idx=train_dummy.class_to_idx)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    # Prepare Report File
    with open(REPORT_FILE, "w") as f:
        f.write("Inference Report\n")
        f.write("=========================\n\n")

    # 3. Iterate Models
    for model_path in MODEL_LIST:
        if not os.path.exists(model_path):
            print(f"Skipping {model_path} (File not found)")
            continue
            
        print(f"\nProcessing: {model_path}")
        model, is_mtl = load_model(model_path, train_dummy.num_classes)
        
        if model is None: continue
        model.to(DEVICE)
        model.eval()
        
        all_sp_probs = []
        all_sp_lbls = []
        all_tox_preds = []
        all_tox_lbls = []

        with torch.no_grad():
            for imgs, sp_lbl, tox_lbl in tqdm(test_loader, desc=os.path.basename(model_path)):
                imgs = imgs.to(DEVICE)
                
                if is_mtl:
                    sp_out, tox_out = model(imgs)
                    # Species
                    sp_probs = torch.softmax(sp_out, dim=1)
                    all_sp_probs.extend(sp_probs.cpu().numpy())
                    all_sp_lbls.extend(sp_lbl.numpy())
                else:
                    tox_out = model(imgs)
                
                # Toxicity
                tox_probs = torch.sigmoid(tox_out)
                preds = (tox_probs > 0.5).float().cpu().numpy()
                all_tox_preds.extend(preds)
                all_tox_lbls.extend(tox_lbl.numpy())

        # Calculate Metric
        
        # Toxicity
        y_pred = np.array(all_tox_preds).flatten()
        y_true = np.array(all_tox_lbls).flatten()
        
        final_tox_acc = 100 * (y_pred == y_true).sum() / len(y_true)
        final_tox_recall = recall_score(y_true, y_pred, pos_label=1, zero_division=0)
        
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
        fn = cm[1, 0] if cm.shape == (2,2) else 0
        actual_positives = cm[1, :].sum() if cm.shape == (2,2) else 0
        fnr = 100 * fn / (actual_positives + 1e-6)

        # Species (Only if MTL)
        if is_mtl:
            valid_ids = np.arange(train_dummy.num_classes)
            final_sp_top1 = top_k_accuracy_score(all_sp_lbls, all_sp_probs, k=1, labels=valid_ids) * 100
            final_sp_top3 = top_k_accuracy_score(all_sp_lbls, all_sp_probs, k=3, labels=valid_ids) * 100
            sp_top1_str = f"{final_sp_top1:.2f}%"
            sp_top3_str = f"{final_sp_top3:.2f}%"
        else:
            sp_top1_str = "N/A"
            sp_top3_str = "N/A"

        # Output
        output_str = (
            f"FINAL SLIDE RESULTS (PubTest)\n"
            f"Model: {os.path.basename(model_path)}\n"
            f"========================================\n"
            f"Species Top-1 Acc:  {sp_top1_str}\n"
            f"Species Top-3 Acc:  {sp_top3_str}\n"
            f"------------------------------\n"
            f"Toxicity Accuracy:  {final_tox_acc:.2f}%\n"
            f"Toxicity Recall:    {final_tox_recall:.4f}\n"
            f"FNR (Death Rate):   {fnr:.2f}%\n"
            f"Confusion Matrix:\n{cm}\n"
            f"========================================\n"
        )
        
        print(output_str)
        
        # Save to File
        with open(REPORT_FILE, "a") as f:
            f.write(output_str + "\n")

    print(f"\nReport saved to {REPORT_FILE}")

if __name__ == "__main__":
    run_evaluation()