In [1]:
from IPython import get_ipython

from PIL import Image
import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

import time
import random
from math import ceil
from tqdm import tqdm
import os

import torch.utils.data as data
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import torch.optim as optim
import torch.nn.functional as F


# Set vars used after
class VarsConfig(object):
    train_data = "train_csv_G.csv"
    test_data = "val_csv_G.csv"
    pre_data = "test_csv_G.csv"

    epoch = 30
    batch_size = 64
    img_height = 64
    img_weight = 64
    seed = 666


config = VarsConfig()


# Set methods of transform
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机旋转变换
    transforms.RandomCrop(64, padding=4),  # 随机裁剪
    transforms.ToTensor(),  # 转换为tensor格式
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

test_transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为tensor格式
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])


# 读取训练集和测试集
train_csv = pd.read_csv(config.train_data, header=None).values
test_csv = pd.read_csv(config.test_data, header=None).values
val_csv = pd.read_csv(config.pre_data, header=None).values

train_label = train_csv[0:65000, 0]
train_image = train_csv[0:65000, 1:]
test_label = test_csv[0:12500, 0]
test_image= test_csv[0:12500, 1:]


train_image = torch.from_numpy(train_image).type(torch.LongTensor)
test_image = torch.from_numpy(test_image).type(torch.LongTensor)
train_label = torch.from_numpy(train_label).type(torch.LongTensor)
test_label = torch.from_numpy(test_label).type(torch.LongTensor)

# 读取训练集和测试集
train_image = train_image.view(-1, 1, 64, 64).float()
test_image = test_image.view(-1, 1, 64, 64).float()
train_data = torch.utils.data.TensorDataset(train_image, train_label)
test_data = torch.utils.data.TensorDataset(test_image, test_label)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=config.batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=config.batch_size,
                                          shuffle=False)


class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3,
                      stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=12):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv0 = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)  # strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv0(x)
        out = self.conv1(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet18():
    return ResNet(ResidualBlock)

# 定义模型保存函数
def save_models(epoch):
    torch.save(model, "Resnet_{}_G.mdl".format(epoch + 1))
    print("Chekcpoint saved")


model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model)


# train the model
best_acc = 0.0
results = pd.DataFrame(columns=['epoch', 'train_acc', 'test_acc',
                                'train_loss', 'test_loss'], index=range(0, config.epoch))
for epoch in range(config.epoch):
    print("epoch:", +epoch)
    model.train()
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)

        lossvalue = criterion(output, labels)
        lossvalue.backward()
        optimizer.step()

    # evaluate the model
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    correct_num_test = 0
    train_loss = 0.0
    train_acc = 0.0
    correct_num_train = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            train_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_train += pred.eq(labels.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    correct_rate = 100. * correct_num_train / len(train_loader.dataset)
    train_acc = correct_num_train / len(train_loader.dataset)
    print('Train Accuracy: {0}/{1}({2:.4}%),Train Loss: {3}'.format(
        correct_num_train, len(train_loader.dataset), correct_rate, train_loss))

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            test_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_test += pred.eq(labels.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    correct_rate = 100. * correct_num_test / len(test_loader.dataset)
    test_acc = correct_num_test / len(test_loader.dataset)

    # 将更优的模型保存下来
    if test_acc > best_acc:
        save_models(epoch)
        best_acc = test_acc
        best_epoch = epoch + 1
    results.iloc[epoch] = [epoch + 1, train_acc,
                           test_acc, train_loss, test_loss]
    print('Test  Accuracy: {0}/{1}({2:.4}%), Test Loss: {3}'.format(
        correct_num_test, len(test_loader.dataset), correct_rate, test_loss))


# %%
results.to_csv('resnet_G.csv', index=False)

  0%|          | 2/1016 [00:00<00:51, 19.69it/s]

ResNet(
  (conv0): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (left): Sequen

100%|██████████| 1016/1016 [00:36<00:00, 27.91it/s]


Train Accuracy: 53909/65000(82.94%),Train Loss: 0.008249969831567543


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  0%|          | 4/1016 [00:00<00:31, 31.64it/s]

Chekcpoint saved
Test  Accuracy: 9746/12500(77.97%), Test Loss: 0.010939275116920472
epoch: 1


100%|██████████| 1016/1016 [00:37<00:00, 27.35it/s]


Train Accuracy: 53832/65000(82.82%),Train Loss: 0.008049675388290331


  0%|          | 2/1016 [00:00<00:50, 19.91it/s]

Test  Accuracy: 9388/12500(75.1%), Test Loss: 0.012401625385284424
epoch: 2


100%|██████████| 1016/1016 [01:17<00:00, 13.14it/s]


Train Accuracy: 54931/65000(84.51%),Train Loss: 0.007408914883778646


  0%|          | 3/1016 [00:00<01:00, 16.62it/s]

Chekcpoint saved
Test  Accuracy: 9777/12500(78.22%), Test Loss: 0.010915641431808471
epoch: 3


100%|██████████| 1016/1016 [01:17<00:00, 13.14it/s]


Train Accuracy: 58276/65000(89.66%),Train Loss: 0.0048714744079571505


  0%|          | 2/1016 [00:00<00:52, 19.20it/s]

Chekcpoint saved
Test  Accuracy: 10396/12500(83.17%), Test Loss: 0.008250920157432557
epoch: 4


100%|██████████| 1016/1016 [01:17<00:00, 13.11it/s]


Train Accuracy: 59357/65000(91.32%),Train Loss: 0.004092858835653617


  0%|          | 2/1016 [00:00<00:51, 19.80it/s]

Chekcpoint saved
Test  Accuracy: 10479/12500(83.83%), Test Loss: 0.008210834829807282
epoch: 5


100%|██████████| 1016/1016 [01:17<00:00, 13.10it/s]


Train Accuracy: 60383/65000(92.9%),Train Loss: 0.003322521037321824


  0%|          | 2/1016 [00:00<00:52, 19.40it/s]

Chekcpoint saved
Test  Accuracy: 10582/12500(84.66%), Test Loss: 0.0077498864185810085
epoch: 6


100%|██████████| 1016/1016 [01:17<00:00, 13.08it/s]


Train Accuracy: 59915/65000(92.18%),Train Loss: 0.003576256943666018


  0%|          | 3/1016 [00:00<01:01, 16.53it/s]

Test  Accuracy: 10385/12500(83.08%), Test Loss: 0.008611539257764816
epoch: 7


100%|██████████| 1016/1016 [01:17<00:00, 13.07it/s]


Train Accuracy: 61947/65000(95.3%),Train Loss: 0.002266145281235759


  0%|          | 2/1016 [00:00<00:51, 19.54it/s]

Test  Accuracy: 10537/12500(84.3%), Test Loss: 0.008568594843149185
epoch: 8


100%|██████████| 1016/1016 [01:17<00:00, 13.07it/s]


Train Accuracy: 61526/65000(94.66%),Train Loss: 0.002436714506951662


  0%|          | 3/1016 [00:00<01:00, 16.61it/s]

Test  Accuracy: 10490/12500(83.92%), Test Loss: 0.008920655435323716
epoch: 9


100%|██████████| 1016/1016 [01:18<00:00, 13.02it/s]


Train Accuracy: 62987/65000(96.9%),Train Loss: 0.0014747425009568151


  0%|          | 2/1016 [00:00<00:51, 19.68it/s]

Chekcpoint saved
Test  Accuracy: 10642/12500(85.14%), Test Loss: 0.00954374898672104
epoch: 10


100%|██████████| 1016/1016 [01:18<00:00, 13.02it/s]


Train Accuracy: 62172/65000(95.65%),Train Loss: 0.001954107905551791


  0%|          | 2/1016 [00:00<00:51, 19.77it/s]

Test  Accuracy: 10481/12500(83.85%), Test Loss: 0.01000856784105301
epoch: 11


100%|██████████| 1016/1016 [01:18<00:00, 13.01it/s]


Train Accuracy: 63678/65000(97.97%),Train Loss: 0.0010158599260191505


  0%|          | 3/1016 [00:00<01:01, 16.57it/s]

Test  Accuracy: 10620/12500(84.96%), Test Loss: 0.01015289942264557
epoch: 12


100%|██████████| 1016/1016 [01:18<00:00, 13.00it/s]


Train Accuracy: 63730/65000(98.05%),Train Loss: 0.000922601485811174


  0%|          | 3/1016 [00:00<01:01, 16.52it/s]

Test  Accuracy: 10599/12500(84.79%), Test Loss: 0.011341804695129394
epoch: 13


100%|██████████| 1016/1016 [01:18<00:00, 12.99it/s]


Train Accuracy: 63379/65000(97.51%),Train Loss: 0.001169409836549312


  0%|          | 3/1016 [00:00<01:00, 16.66it/s]

Test  Accuracy: 10531/12500(84.25%), Test Loss: 0.011563194721937179
epoch: 14


100%|██████████| 1016/1016 [01:18<00:00, 12.98it/s]


Train Accuracy: 64200/65000(98.77%),Train Loss: 0.0006039384989282833


  0%|          | 2/1016 [00:00<00:50, 19.98it/s]

Test  Accuracy: 10597/12500(84.78%), Test Loss: 0.011921052943468093
epoch: 15


100%|██████████| 1016/1016 [01:18<00:00, 12.97it/s]


Train Accuracy: 63727/65000(98.04%),Train Loss: 0.0009171650946785051


  0%|          | 2/1016 [00:00<00:50, 19.98it/s]

Test  Accuracy: 10496/12500(83.97%), Test Loss: 0.012468632054328918
epoch: 16


100%|██████████| 1016/1016 [01:18<00:00, 12.94it/s]


Train Accuracy: 64341/65000(98.99%),Train Loss: 0.0005075884439839193


  0%|          | 2/1016 [00:00<00:50, 19.99it/s]

Test  Accuracy: 10582/12500(84.66%), Test Loss: 0.012611805115938186
epoch: 17


100%|██████████| 1016/1016 [01:18<00:00, 12.94it/s]


Train Accuracy: 64278/65000(98.89%),Train Loss: 0.0005169899573549628


  0%|          | 2/1016 [00:00<00:51, 19.73it/s]

Test  Accuracy: 10567/12500(84.54%), Test Loss: 0.013281761249303818
epoch: 18


100%|██████████| 1016/1016 [01:18<00:00, 12.93it/s]


Train Accuracy: 64300/65000(98.92%),Train Loss: 0.0005214721951812793


  0%|          | 2/1016 [00:00<00:51, 19.76it/s]

Test  Accuracy: 10531/12500(84.25%), Test Loss: 0.013715630067586899
epoch: 19


100%|██████████| 1016/1016 [01:18<00:00, 12.90it/s]


Train Accuracy: 64543/65000(99.3%),Train Loss: 0.0003482364684522438


  0%|          | 2/1016 [00:00<00:51, 19.87it/s]

Test  Accuracy: 10539/12500(84.31%), Test Loss: 0.014015999935865402
epoch: 20


100%|██████████| 1016/1016 [01:18<00:00, 12.94it/s]


Train Accuracy: 64194/65000(98.76%),Train Loss: 0.0005517809982948865


  0%|          | 2/1016 [00:00<00:51, 19.84it/s]

Test  Accuracy: 10554/12500(84.43%), Test Loss: 0.014718520925045013
epoch: 21


100%|██████████| 1016/1016 [01:18<00:00, 12.90it/s]


Train Accuracy: 64418/65000(99.1%),Train Loss: 0.0004577214078547863


  0%|          | 2/1016 [00:00<00:51, 19.58it/s]

Test  Accuracy: 10554/12500(84.43%), Test Loss: 0.014587218261957169
epoch: 22


100%|██████████| 1016/1016 [01:18<00:00, 12.90it/s]


Train Accuracy: 64355/65000(99.01%),Train Loss: 0.0004805536158669453


  0%|          | 2/1016 [00:00<00:51, 19.80it/s]

Test  Accuracy: 10380/12500(83.04%), Test Loss: 0.016643224401473998
epoch: 23


100%|██████████| 1016/1016 [01:18<00:00, 12.87it/s]


Train Accuracy: 64505/65000(99.24%),Train Loss: 0.00035618715797676345


  0%|          | 2/1016 [00:00<00:50, 19.88it/s]

Test  Accuracy: 10404/12500(83.23%), Test Loss: 0.014779987487792969
epoch: 24


100%|██████████| 1016/1016 [01:18<00:00, 12.88it/s]


Train Accuracy: 64314/65000(98.94%),Train Loss: 0.0005027807484219711


  0%|          | 2/1016 [00:00<00:51, 19.87it/s]

Test  Accuracy: 10599/12500(84.79%), Test Loss: 0.014111048413515091
epoch: 25


100%|██████████| 1016/1016 [01:18<00:00, 12.87it/s]


Train Accuracy: 64724/65000(99.58%),Train Loss: 0.00020967409076043763


  0%|          | 2/1016 [00:00<00:51, 19.76it/s]

Chekcpoint saved
Test  Accuracy: 10676/12500(85.41%), Test Loss: 0.014807172012329102
epoch: 26


100%|██████████| 1016/1016 [01:19<00:00, 12.84it/s]


Train Accuracy: 64609/65000(99.4%),Train Loss: 0.00027425749655812977


  0%|          | 3/1016 [00:00<01:00, 16.65it/s]

Test  Accuracy: 10509/12500(84.07%), Test Loss: 0.015390890893936157
epoch: 27


100%|██████████| 1016/1016 [01:19<00:00, 12.83it/s]


Train Accuracy: 64605/65000(99.39%),Train Loss: 0.00027116971207209506


  0%|          | 2/1016 [00:00<00:50, 19.89it/s]

Test  Accuracy: 10425/12500(83.4%), Test Loss: 0.017525251581668854
epoch: 28


100%|██████████| 1016/1016 [01:19<00:00, 12.83it/s]


Train Accuracy: 64252/65000(98.85%),Train Loss: 0.0005018652763844539


  0%|          | 2/1016 [00:00<00:52, 19.33it/s]

Test  Accuracy: 10473/12500(83.78%), Test Loss: 0.01642757047176361
epoch: 29


100%|██████████| 1016/1016 [01:19<00:00, 12.82it/s]


Train Accuracy: 64772/65000(99.65%),Train Loss: 0.00017783190208272293
Test  Accuracy: 10610/12500(84.88%), Test Loss: 0.014789369572997093
