In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn

In [2]:
# 自定义数据集类
class ImageDataset(Dataset):
	def __init__(self, folder_path, label, transform=None):
		self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.jpg'))]
		self.labels = [label] * len(self.image_paths)  # 所有图像的标签相同
		self.transform = transform

	def __len__(self):
		return len(self.image_paths)

	def __getitem__(self, idx):
		image = self._load_image(self.image_paths[idx])
		label = torch.tensor(self.labels[idx], dtype=torch.float32)

		if self.transform:
			image = self.transform(image)

		return image, label

	def _load_image(self, path):
		return Image.open(path).convert('RGB')

# 数据加载器创建函数
def create_dataloaders(train_real_folder, train_fake_folder, val_real_folder, val_fake_folder, batch_size=16):
	transform = transforms.Compose([
		transforms.Resize((512, 512)),
		transforms.ToTensor(),
		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 使用ImageNet均值和标准差
	])

	# 创建数据集实例
	train_real_dataset = ImageDataset(train_real_folder, label=1, transform=transform)  # 真实图像标签为1
	train_fake_dataset = ImageDataset(train_fake_folder, label=0, transform=transform)  # 虚假图像标签为0

	val_real_dataset = ImageDataset(val_real_folder, label=1, transform=transform)  # 真实图像标签为1
	val_fake_dataset = ImageDataset(val_fake_folder, label=0, transform=transform)  # 虚假图像标签为0
	
	print("there are ", len(train_real_dataset), " photos in train_real_dataset")
	print("there are ", len(train_fake_dataset), " photos in train_fake_dataset")
	print("there are ", len(val_real_dataset), " photos in val_real_dataset")
	print("there are ", len(val_fake_dataset), " photos in val_fake_dataset")

	# 使用ConcatDataset合并训练集
	train_dataset = torch.utils.data.ConcatDataset([train_real_dataset, train_fake_dataset])
	val_dataset = torch.utils.data.ConcatDataset([val_real_dataset, val_fake_dataset])

	train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
	val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

	return train_loader, val_loader

In [3]:
# 使用示例
train_real_folder = "../AIGC-Detection-Dataset/train/0_real"  # 训练集真实图像文件夹路径
train_fake_folder = "../AIGC-Detection-Dataset/train/1_fake"  # 训练集虚假图像文件夹路径
val_real_folder 	= "../AIGC-Detection-Dataset/val/0_real"  	# 验证集真实图像文件夹路径
val_fake_folder 	= "../AIGC-Detection-Dataset/val/1_fake"  	# 验证集虚假图像文件夹路径

# 创建数据加载器
train_loader, val_loader = create_dataloaders(train_real_folder, train_fake_folder, val_real_folder, val_fake_folder, batch_size=16)

there are  500  photos in train_real_dataset
there are  500  photos in train_fake_dataset
there are  2500  photos in val_real_dataset
there are  2500  photos in val_fake_dataset


In [4]:
class SSPNetwork(nn.Module):
    def __init__(self, patch_size=32):
        super(SSPNetwork, self).__init__()
        self.patch_size = patch_size
        
        # SRMfilter初始化
        self.srm_filters = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False)
        self.srm_filters.weight.data = self._init_srm_filters()
        self.srm_filters.weight.requires_grad = False  # 固定滤波器
        
        # ResNet50模型
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(2048, 1)  # 修改为二分类输出

    def _init_srm_filters(self):
        filters = torch.zeros((3, 3, 3, 3))
        filters[0, 0] = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32)  # 示例滤波器
        filters[1, 0] = torch.tensor([[-1, 2, -1], [2, -4, 2], [-1, 2, -1]], dtype=torch.float32)  # 边缘检测
        filters[2, 0] = torch.tensor([[1, -2, 1], [-2, 4, -2], [1, -2, 1]], dtype=torch.float32)  # 高通滤波器
        return filters

    def extract_simple_patch(self, x):
        patches = []
        batch_size, _, H, W = x.shape
        num_patches_h = H // self.patch_size
        num_patches_w = W // self.patch_size

        for i in range(batch_size):
            min_div = float('inf')
            simplest_patch = None

            for h in range(num_patches_h):
                for w in range(num_patches_w):
                    patch = x[i, :, h * self.patch_size:(h + 1) * self.patch_size, w * self.patch_size:(w + 1) * self.patch_size]
                    l_div = self._compute_texture_diversity(patch)
                    if l_div < min_div:
                        min_div = l_div
                        simplest_patch = patch

            patches.append(simplest_patch)

        return torch.stack(patches)

    def _compute_texture_diversity(self, patch):
        patch = patch.mean(dim=0)  # 转为灰度图
        h_diff = torch.abs(patch[:, :-1] - patch[:, 1:]).sum()
        v_diff = torch.abs(patch[:-1, :] - patch[1:, :]).sum()
        diag_diff = torch.abs(patch[:-1, :-1] - patch[1:, 1:]).sum()
        anti_diag_diff = torch.abs(patch[1:, :-1] - patch[:-1, 1:]).sum()
        return h_diff + v_diff + diag_diff + anti_diag_diff

    def forward(self, x):
        simplest_patch = self.extract_simple_patch(x)
        noise_pattern = self.srm_filters(simplest_patch)
        output = self.resnet(noise_pattern)
        return torch.sigmoid(output)

def evaluate_model(model, val_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).squeeze()  # 输出形状为 (batch_size)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            predictions = (outputs >= 0.5).float()  # 阈值为0.5
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return total_loss / len(val_loader), accuracy

def train_model(ssp_network, train_loader, val_loader, epochs=10, lr=1e-4):
    loss_fn = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(ssp_network.parameters(), lr=lr)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ssp_network = ssp_network.to(device)
    
    for epoch in range(epochs):
        ssp_network.train()
        total_loss = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = ssp_network(images).squeeze()  # 输出形状为 (batch_size)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        val_loss, val_accuracy = evaluate_model(ssp_network, val_loader, loss_fn, device)
        print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {total_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

In [5]:
ssp_network = SSPNetwork(patch_size=32)
train_model(ssp_network, train_loader, val_loader, epochs=10, lr=1e-4)



Epoch 1/10, Train Loss: 38.2038, Val Loss: 0.6197, Val Accuracy: 0.7584
Epoch 2/10, Train Loss: 33.0496, Val Loss: 0.5275, Val Accuracy: 0.7266
Epoch 3/10, Train Loss: 30.0471, Val Loss: 0.5191, Val Accuracy: 0.7244
Epoch 4/10, Train Loss: 24.2383, Val Loss: 0.5143, Val Accuracy: 0.7630
Epoch 5/10, Train Loss: 22.5169, Val Loss: 0.3455, Val Accuracy: 0.8542
Epoch 6/10, Train Loss: 18.0961, Val Loss: 0.3106, Val Accuracy: 0.8576
Epoch 7/10, Train Loss: 15.2795, Val Loss: 0.2811, Val Accuracy: 0.8676
Epoch 8/10, Train Loss: 15.2168, Val Loss: 0.2503, Val Accuracy: 0.8786
Epoch 9/10, Train Loss: 14.1664, Val Loss: 0.3066, Val Accuracy: 0.8590
Epoch 10/10, Train Loss: 13.4335, Val Loss: 0.2291, Val Accuracy: 0.9198
