# 以下是神经网络方法。

In [None]:
import os
from datetime import datetime

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.optim import SGD, Adagrad, Adam
from torchvision import datasets, models, transforms
from torchvision.models import (ResNet18_Weights, ResNet50_Weights,
                                VGG11_Weights, resnet18, resnet50, vgg11)
from tqdm import tqdm

from dataset import MyDataset
from my_mlp import MyMLP
from utils import label_int2str, label_str2int, save_config


device = "cuda"
assert torch.cuda.is_available(), "你TM连CUDA都没你跑个JB"

## 运行设置：

In [None]:
# Run name
run_name = datetime.now().strftime("%Y-%m-%d %H.%M.%S")

# Network architecture
# mlp, resnet18, resnet50, vgg11
model_name = "mlp"
model_is_pretrained = False
num_classes = 6

# Optimizer
# SGD, Adagrad, Adam
optimizer_name = "Adam"
lr = 1e-3

# Data augmentation
augmentation = "none"

# Regularization
regularization = "none"

## 模型网络架构：

(Reference from pytorch.org)

| Weight |Acc@1 |Acc@5 |Params |GFLOPS|
| ---- | ---- | ---- | ---- | ---- |
| ResNet18_Weights.IMAGENET1K_V1 | 69.758 | 89.078 | 11.7M | 1.81 |
| ResNet50_Weights.IMAGENET1K_V2 | 80.858 | 95.434 | 25.6M | 4.09 |
| VGG11_Weights.IMAGENET1K_V1 | 69.02 | 88.628 | 132.9M | 7.61 |

In [None]:
if model_name=="mlp":
    model = MyMLP(150*150*3, num_classes)
    input_size = (150, 150)
elif model_name=="resnet18":
    if model_is_pretrained:
        model = resnet18(weights=ResNet18_Weights.DEFAULT)
    else:
        model = resnet18(weights=None)
    input_size = (224, 224)
    model.fc = torch.nn.Linear(512, num_classes)
elif model_name=="resnet50":
    if model_is_pretrained:
        model = resnet50(weights=ResNet50_Weights.DEFAULT)
    else:
        model = resnet50(weights=None)
    input_size = (224, 224)
    model.fc = torch.nn.Linear(512, num_classes)
elif model_name=="vgg11":
    if model_is_pretrained:
        model = vgg11(weights=VGG11_Weights.DEFAULT)
    else:
        model = vgg11(weights=None)
    input_size = (224, 224)
    model.classifier[6] = torch.nn.Linear(4096,num_classes)

model = model.to(device)
print(model)

## 优化器：

In [None]:
if optimizer_name=="SGD":
    optimizer = SGD(model.parameters(), lr=lr)
elif optimizer_name=="Adagrad":
    optimizer = Adagrad(model.parameters(), lr=lr)
elif optimizer_name=="Adam":
    optimizer = Adam(model.parameters(), lr=lr)

## 数据增强与预处理：

In [None]:
if augmentation=="none":
    transform = transforms.Compose([
        transforms.Resize(input_size), 
        transforms.ToTensor()
    ])
elif augmentation=="flip":
    transform = transforms.Compose([
        transforms.Resize(input_size), 
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor()
    ])
elif augmentation=="crop":
    transform = transforms.Compose([
        transforms.RandomResizedCrop(input_size), 
        transforms.ToTensor()
    ])
elif augmentation=="norm":    
    transform = transforms.Compose([
        transforms.Resize(input_size), 
        transforms.Normalize(mean=(0, 0, 0), std=(1, 1, 1)), 
        transforms.ToTensor()
    ])
elif augmentation=="all":
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(), 
        transforms.RandomResizedCrop(input_size), 
        transforms.Normalize(mean=(0, 0, 0), std=(1, 1, 1)), 
        transforms.ToTensor()
    ])

dataset = datasets.ImageFolder("./dataset/seg_train/", transform)
testset = datasets.ImageFolder("./dataset/seg_test/") # Do not transform input images in testing. 
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=32, shuffle=True, num_workers=8)

## 训练过程：

In [None]:
def train_one_epoch(model: torch.nn.Module, dataloader: DataLoader, transform: transforms.Compose, 
                    optimizer: torch.optim.Optimizer, current_epoch: int):
    correct_predictions = 0
    epoch_instance_count = 0
    epoch_loss = 0
    for inputs, labels in tqdm(dataloader, desc="Training epoch %d"%current_epoch):
        optimizer.zero_grad()
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()*inputs.size(0)
        _, predictions = torch.max(outputs, 1)
        correct_predictions += torch.sum(predictions==labels.data)
        epoch_instance_count += inputs.size(0)
    epoch_accuracy = correct_predictions/epoch_instance_count
    epoch_loss /= epoch_instance_count
    if current_epoch%1==0:
        print("Epoch %d (%d instances), loss = %f, accuracy = %f"%(current_epoch, epoch_instance_count, epoch_loss, epoch_accuracy))
    if current_epoch%10==0:
        torch.save(model, "./outputs/runs/%s/model_epoch_%d"%(run_name, current_epoch))
    return epoch_accuracy, epoch_loss

def test(model: torch.nn.Module, dataloader: DataLoader):
    correct_predictions = 0
    epoch_instance_count = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Testing"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predictions = torch.max(outputs, 1)
            correct_predictions += torch.sum(predictions==labels.data)
            epoch_instance_count += inputs.size(0)
    epoch_accuracy = correct_predictions/epoch_instance_count
    print("Testing (%d instances), accuracy = %f"%(epoch_instance_count, epoch_accuracy))

epoch_losses = []
epoch_accuracies = []
model.train()
for epoch in range(1, 30):
    epoch_loss, epoch_accuracy = train_one_epoch(model, dataloader, transform, optimizer, epoch)
    epoch_losses.append(epoch_loss)
    epoch_accuracies.append(epoch_accuracies)
test(model, testloader)
save_config(run_name, model_name, model_is_pretrained, num_classes, optimizer_name, lr, augmentation, regularization)

## 画图

In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_losses)
plt.title("Epoch Losses")
plt.show()
plt.plot(epoch_accuracies)
plt.title("Epoch Accuracies")
plt.show()