In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchsummary import summary  # Denne er fin for å printe ut informasjon om PyTorch-modeller
import matplotlib.pyplot as plt
from collections import Counter
from code_base.functions import train, custom_loss
from code_base.models import TestNet
from code_base.DataAnalysis import DataAnalysis

In [None]:
# Loading the data as PyTorch Tensor Dataset
data_train = torch.load("data/localization_train.pt")
data_val = torch.load("data/localization_val.pt")
data_test = torch.load("data/localization_test.pt")
print("Type: ", type(data_train))

In [None]:
DataAnalysis.get_summary(data_train)

In [None]:
DataAnalysis.get_summary(data_val)

In [None]:
DataAnalysis.get_summary(data_test)

In [None]:
# Display 4 instances of each class
n_classes = len(Counter([int(label[-1]) for _, label in data_train]))
for i in range(n_classes):
    # DataAnalysis.plot_instances_with_bounding_box(data_train, i, n_instances=4)
    DataAnalysis.plot_instances_with_bounding_box(data_train, i, n_instances=4)
# No class
DataAnalysis.plot_instances_with_bounding_box(data_train, None, n_instances=4)

In [None]:
batch_size = 512
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

In [None]:
# Temporary code for testing train function
model = TestNet()
model.to(device)

learning_rate = 0.01

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
train_loss, val_loss, train_performance, val_performance = train(30, optimizer, model, custom_loss, train_loader, val_loader, device)

In [None]:
DataAnalysis.plot_performance_over_time(train_loss, val_loss, "Training vs val loss")

In [None]:
DataAnalysis.plot_performance_over_time(train_performance, val_performance, "Training accuracy and validation accuracy over epochs", 
                               label1="Training Accuracy", label2="Validation Accuracy")