In [6]:
from torch import nn
from torch.nn import *

from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import *


In [15]:
class ResModule(nn.Module):
    def __init__(self, in_channel, out_channel, use_1x1=False, stride=1):
        super().__init__()
        self.relu = ReLU(inplace=True)
        if use_1x1:
            self.conv3 = Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0)
        else:
            self.conv3 = None

        self.model = Sequential(
            Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1),
            BatchNorm2d(out_channel, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            ReLU(inplace=True),
            Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            BatchNorm2d(out_channel, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )

    def forward(self, x):
        out = self.model(x)
        if self.conv3:
            x = self.conv3(x)
        out += x
        out = self.relu(out)
        return out



def resnet_block(input_channel, output_channel, resModule_num, firstblock=False):
    blk = []
    for i in range(resModule_num):
        if i == 0 and not firstblock:
            blk.append(ResModule(input_channel, output_channel, True, stride=2))
        else:
            blk.append(ResModule(output_channel, output_channel))
    return blk



class Resnet(nn.Module):
    def __init__(self):
        super(Resnet, self).__init__()  
        self.module = Sequential(
            Conv2d(3, 64, kernel_size=7, stride=2, padding=1, bias=False),
            BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            ReLU(),
            MaxPool2d(kernel_size=3, stride=2, padding=1),
            *resnet_block(64, 64, 3, firstblock=True),
            *resnet_block(64, 128, 4),
            *resnet_block(128, 256, 8),
            *resnet_block(256, 512, 3),
            AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            # Linear(512, 128),
            # ReLU(),
            # Dropout(0.5),
            Linear(512, 2)
        )

    def forward(self, x):
        out = self.module(x)
        return out


In [None]:
class TransformSubset(torch.utils.data.Subset):
    def __init__(self, dataset, indices, transform=None):
        super().__init__(dataset, indices)
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        if self.transform:
            x = self.transform(x)
        return x, y
    
IMAGE_SIZE = (224, 224)
# 定义数据增强操作
transform_augmented = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 定义不包含数据增强的转换操作 归一化
transform_plain = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset_path = 'images'

all_data = datasets.ImageFolder(root=dataset_path, transform=transform_plain)
train_size = int(0.8 * len(all_data))
test_size = len(all_data) - train_size
# 划分测试集和训练集
train_data, test_data = torch.utils.data.random_split(all_data, [train_size, test_size])
# 对训练集进行数据增强
train_data_augmented = TransformSubset(all_data, train_data.indices, transform=transform_augmented)
combined_dataset = ConcatDataset([train_data, train_data_augmented])


batch_size = 32
testdata_len = test_size

traindata_load = DataLoader(train_data, batch_size)
testdata_load = DataLoader(test_data, batch_size)

# 创建网络模型
qnet = Resnet()
print(torch.cuda.is_available())

# 使用显卡
if torch.cuda.is_available():
    qnet = qnet.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 使用显卡
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# s设置参数
train_step = 0
test_step = 0
epoch = 30

# 优化器
learn_rate = 0.001
optimizer = torch.optim.SGD(qnet.parameters(), lr=learn_rate, momentum=0.9, nesterov=False)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=600, gamma=0.3)

# 添加tensorboard
writer = SummaryWriter("./logs/Resnet_au")

qnet.train()
for i in range(epoch):
    print(f"--------第{i + 1}轮训练--------")
    for data in traindata_load:
        imgs, labels = data
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
        output = qnet(imgs)
        loss = loss_fn(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_step = train_step + 1

        # 学习率更新
        lr_scheduler.step()

        if train_step % 50 == 0:
            print(f"训练次数{train_step},损失值{loss.item()}")
            writer.add_scalar("train_loss{}_{}".format(batch_size, epoch), loss.item(), train_step)

    # 测试
    qnet.eval()

    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in testdata_load:
            imgs, labels = data
            
            if torch.cuda.is_available():
                imgs = imgs.cuda()
                labels = labels.cuda()
            output = qnet(imgs)
            loss = loss_fn(output, labels)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (output.argmax(1) == labels).sum()
            total_accuracy = total_accuracy + accuracy
    print(f"整体测试集上的loss{total_test_loss}")
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current learning rate: {current_lr}")
    print(f"整体测试集上的正确率{total_accuracy / testdata_len}")
    writer.add_scalar("test_loss{}_{}".format(batch_size, epoch), total_test_loss, test_step)
    writer.add_scalar("accuracy{}_{}".format(batch_size, epoch),
                      total_accuracy / testdata_len,
                      test_step)
    test_step = test_step + 1
    # 保存模型
    
    torch.save(qnet, "qnet_{}.pth".format(i))

writer.close()
