## Global Setting


In [1]:
import torch

# True if you want to get dynamic output in notebook block, else False
SHOW_PLOT_IN_BLOCK = True

GLOBAL_SEED = 42
CIFAR_PTAH = '../data/'

LEARNING_RATE = 5e-2
BATCH_SIZE = 32
MAX_EPOCHS = 200
EVAL_FREQ = 1
OPTIMIZER = 'ADAM'
MODEL_PARA_PATH = ''
# MODEL_PARA_PATH = '../model/part2_cifar.pth'

torch.manual_seed(GLOBAL_SEED)
torch.cuda.manual_seed_all(GLOBAL_SEED)

## Task 3

### Define Util Functions

In [None]:
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader


def save_fig_acc(eval_epochs, train_acc, test_acc, train_loss, test_loss, file_name, show=SHOW_PLOT_IN_BLOCK):
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(eval_epochs, train_acc, label='CIFAR10 Train Acc')
    plt.plot(eval_epochs, test_acc, label='CIFAR10 Test Acc')
    plt.title(f'{file_name} Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(eval_epochs, train_loss, label='CIFAR10 Train Loss')
    plt.plot(eval_epochs, test_loss, label='CIFAR10 Test Loss')
    plt.title(f'{file_name} Testing Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    path = '../Report/img/Part2/' + file_name + '_latest.png'
    plt.savefig(path)

    if show:
        plt.show()
    plt.close()


def save_fig_class(train_acc, test_acc, file_name, show=SHOW_PLOT_IN_BLOCK):
    classes_name = ('plane', 'car', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck')

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    bars = plt.bar(classes_name, train_acc, color='skyblue')
    for bar, acc in zip(bars, train_acc):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f'{acc:.2f}', ha='center', va='bottom')
    plt.title('Train Accuracy of 10 Classes')
    plt.xlabel('Classes')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)

    plt.subplot(1, 2, 2)
    bars = plt.bar(classes_name, test_acc, color='skyblue')
    for bar, acc in zip(bars, test_acc):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f'{acc:.2f}', ha='center', va='bottom')
    plt.title('Test Accuracy of 10 Classes')
    plt.xlabel('Classes')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)

    plt.tight_layout()

    path = '../Report/img/Part2/' + file_name + '_latest.png'
    plt.savefig(path)

    if show:
        plt.show()
    plt.close()


def get_data():
    train_set = torchvision.datasets.CIFAR10(root=CIFAR_PTAH, train=True, download=True)
    train_mean = train_set.data.mean(axis=(0, 1, 2)) / 255
    train_std = train_set.data.std(axis=(0, 1, 2)) / 255

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(train_mean, train_std)
    ])

    train_data = torchvision.datasets.CIFAR10(root=CIFAR_PTAH, train=True,
                                              download=True, transform=transform)

    test_data = torchvision.datasets.CIFAR10(root=CIFAR_PTAH, train=False,
                                             download=True, transform=transform)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, pin_memory=False)
    return train_loader, test_loader

### Start Train


In [None]:
import cnn_train

train_loader, test_loader = get_data()

epochs, train_acc, test_acc, train_loss, test_loss, train_class_acc, test_class_acc = (
    cnn_train.train(train_loader=train_loader, test_loader=test_loader,
                    model_para_path=MODEL_PARA_PATH,
                    eval_freq=EVAL_FREQ,
                    learning_rate=LEARNING_RATE,
                    optimizer_type=OPTIMIZER,
                    max_epoch=MAX_EPOCHS
                    )
)

save_fig_acc(epochs, train_acc, test_acc, train_loss, test_loss, file_name=f'curve_cnn')
save_fig_class(train_class_acc, test_class_acc, file_name=f'acc_cnn_class')


## Static Show Images

### Default Parameters

Accuracy and Loss Curve

![curve_cnn_default](../Report/img/Part2/curve_cnn_default.png)

Class Accuracy

![acc_cnn_class_default](../Report/img/Part2/acc_cnn_class_default.png)

### SGD Optimizer

Accuracy and Loss Curve

![curve_cnn_sgd](../Report/img/Part2/optimizer/curve_cnn_sgd.png)

Class Accuracy

![acc_cnn_class_sgd](../Report/img/Part2/optimizer/acc_cnn_class_sgd.png)

### RMSprop Optimizer

Accuracy and Loss Curve

![curve_cnn_RMS](../Report/img/Part2/optimizer/curve_cnn_RMS.png)

Class Accuracy

![acc_cnn_class_RMS](../Report/img/Part2/optimizer/acc_cnn_class_RMS.png)

### Learning Rate 1e-3

Accuracy and Loss Curve

![curve_cnn_1e-3](../Report/img/Part2/lr/curve_cnn_1e-3.png)

Class Accuracy

![acc_cnn_class_1e-3](../Report/img/Part2/lr/acc_cnn_class_1e-3.png)

### Learning Rate 5e-2

Accuracy and Loss Curve

![curve_cnn_5e-2](../Report/img/Part2/lr/curve_cnn_5e-2.png)

Class Accuracy

![acc_cnn_class_5e-2](../Report/img/Part2/lr/acc_cnn_class_5e-2.png)


