# <div align='center'> Test Mnist </div>

In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms, models
from torchvision import datasets
from torch import optim
from torch.utils.data import (Dataset, DataLoader)
from k12libs.utils.nb_easy import K12AI_PRETRAINED_ROOT, K12AI_DATASETS_ROOT

In [None]:
def k12ai_init_project(task, dataset, use_cuda=True):
    data_root = os.path.join(K12AI_DATASETS_ROOT, task, dataset) # 数据集的根目录
    device = torch.device("cuda" if use_cuda else "cpu") # 使用cpu还是gpu训练
    return data_root, device

def k12ai_load_dataset(data_root, json_file, resize=None, transform=None):
    class JsonfileDataset(Dataset):
        def __init__(self, data_root, json_file, resize=None, transform=None):
            self.data_root = data_root
            self.json_file = json_file
            self.resize = resize
            self.image_list, self.label_list = self.__read_jsonfile(json_file)
            if transform:
                self.transform = transform
            else:
                self.transform = transforms.Compose([transforms.ToTensor()])

        def __getitem__(self, index):
            img = Image.open(self.image_list[index]).convert('RGB')
            if self.resize:
                img = img.resize(self.resize)
            if self.transform is not None:
                img = self.transform(img)
            return img, self.label_list[index]

        def __len__(self):
            return len(self.image_list)

        def __read_jsonfile(self, jsonfile):
            image_list = []
            label_list = []
            with open(os.path.join(self.data_root, self.json_file)) as f:
                items = json.load(f)
                for item in items:
                    image_list.append(os.path.join(self.data_root, item['image_path']))
                    label_list.append(item['label'])
            return image_list, label_list
    return JsonfileDataset(data_root, json_file, resize, transform)
    
def k12ai_load_model(name, pretrained=True, device='cuda'):
    if name == 'vgg16':
        model = models.vgg16(pretrained=False)
        pretrained_file = 'vgg16-397923af.pth'
    elif name == 'resnet50':
        model = models.resnet50(pretrained=False)
        pretrained_file = 'resnet50-19c8e357.pth'
    elif name == 'resnet152':
        model = models.resnet152(pretrained=False)
        pretrained_file = 'resnet152-b121ed2d.pth'
    elif name == 'alexnet':
        model = models.alexnet(pretrained=False)
        pretrained_file = 'alexnet-owt-4df8aa71.pth'
    else:
        return None

    if pretrained:
        state = torch.load(os.path.join(K12AI_PRETRAINED_ROOT, 'cv', pretrained_file))
        model.load_state_dict(state)
    model = model.to(device)
    return model

## 任务设定

In [None]:
data_root, device = k12ai_init_project(
    task='cv',       # 任务类型
    dataset='mnist', # 数据集
    use_cuda=True,   # GPU训练
)

## 数据准备

In [None]:
### 加载数据集
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.6),    # 数据增强: 对PIL Image数据做随机水平翻转
    transforms.ToTensor(),                     # PIL Image格式转换为Tensor张量格式               
    transforms.Normalize((0.1307,), (0.3081,)) # 对数据归一化处理
])
train_dataset = k12ai_load_dataset(data_root, 'train.json', transform=transform)

### 将数据集分割为80%用作训练(train), 20%用作校验(val)
split_8_2 = int(0.8 * len(train_dataset))
data_list = list(range(len(train_dataset)))
train_idx, valid_idx = data_list[:split_8_2], data_list[split_8_2:]

### 随机图片样本
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

### 采用批量(batch_size)随机的方式加载图片样本
train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler) 
valid_loader = DataLoader(train_dataset, batch_size=64, sampler=valid_sampler)

## 算法模型的选择与设计

In [None]:
### 方式 1: 使用预置的模型resnet50
pretrained = True # 使用预置权重训练
resnet_model = k12ai_load_model('resnet50', pretrained, device=device)

### 方式 2: 用户自定义模型
class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)  # 卷积层, 图片特征提取
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)   # Dropout正则化, 减少模型过拟合
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)        # 全连接层, 图片线性变换

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1) # 每个分类的概率分布
    
custom_model = CustomNet().to(device)

## 超参数调整

In [None]:
### 设置训练轮回(max_epoch)
max_epoch = 10

### 设置损失函数(交叉熵CE)
reduction = 'mean' # 约简方式为mean(张量各个维度上的元素的平均值)
loss_func = nn.CrossEntropyLoss(reduction=reduction)

### 设置优化器(随机梯度下降SGD)
optimizer = SGD(custom_model.parameters(),
                lr=0.01,           # 基础学习率
                weight_decay=1e-6, # 权重衰减, 使得模型参数值更小, 有效防止过拟合
                momentum=0.9,      # 动量因子, 更快局部收敛
                nesterov=True      # 使用Nesterov动量, 加快收敛速度
               )

### 设置学习率衰减策略(可选, 固定步长衰减StepLR)
scheduler = StepLR(optimizer,
                   step_size=2, # 每间隔2次epoch进行一次LR调整
                   gamma=0.6    # LR调整为原来0.6倍
                  )

## 模型训练及反馈

In [None]:
def train_epoch(model, device, data_loader, loss_func, optimizer, epoch):
    ### 模型进入训练状态(启用 BN 和 Dropout)
    model.train()
    for data, target in data_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        print(output.shape, target.shape)
        loss = loss_func(output, target)
        loss.backward()
        optimizer.step()
        # print('Epoch:', epoch, ', Training Loss:', loss.item())
        
def valid_epoch(model, device, data_loader, epoch):
    ### 模型进入评估模式(禁用 BN 和 Dropou)
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_func(output, target)
            pred = torch.max(output, 1)[1]
            correct += (pred == target).sum().item()
            # print('Epoch:', epoch, ', Validing Loss:', loss.item())
    ### 计算正确率
    acc = 100.0 * correct / len(data_loader.dataset)
    if epoch % 5:
        print("ACC:", acc)
    return acc
    
def train(epoch_num, model, train_loader, valid_loader, loss_func, optimizer, scheduler):
    ### 获取模型训练所用设备(cpu或者gpu)
    device = next(model.parameters()).device
    for epoch in range(0, epoch_num): 
        ### 训练模型
        train_epoch(model, device, train_loader, loss_func, optimizer, epoch)
        ### 校验模型
        valid_epoch(model, device, valid_loader, epoch)
        ### 调整学习率
        scheduler.step()

    ### 保存模型
    torch.save(model.state_dict(), "last.pt")
        
### 启动训练
train(max_epoch, custom_model, train_loader, valid_loader, loss_func, optimizer, scheduler)

## 模型评估及测试

In [None]:
### 加载测试数据集
test_dataset = k12ai_load_dataset(data_root, 'test.json')
test_loader  = DataLoader(test_dataset, batch_size=64, num_workers=4)

### 加载训练完成的模型
last_model = CustomNet()
last_model.load_state_dict(torch.load('last.pt'))

def evaluate(model, data_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            output = model(data)
            pred = torch.max(output, 1)[1]
            correct += (pred == target).sum().item()
    ### 计算正确率
    acc = 100.0 * correct / len(data_loader.dataset)
    return acc

### 启动评估
acc = evaluate(last_model, test_loader)
print("Acc:", acc)