In [None]:
import os
import copy
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets, models
from torch import optim
import pickle as pkl
from torchsummary import summary
from torchvision.models import vgg19_bn, VGG19_BN_Weights
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime

In [None]:
time = datetime.now().strftime("%Y%m%d_%H%M%S")

### Load the CIFAR10 training and test datasets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
])

batch_size = 16

train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
)

classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

### Defind model

In [None]:
class MyVGG19_BN(nn.Module):
    def __init__(self, num_classes=10, dropout: float = 0.5):
        super(MyVGG19_BN, self).__init__()
        pretrain_model = models.vgg19_bn(weights=VGG19_BN_Weights.DEFAULT)
        pretrain_model.classifier = nn.Sequential()  # remove last layer
        self.features = pretrain_model.features
        self.avgpool = pretrain_model.avgpool
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
model = MyVGG19_BN(num_classes=10)
summary(model, (3, 32, 32))

In [None]:
print(model)

### Define a Loss function and optimizer

In [None]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

###  Train the network

In [None]:
num_epoches = 50

In [None]:
def harmonic(train_acc, test_acc):
    harmonic = 2 * train_acc * test_acc / (train_acc + test_acc)
    return harmonic

In [None]:
training_acc = []
training_loss = []
testing_acc = []
testing_loss = []
# H = 0.0  # harmonic
best_testing_acc = 0.0
best_model = copy.deepcopy(model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


# train model
for epoch in range(num_epoches):
    model.train()
    print("\n", "*" * 25, "epoch {}".format(epoch + 1), "*" * 25)
    running_loss = 0.0
    num_correct = 0.0

    train_loop = tqdm(train_loader)
    for data in train_loop:
        img, label = data
        img, label = img.to(device), label.to(device)

        out = model(img)

        loss = loss_func(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # calculate acc & loss
        running_loss += loss.item() * label.size(0)
        probs = torch.softmax(out, dim=1)
        _, pred = torch.max(probs, dim=1)
        num_correct += (pred == label).sum().item()

        train_loop.set_description(f"Train Epoch [{epoch + 1}/{num_epoches}]")

    train_acc = num_correct / len(train_dataset)
    train_loss = running_loss / len(train_dataset)
    training_acc.append(train_acc)
    training_loss.append(train_loss)
    print("Train --> Loss: {:.6f}, Acc: {:.6f}".format(train_loss, train_acc))

    # 用 testing dataset 來評估 model
    model.eval()
    eval_loss = 0
    num_correct = 0

    test_loop = tqdm(test_loader)
    for data in test_loop:
        img, label = data
        img, label = img.to(device).detach(), label.to(device).detach()

        out = model(img)
        loss = loss_func(out, label)
        eval_loss += loss.item() * label.size(0)
        probs = torch.softmax(out, dim=1)
        _, pred = torch.max(probs, dim=1)
        num_correct += (pred == label).sum().item()

        test_loop.set_description(f"Test Epoch [{epoch + 1}/{num_epoches}]")

    test_acc = num_correct / len(test_dataset)
    test_loss = eval_loss / len(test_dataset)
    testing_acc.append(test_acc)
    testing_loss.append(test_loss)
    print("Test -->  Loss: {:.6f}, Acc: {:.6f}".format(test_loss, test_acc))

    # 紀錄 Harmonic 最高的 model 為 best model

    # current_H = harmonic(train_acc, test_acc)
    # if current_H > H:
    #     best_model = copy.deepcopy(model)
    #     H = current_H
    # print("Current Harmonic : {:.6f}, Best Harmonic : {:.6f}".format(current_H, H))

    if test_acc > best_testing_acc :
        best_model = copy.deepcopy(model)
        best_testing_acc  = test_acc
    print("Current testing acc : {:.6f}, Best testing acc : {:.6f}".format(test_acc, best_testing_acc ))



# save best model
if os.path.exists("./models") == False:
    os.mkdir("./models")
torch.save(best_model.state_dict(), f"./models/VGG19_bn_cifar10_state_dict_{time}_{test_acc}.pth")
# torch.save(best_model, "./models/VGG19_bn_cifar10.pth")

In [None]:
if os.path.exists("./log") == False:
    os.mkdir("./log")
# save training & testing loss
title = "Loss"
x = [i for i in range(1, num_epoches + 1)]
plt.figure()
plt.plot(x, training_loss)
plt.plot(x, testing_loss)
plt.title(title)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(["training", "testing"], loc="upper right")
plt.savefig(f"./log/vgg19_bn_{title}_{time}.jpg")
plt.show()

In [None]:
# save training & testing acc
title = "Accuracy"
x = [i for i in range(1, num_epoches + 1)]
plt.plot(x, training_acc)
plt.plot(x, testing_acc)
plt.title(title)
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend(["training", "testing"], loc="lower right")
plt.savefig(f"./log/vgg19_bn_{title}_{time}.jpg")
plt.show()

## Inference

In [None]:
classes = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

In [None]:
# from PIL import Image

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_path = "./models/VGG19_bn_cifar10_state_dict.pth"
# test_model = MyVGG19_BN(num_classes=10)
# test_model.load_state_dict(torch.load(model_path))


# inference_img = Image.open("./Dataset_CvDl_Hw1/Q5_image/Q5_4/airplane.png")
# transform = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
# )

# inference_img = transform(inference_img)

# img_normalized = inference_img.unsqueeze_(0)
# img_normalized = img_normalized.to(device)


# with torch.no_grad():
#     test_model = model.to(device)
#     test_model.eval()
#     output = test_model(img_normalized)
#     probs = torch.softmax(output, dim=1)
#     # _, pred = torch.max(probs, dim=1)
#     index = probs.data.cpu().numpy().argmax()
#     pred_class_name = classes[index]
#     print(f"Predicted Class: {pred_class_name}")
#     x = [i for i in range(len(classes))]
#     fig = plt.figure(figsize=(5, 5))
#     plt.bar(x, probs.data.cpu().numpy()[0], tick_label=classes)
#     plt.title(f"Probability of each class")
#     plt.xticks(rotation=45)
#     plt.xlabel("Class")
#     plt.ylabel("Probability")
#     plt.show()