In [1]:
import torch
import torch.nn as nn

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau  # add import

from torchvision import transforms, datasets
from torchvision.transforms import InterpolationMode
from torchvision.transforms import v2 as T

from torch.utils.data import DataLoader, Dataset

from ShroomDataset import ShroomDataset

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import classification_report, top_k_accuracy_score

import numpy as np
import matplotlib.pyplot as plt
import math
import time
import pandas as pd
from PIL import Image, ImageOps
import os
from tqdm.auto import tqdm

data_path = './data/'

train_meta = pd.read_csv(os.path.join(data_path, 'train.csv'))#.iloc[:1000]
val_meta = pd.read_csv(os.path.join(data_path, 'val.csv'))
test_meta = pd.read_csv(os.path.join(data_path, 'test.csv'))

# val_meta = train_meta
# test_meta = train_meta


In [2]:
# class ShroomDataset(Dataset):

#     def __init__(self, df, base_path = './data/', transform = None, label2idx = None):
#         self.df = df
#         self.base_path = base_path
#         self.transform = transform

#         if label2idx is None:
#             unique = sorted(self.df['label'].unique())
#             self.label2idx = {label : idx for idx, label in enumerate(unique)}
#         else:
#             self.label2idx = label2idx

#     def __len__(self):
#         return len(self.df)
    
#     def __getitem__(self, index):
        
#         row  =  self.df.iloc[index]

#         img_path = self.base_path + row['image_path']
#         img = Image.open(img_path)
#         img = ImageOps.exif_transpose(img)       # handle camera rotation
#         if img.mode != 'RGB':                    # <- key line
#             img = img.convert('RGB')

#         if self.transform:
#             img = self.transform(img)
        
#         label = self.label2idx[row['label']]

#         return img, label
    



In [3]:

IMG_SIZE = 244  # try 320/384/448
mean = (0.485, 0.456, 0.406)  # replace with dataset stats later
std  = (0.229, 0.224, 0.225)

train_tf = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.7,1.0), ratio=(0.75,1.33), antialias=True),
    T.RandomHorizontalFlip(0.5),
    T.RandomApply([T.ColorJitter(0.2,0.2,0.2,0.05)], p=0.4),
    T.RandomApply([T.RandomRotation(15)], p=0.3),
    T.RandomApply([T.RandomPerspective(0.25)], p=0.15),
    T.RandomApply([T.GaussianBlur(3)], p=0.15),
    T.ToDtype(torch.float32, scale=True),   # scales [0..255] -> [0..1]
    T.Normalize(mean, std),
])

val_tf = T.Compose([
    T.Resize(IMG_SIZE, antialias=True),
    T.CenterCrop(IMG_SIZE),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean, std),
])

train_ds = ShroomDataset(train_meta, data_path, transform=train_tf)
val_ds   = ShroomDataset(val_meta,   data_path, transform=val_tf, label2idx=train_ds.label2idx)
test_ds  = ShroomDataset(test_meta,  data_path, transform=val_tf, label2idx=train_ds.label2idx)

num_classes = len(train_ds.label2idx)
idx2label = {v:k for k,v in train_ds.label2idx.items()}
class_names = [idx2label[i] for i in range(num_classes)]

print(f"# classes: {num_classes}")

train_loader = DataLoader(train_ds, batch_size=32, shuffle = True, num_workers=6, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size = 32, num_workers=6, persistent_workers=True)
test_loader = DataLoader(test_ds, batch_size = 32, num_workers=6, persistent_workers=True)
# BATCH_SIZE = 32  # try 24/32/48 depending on memory
# NUM_WORKERS = 4

# train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
#                           num_workers=NUM_WORKERS, persistent_workers=True, pin_memory=False)
# val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
#                           num_workers=NUM_WORKERS, persistent_workers=True, pin_memory=False)
# test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
#                           num_workers=NUM_WORKERS, persistent_workers=True, pin_memory=False)

# classes: 169


In [4]:

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block using 1×1 convolutions for efficiency."""
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.GELU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Squeeze: global spatial info
        y = self.avg_pool(x)
        # Excitation & scale
        y = self.fc(y)
        return x * y


class CBAM(nn.Module):
    """Convolutional Block Attention Module (CBAM) with streamlined operations."""
    def __init__(self, channels, reduction=16, kernel_size=5):
        super().__init__()
        # Channel attention
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.BatchNorm2d(channels // reduction),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Conv2d(channels // reduction, channels, 1, bias=False),
            nn.Sigmoid()
        )
        # Spatial attention
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Channel attention: refine channel weights
        ca = self.channel_attn(x)
        x = x * ca

        # Spatial attention: refine spatial focus
        max_pool, _ = x.max(dim=1, keepdim=True)
        avg_pool = x.mean(dim=1, keepdim=True)
        sa = self.spatial_attn(torch.cat([max_pool, avg_pool], dim=1))
        return x * sa


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_se=False, reduction=16):
        super().__init__()
        self.use_se = use_se
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)
        if use_se:
            self.se = SEBlock(channels, reduction)
        self.act = nn.GELU()
    def forward(self, x):
        out = self.act(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.use_se:
            out = self.se(out)
        return self.act(x + out)



class ConvAttnBlock(nn.Module):
    """
    Modular conv block:
      1) Conv → BN → GELU
      2) Optional CBAM attention
      3) Optional ResidualBlock (with SE)
    """
    def __init__(self, in_channels, out_channels,
                 use_cbam=False, use_res=False, use_se=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        self.attn = CBAM(out_channels) if use_cbam else nn.Identity()
        self.res  = ResidualBlock(out_channels, use_se=use_se) if use_res else nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.attn(x)
        x = self.res(x)
        return x


class ShroomCNNAttentive(nn.Module):
    def __init__(self, in_ch=3, block_cfgs=None,
                 mlp_units=(512,), num_classes=162):
        super().__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, 32, kernel_size=3, stride=2, padding=1, bias=False),  # down to 192×192
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU(),
        )
        layers, ch = [], in_ch
        ch = 64
        for out_ch, do_pool, cbam, res, se in block_cfgs:
            layers.append(ConvAttnBlock(ch, out_ch,
                                       use_cbam=cbam,
                                       use_res=res,
                                       use_se=se))
            if do_pool:
                # stride-2 conv instead of MaxPool for more capacity
                layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False))
                layers.append(nn.BatchNorm2d(out_ch))
                layers.append(nn.GELU())
            ch = out_ch
        self.model       = nn.Sequential(*layers)
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # classifier
        mlp = []
        in_feat = ch
        for u in mlp_units:
            mlp += [nn.Linear(in_feat, u),
                    nn.BatchNorm1d(u),
                    nn.GELU(),
                    nn.Dropout(0.4)]
            in_feat = u
        mlp.append(nn.Linear(in_feat, num_classes))
        self.classifier = nn.Sequential(*mlp)

        self.apply(self._init_weights)
        self.name = 'Attentive CNN'
    
    @staticmethod
    def _init_weights(m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if getattr(m, 'bias', None) is not None: nn.init.zeros_(m.bias)
        if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.stem(x)
        x = self.model(x)
        x = self.global_pool(x).view(x.size(0), -1)
        return self.classifier(x)
    

blocks = [
    # (out_ch, downsample, CBAM, Residual, SE)
    (64,   True,  False, True,  True),   # light attention, residuals on
    (128,  True,  False, True,  True),
    (256,  True,  True,  True,  True),   # start CBAM here
    (384,  True,  True,  True,  True),
]
# add more models to compare here
models = { 'attentive': ShroomCNNAttentive(in_ch=3, block_cfgs=blocks,
                    mlp_units=[1024, 512], num_classes=num_classes)}

In [5]:

class SAM(torch.optim.Optimizer):
    """
    Sharpness-Aware Minimization (SAM) optimizer wrapper.
    Wraps any base optimizer (e.g. SGD or AdamW) to perform the two-step SAM update.
    """

    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        """
        params         : iterable of parameters to optimize
        base_optimizer : torch.optim.Optimizer class (not instance), e.g. torch.optim.SGD
        rho            : SAM neighborhood size
        kwargs         : arguments for the base optimizer (lr, momentum, weight_decay, etc.)
        """
        assert rho >= 0.0, "rho must be non-negative"
        defaults = dict(rho=rho, **kwargs)
        super().__init__(params, defaults)
        # instantiate your base optimizer with the same param groups
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)

    @torch.no_grad()
    def first_step(self, zero_grad=True):
        """
        1) Ascent step: move to the worst‐case neighbor w + ε
        """
        # 1a) Compute the L2 norm of all gradients
        grad_norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2)
                for group in self.param_groups for p in group['params']
                if p.grad is not None
            ]),
            p=2
        )
        scale = self.param_groups[0]['rho'] / (grad_norm + 1e-12)

        # 1b) Perturb each parameter by ε = scale * grad
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)                  # w = w + ε
                self.state[p]['eps'] = e_w   # store ε for the second step

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=True):
        """
        2) Descent step: take an optimizer step from the perturbed weights,
           then restore w ← w + ε − ε = the new w.
        """
        # 2a) descent on the perturbed weights
        self.base_optimizer.step()

        # 2b) subtract ε to return to the updated original weights
        for group in self.param_groups:
            for p in group['params']:
                eps = self.state[p].get('eps')
                if eps is None:
                    continue
                p.sub_(eps)

        if zero_grad:
            self.zero_grad()

    def step(self, closure=None):
        """
        We don’t use this—call first_step() and second_step() explicitly.
        """
        raise NotImplementedError("Use first_step() and second_step() instead")

In [6]:
def train(model,
          train_loader,
          val_loader,
          epochs: int = 50,
          lr: float = 1e-3,
          weight_decay: float = 5e-4):
    """
    Train `model` with SAM, logging per-epoch loss & accuracy.
    Returns: (train_loss_hist, train_acc_hist, val_loss_hist, val_acc_hist)
    """
    # 1) Device setup
    device = 'mps' if torch.backends.mps.is_available() else \
             'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"[Training] using device = {device}")
    model.to(device)

    # 2) Loss, optimizer, scheduler
    loss_fn = nn.CrossEntropyLoss()
    base_opt = optim.AdamW
    optimizer = SAM(
        model.parameters(),
        base_optimizer=base_opt,
        rho=0.06,
        lr=lr,
        weight_decay=weight_decay
    )


    scheduler = ReduceLROnPlateau(
        optimizer.base_optimizer,   # <- important with SAM
        mode='min',
        factor=0.5,                 # LR *= 0.5  (try 0.2–0.5)
        patience=3,                 # epochs with no val improvement before reducing
        threshold=1e-3,             # “min improvement” to count as progress
        cooldown=0,
        min_lr=1e-6
    )

    # 3) History buffers
    train_loss_hist, train_acc_hist = [], []
    val_loss_hist,   val_acc_hist   = [], []

    # 4) Epoch loop
    for epoch in range(1, epochs+1):
        model.train()
        running_loss, running_correct, running_total = 0.0, 0, 0
        i = 0
        for X, y in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            X, y = X.to(device), y.to(device)
            # ---- step 1: forward & backward on current weights ----
            logits = model(X)
            loss   = loss_fn(logits, y)
            loss.backward()
            # record train-batch accuracy
            preds = logits.argmax(dim=1)
            running_correct += (preds == y).sum().item()
            running_total   += y.size(0)

            optimizer.first_step(zero_grad=True)

            # ---- step 2: forward & backward on perturbed weights ----
            logits2 = model(X)
            loss2   = loss_fn(logits2, y)
            loss2.backward()
            optimizer.second_step(zero_grad=True)

            running_loss += loss2.item() * y.size(0)
            # print(f"Epoch {epoch}: loss={running_loss/running_total:.4f}, "
            #     f"acc={running_correct/running_total:.4f}")


        epoch_loss = running_loss / running_total
        epoch_acc  = running_correct / running_total
        train_loss_hist.append(epoch_loss)
        train_acc_hist.append(epoch_acc)

        # ---- validation ----
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for Xv, yv in tqdm(val_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
                Xv, yv = Xv.to(device), yv.to(device)
                out = model(Xv)
                val_loss += loss_fn(out, yv).item() * yv.size(0)
                val_correct += (out.argmax(dim=1) == yv).sum().item()
                val_total   += yv.size(0)

        val_loss /= val_total
        scheduler.step(val_loss)
        val_acc   = val_correct / val_total
        val_loss_hist.append(val_loss)
        val_acc_hist.append(val_acc)
        curr_lr = optimizer.base_optimizer.param_groups[0]['lr']

        print(f"Epoch {epoch:3d}/{epochs:3d} | "
              f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | "
              f"LR: {curr_lr:.6g}")
        
 

    return train_loss_hist, train_acc_hist, val_loss_hist, val_acc_hist

In [7]:
def train_models(models, train_loader, val_loader, epochs = 20, lr = 0.001):

    ''' trains a dictionary of models with the given params'''
    
    model_history = {}
    for model in models.values():
        start_time = time.time()
        print('######################################################')
        print(f'###         TRAINING MODEL: {model.name}          ###')
        print('######################################################')

        history = train(model,
                train_loader,
                val_loader,
                epochs=epochs,
                lr=lr)
        
        model_history[model.name] = history

        print('######################################################')
        print(f'##         TIME TO TRAIN MODEL:{model.name}       ###')
        print(f'###     {(time.time() - start_time)/60} MINUTES   ###')
        print('######################################################')

    return model_history

        


In [8]:
def evaluate_models(models, test_loader, classes = idx2label.values()):

    ''' evaluate the models on never seen test data'''

    for model in models.values():
        print('######################################################')
        print(f'###         EVALUATING MODEL: {model.name}          ###')
        print('######################################################')
        
        # set device
        device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
        model.to(device)
        model.eval()

        # collect predictions and true labels to calculate metrics
        y_true = []
        y_pred = []
        y_logits = []

        with torch.no_grad():
            # load data
            for X, y in test_loader:
                # to gpu
                X = X.to(device)
                y = y.to(device)

                # make predictions
                logits = model(X)

                # get most confident
                preds = logits.argmax(dim=1)

                y_logits.append(logits.cpu())
                y_pred.extend(preds.cpu().numpy())
                y_true.extend(y.cpu().numpy())

        # convert stacked logits
        y_logits = torch.cat(y_logits).numpy()

        # top-k accuracy
        top1_acc = accuracy_score(y_true, y_pred)
        top5_acc = top_k_accuracy_score(y_true, y_logits, k=5, labels=list(range(len(classes))))  # labels=class range

        # other classification metrics from sklearn
        f1 = f1_score(y_true, y_pred, average='weighted')
        precision = precision_score(y_true, y_pred, average='weighted')
        recall = recall_score(y_true, y_pred, average='weighted')

        print(f"\nTest Accuracy (Top-1): {top1_acc:.4f}")
        print(f"Top-5 Accuracy: {top5_acc:.4f}")
        print(f"F1 Score (weighted): {f1:.4f}")
        print(f"Precision (weighted): {precision:.4f}")
        print(f"Recall (weighted): {recall:.4f}")

        # per-class performance
        print("\nPer-Class Report:")
        cl_report = classification_report(y_true, y_pred, target_names=classes)
        print(cl_report)
    return (top1_acc, top5_acc, f1, precision, recall, cl_report)

In [9]:
def compare_models(models, train_loader, val_loader, test_loader, epochs=50, lr=0.001, class_names=None):

    # train all models
    model_history = train_models(models, train_loader, val_loader, epochs=epochs, lr=lr)

    # evaluate all models
    eval_results = {}
    for name, model in models.items():
        print("\n\n")
        results = evaluate_models({name: model}, test_loader, classes=class_names)
        eval_results[name] = results

    # plot training curves
    fig, axs = plt.subplots(2, 1, figsize=(10, 10))
    for model_name, history in model_history.items():
        train_loss, train_acc, val_loss, val_acc = history
        axs[0].plot(train_loss, label=f'{model_name} - Train' )
        axs[0].plot(val_loss, label=f'{model_name} - Val',marker='o', markersize=4 )
        axs[1].plot(train_acc, label=f'{model_name} - Train')
        axs[1].plot(val_acc, label=f'{model_name} - Val', marker='o', markersize=4)

    axs[0].set_title('Loss Curves')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[0].legend()

    axs[1].set_title('Accuracy Curves')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()

    plt.tight_layout()
    plt.show()

    # return eval results and training history for data logging
    return model_history, eval_results

In [None]:
num_epochs = 25
lr = 0.01

model_history, eval_results = compare_models(models, train_loader, val_loader, test_loader,
                                             epochs=num_epochs, lr=lr, class_names=idx2label.values())

######################################################
###         TRAINING MODEL: Attentive CNN          ###
######################################################
[Training] using device = mps


Epoch 1/25:   0%|          | 0/21548 [00:23<?, ?it/s]

Epoch 1/25:   0%|          | 0/488 [00:08<?, ?it/s]

Epoch   1/ 25 | Train Loss: 4.4582, Train Acc: 0.0705 | Val Loss: 3.5042, Val Acc: 0.1816 | LR: 0.01


Epoch 2/25:   0%|          | 0/21548 [00:00<?, ?it/s]

Epoch 2/25:   0%|          | 0/488 [00:00<?, ?it/s]

Epoch   2/ 25 | Train Loss: 3.6754, Train Acc: 0.1767 | Val Loss: 2.9971, Val Acc: 0.2734 | LR: 0.01


Epoch 3/25:   0%|          | 0/21548 [00:00<?, ?it/s]