In [1]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.4
MIXUP_PROB = 0.3
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.05

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(2048, num_classes)
    )
    # 冻结前两层
    for param in m.layer1.parameters():
        param.requires_grad = True
    for param in m.layer2.parameters():
        param.requires_grad = True
    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.2),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)
opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop()

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom9.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom10.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom10.pth")

if __name__ == "__main__":
    train()


  from .autonotebook import tqdm as notebook_tqdm


[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250707_154934-d779159b[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mdog-12[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/bdbvti0qailz65eqx85d0[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.2),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 507/507 [02:10<00:00,  3.89it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 63/63 [00:11<00:00,  5.32it/s]


E1/30 | TL 0.766 TA 0.849 | VL 0.910 VA 0.879
  ✔ Save best 0.879407001370375


Epoch 2 Train: 100%|██████████| 507/507 [01:45<00:00,  4.78it/s]
Epoch 2 Val: 100%|██████████| 63/63 [00:07<00:00,  8.45it/s]


E2/30 | TL 0.781 TA 0.820 | VL 0.900 VA 0.878


Epoch 3 Train: 100%|██████████| 507/507 [01:45<00:00,  4.78it/s]
Epoch 3 Val: 100%|██████████| 63/63 [00:07<00:00,  8.40it/s]


E3/30 | TL 0.762 TA 0.844 | VL 0.918 VA 0.878


Epoch 4 Train: 100%|██████████| 507/507 [01:45<00:00,  4.78it/s]
Epoch 4 Val: 100%|██████████| 63/63 [00:07<00:00,  8.55it/s]


E4/30 | TL 0.803 TA 0.830 | VL 0.902 VA 0.880
  ✔ Save best 0.8799053195465305


Epoch 5 Train: 100%|██████████| 507/507 [01:46<00:00,  4.78it/s]
Epoch 5 Val: 100%|██████████| 63/63 [00:07<00:00,  8.49it/s]


E5/30 | TL 0.785 TA 0.808 | VL 0.905 VA 0.880
  ✔ Save best 0.8800298990905693


Epoch 6 Train: 100%|██████████| 507/507 [01:46<00:00,  4.76it/s]
Epoch 6 Val: 100%|██████████| 63/63 [00:07<00:00,  8.51it/s]


E6/30 | TL 0.791 TA 0.822 | VL 0.893 VA 0.881
  ✔ Save best 0.8811511149869191


Epoch 7 Train: 100%|██████████| 507/507 [01:46<00:00,  4.78it/s]
Epoch 7 Val: 100%|██████████| 63/63 [00:07<00:00,  8.51it/s]


E7/30 | TL 0.765 TA 0.825 | VL 0.903 VA 0.880


Epoch 8 Train: 100%|██████████| 507/507 [01:45<00:00,  4.79it/s]
Epoch 8 Val: 100%|██████████| 63/63 [00:07<00:00,  8.35it/s]


E8/30 | TL 0.731 TA 0.840 | VL 0.900 VA 0.882
  ✔ Save best 0.8818985922511524


Epoch 9 Train: 100%|██████████| 507/507 [01:46<00:00,  4.78it/s]
Epoch 9 Val: 100%|██████████| 63/63 [00:07<00:00,  8.32it/s]


E9/30 | TL 0.746 TA 0.845 | VL 0.896 VA 0.880


Epoch 10 Train: 100%|██████████| 507/507 [01:46<00:00,  4.77it/s]
Epoch 10 Val: 100%|██████████| 63/63 [00:07<00:00,  8.52it/s]


E10/30 | TL 0.739 TA 0.859 | VL 0.905 VA 0.884
  ✔ Save best 0.8842656035878909


Epoch 11 Train: 100%|██████████| 507/507 [01:46<00:00,  4.76it/s]
Epoch 11 Val: 100%|██████████| 63/63 [00:07<00:00,  8.38it/s]


E11/30 | TL 0.787 TA 0.830 | VL 0.905 VA 0.883


Epoch 12 Train:  41%|████      | 207/507 [00:44<01:04,  4.63it/s]


KeyboardInterrupt: 

In [3]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 192
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.35
MIXUP_PROB = 0.25
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.05

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.4),
        nn.Linear(2048, num_classes)
    )
    # 冻结前两层
    for param in m.layer1.parameters():
        param.requires_grad = True
    for param in m.layer2.parameters():
        param.requires_grad = True
    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.2),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)
opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5, weight_decay=1e-4)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop()

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom8.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom10.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom10.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250707_161220-c5c1effa[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mdog-12[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/3jbm15i658ielbe1p7gm4[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.2),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 42/42 [00:07<00:00,  5.41it/s]


E1/30 | TL 0.695 TA 0.870 | VL 0.864 VA 0.893
  ✔ Save best 0.8929861716706117


Epoch 2 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
Epoch 2 Val: 100%|██████████| 42/42 [00:07<00:00,  5.50it/s]


E2/30 | TL 0.677 TA 0.847 | VL 0.873 VA 0.888


Epoch 3 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
Epoch 3 Val: 100%|██████████| 42/42 [00:07<00:00,  5.37it/s]


E3/30 | TL 0.674 TA 0.875 | VL 0.875 VA 0.890


Epoch 4 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
Epoch 4 Val: 100%|██████████| 42/42 [00:07<00:00,  5.51it/s]


E4/30 | TL 0.726 TA 0.855 | VL 0.880 VA 0.888


Epoch 5 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
Epoch 5 Val: 100%|██████████| 42/42 [00:07<00:00,  5.43it/s]


E5/30 | TL 0.723 TA 0.882 | VL 0.882 VA 0.889


Epoch 6 Train: 100%|██████████| 338/338 [01:45<00:00,  3.19it/s]
Epoch 6 Val: 100%|██████████| 42/42 [00:07<00:00,  5.39it/s]


E6/30 | TL 0.686 TA 0.877 | VL 0.880 VA 0.889


Epoch 7 Train: 100%|██████████| 338/338 [01:45<00:00,  3.21it/s]
Epoch 7 Val: 100%|██████████| 42/42 [00:07<00:00,  5.36it/s]


E7/30 | TL 0.675 TA 0.879 | VL 0.884 VA 0.890


Epoch 8 Train: 100%|██████████| 338/338 [01:45<00:00,  3.20it/s]
Epoch 8 Val: 100%|██████████| 42/42 [00:07<00:00,  5.42it/s]


E8/30 | TL 0.643 TA 0.866 | VL 0.881 VA 0.892
Early stop!


In [5]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 224
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.35
MIXUP_PROB = 0.25
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.05

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(2048, num_classes)
    )
    # 冻结前两层
    for param in m.layer1.parameters():
        param.requires_grad = True
    for param in m.layer2.parameters():
        param.requires_grad = True
    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.1),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)
opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, weight_decay=1e-4)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop()

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom10.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom11.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom11.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250707_163655-eccd019e[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mpig-13[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/x47544ky0ok1ra44y32bt[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.1),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 290/290 [01:46<00:00,  2.72it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 36/36 [00:07<00:00,  4.60it/s]


E1/30 | TL 0.705 TA 0.862 | VL 0.857 VA 0.893
  ✔ Save best 0.8933599103027283


Epoch 2 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 2 Val: 100%|██████████| 36/36 [00:07<00:00,  4.63it/s]


E2/30 | TL 0.703 TA 0.846 | VL 0.862 VA 0.892


Epoch 3 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 3 Val: 100%|██████████| 36/36 [00:07<00:00,  4.73it/s]


E3/30 | TL 0.689 TA 0.877 | VL 0.857 VA 0.894
  ✔ Save best 0.8939828080229226


Epoch 4 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 4 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E4/30 | TL 0.661 TA 0.852 | VL 0.859 VA 0.895
  ✔ Save best 0.8946057057431169


Epoch 5 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 5 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E5/30 | TL 0.690 TA 0.884 | VL 0.858 VA 0.894


Epoch 6 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 6 Val: 100%|██████████| 36/36 [00:07<00:00,  4.76it/s]


E6/30 | TL 0.679 TA 0.831 | VL 0.858 VA 0.895


Epoch 7 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 7 Val: 100%|██████████| 36/36 [00:07<00:00,  4.60it/s]


E7/30 | TL 0.639 TA 0.913 | VL 0.861 VA 0.895
  ✔ Save best 0.8948548648311947


Epoch 8 Train: 100%|██████████| 290/290 [01:44<00:00,  2.76it/s]
Epoch 8 Val: 100%|██████████| 36/36 [00:07<00:00,  4.72it/s]


E8/30 | TL 0.652 TA 0.884 | VL 0.858 VA 0.894
Early stop!


In [8]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 224
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.32
MIXUP_PROB = 0.22
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.02

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(2048, num_classes)
    )
    # 冻结前两层
    for param in m.layer1.parameters():
        param.requires_grad = True
    for param in m.layer2.parameters():
        param.requires_grad = True
    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)
param_groups = [
    {"params": model.layer1.parameters(), "lr": 1e-6},
    {"params": model.layer2.parameters(), "lr": 5e-6},
    {"params": model.layer3.parameters(), "lr": 1e-5},
    {"params": model.layer4.parameters(), "lr": 1e-5},
    {"params": model.fc.parameters(), "lr": 1e-4},
]

opt = optim.AdamW(param_groups, weight_decay=1e-4)
sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop()

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom11.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom12.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom12.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: \ Waiting for the swanlab cloud response.

[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250707_170453-6e788643[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mcat-14[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/2ue53lgkqhtziaeb0m0je[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 36/36 [00:07<00:00,  4.71it/s]


E1/30 | TL 0.419 TA 0.834 | VL 0.638 VA 0.894
  ✔ Save best 0.8942319671110004


Epoch 2 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 2 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E2/30 | TL 0.417 TA 0.871 | VL 0.640 VA 0.895
  ✔ Save best 0.8948548648311947


Epoch 3 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 3 Val: 100%|██████████| 36/36 [00:07<00:00,  4.79it/s]


E3/30 | TL 0.381 TA 0.895 | VL 0.641 VA 0.895


Epoch 4 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 4 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E4/30 | TL 0.437 TA 0.867 | VL 0.642 VA 0.895
  ✔ Save best 0.8953531830073502


Epoch 5 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 5 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E5/30 | TL 0.431 TA 0.858 | VL 0.644 VA 0.894


Epoch 6 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 6 Val: 100%|██████████| 36/36 [00:07<00:00,  4.55it/s]


E6/30 | TL 0.399 TA 0.898 | VL 0.648 VA 0.896
  ✔ Save best 0.8959760807275445


Epoch 7 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 7 Val: 100%|██████████| 36/36 [00:07<00:00,  4.78it/s]


E7/30 | TL 0.390 TA 0.884 | VL 0.644 VA 0.896


Epoch 8 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 8 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E8/30 | TL 0.353 TA 0.918 | VL 0.643 VA 0.895
Early stop!


In [12]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile
from torch.optim.swa_utils import AveragedModel, SWALR

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 224
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.32
MIXUP_PROB = 0.22
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.02

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(2048, num_classes)
    )
    # 冻结前两层
    for param in m.layer1.parameters():
        param.requires_grad = True
    for param in m.layer2.parameters():
        param.requires_grad = True
    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)

param_groups = [
    {"params": model.layer1.parameters(), "lr": 1e-6},
    {"params": model.layer2.parameters(), "lr": 5e-6},
    {"params": model.layer3.parameters(), "lr": 1e-5},
    {"params": model.layer4.parameters(), "lr": 1e-5},
    {"params": model.fc.parameters(), "lr": 1e-4},
]


opt = optim.AdamW(param_groups, weight_decay=1e-4)
# SWA 设置
swa_start = 15  # 第几轮开始SWA，可调
swa_model = AveragedModel(model)
swa_scheduler = SWALR(opt, swa_lr=1e-5)



sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop(p=99)

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom12.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    use_swa = False  # 开关，在第 swa_start 轮开启
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        # SWA 更新
        if ep >= swa_start:
            if not use_swa:
                print(f"🔁 SWA started from epoch {ep}")
                use_swa = True
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            if use_swa:
                torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
            else:
                torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom13.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    if use_swa:
        torch.optim.swa_utils.update_bn(train_ld, swa_model, device=DEVICE)
        torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
        print("📦 SWA model saved.")
    else:
        torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom13.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: \ Waiting for the swanlab cloud response.

[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250707_174719-fd2802a2[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mdog-12[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/zk95qa07y3kfq4vrtkq47[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 36/36 [00:08<00:00,  4.50it/s]


E1/30 | TL 0.379 TA 0.895 | VL 0.638 VA 0.894
  ✔ Save best 0.8937336489348449


Epoch 2 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 2 Val: 100%|██████████| 36/36 [00:07<00:00,  4.56it/s]


E2/30 | TL 0.401 TA 0.892 | VL 0.637 VA 0.895
  ✔ Save best 0.8947302852871558


Epoch 3 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 3 Val: 100%|██████████| 36/36 [00:07<00:00,  4.56it/s]


E3/30 | TL 0.385 TA 0.893 | VL 0.649 VA 0.895
  ✔ Save best 0.8952286034633113


Epoch 4 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 4 Val: 100%|██████████| 36/36 [00:07<00:00,  4.74it/s]


E4/30 | TL 0.384 TA 0.889 | VL 0.647 VA 0.895
  ✔ Save best 0.8954777625513891


Epoch 5 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 5 Val: 100%|██████████| 36/36 [00:07<00:00,  4.63it/s]


E5/30 | TL 0.444 TA 0.869 | VL 0.651 VA 0.895


Epoch 6 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 6 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E6/30 | TL 0.436 TA 0.870 | VL 0.639 VA 0.895


Epoch 7 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 7 Val: 100%|██████████| 36/36 [00:07<00:00,  4.66it/s]


E7/30 | TL 0.386 TA 0.886 | VL 0.639 VA 0.896
  ✔ Save best 0.8956023420954279


Epoch 8 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 8 Val: 100%|██████████| 36/36 [00:07<00:00,  4.72it/s]


E8/30 | TL 0.379 TA 0.894 | VL 0.647 VA 0.896
  ✔ Save best 0.8962252398156223


Epoch 9 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 9 Val: 100%|██████████| 36/36 [00:07<00:00,  4.68it/s]


E9/30 | TL 0.399 TA 0.894 | VL 0.641 VA 0.898
  ✔ Save best 0.8977201943440887


Epoch 10 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 10 Val: 100%|██████████| 36/36 [00:08<00:00,  4.46it/s]


E10/30 | TL 0.363 TA 0.890 | VL 0.632 VA 0.897


Epoch 11 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 11 Val: 100%|██████████| 36/36 [00:07<00:00,  4.54it/s]


E11/30 | TL 0.407 TA 0.861 | VL 0.644 VA 0.896


Epoch 12 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 12 Val: 100%|██████████| 36/36 [00:07<00:00,  4.58it/s]


E12/30 | TL 0.403 TA 0.883 | VL 0.644 VA 0.898


Epoch 13 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 13 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E13/30 | TL 0.446 TA 0.872 | VL 0.641 VA 0.896


Epoch 14 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 14 Val: 100%|██████████| 36/36 [00:07<00:00,  4.74it/s]


E14/30 | TL 0.381 TA 0.857 | VL 0.636 VA 0.898


Epoch 15 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 15 Val: 100%|██████████| 36/36 [00:07<00:00,  4.57it/s]


🔁 SWA started from epoch 15
E15/30 | TL 0.390 TA 0.864 | VL 0.632 VA 0.897


Epoch 16 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 16 Val: 100%|██████████| 36/36 [00:07<00:00,  4.53it/s]


E16/30 | TL 0.408 TA 0.898 | VL 0.630 VA 0.899
  ✔ Save best 0.8988414102404385


Epoch 17 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 17 Val: 100%|██████████| 36/36 [00:07<00:00,  4.63it/s]


E17/30 | TL 0.408 TA 0.895 | VL 0.636 VA 0.898


Epoch 18 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 18 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E18/30 | TL 0.386 TA 0.916 | VL 0.640 VA 0.898


Epoch 19 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 19 Val: 100%|██████████| 36/36 [00:07<00:00,  4.68it/s]


E19/30 | TL 0.421 TA 0.839 | VL 0.632 VA 0.897


Epoch 20 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 20 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E20/30 | TL 0.423 TA 0.882 | VL 0.629 VA 0.899
  ✔ Save best 0.8992151488725552


Epoch 21 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 21 Val: 100%|██████████| 36/36 [00:07<00:00,  4.58it/s]


E21/30 | TL 0.397 TA 0.878 | VL 0.635 VA 0.897


Epoch 22 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 22 Val: 100%|██████████| 36/36 [00:07<00:00,  4.65it/s]


E22/30 | TL 0.402 TA 0.900 | VL 0.633 VA 0.897


Epoch 23 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 23 Val: 100%|██████████| 36/36 [00:07<00:00,  4.64it/s]


E23/30 | TL 0.394 TA 0.879 | VL 0.637 VA 0.895


Epoch 24 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 24 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E24/30 | TL 0.398 TA 0.888 | VL 0.644 VA 0.895


Epoch 25 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 25 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E25/30 | TL 0.335 TA 0.905 | VL 0.659 VA 0.894


Epoch 26 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 26 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E26/30 | TL 0.407 TA 0.893 | VL 0.646 VA 0.895


Epoch 27 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 27 Val: 100%|██████████| 36/36 [00:07<00:00,  4.53it/s]


E27/30 | TL 0.415 TA 0.861 | VL 0.637 VA 0.894


Epoch 28 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 28 Val: 100%|██████████| 36/36 [00:07<00:00,  4.55it/s]


E28/30 | TL 0.358 TA 0.886 | VL 0.650 VA 0.894


Epoch 29 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 29 Val: 100%|██████████| 36/36 [00:07<00:00,  4.62it/s]


E29/30 | TL 0.433 TA 0.885 | VL 0.632 VA 0.896


Epoch 30 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 30 Val: 100%|██████████| 36/36 [00:07<00:00,  4.51it/s]


E30/30 | TL 0.435 TA 0.892 | VL 0.632 VA 0.895
📦 SWA model saved.


In [1]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile
from torch.optim.swa_utils import AveragedModel, SWALR

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 224
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.15
MIXUP_PROB = 0.15
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.05

# ---------------------------- SE ResNet50 + Dropout ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        w = self.pool(x).view(b, c)
        w = self.fc(w).view(b, c, 1, 1)
        return x * w

class SEBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        self.se = SEBlock(bottleneck.conv3.out_channels)

    def forward(self, x):
        return self.se(self.body(x))

def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        setattr(m, name, nn.Sequential(*[SEBottleneck(b) for b in getattr(m, name)]))
    m.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(2048, num_classes)
    )
    
    for param in m.parameters():
        param.requires_grad = True

    return m

# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)




opt = optim.AdamW(param_groups, weight_decay=1e-4)
# SWA 设置
swa_start = 12  # 第几轮开始SWA，可调
swa_model = AveragedModel(model)
swa_scheduler = SWALR(opt, swa_lr=1e-5)



sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop(p=99)

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom14.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    use_swa = False  # 开关，在第 swa_start 轮开启
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        # SWA 更新
        if ep >= swa_start:
            if not use_swa:
                print(f"🔁 SWA started from epoch {ep}")
                use_swa = True
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            if use_swa:
                torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
            else:
                torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom15.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    if use_swa:
        torch.optim.swa_utils.update_bn(train_ld, swa_model, device=DEVICE)
        torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
        print("📦 SWA model saved.")
    else:
        torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom15.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.6.1                                   
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/swanlog/run-20250708_043714-5f256344[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mSZY_230507[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mdog-12[0m to the cloud
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/1seep8rhaz5uuyj3b1zpf[0m[0m


  A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
  scaler = GradScaler()
  with autocast():
Epoch 1 Train: 100%|██████████| 290/290 [01:55<00:00,  2.51it/s]
  with autocast():
Epoch 1 Val: 100%|██████████| 36/36 [00:09<00:00,  3.95it/s]


E1/50 | TL 0.579 TA 0.922 | VL 0.853 VA 0.897
  ✔ Save best 0.8968481375358166


Epoch 2 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 2 Val: 100%|██████████| 36/36 [00:07<00:00,  4.64it/s]


E2/50 | TL 0.592 TA 0.873 | VL 0.861 VA 0.896


Epoch 3 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 3 Val: 100%|██████████| 36/36 [00:07<00:00,  4.74it/s]


E3/50 | TL 0.583 TA 0.872 | VL 0.862 VA 0.895


Epoch 4 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 4 Val: 100%|██████████| 36/36 [00:07<00:00,  4.68it/s]


E4/50 | TL 0.592 TA 0.887 | VL 0.860 VA 0.897


Epoch 5 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 5 Val: 100%|██████████| 36/36 [00:07<00:00,  4.80it/s]


E5/50 | TL 0.587 TA 0.900 | VL 0.865 VA 0.895


Epoch 6 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 6 Val: 100%|██████████| 36/36 [00:07<00:00,  4.88it/s]


E6/50 | TL 0.605 TA 0.875 | VL 0.868 VA 0.895


Epoch 7 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 7 Val: 100%|██████████| 36/36 [00:07<00:00,  4.77it/s]


E7/50 | TL 0.611 TA 0.882 | VL 0.860 VA 0.895


Epoch 8 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 8 Val: 100%|██████████| 36/36 [00:07<00:00,  4.64it/s]


E8/50 | TL 0.572 TA 0.891 | VL 0.866 VA 0.895


Epoch 9 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 9 Val: 100%|██████████| 36/36 [00:07<00:00,  4.71it/s]


E9/50 | TL 0.577 TA 0.916 | VL 0.860 VA 0.896


Epoch 10 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 10 Val: 100%|██████████| 36/36 [00:07<00:00,  4.73it/s]


E10/50 | TL 0.589 TA 0.886 | VL 0.860 VA 0.896


Epoch 11 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 11 Val: 100%|██████████| 36/36 [00:07<00:00,  4.63it/s]


E11/50 | TL 0.595 TA 0.878 | VL 0.857 VA 0.896


Epoch 12 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 12 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E12/50 | TL 0.607 TA 0.879 | VL 0.861 VA 0.896


Epoch 13 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 13 Val: 100%|██████████| 36/36 [00:08<00:00,  4.22it/s]


E13/50 | TL 0.626 TA 0.874 | VL 0.865 VA 0.896


Epoch 14 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 14 Val: 100%|██████████| 36/36 [00:07<00:00,  4.74it/s]


E14/50 | TL 0.620 TA 0.873 | VL 0.858 VA 0.897


Epoch 15 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 15 Val: 100%|██████████| 36/36 [00:07<00:00,  4.65it/s]


E15/50 | TL 0.569 TA 0.888 | VL 0.863 VA 0.897


Epoch 16 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 16 Val: 100%|██████████| 36/36 [00:08<00:00,  4.44it/s]


E16/50 | TL 0.605 TA 0.883 | VL 0.858 VA 0.897
  ✔ Save best 0.8970972966238944


Epoch 17 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 17 Val: 100%|██████████| 36/36 [00:07<00:00,  4.69it/s]


E17/50 | TL 0.588 TA 0.887 | VL 0.872 VA 0.898
  ✔ Save best 0.8979693534321664


Epoch 18 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 18 Val: 100%|██████████| 36/36 [00:07<00:00,  4.55it/s]


E18/50 | TL 0.608 TA 0.900 | VL 0.864 VA 0.896


Epoch 19 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 19 Val: 100%|██████████| 36/36 [00:07<00:00,  4.62it/s]


E19/50 | TL 0.539 TA 0.896 | VL 0.870 VA 0.896


Epoch 20 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 20 Val: 100%|██████████| 36/36 [00:07<00:00,  4.56it/s]


E20/50 | TL 0.562 TA 0.914 | VL 0.866 VA 0.895


Epoch 21 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 21 Val: 100%|██████████| 36/36 [00:07<00:00,  4.67it/s]


E21/50 | TL 0.598 TA 0.885 | VL 0.861 VA 0.896


Epoch 22 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 22 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E22/50 | TL 0.589 TA 0.914 | VL 0.869 VA 0.896


Epoch 23 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 23 Val: 100%|██████████| 36/36 [00:07<00:00,  4.73it/s]


E23/50 | TL 0.597 TA 0.877 | VL 0.870 VA 0.895


Epoch 24 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 24 Val: 100%|██████████| 36/36 [00:07<00:00,  4.77it/s]


E24/50 | TL 0.560 TA 0.879 | VL 0.866 VA 0.894


Epoch 25 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 25 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E25/50 | TL 0.589 TA 0.873 | VL 0.861 VA 0.897


Epoch 26 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 26 Val: 100%|██████████| 36/36 [00:07<00:00,  4.59it/s]


E26/50 | TL 0.625 TA 0.865 | VL 0.867 VA 0.898


Epoch 27 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 27 Val: 100%|██████████| 36/36 [00:07<00:00,  4.54it/s]


E27/50 | TL 0.638 TA 0.882 | VL 0.861 VA 0.896


Epoch 28 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 28 Val: 100%|██████████| 36/36 [00:07<00:00,  4.64it/s]


E28/50 | TL 0.624 TA 0.881 | VL 0.864 VA 0.897


Epoch 29 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 29 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


E29/50 | TL 0.620 TA 0.868 | VL 0.864 VA 0.899
  ✔ Save best 0.8985922511523607


Epoch 30 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 30 Val: 100%|██████████| 36/36 [00:07<00:00,  4.61it/s]


🔁 SWA started from epoch 30
E30/50 | TL 0.595 TA 0.880 | VL 0.862 VA 0.898


Epoch 31 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 31 Val: 100%|██████████| 36/36 [00:07<00:00,  4.52it/s]


E31/50 | TL 0.579 TA 0.903 | VL 0.865 VA 0.898


Epoch 32 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 32 Val: 100%|██████████| 36/36 [00:07<00:00,  4.57it/s]


E32/50 | TL 0.575 TA 0.893 | VL 0.856 VA 0.899
  ✔ Save best 0.899339728416594


Epoch 33 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 33 Val: 100%|██████████| 36/36 [00:07<00:00,  4.62it/s]


E33/50 | TL 0.615 TA 0.879 | VL 0.866 VA 0.897


Epoch 34 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 34 Val: 100%|██████████| 36/36 [00:07<00:00,  4.57it/s]


E34/50 | TL 0.547 TA 0.915 | VL 0.868 VA 0.898


Epoch 35 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 35 Val: 100%|██████████| 36/36 [00:07<00:00,  4.68it/s]


E35/50 | TL 0.636 TA 0.883 | VL 0.872 VA 0.897


Epoch 36 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 36 Val: 100%|██████████| 36/36 [00:07<00:00,  4.53it/s]


E36/50 | TL 0.568 TA 0.894 | VL 0.873 VA 0.896


Epoch 37 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 37 Val: 100%|██████████| 36/36 [00:07<00:00,  4.51it/s]


E37/50 | TL 0.599 TA 0.888 | VL 0.871 VA 0.897


Epoch 38 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 38 Val: 100%|██████████| 36/36 [00:07<00:00,  4.54it/s]


E38/50 | TL 0.581 TA 0.895 | VL 0.879 VA 0.895


Epoch 39 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 39 Val: 100%|██████████| 36/36 [00:07<00:00,  4.64it/s]


E39/50 | TL 0.560 TA 0.903 | VL 0.869 VA 0.897


Epoch 40 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 40 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E40/50 | TL 0.615 TA 0.886 | VL 0.863 VA 0.898


Epoch 41 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 41 Val: 100%|██████████| 36/36 [00:07<00:00,  4.70it/s]


E41/50 | TL 0.568 TA 0.905 | VL 0.875 VA 0.898


Epoch 42 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 42 Val: 100%|██████████| 36/36 [00:07<00:00,  4.53it/s]


E42/50 | TL 0.616 TA 0.856 | VL 0.870 VA 0.896


Epoch 43 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 43 Val: 100%|██████████| 36/36 [00:07<00:00,  4.53it/s]


E43/50 | TL 0.574 TA 0.899 | VL 0.870 VA 0.896


Epoch 44 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 44 Val: 100%|██████████| 36/36 [00:07<00:00,  4.75it/s]


E44/50 | TL 0.544 TA 0.900 | VL 0.874 VA 0.896


Epoch 45 Train: 100%|██████████| 290/290 [01:45<00:00,  2.76it/s]
Epoch 45 Val: 100%|██████████| 36/36 [00:08<00:00,  4.49it/s]


E45/50 | TL 0.582 TA 0.907 | VL 0.868 VA 0.896


Epoch 46 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 46 Val: 100%|██████████| 36/36 [00:07<00:00,  4.73it/s]


E46/50 | TL 0.563 TA 0.913 | VL 0.865 VA 0.896


Epoch 47 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 47 Val: 100%|██████████| 36/36 [00:07<00:00,  4.68it/s]


E47/50 | TL 0.573 TA 0.900 | VL 0.876 VA 0.896


Epoch 48 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 48 Val: 100%|██████████| 36/36 [00:08<00:00,  4.48it/s]


E48/50 | TL 0.580 TA 0.899 | VL 0.869 VA 0.898


Epoch 49 Train: 100%|██████████| 290/290 [01:45<00:00,  2.75it/s]
Epoch 49 Val: 100%|██████████| 36/36 [00:07<00:00,  4.62it/s]


E49/50 | TL 0.560 TA 0.895 | VL 0.880 VA 0.895


Epoch 50 Train: 100%|██████████| 290/290 [01:45<00:00,  2.74it/s]
Epoch 50 Val: 100%|██████████| 36/36 [00:08<00:00,  4.47it/s]

E50/50 | TL 0.574 TA 0.873 | VL 0.877 VA 0.894





📦 SWA model saved.


In [8]:
swanlab.finish()

[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@SZY_230507/mushroom-toxicity-detection/runs/99ipcz7n12jh9pjv87wwd[0m[0m
                                                                                                    

In [9]:
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets
from torch.utils.data import DataLoader, WeightedRandomSampler
import swanlab
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from PIL import ImageFile
from torch.optim.swa_utils import AveragedModel, SWALR

ImageFile.LOAD_TRUNCATED_IMAGES = True

swanlab.init(project="mushroom-toxicity-detection2", run="se_cbam_resnet50_mushroom1")

# ---------------------------- 基础配置 ----------------------------
PROJECT = "mushroom-toxicity-detection"
RUN_NAME = "se_cbam_resnet50_v100"
DATA_DIR = Path("/workspace/mushroom_dataset_single_split")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 256
EPOCHS = 30
NUM_WORKERS = 8
MIXUP_ALPHA = 0.15
MIXUP_PROB = 0.15
TTA_SCALES = [224, 256]
LABEL_SMOOTHING = 0.05



# ---------------------------- CBAM 模块 ----------------------------
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Linear(in_planes, in_planes // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_planes // ratio, in_planes, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        return x * self.sigmoid(out).view(b, c, 1, 1)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return x * self.sigmoid(self.conv(x_cat))

class CBAM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ca = ChannelAttention(channels)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x

# ---------------------------- 修改 SEBottleneck 加 CBAM ----------------------------
class SECBAMBottleneck(nn.Module):
    def __init__(self, bottleneck):
        super().__init__()
        self.body = bottleneck
        out_channels = bottleneck.conv3.out_channels
        self.se = SEBlock(out_channels)
        self.cbam = CBAM(out_channels)

    def forward(self, x):
        x = self.body(x)
        x = self.se(x)
        x = self.cbam(x)
        return x

# ---------------------------- 替换原始构建函数 ----------------------------
def build_backbone(num_classes):
    m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for name in ["layer1", "layer2", "layer3", "layer4"]:
        blocks = [SECBAMBottleneck(b) for b in getattr(m, name)]
        setattr(m, name, nn.Sequential(*blocks))
    m.fc = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(2048, num_classes)
    )
    for param in m.parameters():
        param.requires_grad = True
    return m


# ---------------------------- 数据增强 ----------------------------
train_tf = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.HueSaturationValue(10, 15, 10, p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.5),
    A.CoarseDropout(max_holes=4, max_height=16, max_width=16, p=0.05),
    A.GaussianBlur(5, p=0.4),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])
val_tf = A.Compose([
    A.Resize(224, 224),
    A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ---------------------------- Dataset ----------------------------
class AlbDataset(torch.utils.data.Dataset):
    def __init__(self, root, tf):
        self.ds = datasets.ImageFolder(root)
        self.tf = tf
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        p, l = self.ds.samples[idx]
        img = np.asarray(Image.open(p).convert("RGB"))
        return self.tf(image=img)["image"], l

train_ds = AlbDataset(DATA_DIR/"train", train_tf)
val_ds = AlbDataset(DATA_DIR/"val", val_tf)
num_classes = len(train_ds.ds.classes)

# -------- WeightedRandomSampler --------
labels = [l for _, l in train_ds.ds.samples]
counts = np.bincount(labels)
weights = 1.0 / counts
sample_weights = [weights[l] for l in labels]
train_sampler = WeightedRandomSampler(sample_weights, len(train_ds), replacement=True)

train_ld = DataLoader(train_ds, batch_size=BATCH_SIZE,shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=True)
val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                    num_workers=NUM_WORKERS, pin_memory=True)

# ---------------------------- MixUp ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, None, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# ---------------------------- Label Smoothing Loss ----------------------------
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, target):
        logprobs = self.log_softmax(x)
        with torch.no_grad():
            true_dist = torch.zeros_like(logprobs)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * logprobs, dim=-1))

criterion = LabelSmoothingLoss(num_classes, smoothing=LABEL_SMOOTHING)

# ---------------------------- 优化器 & 调度 ----------------------------
model = build_backbone(num_classes).to(DEVICE)

param_groups = [
    {"params": model.conv1.parameters(), "lr": 1e-6},
    {"params": model.bn1.parameters(), "lr": 1e-6},
    {"params": model.layer1.parameters(), "lr": 1e-6},
    {"params": model.layer2.parameters(), "lr": 5e-6},
    {"params": model.layer3.parameters(), "lr": 1e-5},
    {"params": model.layer4.parameters(), "lr": 1e-5},
    {"params": model.fc.parameters(), "lr": 1e-4},
]


opt = optim.AdamW(param_groups, weight_decay=1e-4)
# SWA 设置
swa_start = 12  # 第几轮开始SWA，可调
swa_model = AveragedModel(model)
swa_scheduler = SWALR(opt, swa_lr=1e-5)



sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)
scaler = GradScaler()

class EarlyStop:
    def __init__(self, p=7, delta=0.001):
        self.p, self.d = p, delta; self.best = 1e9; self.c = 0
    def __call__(self, v):
        if v < self.best - self.d: self.best, self.c = v, 0
        else: self.c += 1
        return self.c >= self.p

estop = EarlyStop(p=99)

# 加载之前保存的最优模型
best_model_path = "/workspace/best_se_resnet50_mushroom14.pth"
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    
# ---------------------------- 训练循环 ----------------------------
def train():
    best = 0
    use_swa = False  # 开关，在第 swa_start 轮开启
    for ep in range(1, EPOCHS+1):
        model.train(); tl, tc = 0, 0
        train_pbar = tqdm(train_ld, desc=f"Epoch {ep} Train", unit="it", dynamic_ncols=True)
        for xb, yb in train_pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            lam = 1.0
            if random.random() < MIXUP_PROB:
                xb, y_a, y_b, lam = mixup_data(xb, yb)
            opt.zero_grad(set_to_none=True)
            with autocast():
                logits = model(xb)
                if lam == 1.0:
                    loss = criterion(logits, yb)
                else:
                    loss = lam*criterion(logits, y_a) + (1-lam)*criterion(logits, y_b)
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
            tl += loss.item()*xb.size(0)
            preds = logits.argmax(1)
            tc += (preds==yb).sum().item()
        train_loss = tl/len(train_ds); train_acc = tc/len(train_ds)

        model.eval(); vl, vc = 0, 0
        with torch.no_grad():
            val_pbar = tqdm(val_ld, desc=f"Epoch {ep} Val", unit="it", dynamic_ncols=True)
            for xb, yb in val_pbar:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                with autocast():
                    logits = model(xb)
                    loss = criterion(logits, yb)
                vl += loss.item()*xb.size(0)


                vc += (logits.argmax(1)== yb).sum().item()
        val_loss = vl/len(val_ds); val_acc = vc/len(val_ds)
        # SWA 更新
        if ep >= swa_start:
            if not use_swa:
                print(f"🔁 SWA started from epoch {ep}")
                use_swa = True
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            sched.step()

        print(f"E{ep}/{EPOCHS} | TL {train_loss:.3f} TA {train_acc:.3f} | VL {val_loss:.3f} VA {val_acc:.3f}")
        swanlab.log({"epoch":ep,"train_loss":train_loss,"train_acc":train_acc,"val_loss":val_loss,"val_acc":val_acc,"lr":sched.get_last_lr()[0]})

        if val_acc > best:
            best = val_acc
            if use_swa:
                torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
            else:
                torch.save(model.state_dict(), "/workspace/best_se_resnet50_mushroom15.pth")
            print("  ✔ Save best", best)
        if estop(val_loss):
            print("Early stop!"); break
    if use_swa:
        torch.optim.swa_utils.update_bn(train_ld, swa_model, device=DEVICE)
        torch.save(swa_model.module.state_dict(), "/workspace/best_swa_model.pth")
        print("📦 SWA model saved.")
    else:
        torch.save(model.state_dict(), "/workspace/last_se_resnet50_mushroom15.pth")

if __name__ == "__main__":
    train()


[1m[34mswanlab[0m[0m: \ Waiting for the swanlab cloud response.

[1m[34mswanlab[0m[0m: swanlab version 0.6.5 is available!  Upgrade: `pip install -U swanlab`    
[1m[34mswanlab[0m[0m: \ Creating experiment...                                                  

KeyboardInterrupt: 