In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.io import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torch
from torch import nn
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models import resnet18

# 自定义梯度反转层
class GradientReversalLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -grad_output

# 编码器: 使用Faster R-CNN作为特征提取器
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # 使用预训练的Faster R-CNN模型
        self.feature_extractor = fasterrcnn_resnet50_fpn(pretrained=True).backbone

    def forward(self, x):
        x = self.feature_extractor(x)
        return x

# 分类器1: 真假图像分类器
class Classifier1(nn.Module):
    def __init__(self):
        super().__init__()
        # 假设特征维度为256
        self.fc = nn.Linear(256, 2)

    def forward(self, x):
        x = self.fc(x)
        return x

# 分类器2: GAN1与GAN2图像来源分类器
class Classifier2(nn.Module):
    def __init__(self):
        super().__init__()
        # 假设特征维度为256
        self.fc = nn.Linear(256, 2)
        self.grl = GradientReversalLayer.apply

    def forward(self, x):
        x = self.grl(x)
        x = self.fc(x)
        return x

# 整体模型
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.classifier1 = Classifier1()
        self.classifier2 = Classifier2()

    def forward(self, x):
        features = self.encoder(x)
        class1_pred = self.classifier1(features)
        class2_pred = self.classifier2(features)
        return class1_pred, class2_pred


# 自定义数据集类
class ImageDataset(Dataset):
    def __init__(self, real_dir, biggan_dir, which_dir):
        self.real_images = [os.path.join(real_dir, file) for file in os.listdir(real_dir)]
        self.biggan_images = [os.path.join(biggan_dir, file) for file in os.listdir(biggan_dir)]
        self.which_images = [os.path.join(which_dir, file) for file in os.listdir(which_dir)]
        self.total_images = self.real_images + self.biggan_images + self.which_images
        # 真实图像标签为0，biggan生成的图像标签为1，which生成的图像标签为2
        self.labels = [0]*len(self.real_images) + [1]*len(self.biggan_images) + [2]*len(self.which_images)

        # 定义图像转换操作
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.total_images)

    def __getitem__(self, idx):
        image_path = self.total_images[idx]
        image = read_image(image_path).float()
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

# 初始化自定义数据集
real_dir = 'real'
biggan_dir = 'biggan'
which_dir = 'which'
dataset = ImageDataset(real_dir, biggan_dir, which_dir)

# 创建数据加载器
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# 训练模型
def train(model, data_loader, loss_fn, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        for images, labels in data_loader:
            # 为分类器1准备真假标签（0或1）
            real_or_fake_labels = (labels > 0).long()  # 真实图像标签0，假图像标签1
            # 模型前向传播
            class1_pred, class2_pred = model(images)
            # 计算损失
            loss1 = loss_fn(class1_pred, real_or_fake_labels)
            loss2 = loss_fn(class2_pred, labels)
            # 反向传播和优化
            optimizer.zero_grad()
            total_loss = loss1 + loss2
            total_loss.backward()
            optimizer.step()
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item()}")

# 设置训练参数
epochs = 5
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())

# 开始训练
train(model, data_loader, loss_fn, optimizer, epochs)


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /Users/wangjuntong/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
  1%|▏                                   | 1.09M/160M [00:59<15:35:57, 2.96kB/s]

In [1]:
pip install torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torchvision
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/d9/4228947ca56483aebb2162163cc1aee8abc29ebd2c05747863e22145f6b2/torchvision-0.16.1-cp39-cp39-macosx_10_13_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
Collecting torch==2.1.1 (from torchvision)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a2/20/7f297fca8fa55f4401d23f941714834c514b0d43b7f0af9b6615f86e3b97/torch-2.1.1-cp39-none-macosx_10_9_x86_64.whl (147.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.0/147.0 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.10.2
    Uninstalling torch-1.10.2:
      Successfully uninstalled torch-1.10.2
Successfully installed torch-