In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import CIFAR10

In [2]:
class MYCIFAR10(Dataset):
    def __init__(self, root, train=True, transforms=None, target_transforms=None):
        self.cls = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship' ,'truck']
        self.train = train
        self.transforms = transforms
        self.target_transforms = target_transforms
        self.root = root
        
        # imgs:np.uint8, labels:np.int64
        if self.train:
            f1 = open(os.path.join(self.root, "data_batch_1"), "rb")
            f2 = open(os.path.join(self.root, "data_batch_2"), "rb")
            f3 = open(os.path.join(self.root, "data_batch_3"), "rb")
            f4 = open(os.path.join(self.root, "data_batch_4"), "rb")
            f5 = open(os.path.join(self.root, "data_batch_5"), "rb")
            raw = f1.read() + f2.read() + f3.read() + f4.read() + f5.read()
            
            self.labels = []
            self.data = []
            for i in range(50000):
                idx = i * (32 * 32 * 3 + 1)
                labels_np = np.array(list(raw[idx:idx + 1]), dtype='int64')
                data_np = np.array(list(raw[idx + 1:idx + 32 * 32 * 3 + 1]), dtype='uint8')
                self.labels.append(labels_np)
                self.data.append(data_np)
            self.labels = np.concatenate(self.labels)
            self.data = np.concatenate(self.data)
            #print(self.data.shape)
            #print(self.labels.shape)
            
            self.data = self.data.reshape((50000, 3, 32, 32))
            self.data = self.data.transpose((0, 2, 3, 1))
            #print(self.data.shape)
            
            f1.close()
            f2.close()
            f3.close()
            f4.close()
            f5.close()
        else:
            f1 = open(os.path.join(self.root, "test_batch"), "rb")
            raw = f1.read()
            
            self.labels = []
            self.data = []
            for i in range(10000):
                idx = i * (32 * 32 * 3 + 1)
                labels_np = np.array(list(raw[idx:idx + 1]), dtype='int64')
                data_np = np.array(list(raw[idx + 1:idx + 32 * 32 * 3 + 1]), dtype='uint8')
                self.labels.append(labels_np)
                self.data.append(data_np)
            self.labels = np.concatenate(self.labels)
            self.data = np.concatenate(self.data)
            
            self.data = self.data.reshape((10000, 3, 32, 32))
            self.data = self.data.transpose((0, 2, 3, 1))
            
            f1.close()
            
    def __len__(self):
        if self.train:
            return 50000
        else:
            return 10000
        
    def __getitem__(self, idx):
        img, target = self.data[idx], self.labels[idx]
        #print(img.shape)
        
        img = Image.fromarray(img)

        if self.transforms is not None:
            img = self.transforms(img)
        if self.target_transforms is not None:
            target = self.target_transforms(target)

        return img, target
    
    def show(self):
        if self.train:
            for i in range(15):
                plt.subplot(3, 5, i + 1)
                idx = random.randint(0, 50000 - 1)
                img, label = Image.fromarray(self.data[i]), self.labels[i]
                print(label)
                plt.axis('off')
                plt.imshow(img)
                plt.title(self.cls[label])
        else:
            for i in range(15):
                plt.subplot(3, 5, i + 1)
                idx = random.randint(0, 10000 - 1)
                img, label = Image.fromarray(self.data[i]), self.labels[i]
                plt.axis('off')
                plt.imshow(img)
                plt.title(self.cls[label])
        plt.show()

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#dataset = CIFAR10("./data/cifar-10-batches-py/", train=True, transforms=transform)
#dataset.show()
train_dataset = CIFAR10("./data", train=True, download=False, transform=transform_train)
train_dataloader = DataLoader(train_dataset, batch_size=64, num_workers=2, shuffle=True)

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_dataset = CIFAR10("./data", train=False, download=False, transform=transform_test)
test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=2, shuffle=False)

In [3]:
def resnet():
    net = models.resnet34(pretrained=False)
    net.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) #(b,64,32,32)
    net.bn1 = nn.BatchNorm2d(64)
    net.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)
    net.avgpool = nn.AvgPool2d(4, stride=1) #32/8
    expansion = 1 #18,34 : expansion=1 else 4
    net.fc = nn.Linear(512*expansion, 10) #
    return net

# torch.cuda.set_device(2)
# CUDA_VISIBLE_DEVICES = 2
# net = torch.nn.DataParallel(net, device_ids=[2])
# net = Net.cuda(2)
load = False
gpu_en = True
gpus = [2]
net = resnet()
if load:
    net.load_state_dict(torch.load('net.pkl'))
if gpu_en:
    device = torch.device("cuda:{}".format(gpus[0]) if torch.cuda.is_available() else "cpu")
    """
    os.environment["CUDA_VISIBLE_DEVICES"] = ""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net, device_ids=gpus)
    net.to(device=gpus[0])
print(net)

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace)
    (maxpool): MaxPool2d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [4]:
num_epochs = 1
save = True

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=net.parameters(), lr=0.001)
net.train()
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_dataloader):
        if gpu_en:
            inputs = inputs.to(device=gpus[0])## 训练数据放在主设备
            labels = labels.to(device=gpus[0])
            
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        y_pred = torch.max(outputs, 1)[1]
        acc = float(sum(np.array((y_pred == labels).cpu().numpy()))) / len(labels)
        print('epc:%d,step:%d,loss:%f,acc:%f' % (epoch,i,loss,acc))
        
    test_loss = 0.0
    for inputs, labels in test_dataloader:
        if gpu_en:
            inputs = inputs.to(device=gpus[0])## 训练数据放在主设备
            labels = labels.to(device=gpus[0])
            outputs = net(inputs)
            test_loss += criterion(outputs, labels).item()
    test_loss /= len(test_dataloader.dataset)
    print('\nVAL set: Average loss: {:.4f}\n'.format(
        test_loss))

if save:
    torch.save(net.state_dict(), 'net.pkl')

epc:0,step:0,loss:2.529641,acc:0.109375
epc:0,step:1,loss:2.856798,acc:0.187500
epc:0,step:2,loss:2.866618,acc:0.140625
epc:0,step:3,loss:2.871935,acc:0.156250
epc:0,step:4,loss:2.433586,acc:0.109375
epc:0,step:5,loss:3.054872,acc:0.125000
epc:0,step:6,loss:2.285182,acc:0.171875
epc:0,step:7,loss:2.488332,acc:0.203125
epc:0,step:8,loss:2.324755,acc:0.140625
epc:0,step:9,loss:2.410621,acc:0.171875
epc:0,step:10,loss:2.431606,acc:0.109375
epc:0,step:11,loss:2.429797,acc:0.281250
epc:0,step:12,loss:2.318946,acc:0.234375
epc:0,step:13,loss:2.049551,acc:0.265625
epc:0,step:14,loss:2.199800,acc:0.093750
epc:0,step:15,loss:2.163374,acc:0.250000
epc:0,step:16,loss:2.321077,acc:0.171875
epc:0,step:17,loss:2.007106,acc:0.203125
epc:0,step:18,loss:2.076537,acc:0.187500
epc:0,step:19,loss:2.124095,acc:0.218750
epc:0,step:20,loss:2.074638,acc:0.265625
epc:0,step:21,loss:2.268789,acc:0.296875
epc:0,step:22,loss:2.101147,acc:0.218750
epc:0,step:23,loss:2.216924,acc:0.171875
epc:0,step:24,loss:2.11872

epc:0,step:198,loss:1.508167,acc:0.437500
epc:0,step:199,loss:1.587357,acc:0.453125
epc:0,step:200,loss:1.653886,acc:0.343750
epc:0,step:201,loss:1.553892,acc:0.359375
epc:0,step:202,loss:1.680575,acc:0.390625
epc:0,step:203,loss:1.662138,acc:0.343750
epc:0,step:204,loss:1.790108,acc:0.375000
epc:0,step:205,loss:1.631426,acc:0.406250
epc:0,step:206,loss:1.790998,acc:0.375000
epc:0,step:207,loss:1.829654,acc:0.312500
epc:0,step:208,loss:1.611768,acc:0.375000
epc:0,step:209,loss:1.629990,acc:0.437500
epc:0,step:210,loss:1.656806,acc:0.390625
epc:0,step:211,loss:1.779224,acc:0.343750
epc:0,step:212,loss:1.507980,acc:0.421875
epc:0,step:213,loss:1.653992,acc:0.312500
epc:0,step:214,loss:1.728939,acc:0.359375
epc:0,step:215,loss:1.889975,acc:0.296875
epc:0,step:216,loss:1.834121,acc:0.375000
epc:0,step:217,loss:1.604552,acc:0.359375
epc:0,step:218,loss:1.633944,acc:0.437500
epc:0,step:219,loss:1.459094,acc:0.406250
epc:0,step:220,loss:1.648626,acc:0.437500
epc:0,step:221,loss:1.800475,acc:0

epc:0,step:394,loss:1.651608,acc:0.453125
epc:0,step:395,loss:1.624426,acc:0.406250
epc:0,step:396,loss:1.674324,acc:0.406250
epc:0,step:397,loss:1.412560,acc:0.453125
epc:0,step:398,loss:1.430778,acc:0.500000
epc:0,step:399,loss:1.532195,acc:0.437500
epc:0,step:400,loss:1.486508,acc:0.500000
epc:0,step:401,loss:1.381965,acc:0.468750
epc:0,step:402,loss:1.696637,acc:0.437500
epc:0,step:403,loss:1.414034,acc:0.484375
epc:0,step:404,loss:1.451959,acc:0.468750
epc:0,step:405,loss:1.381572,acc:0.515625
epc:0,step:406,loss:1.739638,acc:0.375000
epc:0,step:407,loss:1.303762,acc:0.484375
epc:0,step:408,loss:1.482362,acc:0.468750
epc:0,step:409,loss:1.475715,acc:0.343750
epc:0,step:410,loss:1.466106,acc:0.468750
epc:0,step:411,loss:1.303799,acc:0.546875
epc:0,step:412,loss:1.258653,acc:0.531250
epc:0,step:413,loss:1.398704,acc:0.484375
epc:0,step:414,loss:1.643488,acc:0.484375
epc:0,step:415,loss:1.412202,acc:0.453125
epc:0,step:416,loss:1.447120,acc:0.437500
epc:0,step:417,loss:1.576023,acc:0

epc:0,step:590,loss:1.577927,acc:0.437500
epc:0,step:591,loss:1.106823,acc:0.531250
epc:0,step:592,loss:1.330951,acc:0.468750
epc:0,step:593,loss:1.790042,acc:0.296875
epc:0,step:594,loss:1.723180,acc:0.390625
epc:0,step:595,loss:1.418445,acc:0.453125
epc:0,step:596,loss:1.245846,acc:0.593750
epc:0,step:597,loss:1.162270,acc:0.609375
epc:0,step:598,loss:1.195924,acc:0.625000
epc:0,step:599,loss:1.249360,acc:0.546875
epc:0,step:600,loss:1.343473,acc:0.500000
epc:0,step:601,loss:1.241715,acc:0.562500
epc:0,step:602,loss:1.226144,acc:0.562500
epc:0,step:603,loss:1.206354,acc:0.609375
epc:0,step:604,loss:1.270456,acc:0.515625
epc:0,step:605,loss:1.171633,acc:0.515625
epc:0,step:606,loss:1.401724,acc:0.531250
epc:0,step:607,loss:1.503014,acc:0.453125
epc:0,step:608,loss:1.080057,acc:0.625000
epc:0,step:609,loss:1.252995,acc:0.531250
epc:0,step:610,loss:1.328882,acc:0.531250
epc:0,step:611,loss:1.257773,acc:0.546875
epc:0,step:612,loss:1.050150,acc:0.578125
epc:0,step:613,loss:1.541835,acc:0