In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from VGG_hc import VGG19
import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
%matplotlib inline

## Processing

In [2]:
# 利用torchvision对图像数据预处理
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomAffine(degrees=15,scale=(0.8,1.5)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.ImageFolder(root='train/', transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=20, shuffle=True, num_workers=4)

valset = torchvision.datasets.ImageFolder(root='valid/', transform=val_transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=20, shuffle=False, num_workers=4)

In [3]:
# 展示训练样本和测试样本数
print(len(trainloader))
print(len(valloader))
# CPU 或者 GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 初始化网络,加载预训练模型
model = VGG19(num_classes=40, init_weights=False)
model_dict = model.state_dict()
state_dict = torch.load('vgg19-dcbb9e9d.pth')
new_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)

859
119


<All keys matched successfully>

## Training

In [4]:
# 查看GPU可用情况
if torch.cuda.device_count()>1:
    print('We are using',torch.cuda.device_count(),'GPUs!')
    model = nn.DataParallel(model)
model.to(device)

# 定义loss function和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# 保存每个epoch后的Accuracy Loss Val_Accuracy
Accuracy = []
Loss = []
Val_Accuracy = []
BEST_VAL_ACC = 0.
# 训练
since = time.time()
for epoch in range(15):
    train_loss = 0.
    train_accuracy = 0.
    run_accuracy = 0.
    run_loss =0.
    total = 0.
    model.train()
    for i,data in enumerate(trainloader,0):
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)  
        # 经典四步
        optimizer.zero_grad()
        outs = model(images)
        loss = criterion(outs, labels)
        loss.backward()
        optimizer.step()
        # 输出状态
        total += labels.size(0)
        run_loss += loss.item()
        _,prediction = torch.max(outs,1)
        run_accuracy += (prediction == labels).sum().item()
        if i % 20 == 19:
            print('epoch {},iter {},train accuracy: {:.4f}%   loss:  {:.4f}'.format(epoch, i+1, 100*run_accuracy/(labels.size(0)*20), run_loss/20))
            train_accuracy += run_accuracy
            train_loss += run_loss
            run_accuracy, run_loss = 0., 0.
    Loss.append(train_loss/total)
    Accuracy.append(100*train_accuracy/total)
    # 可视化训练过程
    fig1, ax1 = plt.subplots(figsize=(11, 8))
    ax1.plot(range(0, epoch+1, 1), Accuracy)
    ax1.set_title("Average trainset accuracy vs epochs")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Avg. train. accuracy")
    plt.savefig('Train_accuracy_vs_epochs.png')
    plt.clf()
    plt.close()
    
    fig2, ax2 = plt.subplots(figsize=(11, 8))
    ax2.plot(range(epoch+1), Loss)
    ax2.set_title("Average trainset loss vs epochs")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Current loss")
    plt.savefig('loss_vs_epochs.png')

    plt.clf()
    plt.close()
    # 验证
    acc = 0.
    model.eval()
    print('waitting for Val...')
    with torch.no_grad():
        accuracy = 0.
        total =0
        for data in valloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            out = model(images)
            _, prediction = torch.max(out, 1)
            total += labels.size(0)
            accuracy += (prediction == labels).sum().item()
            acc = 100.*accuracy/total
    print('epoch {}  The ValSet accuracy is {:.4f}% \n'.format(epoch, acc))
    Val_Accuracy.append(acc)
    if acc > BEST_VAL_ACC:
        print('Find Better Model and Saving it...')
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(model.state_dict(), './checkpoint/VGG19_Cats_Dogs_hc.pth')
        BEST_VAL_ACC = acc
        print('Saved!')
    
    fig3, ax3 = plt.subplots(figsize=(11, 8))

    ax3.plot(range(epoch+1),Val_Accuracy )
    ax3.set_title("Average Val accuracy vs epochs")
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("Current Val accuracy")

    plt.savefig('val_accuracy_vs_epoch.png')
    plt.close()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed%60))
    print('Now the best val Acc is {:.4f}%'.format(BEST_VAL_ACC))

We are using 2 GPUs!
epoch 0,iter 20,train accuracy: 3.0000%   loss:  3.7907
epoch 0,iter 40,train accuracy: 3.2500%   loss:  3.7315
epoch 0,iter 60,train accuracy: 6.0000%   loss:  3.6596
epoch 0,iter 80,train accuracy: 5.2500%   loss:  3.6144
epoch 0,iter 100,train accuracy: 7.2500%   loss:  3.5496
epoch 0,iter 120,train accuracy: 10.0000%   loss:  3.5219
epoch 0,iter 140,train accuracy: 8.7500%   loss:  3.4580
epoch 0,iter 160,train accuracy: 16.2500%   loss:  3.3732
epoch 0,iter 180,train accuracy: 18.2500%   loss:  3.3297
epoch 0,iter 200,train accuracy: 18.0000%   loss:  3.2580
epoch 0,iter 220,train accuracy: 21.0000%   loss:  3.1486
epoch 0,iter 240,train accuracy: 22.0000%   loss:  3.0233
epoch 0,iter 260,train accuracy: 22.0000%   loss:  2.9848
epoch 0,iter 280,train accuracy: 25.5000%   loss:  2.8335
epoch 0,iter 300,train accuracy: 30.7500%   loss:  2.6605
epoch 0,iter 320,train accuracy: 30.7500%   loss:  2.5748
epoch 0,iter 340,train accuracy: 31.0000%   loss:  2.5478
epo

epoch 3,iter 160,train accuracy: 63.5000%   loss:  1.2581
epoch 3,iter 180,train accuracy: 65.7500%   loss:  1.1335
epoch 3,iter 200,train accuracy: 66.7500%   loss:  1.0965
epoch 3,iter 220,train accuracy: 64.2500%   loss:  1.1960
epoch 3,iter 240,train accuracy: 63.5000%   loss:  1.2053
epoch 3,iter 260,train accuracy: 68.0000%   loss:  1.1239
epoch 3,iter 280,train accuracy: 61.5000%   loss:  1.3479
epoch 3,iter 300,train accuracy: 64.5000%   loss:  1.2294
epoch 3,iter 320,train accuracy: 66.2500%   loss:  1.1409
epoch 3,iter 340,train accuracy: 69.7500%   loss:  0.9713
epoch 3,iter 360,train accuracy: 68.0000%   loss:  1.1109
epoch 3,iter 380,train accuracy: 66.7500%   loss:  1.1992
epoch 3,iter 400,train accuracy: 66.0000%   loss:  1.1366
epoch 3,iter 420,train accuracy: 65.0000%   loss:  1.2213
epoch 3,iter 440,train accuracy: 62.7500%   loss:  1.2269
epoch 3,iter 460,train accuracy: 62.7500%   loss:  1.2248
epoch 3,iter 480,train accuracy: 64.7500%   loss:  1.1966
epoch 3,iter 5

epoch 6,iter 340,train accuracy: 73.5000%   loss:  1.0021
epoch 6,iter 360,train accuracy: 67.7500%   loss:  1.0790
epoch 6,iter 380,train accuracy: 73.5000%   loss:  0.9280
epoch 6,iter 400,train accuracy: 66.2500%   loss:  1.0716
epoch 6,iter 420,train accuracy: 68.0000%   loss:  1.0864
epoch 6,iter 440,train accuracy: 66.2500%   loss:  1.0716
epoch 6,iter 460,train accuracy: 67.0000%   loss:  1.1594
epoch 6,iter 480,train accuracy: 65.7500%   loss:  1.1047
epoch 6,iter 500,train accuracy: 70.0000%   loss:  1.0838
epoch 6,iter 520,train accuracy: 69.7500%   loss:  1.0506
epoch 6,iter 540,train accuracy: 73.0000%   loss:  0.9316
epoch 6,iter 560,train accuracy: 68.2500%   loss:  1.0423
epoch 6,iter 580,train accuracy: 68.0000%   loss:  1.0735
epoch 6,iter 600,train accuracy: 70.2500%   loss:  0.9385
epoch 6,iter 620,train accuracy: 69.0000%   loss:  1.1381
epoch 6,iter 640,train accuracy: 72.2500%   loss:  1.0079
epoch 6,iter 660,train accuracy: 66.7500%   loss:  1.1137
epoch 6,iter 6

epoch 9,iter 480,train accuracy: 72.0000%   loss:  0.9768
epoch 9,iter 500,train accuracy: 70.7500%   loss:  1.0123
epoch 9,iter 520,train accuracy: 72.2500%   loss:  0.9796
epoch 9,iter 540,train accuracy: 73.2500%   loss:  0.9878
epoch 9,iter 560,train accuracy: 66.0000%   loss:  1.0862
epoch 9,iter 580,train accuracy: 70.7500%   loss:  0.9633
epoch 9,iter 600,train accuracy: 71.2500%   loss:  0.9782
epoch 9,iter 620,train accuracy: 72.7500%   loss:  0.9036
epoch 9,iter 640,train accuracy: 71.2500%   loss:  0.8959
epoch 9,iter 660,train accuracy: 70.0000%   loss:  0.9515
epoch 9,iter 680,train accuracy: 69.7500%   loss:  1.0184
epoch 9,iter 700,train accuracy: 70.7500%   loss:  1.0734
epoch 9,iter 720,train accuracy: 75.2500%   loss:  0.8940
epoch 9,iter 740,train accuracy: 70.2500%   loss:  1.0060
epoch 9,iter 760,train accuracy: 71.0000%   loss:  0.9287
epoch 9,iter 780,train accuracy: 71.7500%   loss:  1.0171
epoch 9,iter 800,train accuracy: 73.2500%   loss:  0.9106
epoch 9,iter 8

epoch 12,iter 620,train accuracy: 73.7500%   loss:  0.9231
epoch 12,iter 640,train accuracy: 73.5000%   loss:  0.9335
epoch 12,iter 660,train accuracy: 72.5000%   loss:  0.9258
epoch 12,iter 680,train accuracy: 70.2500%   loss:  0.9783
epoch 12,iter 700,train accuracy: 72.5000%   loss:  0.9746
epoch 12,iter 720,train accuracy: 72.7500%   loss:  0.9214
epoch 12,iter 740,train accuracy: 73.0000%   loss:  0.9050
epoch 12,iter 760,train accuracy: 73.2500%   loss:  0.9416
epoch 12,iter 780,train accuracy: 77.5000%   loss:  0.8291
epoch 12,iter 800,train accuracy: 75.5000%   loss:  0.8622
epoch 12,iter 820,train accuracy: 73.2500%   loss:  0.9702
epoch 12,iter 840,train accuracy: 75.5000%   loss:  0.8372
waitting for Val...
epoch 12  The ValSet accuracy is 86.5263% 

Find Better Model and Saving it...
Saved!
Training complete in 105m 15s
Now the best val Acc is 86.5263%
epoch 13,iter 20,train accuracy: 72.2500%   loss:  0.8896
epoch 13,iter 40,train accuracy: 72.2500%   loss:  0.8678
epoch 1