# This is for training with Resnet on RGB.

In [2]:
from IPython import get_ipython

from PIL import Image
import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

import time
import random
from math import ceil
from tqdm import tqdm
import os

import torch.utils.data as data
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
import torch.optim as optim
import torch.nn.functional as F


# Set vars used after
class VarsConfig(object):
    train_data = "E3_data/train_img"
    test_data = "E3_data/val_img"
    pre_data = "E3_data/test_img"

    epoch = 30
    batch_size = 64
    img_height = 64
    img_weight = 64
    seed = 666


config = VarsConfig()


# Set methods of transform
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  
    transforms.RandomCrop(64, padding=4),  
    transforms.ToTensor(),  
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])

test_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# Load datasets
train_data = datasets.ImageFolder(root=config.train_data,
                                  transform=train_transform)
test_data = datasets.ImageFolder(root=config.test_data,
                                 transform=test_transform)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=config.batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=config.batch_size,
                                          shuffle=False)



class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=12):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv0 = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(3),
            nn.ReLU(),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)   #strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out=self.conv0(x)
        out = self.conv1(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet18():
    return ResNet(ResidualBlock)

# Function for saving model
def save_models(epoch):
    torch.save(model, "Resnet_{}.mdl".format(epoch + 1))
    print("Chekcpoint saved")


model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Get the existence of GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model)


# train the model
best_acc = 0.0
results = pd.DataFrame(columns=['epoch','train_acc', 'test_acc', 'train_loss', 'test_loss'], index=range(0,config.epoch))
for epoch in range(config.epoch):
    print("epoch:", +epoch)
    model.train()
    for images, labels in tqdm(train_loader):
        images,labels = images.to(device),labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        
        lossvalue = criterion(output,labels)
        lossvalue.backward()
        optimizer.step()

    # evaluate the model
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    correct_num_test = 0
    train_loss = 0.0
    train_acc = 0.0
    correct_num_train = 0
    with torch.no_grad():
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            train_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_train += pred.eq(labels.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    correct_rate = 100. * correct_num_train / len(train_loader.dataset)
    train_acc = correct_num_train / len(train_loader.dataset)
    print('Train Accuracy: {0}/{1}({2:.4}%),Train Loss: {3}'.format(correct_num_train,len(train_loader.dataset),correct_rate, train_loss))

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images)
            test_loss += criterion(output, labels).item()
            pred = output.max(1, keepdim=True)[1]
            correct_num_test += pred.eq(labels.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    correct_rate = 100. * correct_num_test / len(test_loader.dataset)
    test_acc = correct_num_test / len(test_loader.dataset)

    # 将更优的模型保存下来
    if test_acc > best_acc:
        save_models(epoch)
        best_acc = test_acc
        best_epoch = epoch + 1
    results.iloc[epoch] = [epoch + 1, train_acc, test_acc, train_loss, test_loss]
    print('Test  Accuracy: {0}/{1}({2:.4}%), Test Loss: {3}'.format(correct_num_test,len(test_loader.dataset),correct_rate, test_loss)) 


# %%
results.to_csv('resnet.csv', index=False)



  0%|          | 2/1016 [00:00<01:28, 11.49it/s]

ResNet(
  (conv0): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (left): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (left): Sequen

100%|██████████| 1016/1016 [01:25<00:00, 11.93it/s]


Train Accuracy: 23584/65000(36.28%),Train Loss: 0.08164842954782339


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  0%|          | 2/1016 [00:00<01:29, 11.36it/s]

Chekcpoint saved
Test  Accuracy: 4209/12500(33.67%), Test Loss: 0.09897198125898839
epoch: 1


100%|██████████| 1016/1016 [01:23<00:00, 12.10it/s]


Train Accuracy: 49379/65000(75.97%),Train Loss: 0.012875611788951434


  0%|          | 2/1016 [00:00<01:24, 11.93it/s]

Chekcpoint saved
Test  Accuracy: 9046/12500(72.37%), Test Loss: 0.015716091667413713
epoch: 2


100%|██████████| 1016/1016 [01:23<00:00, 12.10it/s]


Train Accuracy: 55437/65000(85.29%),Train Loss: 0.006906778425436753


  0%|          | 2/1016 [00:00<01:25, 11.79it/s]

Chekcpoint saved
Test  Accuracy: 9977/12500(79.82%), Test Loss: 0.00999248888462782
epoch: 3


100%|██████████| 1016/1016 [01:22<00:00, 12.31it/s]


Train Accuracy: 53360/65000(82.09%),Train Loss: 0.008595899703869453


  0%|          | 2/1016 [00:00<01:23, 12.13it/s]

Test  Accuracy: 9637/12500(77.1%), Test Loss: 0.010734991001486779
epoch: 4


100%|██████████| 1016/1016 [01:22<00:00, 12.25it/s]


Train Accuracy: 46253/65000(71.16%),Train Loss: 0.013997065472602844


  0%|          | 2/1016 [00:00<01:25, 11.92it/s]

Test  Accuracy: 8375/12500(67.0%), Test Loss: 0.01679132677078247
epoch: 5


100%|██████████| 1016/1016 [01:23<00:00, 12.15it/s]


Train Accuracy: 50114/65000(77.1%),Train Loss: 0.011586163822962688


  0%|          | 2/1016 [00:00<01:24, 11.99it/s]

Test  Accuracy: 9268/12500(74.14%), Test Loss: 0.014142616802155972
epoch: 6


100%|██████████| 1016/1016 [01:24<00:00, 12.06it/s]


Train Accuracy: 55571/65000(85.49%),Train Loss: 0.008126349413394927


  0%|          | 2/1016 [00:00<01:32, 11.02it/s]

Test  Accuracy: 9723/12500(77.78%), Test Loss: 0.0158391494089365
epoch: 7


100%|██████████| 1016/1016 [01:24<00:00, 12.03it/s]


Train Accuracy: 48841/65000(75.14%),Train Loss: 0.01947493549172695


  0%|          | 2/1016 [00:00<01:21, 12.46it/s]

Test  Accuracy: 8191/12500(65.53%), Test Loss: 0.03244691676080227
epoch: 8


100%|██████████| 1016/1016 [01:23<00:00, 12.15it/s]


Train Accuracy: 58638/65000(90.21%),Train Loss: 0.004358884704800753


  0%|          | 2/1016 [00:00<01:30, 11.22it/s]

Chekcpoint saved
Test  Accuracy: 10520/12500(84.16%), Test Loss: 0.00835614380031824
epoch: 9


100%|██████████| 1016/1016 [01:24<00:00, 11.99it/s]


Train Accuracy: 58847/65000(90.53%),Train Loss: 0.0043220050384218875


  0%|          | 2/1016 [00:00<01:28, 11.46it/s]

Test  Accuracy: 10436/12500(83.49%), Test Loss: 0.008461672978699207
epoch: 10


100%|██████████| 1016/1016 [01:23<00:00, 12.17it/s]


Train Accuracy: 57670/65000(88.72%),Train Loss: 0.005179042702397475


  0%|          | 2/1016 [00:00<01:24, 11.94it/s]

Test  Accuracy: 10118/12500(80.94%), Test Loss: 0.01046081656217575
epoch: 11


100%|██████████| 1016/1016 [01:25<00:00, 11.94it/s]


Train Accuracy: 56866/65000(87.49%),Train Loss: 0.006100033858533089


  0%|          | 2/1016 [00:00<01:23, 12.07it/s]

Test  Accuracy: 10159/12500(81.27%), Test Loss: 0.010890891262441874
epoch: 12


100%|██████████| 1016/1016 [01:23<00:00, 12.11it/s]


Train Accuracy: 57454/65000(88.39%),Train Loss: 0.005149705993097562


  0%|          | 2/1016 [00:00<01:21, 12.47it/s]

Test  Accuracy: 10299/12500(82.39%), Test Loss: 0.009944645783603191
epoch: 13


100%|██████████| 1016/1016 [01:22<00:00, 12.38it/s]


Train Accuracy: 59637/65000(91.75%),Train Loss: 0.0036730270149616094


  0%|          | 2/1016 [00:00<01:25, 11.85it/s]

Chekcpoint saved
Test  Accuracy: 10615/12500(84.92%), Test Loss: 0.008809197162985801
epoch: 14


100%|██████████| 1016/1016 [01:22<00:00, 12.25it/s]


Train Accuracy: 58706/65000(90.32%),Train Loss: 0.004555164368737203


  0%|          | 2/1016 [00:00<01:24, 12.02it/s]

Test  Accuracy: 10445/12500(83.56%), Test Loss: 0.010040301506221294
epoch: 15


100%|██████████| 1016/1016 [01:23<00:00, 12.19it/s]


Train Accuracy: 60338/65000(92.83%),Train Loss: 0.0032158356494055343


  0%|          | 2/1016 [00:00<01:29, 11.31it/s]

Chekcpoint saved
Test  Accuracy: 10777/12500(86.22%), Test Loss: 0.007041596397906542
epoch: 16


100%|██████████| 1016/1016 [01:22<00:00, 12.27it/s]


Train Accuracy: 59083/65000(90.9%),Train Loss: 0.004107772154246386


  0%|          | 2/1016 [00:00<01:24, 12.02it/s]

Test  Accuracy: 10412/12500(83.3%), Test Loss: 0.009232090216577053
epoch: 17


100%|██████████| 1016/1016 [01:23<00:00, 12.21it/s]


Train Accuracy: 42210/65000(64.94%),Train Loss: 0.0376920392036438


  0%|          | 2/1016 [00:00<01:27, 11.55it/s]

Test  Accuracy: 7519/12500(60.15%), Test Loss: 0.04841946382641792
epoch: 18


100%|██████████| 1016/1016 [01:23<00:00, 12.18it/s]


Train Accuracy: 61482/65000(94.59%),Train Loss: 0.002341805311808219


  0%|          | 2/1016 [00:00<01:23, 12.17it/s]

Chekcpoint saved
Test  Accuracy: 10921/12500(87.37%), Test Loss: 0.007123818454146385
epoch: 19


100%|██████████| 1016/1016 [01:23<00:00, 12.22it/s]


Train Accuracy: 58073/65000(89.34%),Train Loss: 0.005497305249680694


  0%|          | 2/1016 [00:00<01:28, 11.41it/s]

Test  Accuracy: 9864/12500(78.91%), Test Loss: 0.013565058484375477
epoch: 20


100%|██████████| 1016/1016 [01:22<00:00, 12.30it/s]


Train Accuracy: 59894/65000(92.14%),Train Loss: 0.003663161896197842


  0%|          | 2/1016 [00:00<01:24, 12.07it/s]

Test  Accuracy: 10607/12500(84.86%), Test Loss: 0.008575274046212434
epoch: 21


100%|██████████| 1016/1016 [01:23<00:00, 12.18it/s]


Train Accuracy: 56936/65000(87.59%),Train Loss: 0.00929446441072684


  0%|          | 2/1016 [00:00<01:21, 12.37it/s]

Test  Accuracy: 9950/12500(79.6%), Test Loss: 0.01798140227779746
epoch: 22


100%|██████████| 1016/1016 [01:23<00:00, 12.22it/s]


Train Accuracy: 60541/65000(93.14%),Train Loss: 0.0031343333741793264


  0%|          | 2/1016 [00:00<01:21, 12.51it/s]

Test  Accuracy: 10793/12500(86.34%), Test Loss: 0.007282103643119335
epoch: 23


100%|██████████| 1016/1016 [01:23<00:00, 12.10it/s]


Train Accuracy: 61969/65000(95.34%),Train Loss: 0.0020422816147884497


  0%|          | 2/1016 [00:00<01:25, 11.83it/s]

Chekcpoint saved
Test  Accuracy: 10979/12500(87.83%), Test Loss: 0.0064951540316641335
epoch: 24


100%|██████████| 1016/1016 [01:23<00:00, 12.20it/s]


Train Accuracy: 62064/65000(95.48%),Train Loss: 0.002032482512982992


  0%|          | 2/1016 [00:00<01:30, 11.24it/s]

Test  Accuracy: 10974/12500(87.79%), Test Loss: 0.006662458070069551
epoch: 25


100%|██████████| 1016/1016 [01:23<00:00, 12.23it/s]


Train Accuracy: 59274/65000(91.19%),Train Loss: 0.004529877617267462


  0%|          | 2/1016 [00:00<01:25, 11.88it/s]

Test  Accuracy: 10388/12500(83.1%), Test Loss: 0.011886123383939266
epoch: 26


100%|██████████| 1016/1016 [01:23<00:00, 12.19it/s]


Train Accuracy: 61749/65000(95.0%),Train Loss: 0.002163677581055806


  0%|          | 2/1016 [00:00<01:25, 11.85it/s]

Test  Accuracy: 10617/12500(84.94%), Test Loss: 0.009359612171649934
epoch: 27


100%|██████████| 1016/1016 [01:22<00:00, 12.26it/s]


Train Accuracy: 59315/65000(91.25%),Train Loss: 0.004897075955340496


  0%|          | 2/1016 [00:00<01:24, 12.00it/s]

Test  Accuracy: 10286/12500(82.29%), Test Loss: 0.013570847352445126
epoch: 28


100%|██████████| 1016/1016 [01:24<00:00, 12.05it/s]


Train Accuracy: 60021/65000(92.34%),Train Loss: 0.0035514374650441683


  0%|          | 2/1016 [00:00<01:24, 12.07it/s]

Test  Accuracy: 10396/12500(83.17%), Test Loss: 0.011557167078256606
epoch: 29


100%|██████████| 1016/1016 [01:23<00:00, 12.13it/s]


Train Accuracy: 63117/65000(97.1%),Train Loss: 0.001315866860976586
Test  Accuracy: 10930/12500(87.44%), Test Loss: 0.00777453663289547
