In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torchvision.models.resnet import BasicBlock
import time

from datetime import datetime

# 创建时间戳文件夹
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f"runs/FER2013_{timestamp}"
writer = SummaryWriter(log_dir)

print(f"TensorBoard logs are being saved to: {log_dir}")

# ==================== 定义网络 ====================
# SE 模块定义
class SEModule(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEModule, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels // reduction, channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch, channels, _, _ = x.size()
        device = x.device  # 获取输入张量的设备（CPU 或 GPU）
        
        # 确保全连接层的计算在同一设备上
        se = self.global_avg_pool(x).view(batch, channels).to(device)  # 全局平均池化
        se = self.fc1(se).to(device)  # 压缩
        se = self.relu(se).to(device)
        se = self.fc2(se).to(device)  # 激活
        se = self.sigmoid(se).view(batch, channels, 1, 1).to(device)  # 调整形状
    
        return x * se  # 注意力加权

# 基于 SE 模块的 BasicBlock
class SEBasicBlock(BasicBlock):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(SEBasicBlock, self).__init__(inplanes, planes, stride, downsample)
        self.se = SEModule(planes * self.expansion)  # 添加 SE 模块

    def forward(self, x):
        residual = x  # 残差
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual  # 残差连接
        out = self.relu(out)

        out = self.se(out)  # 添加 SE 模块处理
        return out

# ResNetUNet 使用 SE 模块
class ResNetUNet(nn.Module):
    def __init__(self, block, layers, num_classes=7):
        super(ResNetUNet, self).__init__()
        self.inplanes = 64

        # 初始卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)

        # 使用 SE 模块的 ResNet 编码器层
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # U-Net 解码器层
        self.upconv4 = self._upconv(512, 256)
        self.dec4 = self._block(256 + 256, 256)

        self.upconv3 = self._upconv(256, 128)
        self.dec3 = self._block(128 + 128, 128)

        self.upconv2 = self._upconv(128, 64)
        self.dec2 = self._block(64 + 64, 64)

        self.upconv1 = self._upconv(64, 64)
        self.dec1 = self._block(64 + 64, 64)

        # 分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def _block(self, in_channels, out_channels):
        """卷积块"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def _upconv(self, in_channels, out_channels):
        """反卷积用于上采样"""
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # 编码器
        enc1 = self.conv1(x)  # 初始卷积
        enc2 = self.layer1(self.maxpool(enc1))  # ResNet layer 1
        enc3 = self.layer2(enc2)                # ResNet layer 2
        enc4 = self.layer3(enc3)                # ResNet layer 3
        enc5 = self.layer4(enc4)                # ResNet layer 4

        # 解码器
        dec4 = self.upconv4(enc5)
        enc4 = self._resize(enc4, dec4)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        enc3 = self._resize(enc3, dec3)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        enc2 = self._resize(enc2, dec2)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        enc1 = self._resize(enc1, dec1)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        # 全局池化与分类
        x = self.global_avg_pool(dec1)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

    def _resize(self, enc, dec):
        """调整尺寸以适配跳跃连接"""
        if enc.shape[2:] != dec.shape[2:]:
            enc = F.interpolate(enc, size=dec.shape[2:], mode="bilinear", align_corners=False)
        return enc

# ==================== Label Smoothing Loss ====================
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(1)
        smooth_target = torch.zeros_like(pred).scatter_(1, target.unsqueeze(1), 1)
        smooth_target = smooth_target * (1 - self.smoothing) + self.smoothing / n_classes
        log_prob = F.log_softmax(pred, dim=1)
        return (-smooth_target * log_prob).sum(dim=1).mean()

# ==================== 数据集加载 ====================
print("Step 1: Loading and preparing dataset...")

transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.RandomCrop(44),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

train_dir = "FER-2013 dataset/train"
test_dir = "FER-2013 dataset/test"

train_dataset_full = datasets.ImageFolder(root=train_dir, transform=transform)
train_size = int(0.8 * len(train_dataset_full))
val_size = len(train_dataset_full) - train_size
train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size])

test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# ==================== 模型定义 ====================
print("Step 2: Defining the model...")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetUNet(SEBasicBlock, [3, 4, 6, 3], num_classes=7).to(device)

print("Model structure:")
print(model, "\n")

# ==================== 损失函数与优化器 ====================
print("Step 3: Defining loss function and optimizer...")
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
num_epochs = 50
# 使用 CosineAnnealingLR 调度器
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# ==================== 训练循环 ====================
print("Step 4: Starting training...\n")
best_val_acc = 0.0
patience = 15
trigger_times = 0

for epoch in range(num_epochs):
    start_time = time.time()
    print(f"Epoch {epoch+1}/{num_epochs}")
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    writer.add_scalar("Loss/Train", train_loss, epoch)
    writer.add_scalar("Loss/Validation", val_loss, epoch)
    writer.add_scalar("Accuracy/Train", train_acc, epoch)
    writer.add_scalar("Accuracy/Validation", val_acc, epoch)

    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1}/{num_epochs} completed in {epoch_time:.2f} seconds.")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\n")

    scheduler.step(val_loss)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model updated and saved!")
        trigger_times = 0
    else:
        trigger_times += 1
        if trigger_times >= patience:
            print("Early stopping triggered!")
            break

writer.close()

# ==================== 测试 ====================
print("Step 5: Testing the model...")
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_correct, test_total = 0, 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        test_correct += (preds == labels).sum().item()
        test_total += labels.size(0)

test_acc = test_correct / test_total


print(f"Test Accuracy: {test_acc:.4f}")


TensorBoard logs are being saved to: runs/FER2013_2024-11-21_00-49-44
Step 1: Loading and preparing dataset...
Step 2: Defining the model...
Model structure:
ResNetUNet(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (maxpool): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (layer1): Sequential(
    (0): SEBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), s

KeyboardInterrupt: 