In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

In [None]:
##fer2013
import os
from PIL import Image
from torch.utils.data import Dataset
import torch
label_map = {
    'angry': 0,
    'disgust': 1,
    'fear': 2,
    'happy': 3,
    'sad': 4,
    'surprise': 5,
    'neutral': 6
}
class CustomFERDataset(Dataset):
    def __init__(self, root_dir, subset='train', transform=None, label_map=None):
        #标准标签映射
        self.default_label_map = {
            'angry': 0,
            'disgust': 1,
            'fear': 2,
            'happy': 3,
            'sad': 4,
            'surprise': 5,
            'neutral': 6
        }
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform
        self.label_map = label_map or self.default_label_map
        #收集所有图像路径和标签
        self.samples = []
        for class_name in os.listdir(self.root_dir):
            class_dir = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_dir):
                label = self.label_map[class_name.lower()]  # 确保小写匹配
                for img_file in os.listdir(class_dir):
                    if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.samples.append((
                            os.path.join(class_dir, img_file),
                            label
                        ))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        #加载图像
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
##CK+
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
from torchvision import transforms

#定义标签映射
label_map = {
    'anger': 0,
    'contempt': 1,
    'disgust': 2,
    'fear': 3,
    'happy': 4,
    'sadness': 5,
    'surprise': 6
}

class CustomFERDataset(Dataset):
    def __init__(self, samples=None, transform=None):
        self.samples = samples if samples is not None else []
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        #加载为RGB图
        try:
            image = Image.open(img_path).convert('RGB')  # 转换为RGB
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise
        if self.transform:
            image = self.transform(image)
        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long)
        }

In [None]:
##RAF-DB
class RAFDBDataset(Dataset):
    def __init__(self, data, root_dir, transform=None, label_map=None):
        self.data = data
        self.root_dir = root_dir
        self.transform = transform
        self.label_map = label_map

    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        #图像路径和标签
        img_path = str(self.data.iloc[idx, 0])  #路径
        label = self.data.iloc[idx, 1]        #标签
        base_name, ext = os.path.splitext(img_path)
        img_path = f"{base_name}_aligned{ext}"
        full_path = os.path.join(self.root_dir, img_path)
        #加载图像
        try:
            image = Image.open(full_path).convert('RGB')
        except FileNotFoundError:
            print(f"Image not found: {full_path}")
            return None 
        if self.transform:
            image = self.transform(image)
        if self.label_map is not None:
            label = self.label_map[label]
        label = torch.tensor(label, dtype=torch.long)
        return {'image': image, 'label': label}
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)