In [1]:
import os
import pickle
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
from google.colab import drive
import torch.optim as opt
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
from tqdm import tqdm
from torchvision import datasets
import torch.nn.functional as F
drive.mount('/content/drive')

Mounted at /content/drive


In [38]:
class BLOCK(nn.Module) :
  def __init__(self,inplanes:int, planes:int, stride:int, downsample=None) :
    super(BLOCK,self).__init__() #繼承父類
    self.conv1 = nn.Conv2d(inplanes,planes,3,stride,padding=1,bias=False) #子模塊第一層
    self.bn1 = nn.BatchNorm2d(planes) #子模塊第一層標準化
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(planes,planes,3,stride=1,padding=1,bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.stride = stride
    self.downsample = downsample

  def forward(self,x) :
    identity = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    if self.downsample is not None :
      identity = self.downsample(x)
    out = out + identity
    out = self.relu(out)
    return out

In [39]:
class ResNet(nn.Module) :
  def __init__(self,lays:list,num_classes:int) :
    super(ResNet,self).__init__()
    self.inplanes = 64
    # self.conv1 = nn.Conv2d(3,self.inplanes,7,stride=2,padding=3)
    self.conv1 = nn.Conv2d(3,self.inplanes,3,stride=1,padding=1)
    #稍微修改一下第一次的捲積操作，便於保留特徵
    self.bn1 = nn.BatchNorm2d(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3,stride=1,padding=1) #後面池化操作進行刪除
    self.lay1 = self._make_lay(64,lays[0],stride=1)
    self.lay2 = self._make_lay(128,lays[1],stride=2)
    self.lay3 = self._make_lay(256,lays[2],stride=2)
    self.lay4 = self._make_lay(512,lays[3],stride=2)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512, num_classes)

  def _make_lay(self,planes:int,num_blocks:int,stride:int) -> nn.Sequential :
    downsample = None
    if stride != 1 or self.inplanes != planes:
      downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False), nn.BatchNorm2d(planes))
    lays = []
    lays.append(BLOCK(self.inplanes, planes, stride, downsample))
    self.inplanes = planes
    for i in range(num_blocks) :
      lays.append(BLOCK(self.inplanes,planes,stride=1))
    return nn.Sequential(*lays)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    # x = self.maxpool(x)

    x = self.lay1(x)
    x = self.lay2(x)
    x = self.lay3(x)
    x = self.lay4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x


In [40]:
def resnet18(num_classes)->ResNet :
  model=ResNet([2, 2, 2, 2], num_classes)
  return model

In [14]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_CIFAR10_data_batch(batch_id = 1) :
    file_dir = '/content/drive/MyDrive/cifar-10-batches-py/data_batch_' + str(batch_id)
    dict_ = unpickle(file_dir)
    img = dict_[b'data']
    labels = dict_[b'labels']
    return np.array(img),np.array(labels)

class CIFAR10_Dataset(Dataset) :
    def __init__(self,mode = "train"):
        if mode == "test" :
            data ,target = load_CIFAR10_data_batch(6)
            self.data = data
            self.target = target
            self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.49139968 ,0.48215841, 0.44653091), std=(0.20220212, 0.19931542, 0.20086346))])
        elif mode == "train" :
            self.data ,self.target = load_CIFAR10_data_batch(1)
            for i in range(2,6) :
                temp_data ,temp_target = load_CIFAR10_data_batch(i)
                self.data = np.concatenate([self.data,temp_data])
                self.target = np.concatenate([self.target,temp_target])
            self.transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize(mean=(0.49139968 ,0.48215841, 0.44653091), std=(0.20220212, 0.19931542, 0.20086346))])
        self.data = self.data.reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))

    def __len__(self):
      return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.target[idx]
        img = Image.fromarray(np.uint8(img))
        img = self.transform(img)
        return img,label


In [15]:
train_data = CIFAR10_Dataset("train")
test_data = CIFAR10_Dataset("test")

In [41]:
resnet18_model = resnet18(num_classes=10) #返回一個類實例，然而類的實例會接受num_classes=10的參數

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
lr = 0.001
batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size)
model = resnet18_model
model.to(device)
optimizer = opt.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=0.8)  # 每1个epoch学习率乘以0.1

cuda


In [43]:
criterion = nn.CrossEntropyLoss()  # 定义损失函数
criterion = criterion.to(device)
num_epochs = 25  # 设置训练轮数
for epoch in range(num_epochs):
  model.train()  # 设置模型为训练模式
  running_loss = 0.0
  for i,(inputs,labels) in enumerate(train_loader):
      inputs, labels = inputs.to(device), labels.to(device)  # 将输入和标签移至设备
      optimizer.zero_grad()  # 清零梯度
      outputs = model(inputs)  # 前向传播
      loss = criterion(outputs, labels)  # 计算损失
      loss.backward()  # 反向传播:
      optimizer.step()  # 更新模型参数
      running_loss += loss.item()  # 累加损失
  scheduler.step()
  model.eval()
  acc = 0
  for j,(inputs,labels) in enumerate(test_loader):
    inputs, labels = inputs.to(device), labels.to(device)  # 将输入和标签移至设备
    outputs = model(inputs)
    for k in range(len(outputs)) :
      s = torch.argmax(outputs[k])
      if s == labels[k] :
        acc +=1
  print(acc/len(test_data.data),running_loss)
  print(epoch)

0.5441 613.5042065382004
0
0.6543 421.63701379299164
1
0.6806 332.44990944862366
2
0.7378 273.3976976275444
3
0.7727 224.02515038847923
4
0.8292 194.26150715351105
5
0.8382 167.99492514133453
6
0.8483 147.98041181266308
7
0.8563 132.5680220425129
8
0.8719 117.62333180010319
9
0.8735 106.45362375676632
10
0.8788 95.961235396564
11
0.8866 87.66079300642014
12
0.8848 79.27086404711008
13
0.8896 74.10963126271963
14
0.889 68.21156156808138
15
0.8922 64.41388633474708
16
0.892 61.40816478058696
17
0.8942 57.92977983132005
18
0.8946 55.49798947945237
19
0.8961 54.162969287484884
20
0.8951 53.476618794724345
21
0.8972 52.32555706053972
22
0.8956 51.64264262840152
23
0.8949 50.43897262215614
24
