In [2]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import os
import argparse
import numpy as np
import time
import zipfile
from PIL import Image
from resnet import*

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [11]:
net = ResNet50()

In [12]:
net = net.to(device)

In [5]:
trainsform_train = transforms.Compose([
    transforms.RandomCrop(32, padding = 4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainsform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_set = torchvision.datasets.CIFAR10(root = './data', train=True, download=True, transform=trainsform_train)
test_set = torchvision.datasets.CIFAR10(root = './data', train=False, download=True, transform=trainsform_test)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
batch_size = 128
trainloader = torch.utils.data.DataLoader(train_set, batch_size = batch_size,shuffle = True, num_workers = 2)
testloader = torch.utils.data.DataLoader(test_set, batch_size = batch_size, shuffle = True, num_workers = 2)



In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [16]:
def train():
    acc = 0
    testacc = 0
    totaltime = 0
    EPOCH = 100
    print("Start Training, Resnet-50!")  # 定义遍历数据集的次数
    
    ts = time.time()
    for epoch in range(EPOCH):
        
        print('\nEpoch: %d' % (epoch + 1))
        epochtime = 0
        epochstart=time.perf_counter()
        
        net.train()
        sum_loss = 0.0
        correct = 0.0
        total = 0.0
        traintime = 0.0
        loadtime = 0.0
        
        
        loadstart=time.perf_counter()
        for i, data in enumerate(trainloader, 0):
            
            length = len(trainloader)
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            loadend=time.perf_counter()
            
            
            optimizer.zero_grad()
            trainstart=time.perf_counter()


            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            trainend=time.perf_counter()
            
            sum_loss += loss.item()
            traintime += trainend-trainstart
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += predicted.eq(labels.data).cpu().sum()
            loadtime += loadend - loadstart
            #print(trainend - trainstart)
            loadstart=time.perf_counter()
        

        epochend=time.perf_counter()
        epochtime+=epochend-epochstart
        acc = max(acc,correct/total)
        totaltime+=epochtime
        print("loss: ",sum_loss/total)
        print("accuracy: ",acc)
        #print("epochtime", epochtime)
        #print("traintime:", traintime)
        #print("dataloadtime:", loadtime)

        



        with torch.no_grad():
            correct = 0.0
            total = 0.0
            for data in testloader:
                net.eval()
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += predicted.eq(labels.data).cpu().sum()
            val_acc = correct/total
            print('val_acc',val_acc)
            t1 = time.time()
            print('t1 time: ',t1-ts)
            if val_acc>=0.91:
                te = time.time()
                print('total time: ',te-ts)
                PATH = str(4)+".pth"
            if val_acc>=0.92:
                te = time.time()
                print('total time: ',te-ts)
                break
    PATH = str(4)+".pth"
    torch.save(net, PATH) 
#    print("average running time:",totaltime/EPOCH)
    print("best training accuracy:",acc)
    print("Training Finished, TotalEPOCH=%d" % EPOCH)

In [None]:
train()

Start Training, Resnet-50!

Epoch: 1
loss:  0.014289197936058044
accuracy:  tensor(0.3467)
val_acc tensor(0.4137)
t1 time:  83.02027583122253

Epoch: 2
loss:  0.010045299323797226
accuracy:  tensor(0.5356)
val_acc tensor(0.5899)
t1 time:  166.2228651046753

Epoch: 3
loss:  0.007782453709840775
accuracy:  tensor(0.6481)
val_acc tensor(0.6821)
t1 time:  249.0328266620636

Epoch: 4
loss:  0.006214273574352265
accuracy:  tensor(0.7222)
val_acc tensor(0.7261)
t1 time:  332.1075208187103

Epoch: 5
loss:  0.004989673699140548
accuracy:  tensor(0.7790)
val_acc tensor(0.7287)
t1 time:  415.17976355552673

Epoch: 6
loss:  0.004223481510281563
accuracy:  tensor(0.8142)
val_acc tensor(0.7850)
t1 time:  498.30732321739197

Epoch: 7
loss:  0.0037611090064048765
accuracy:  tensor(0.8329)
val_acc tensor(0.8343)
t1 time:  581.421736240387

Epoch: 8
loss:  0.003397668759226799
accuracy:  tensor(0.8495)
val_acc tensor(0.8408)
t1 time:  664.4242115020752

Epoch: 9
loss:  0.0030579580506682395
accuracy:  t

loss:  0.0005641663511656225
accuracy:  tensor(0.9764)
val_acc tensor(0.9160)
t1 time:  5650.280206441879
total time:  5650.280246734619

Epoch: 69
loss:  0.0005473402734845876
accuracy:  tensor(0.9764)
val_acc tensor(0.9155)
t1 time:  5733.3425879478455
total time:  5733.342626571655

Epoch: 70
loss:  0.0005042093744315207
accuracy:  tensor(0.9781)
val_acc tensor(0.9159)
t1 time:  5816.393316745758
total time:  5816.393356323242

Epoch: 71


In [6]:
!mkdir "./data_test"

In [7]:
!unzip "test_imgs.zip" -d "./data_test"

Archive:  test_imgs.zip
  inflating: ./data_test/Test_img/8/04.jpg  
  inflating: ./data_test/Test_img/3/01.jpg  
  inflating: ./data_test/Test_img/2/04.jpg  
  inflating: ./data_test/Test_img/5/00.jpg  
  inflating: ./data_test/Test_img/4/02.jpg  
  inflating: ./data_test/Test_img/7/04.jpg  
  inflating: ./data_test/Test_img/3/00.jpg  
  inflating: ./data_test/Test_img/5/04.jpg  
  inflating: ./data_test/Test_img/5/03.jpg  
  inflating: ./data_test/Test_img/2/00.jpg  
  inflating: ./data_test/Test_img/5/02.jpg  
  inflating: ./data_test/Test_img/3/02.jpg  
  inflating: ./data_test/Test_img/8/00.jpg  
  inflating: ./data_test/Test_img/4/04.jpg  
  inflating: ./data_test/Test_img/7/03.jpg  
  inflating: ./data_test/Test_img/1/04.jpg  
  inflating: ./data_test/Test_img/4/03.jpg  
  inflating: ./data_test/Test_img/8/03.jpg  
  inflating: ./data_test/Test_img/8/02.jpg  
  inflating: ./data_test/Test_img/1/01.jpg  
  inflating: ./data_test/Test_img/2/02.jpg  
  inflating: ./data_test/Test_i

In [31]:
net_path = "3.pth"
net = torch.load(net_path)
net = net.to(device)

In [9]:
trainsform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [32]:
def evaluation():
    img_index = ["00.jpg", "01.jpg", "02.jpg", "03.jpg", "04.jpg"]
    net.eval()
    correct = 0
    total = 50
    for i in range(10):
        for iidex in img_index:
            img_path = "./data_test/Test_img" + "/"+str(i)+"/" + iidex
            #img_path += "/"+str(i)+"/" + iidex
            #print(img_path)
            img = Image.open(img_path)
            img = img.resize((32, 32), Image.BILINEAR)
            img = np.array(img)
            img = trainsform_val(img)
            img = torch.unsqueeze(img, 0)
            img = img.to(device)
            label = i
            outputs = net(img)
            outputs = F.softmax(outputs, dim = 1)
            pred = torch.max(outputs, dim = 1)
            if pred[1].item() == label:
                correct += 1
    print("The accuracy is ", correct / total)

In [33]:
evaluation()

The accuracy is  0.94


In [35]:
np.std([1.0,0.96,0.92,0.94,0.94])

0.027129319932501072