In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms, datasets, models
from d2l import torch as d2l
import os
import cv2
import time
import random
from model.residual_attention_network import ResidualAttentionModel_92_32input_update as ResidualAttentionModel
from model.residual_attention_network import ResidualModel_92_32input_update as ResNet

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed_everything(42)

In [2]:
# 数据增强
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop((32, 32), padding=4), 
    transforms.ToTensor()
])
test_transform = transforms.Compose([
    transforms.ToTensor()
])
# 加载CIFAR-10 Dataset
train_dataset = datasets.CIFAR10(root='./data/',
                               train=True,
                               transform=transform,
                               download=False)

test_dataset = datasets.CIFAR10(root='./data/',
                              train=False,
                              transform=test_transform)

# Data Loader (Input Pipeline)
train_iter = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=64, # 64
                                           shuffle=True, num_workers=8)
test_iter = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=20,
                                          shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ResNet() # 初始化模型
model.to(device)
print(model) # 打印模型结构

ResidualModel_92_32input_update(
  (conv1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (residual_block1): ResidualBlock(
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (conv4): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (residual_block2): ResidualBlock(
    (bn1): BatchNorm2d(128, eps=1e-05, 

In [4]:
# 记录后期可视化数据
vis_data = {
    'class_name': ['plane', 'car', 'bird', 'cat',
                   'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
    'eval_acc': [],
    'eval_class_acc': [],
    'eval_loss': [],
    'train_loss': []
}

In [5]:
def test(model, test_loader, criterion, btrain=False, model_file='./best_model_resnet.pth'):
    # Test
    if not btrain:
        model.load_state_dict(torch.load(model_file))
    model.eval()

    correct = 0
    total = 0
    #
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.data).sum()
        #
        c = (predicted == labels.data).squeeze()
        labels = labels.detach().cpu().numpy()
        for i in range(labels.shape[0]):
            label = labels[i]
            class_correct[label] += c[i]
            class_total[label] += 1

    print('Accuracy of the model on the test images:', float(correct)/total)
    tmp_acc = []
    for i in range(10):
        print(
            f'Accuracy of {classes[i] :.5s} : {class_correct[i] / class_total[i] : .4f}')
        tmp_acc.append((class_correct[i] / class_total[i]).detach().cpu().numpy().tolist())

    if btrain:
        vis_data['eval_loss'] += [float(loss.detach().cpu().numpy())]
        vis_data['eval_acc'] += [float(correct)/total]
        vis_data['eval_class_acc'].append(tmp_acc)

    return correct / total


In [6]:
lr = 0.1
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr,
                      momentum=0.9, nesterov=True, weight_decay=0.0001)
is_train = True
acc_best = 0
total_epoch = 100

if is_train:
    # Training
    for epoch in range(total_epoch):
        model.train()
        tims = time.time()
        for i, (images, labels) in enumerate(train_iter):
            images = images.to(device)
            # print(images.data)
            labels = labels.to(device)

            # Forward + Backward + Optimize
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print("Epoch [%d/%d], Loss: %.4f" %
              (epoch+1, total_epoch, loss.item()))
        print('the epoch takes time:', time.time()-tims)
        vis_data['train_loss'] += [loss.item()]

        print('evaluate test set:')
        acc = test(model, test_iter, criterion, btrain=True)
        if acc > acc_best:
            acc_best = acc
            print('current best acc,', acc_best)
            torch.save(model.state_dict(), './best_model_resnet.pth')

        # Decaying Learning Rate
        if (epoch+1) / float(total_epoch) == 0.3 or (epoch+1) / float(total_epoch) == 0.6 or (epoch+1) / float(total_epoch) == 0.9:
            lr /= 10
            print('reset learning rate to:', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
                print(param_group['lr'])
    # Save the Model
    torch.save(model.state_dict(), 'last_model_resnet.pth')

else:
    test(model, test_iter, criterion, btrain=False)


Epoch [1/100], Loss: 1.0811
the epoch takes time: 61.3788366317749
evaluate test set:
Accuracy of the model on the test images: 0.3841
Accuracy of plane :  0.4540
Accuracy of car :  0.5100
Accuracy of bird :  0.4550
Accuracy of cat :  0.0820
Accuracy of deer :  0.1560
Accuracy of dog :  0.1470
Accuracy of frog :  0.8480
Accuracy of horse :  0.2000
Accuracy of ship :  0.8510
Accuracy of truck :  0.1380
current best acc, tensor(0.3841, device='cuda:0')
Epoch [2/100], Loss: 1.0092
the epoch takes time: 55.825268030166626
evaluate test set:
Accuracy of the model on the test images: 0.4381
Accuracy of plane :  0.2320
Accuracy of car :  0.7170
Accuracy of bird :  0.7780
Accuracy of cat :  0.3840
Accuracy of deer :  0.1040
Accuracy of dog :  0.2760
Accuracy of frog :  0.3830
Accuracy of horse :  0.2370
Accuracy of ship :  0.8690
Accuracy of truck :  0.4010
current best acc, tensor(0.4381, device='cuda:0')
Epoch [3/100], Loss: 1.3854
the epoch takes time: 54.93459749221802
evaluate test set:
A

In [7]:
import json
if is_train:
    json_str = json.dumps(vis_data)
    with open('vis_data_resnet.json', 'w', encoding='utf-8') as json_file:
        json_file.write(json_str)