In [1]:
from AlexNet import AlexNet
from PlywoodDataset import PlywoodDataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import warnings

warnings.filterwarnings("ignore")

In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((227, 227)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
train_dataset = PlywoodDataset("labels.csv", "data", transform=transform)
test_dataset = PlywoodDataset("test_labels.csv", "test_data", transform=transform)

In [3]:
trainloader = DataLoader(
    dataset=train_dataset,
    batch_size=5,
    shuffle=True,
)
testloader = DataLoader(
    dataset=test_dataset,
    batch_size=5,
    shuffle=True,
)
loss_function = torch.nn.SmoothL1Loss()

In [4]:
NUM_EPOCH = 50

In [5]:
def calculate_error(return_values):
    true_number = []
    predicted_number = []
    for data in return_values:
        true_number += data[0]
        predicted_number += data[1]

    error = []
    for i in range(len(true_number)):
        error.append(abs(true_number[i] - predicted_number[i]) / true_number[i])
    error = sum(error) / len(error)
    
    return error

In [6]:
train_errors = []
for epoch in range(NUM_EPOCH):
    net = AlexNet()
    net.load_state_dict(torch.load(f"alexnet/alexnet_epoch{epoch}.pt"))
    with torch.no_grad():
        return_values = []
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            return_values.append(
                (labels.tolist(), outputs.flatten().tolist(), loss.tolist())
            )
    train_error = calculate_error(return_values)
    train_errors.append(train_error)
    print(f'Epoch: {epoch} Train error: {train_error}')


Epoch: 0 Train error: 0.20165307797879492
Epoch: 1 Train error: 0.31081946136771016
Epoch: 2 Train error: 0.19743020612602197
Epoch: 3 Train error: 0.21930332662539825
Epoch: 4 Train error: 0.2112212318690381
Epoch: 5 Train error: 0.18747081375452554
Epoch: 6 Train error: 0.17131587775457843
Epoch: 7 Train error: 0.21716223501521978
Epoch: 8 Train error: 0.27883037390304877
Epoch: 9 Train error: 0.20417580178538275
Epoch: 10 Train error: 0.21434700681682925
Epoch: 11 Train error: 0.1779741024234779
Epoch: 12 Train error: 0.19635881590152998
Epoch: 13 Train error: 0.17696802037623807
Epoch: 14 Train error: 0.17168301964664348
Epoch: 15 Train error: 0.2088797131654661
Epoch: 16 Train error: 0.16585996936084446
Epoch: 17 Train error: 0.17192165306694562
Epoch: 18 Train error: 0.19577050887191724
Epoch: 19 Train error: 0.19754063259942145
Epoch: 20 Train error: 0.1759273573652541
Epoch: 21 Train error: 0.1787403828647965
Epoch: 22 Train error: 0.18734659494906875
Epoch: 23 Train error: 0.1

In [7]:
test_errors = []
for epoch in range(NUM_EPOCH):
    net = AlexNet()
    net.load_state_dict(torch.load(f"alexnet/alexnet_epoch{epoch}.pt"))
    with torch.no_grad():
        return_values = []
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            return_values.append(
                (labels.tolist(), outputs.flatten().tolist(), loss.tolist())
            )
    test_error = calculate_error(return_values)
    test_errors.append(test_error)
    print(f'Epoch: {epoch} Test error: {test_error}')


Epoch: 0 Test error: 0.2299447360893875
Epoch: 1 Test error: 0.23426022598531296
Epoch: 2 Test error: 0.20182869246272953
Epoch: 3 Test error: 0.17557484567581336
Epoch: 4 Test error: 0.1227016040187721
Epoch: 5 Test error: 0.23369479123895517
Epoch: 6 Test error: 0.14228294977100367
Epoch: 7 Test error: 0.09055146135896523
Epoch: 8 Test error: 0.19561244452966964
Epoch: 9 Test error: 0.09891747683812942
Epoch: 10 Test error: 0.18398321664317893
Epoch: 11 Test error: 0.11575932216472916
Epoch: 12 Test error: 0.16013309348449858
Epoch: 13 Test error: 0.1937097729167273
Epoch: 14 Test error: 0.08536263223991221
Epoch: 15 Test error: 0.19190589353479326
Epoch: 16 Test error: 0.14387277991583486
Epoch: 17 Test error: 0.16398622492572548
Epoch: 18 Test error: 0.14045006098766522
Epoch: 19 Test error: 0.1727741910672571
Epoch: 20 Test error: 0.14401043210653355
Epoch: 21 Test error: 0.15564984766474396
Epoch: 22 Test error: 0.16380416583441285
Epoch: 23 Test error: 0.12455026906760502
Epoch: