In [1]:
from IPython import get_ipython

from PIL import Image
import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

import time
import random
from math import ceil
from tqdm import tqdm
import os

import torch.utils.data as data
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import torch.optim as optim
import torch.nn.functional as F


# 设置后续需要的参数（部分实际未被使用）
class VarsConfig(object):
    train_data = "train.csv"
    test_data = "val.csv"
    pre_data = "test.csv"

    epoch = 30
    batch_size = 64
    img_height = 64
    img_weight = 64
    seed = 666


config = VarsConfig()


# 设置图像变换
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机旋转变换
    transforms.RandomCrop(64, padding=4),  # 随机裁剪
    transforms.ToTensor(),  # 转换为tensor格式
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

test_transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为tensor格式
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])



# 读取训练集和测试集
train_csv = pd.read_csv(config.train_data, header=None).values
test_csv = pd.read_csv(config.test_data, header=None).values
val_csv = pd.read_csv(config.pre_data, header=None).values

train_label = train_csv[0:65000, 0]
train_image = train_csv[0:65000, 1:]
test_label = test_csv[0:12500, 0]
test_image= test_csv[0:12500, 1:]


train_image = torch.from_numpy(train_image).type(torch.LongTensor)
test_image = torch.from_numpy(test_image).type(torch.LongTensor)
train_label = torch.from_numpy(train_label).type(torch.LongTensor)
test_label = torch.from_numpy(test_label).type(torch.LongTensor)

# 读取训练集和测试集
train_image = train_image.view(-1, 2, 64, 64).float()
test_image = test_image.view(-1, 2, 64, 64).float()
train_data = torch.utils.data.TensorDataset(train_image, train_label)
test_data = torch.utils.data.TensorDataset(test_image, test_label)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=config.batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=config.batch_size,
                                          shuffle=False)



class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3,
                      stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=12):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv0 = nn.Sequential(
            nn.Conv2d(2, 3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)  # strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv0(x)
        out = self.conv1(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet18():
    return ResNet(ResidualBlock)

# 定义模型保存函数
def save_models(epoch):
    torch.save(model, "Resnet_{}_RG.mdl".format(epoch + 1))
    print("Chekcpoint saved")



model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model)


# train the model
best_acc = 0.0
results = pd.DataFrame(columns=['epoch', 'train_acc', 'test_acc',
                                'train_loss', 'test_loss'], index=range(0, config.epoch))
for epoch in range(config.epoch):
    print("epoch:", +epoch)
    model.train()
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)

        lossvalue = criterion(output, labels)
        lossvalue.backward()
        optimizer.step()

    # evaluate the model
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    correct_num_test = 0
    train_loss = 0.0
    train_acc = 0.0
    correct_num_train = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            train_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_train += pred.eq(labels.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    correct_rate = 100. * correct_num_train / len(train_loader.dataset)
    train_acc = correct_num_train / len(train_loader.dataset)
    print('Train Accuracy: {0}/{1}({2:.4}%),Train Loss: {3}'.format(
        correct_num_train, len(train_loader.dataset), correct_rate, train_loss))

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            test_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_test += pred.eq(labels.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    correct_rate = 100. * correct_num_test / len(test_loader.dataset)
    test_acc = correct_num_test / len(test_loader.dataset)


    if test_acc > best_acc:
        save_models(epoch)
        best_acc = test_acc
        best_epoch = epoch + 1
    results.iloc[epoch] = [epoch + 1, train_acc,
                           test_acc, train_loss, test_loss]
    print('Test  Accuracy: {0}/{1}({2:.4}%), Test Loss: {3}'.format(
        correct_num_test, len(test_loader.dataset), correct_rate, test_loss))


# %%
results.to_csv('resnet_RG.csv', index=False)


  0%|          | 3/1016 [00:00<00:58, 17.44it/s]

ResNet(
  (conv0): Sequential(
    (0): Conv2d(2, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (left): Sequen

100%|██████████| 1016/1016 [01:17<00:00, 13.09it/s]


Train Accuracy: 47005/65000(72.32%),Train Loss: 0.012887772414775995


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  0%|          | 3/1016 [00:00<00:54, 18.63it/s]

Chekcpoint saved
Test  Accuracy: 8092/12500(64.74%), Test Loss: 0.01752677803039551
epoch: 1


100%|██████████| 1016/1016 [01:17<00:00, 13.09it/s]


Train Accuracy: 51930/65000(79.89%),Train Loss: 0.00949064022210928


  0%|          | 3/1016 [00:00<00:54, 18.47it/s]

Chekcpoint saved
Test  Accuracy: 8978/12500(71.82%), Test Loss: 0.014338200905323029
epoch: 2


100%|██████████| 1016/1016 [01:17<00:00, 13.08it/s]


Train Accuracy: 55008/65000(84.63%),Train Loss: 0.0074002322412454165


  0%|          | 3/1016 [00:00<00:54, 18.60it/s]

Chekcpoint saved
Test  Accuracy: 9639/12500(77.11%), Test Loss: 0.012101534149646759
epoch: 3


100%|██████████| 1016/1016 [01:17<00:00, 13.07it/s]


Train Accuracy: 55394/65000(85.22%),Train Loss: 0.00697952571098621


  0%|          | 3/1016 [00:00<00:54, 18.61it/s]

Chekcpoint saved
Test  Accuracy: 9759/12500(78.07%), Test Loss: 0.011385370898246766
epoch: 4


100%|██████████| 1016/1016 [01:17<00:00, 13.05it/s]


Train Accuracy: 57459/65000(88.4%),Train Loss: 0.005485958107274312


  0%|          | 3/1016 [00:00<00:54, 18.53it/s]

Test  Accuracy: 9758/12500(78.06%), Test Loss: 0.011620780217647552
epoch: 5


100%|██████████| 1016/1016 [01:18<00:00, 13.02it/s]


Train Accuracy: 58787/65000(90.44%),Train Loss: 0.004624057704210281


  0%|          | 3/1016 [00:00<00:54, 18.62it/s]

Chekcpoint saved
Test  Accuracy: 9897/12500(79.18%), Test Loss: 0.011480068047046661
epoch: 6


100%|██████████| 1016/1016 [01:17<00:00, 13.03it/s]


Train Accuracy: 60016/65000(92.33%),Train Loss: 0.003692979823052883


  0%|          | 3/1016 [00:00<00:55, 18.24it/s]

Chekcpoint saved
Test  Accuracy: 9950/12500(79.6%), Test Loss: 0.011668679243326187
epoch: 7


100%|██████████| 1016/1016 [01:18<00:00, 12.99it/s]


Train Accuracy: 60329/65000(92.81%),Train Loss: 0.0033176752673891876


  0%|          | 3/1016 [00:00<00:55, 18.22it/s]

Test  Accuracy: 9919/12500(79.35%), Test Loss: 0.011483445432186127
epoch: 8


100%|██████████| 1016/1016 [01:18<00:00, 12.99it/s]


Train Accuracy: 61052/65000(93.93%),Train Loss: 0.0027739594891094243


  0%|          | 3/1016 [00:00<00:54, 18.61it/s]

Test  Accuracy: 9926/12500(79.41%), Test Loss: 0.012412639799118042
epoch: 9


100%|██████████| 1016/1016 [01:18<00:00, 12.97it/s]


Train Accuracy: 61484/65000(94.59%),Train Loss: 0.002436279986254298


  0%|          | 3/1016 [00:00<00:55, 18.11it/s]

Test  Accuracy: 9933/12500(79.46%), Test Loss: 0.013074889688491822
epoch: 10


100%|██████████| 1016/1016 [01:18<00:00, 12.96it/s]


Train Accuracy: 62369/65000(95.95%),Train Loss: 0.0018783906625727048


  0%|          | 3/1016 [00:00<00:55, 18.41it/s]

Test  Accuracy: 9775/12500(78.2%), Test Loss: 0.01479254652261734
epoch: 11


100%|██████████| 1016/1016 [01:18<00:00, 12.95it/s]


Train Accuracy: 62854/65000(96.7%),Train Loss: 0.001538294864933078


  0%|          | 3/1016 [00:00<00:54, 18.66it/s]

Test  Accuracy: 9689/12500(77.51%), Test Loss: 0.0162341202378273
epoch: 12


100%|██████████| 1016/1016 [01:18<00:00, 12.94it/s]


Train Accuracy: 63604/65000(97.85%),Train Loss: 0.0010359385760930868


  0%|          | 3/1016 [00:00<00:55, 18.33it/s]

Chekcpoint saved
Test  Accuracy: 10016/12500(80.13%), Test Loss: 0.01509991004705429
epoch: 13


100%|██████████| 1016/1016 [01:18<00:00, 12.93it/s]


Train Accuracy: 63439/65000(97.6%),Train Loss: 0.0011216936794897685


  0%|          | 3/1016 [00:00<00:55, 18.38it/s]

Test  Accuracy: 9814/12500(78.51%), Test Loss: 0.01745634373664856
epoch: 14


100%|██████████| 1016/1016 [01:18<00:00, 12.91it/s]


Train Accuracy: 63584/65000(97.82%),Train Loss: 0.0009540380962193012


  0%|          | 3/1016 [00:00<00:54, 18.59it/s]

Test  Accuracy: 9793/12500(78.34%), Test Loss: 0.01770775090932846
epoch: 15


100%|██████████| 1016/1016 [01:18<00:00, 12.91it/s]


Train Accuracy: 64151/65000(98.69%),Train Loss: 0.0006371441792696715


  0%|          | 3/1016 [00:00<00:54, 18.55it/s]

Test  Accuracy: 9968/12500(79.74%), Test Loss: 0.017056134798526763
epoch: 16


100%|██████████| 1016/1016 [01:18<00:00, 12.90it/s]


Train Accuracy: 63970/65000(98.42%),Train Loss: 0.0007064251139473457


  0%|          | 3/1016 [00:00<00:54, 18.49it/s]

Test  Accuracy: 9884/12500(79.07%), Test Loss: 0.01898557251691818
epoch: 17


100%|██████████| 1016/1016 [01:18<00:00, 12.86it/s]


Train Accuracy: 63941/65000(98.37%),Train Loss: 0.0007735566375777125


  0%|          | 3/1016 [00:00<00:54, 18.56it/s]

Test  Accuracy: 9881/12500(79.05%), Test Loss: 0.019118107838630678
epoch: 18


100%|██████████| 1016/1016 [01:18<00:00, 12.91it/s]


Train Accuracy: 64425/65000(99.12%),Train Loss: 0.00043772421592416674


  0%|          | 3/1016 [00:00<00:54, 18.47it/s]

Test  Accuracy: 9791/12500(78.33%), Test Loss: 0.020460833270549775
epoch: 19


100%|██████████| 1016/1016 [01:18<00:00, 12.87it/s]


Train Accuracy: 64267/65000(98.87%),Train Loss: 0.0005106958359336624


  0%|          | 3/1016 [00:00<00:55, 18.36it/s]

Test  Accuracy: 9903/12500(79.22%), Test Loss: 0.0193811878490448
epoch: 20


100%|██████████| 1016/1016 [01:18<00:00, 12.87it/s]


Train Accuracy: 64461/65000(99.17%),Train Loss: 0.0004006918352742035


  0%|          | 3/1016 [00:00<00:55, 18.42it/s]

Test  Accuracy: 9971/12500(79.77%), Test Loss: 0.019665533814430237
epoch: 21


100%|██████████| 1016/1016 [01:19<00:00, 12.84it/s]


Train Accuracy: 64741/65000(99.6%),Train Loss: 0.00023450733152791285


  0%|          | 3/1016 [00:00<00:55, 18.37it/s]

Test  Accuracy: 9958/12500(79.66%), Test Loss: 0.01916437611579895
epoch: 22


100%|██████████| 1016/1016 [01:19<00:00, 12.85it/s]


Train Accuracy: 64527/65000(99.27%),Train Loss: 0.0003358486984713146


  0%|          | 3/1016 [00:00<00:55, 18.39it/s]

Test  Accuracy: 9796/12500(78.37%), Test Loss: 0.02181279773712158
epoch: 23


100%|██████████| 1016/1016 [01:19<00:00, 12.83it/s]


Train Accuracy: 64675/65000(99.5%),Train Loss: 0.0002456864282560463


  0%|          | 3/1016 [00:00<00:57, 17.73it/s]

Test  Accuracy: 9921/12500(79.37%), Test Loss: 0.02154823073387146
epoch: 24


100%|██████████| 1016/1016 [01:19<00:00, 12.81it/s]


Train Accuracy: 64714/65000(99.56%),Train Loss: 0.0002179170373433198


  0%|          | 3/1016 [00:00<00:55, 18.33it/s]

Test  Accuracy: 9998/12500(79.98%), Test Loss: 0.020364581487178804
epoch: 25


100%|██████████| 1016/1016 [01:19<00:00, 12.80it/s]


Train Accuracy: 64727/65000(99.58%),Train Loss: 0.0002197597877552303


  0%|          | 3/1016 [00:00<00:54, 18.46it/s]

Test  Accuracy: 9862/12500(78.9%), Test Loss: 0.021448053345680236
epoch: 26


100%|██████████| 1016/1016 [01:19<00:00, 12.80it/s]


Train Accuracy: 64324/65000(98.96%),Train Loss: 0.0004415972341460964


  0%|          | 3/1016 [00:00<00:54, 18.44it/s]

Test  Accuracy: 9725/12500(77.8%), Test Loss: 0.02421370985507965
epoch: 27


100%|██████████| 1016/1016 [01:19<00:00, 12.79it/s]


Train Accuracy: 64535/65000(99.28%),Train Loss: 0.0003265098957129969


  0%|          | 3/1016 [00:00<00:54, 18.44it/s]

Test  Accuracy: 9851/12500(78.81%), Test Loss: 0.022582296476364137
epoch: 28


100%|██████████| 1016/1016 [00:47<00:00, 21.45it/s]


Train Accuracy: 64710/65000(99.55%),Train Loss: 0.00021694417046920324


  0%|          | 4/1016 [00:00<00:32, 31.10it/s]

Test  Accuracy: 9868/12500(78.94%), Test Loss: 0.022896438319683073
epoch: 29


100%|██████████| 1016/1016 [00:36<00:00, 27.69it/s]


Train Accuracy: 64489/65000(99.21%),Train Loss: 0.00037025984233030333
Test  Accuracy: 9698/12500(77.58%), Test Loss: 0.0249389559173584
