In [None]:
import torch
import torchvision
import PIL
import os
import numpy as np
from matplotlib import pyplot as plt
import random

# 加载GPU设备

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 切分训练集和测试集

In [None]:
# 加载所有文件

print(f'current dir: {os.getcwd()}')

model_path = "../model/zdan"
file_path = "../data/images"
files_normal = []
files_nsfw = []
for file in os.listdir(file_path):
    if file[0] == '0':
        files_normal.append(file)
    else:
        files_nsfw.append(file)

random.shuffle(files_normal)
random.shuffle(files_nsfw)

print(f'total images: {len(files_normal)} + {len(files_nsfw)} = {len(files_normal) + len(files_nsfw)}')


In [None]:
# 打乱顺序，切分训练集和测试集。合规图片和不合规图片保持一致，10%用于测试

total_count = min(len(files_nsfw), len(files_normal), 8000)
train_count = int(total_count * 0.9)
test_count = int(total_count - train_count)

files_normal_train = files_normal[:train_count]
files_normal_test = files_normal[train_count:train_count+test_count]
files_nsfw_train = files_nsfw[:train_count]
files_nsfw_test = files_nsfw[train_count:train_count+test_count]

files_train = files_normal_train + files_nsfw_train
files_test = files_normal_test + files_nsfw_test

random.shuffle(files_train)
random.shuffle(files_test)

print(f'train:{len(files_train)}, test: {len(files_test)}')

In [None]:
# 打印数据统计信息
def statistics_info(files):
    ls = {0: 0, 1: 0, 2: 0}
    ts = {0: 0, 1: 0, 2: 0}

    cross = {}

    for i in range(len(files)):
        level = int(files[i][0])
        type_ = int(files[i][2])

        ls[level] += 1
        ts[type_] += 1

        if '%d%d' % (level, type_) not in cross:
            cross['%d%d' % (level, type_)] = 0

        cross['%d%d' % (level, type_)] += 1

    print('total:', len(files))
    print('ls:', ls)
    print('ts:', ts)
    print('cross:', cross)

statistics_info(files_train)
statistics_info(files_test)

In [None]:
# 数据预处理函数
train_transform = torchvision.transforms.Compose([
    lambda x: PIL.Image.open(x).convert('RGB'),
    # 裁剪尺寸
    # torchvision.transforms.Resize((169, 300)),
    torchvision.transforms.Resize(256),

    # 随机旋转-15到15度
    torchvision.transforms.RandomRotation(15),
    # 随机左右翻转
    torchvision.transforms.RandomHorizontalFlip(),
    # 中心裁剪,去除黑边
    torchvision.transforms.CenterCrop(224),
    # 转tensor
    torchvision.transforms.ToTensor(),
    # 数据标准化，需要和resnet一致
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
])

test_transform = torchvision.transforms.Compose([
    lambda x: PIL.Image.open(x).convert('RGB'),
    # torchvision.transforms.Resize((169, 300)),
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
])


# 加载文件为图片
def load_file(file, train):

    file_name = f'{file_path}/{file}'
    if train:
        return train_transform(file_name)
    else:
        return test_transform(file_name)

x = load_file(files_train[0], True)
x.shape

# 定义数据集类

In [None]:
# 定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, files, train):
        self.x = []
        self.y = []
        self.train = train

        for file in files:
            self.x.append(file)
            self.y.append(int(file[0]))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        x = load_file(self.x[i], self.train)
        y = self.y[i]
        if y == 2:
            y = 1
        return x.to(device), torch.tensor(y, device=device)


dataloader_train = torch.utils.data.DataLoader(dataset=Dataset(files_train, True),
                                               batch_size=50,
                                               shuffle=True,
                                               drop_last=True)

dataloader_test = torch.utils.data.DataLoader(dataset=Dataset(files_test, False),
                                              batch_size=50,
                                              shuffle=False,
                                              drop_last=False)

for i, (x, y) in enumerate(dataloader_train):
    break

x.shape, y.shape

# 训练模型

In [None]:
# 使用预训练模型
resnet = torchvision.models.resnet18(weights=None)
resnet_state_dict = torch.load(f"{model_path}/resnet18-f37072fd.pth", weights_only=True, map_location=device)
resnet.load_state_dict(resnet_state_dict)

# 剪掉最后一层
resnet = list(resnet.children())[:-1]
resnet = torch.nn.Sequential(*resnet)
resnet.to(device)

# 不需要计算梯度
resnet.eval()

# 试算
out = resnet(x)
out.shape

In [None]:
# 迁移学习模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(512, 512),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(512, 2)
        )

    def forward(self, x):
        return self.fc(x)

model = Model()
model = model.to(device)
model(out.squeeze()).shape

In [None]:
# 训练模型
def train(num_epochs=10, patience=3):
    optimizer_model = torch.optim.Adam(model.parameters(), lr=1e-4)
    criteon = torch.nn.CrossEntropyLoss()

    best_loss = float('inf')
    uneffect_times = 0
    best_model_state = None

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_model, mode='min', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        # 训练模式
        model.train()
        resnet.eval()

        total_loss = 0
        total_correct = 0
        total_train = 0

        for i, (x, y) in enumerate(dataloader_train):
            optimizer_model.zero_grad()

            # forward
            out = resnet(x)
            out = out.view(out.size(0), -1)
            out = model(out)

            loss = criteon(out, y)
            loss.backward()

            optimizer_model.step()

            total_loss += loss.item()
            preds = out.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_train += y.size(0)

            if i % 10 == 0:
                acc = (preds == y).sum().item() / float(y.size(0))
                print(f"Epoch [{epoch+1}/{num_epochs}] Step [{i}] Loss: {loss.item():.4f} Acc: {acc*100:.2f}%")

        train_loss = total_loss/len(dataloader_train)
        train_acc = total_correct / total_train

        # 动态调整学习率
        scheduler.step(train_loss)

        print(f"==> Epoch [{epoch+1}/{num_epochs}] Finished! Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%\n")

        # 几轮梯度没有下降，结束训练
        if train_loss < best_loss:
            best_loss = train_loss
            best_model_state = model.state_dict()
            uneffect_times = 0
        else:
            uneffect_times += 1
            print(f"Trigger Times: {uneffect_times}")

            if uneffect_times >= patience:
                print("Early stopping!")
                break

    # 恢复最好的模型参数
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # 保存模型
    torch.save(model.state_dict(), f'{model_path}/model.mdl')
    print("Best model saved!")


# 训练模型
train(num_epochs=1)


# 验证模型

In [None]:
# 在测试集上验证

test_model = Model()
state_dict = torch.load(f'{model_path}/model.mdl', weights_only=True,map_location=device)
test_model.to(device)
test_model.load_state_dict(state_dict)
test_model.eval()     


def test():

    criteon = torch.nn.CrossEntropyLoss()
    
    total_loss = 0
    total_test = 0
    total_correct = 0
    with torch.no_grad():
        for x_test, y_test in dataloader_test:
            out_test = resnet(x_test)
            out_test = out_test.view(out_test.size(0), -1)
            out_test = test_model(out_test)

            loss_test = criteon(out_test, y_test)
            total_loss += loss_test.item()

            preds_val = out_test.argmax(dim=1)
            total_correct += (preds_val == y_test).sum().item()
            total_test += y_test.size(0)

    test_loss = total_loss / len(dataloader_test)
    test_acc = total_correct / total_test

    print(f"==> Test Loss: {test_loss:.4f}, Test Acc: {test_acc*100:.2f}%\n")


test()    