In [25]:
from AlexNet import AlexNet
from PlywoodDataset import PlywoodDataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import warnings

warnings.filterwarnings("ignore")

In [26]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((227, 227)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
train_dataset = PlywoodDataset("labels.csv", "data", transform=transform)
test_dataset = PlywoodDataset("test_labels.csv", "test_data", transform=transform)

In [27]:
trainloader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    shuffle=True,
)
testloader = DataLoader(
    dataset=test_dataset,
    batch_size=2,
    shuffle=True,
)
loss_function = torch.nn.SmoothL1Loss()

In [28]:
NUM_EPOCH = 50

In [29]:
def calculate_error(return_values):
    true_number = []
    predicted_number = []
    for data in return_values:
        true_number += data[0]
        predicted_number += data[1]

    error = []
    for i in range(len(true_number)):
        error.append(abs(true_number[i] - predicted_number[i]) / true_number[i])
    error = sum(error) / len(error)
    
    return error

In [30]:
train_errors = []
for epoch in range(NUM_EPOCH):
    net = AlexNet()
    net.load_state_dict(torch.load(f"alexnet/alexnet_epoch{epoch}.pt"))
    with torch.no_grad():
        return_values = []
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            return_values.append(
                (labels.tolist(), outputs.flatten().tolist(), loss.tolist())
            )
    train_error = calculate_error(return_values)
    train_errors.append(train_error)
    print(f'Epoch: {epoch} Train error: {train_error}')


Epoch: 0 Train error: 0.19652609784974093
Epoch: 1 Train error: 0.2957806386749943
Epoch: 2 Train error: 0.2014191806055172
Epoch: 3 Train error: 0.21249268774853486
Epoch: 4 Train error: 0.20221360564496552
Epoch: 5 Train error: 0.18698187479876105
Epoch: 6 Train error: 0.17426490876812697
Epoch: 7 Train error: 0.1928611934950114
Epoch: 8 Train error: 0.2892568469863693
Epoch: 9 Train error: 0.21032948912378657
Epoch: 10 Train error: 0.2122377341209306
Epoch: 11 Train error: 0.17818280490324281
Epoch: 12 Train error: 0.184187020623126
Epoch: 13 Train error: 0.1872959408874722
Epoch: 14 Train error: 0.18756691697456693
Epoch: 15 Train error: 0.20542068267660502
Epoch: 16 Train error: 0.19365697639821944
Epoch: 17 Train error: 0.19833131403681348
Epoch: 18 Train error: 0.18312221117389668
Epoch: 19 Train error: 0.19618337020041532
Epoch: 20 Train error: 0.15718419645046422
Epoch: 21 Train error: 0.18333470267247895
Epoch: 22 Train error: 0.1894881126102098
Epoch: 23 Train error: 0.17078

In [31]:
test_errors = []
for epoch in range(NUM_EPOCH):
    net = AlexNet()
    net.load_state_dict(torch.load(f"alexnet/alexnet_epoch{epoch}.pt"))
    with torch.no_grad():
        return_values = []
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            return_values.append(
                (labels.tolist(), outputs.flatten().tolist(), loss.tolist())
            )
    test_error = calculate_error(return_values)
    test_errors.append(test_error)
    print(f'Epoch: {epoch} Test error: {test_error}')


Epoch: 0 Test error: 0.2615332141641031
Epoch: 1 Test error: 0.18539689655591496
Epoch: 2 Test error: 0.15923618611945659
Epoch: 3 Test error: 0.27090975274891904
Epoch: 4 Test error: 0.19988974962706704
Epoch: 5 Test error: 0.10145564548686792
Epoch: 6 Test error: 0.16501692560387168
Epoch: 7 Test error: 0.0821945475709825
Epoch: 8 Test error: 0.2157630797346154
Epoch: 9 Test error: 0.20020395158100324
Epoch: 10 Test error: 0.21445789940336565
Epoch: 11 Test error: 0.1906525527801329
Epoch: 12 Test error: 0.11523401378002042
Epoch: 13 Test error: 0.16857276098855586
Epoch: 14 Test error: 0.13699647796824652
Epoch: 15 Test error: 0.17141156361614765
Epoch: 16 Test error: 0.19923589233456152
Epoch: 17 Test error: 0.14920893409467575
Epoch: 18 Test error: 0.15473980459290448
Epoch: 19 Test error: 0.09416947998946919
Epoch: 20 Test error: 0.11058086339116091
Epoch: 21 Test error: 0.193285566438801
Epoch: 22 Test error: 0.12912237564124449
Epoch: 23 Test error: 0.15860205947682926
Epoch: 2