In [None]:
import os
os.chdir("../")

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import Counter
from code_base.functions import train, train_models, select_best_model, evaluate_performance
from code_base.models import TestNet, LeNetVariant
from code_base.DataAnalysis import DataAnalysis

In [None]:
SEED = 265
torch.manual_seed(SEED)

DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Using device: {DEVICE}")

# Object Localization

## Data exploration

In [None]:
# Loading the data as PyTorch Tensor Dataset
data_train = torch.load("data/localization_train.pt")
data_val = torch.load("data/localization_val.pt")
data_test = torch.load("data/localization_test.pt")
print("Type: ", type(data_train))

In [None]:
DataAnalysis.get_summary(data_train)

In [None]:
DataAnalysis.get_summary(data_val)

In [None]:
DataAnalysis.get_summary(data_test)

In [None]:
# Display 4 instances of each class
n_classes = len(Counter([int(label[-1]) for _, label in data_train]))
for i in range(n_classes):
    DataAnalysis.plot_instances_with_bounding_box(data_train, i, n_instances=4, save_to_file=f"imgs/localization/true_instances/class_{i}.png")
# No class
DataAnalysis.plot_instances_with_bounding_box(data_train, None, n_instances=4, save_to_file="imgs/localization/true_instances/class_none.png")

## Training

In [None]:
# Normalize from training data
imgs = torch.stack([img for img, _ in data_train])

# Define normalizer
normalizer = transforms.Normalize(
    imgs.mean(dim=(0, 2, 3)), 
    imgs.std(dim=(0, 2, 3))
    )

data_train_norm = [(normalizer(img), label) for img, label in data_train]
data_val_norm = [(normalizer(img), label) for img, label in data_val]
data_test_norm = [(normalizer(img), label) for img, label in data_test]

In [None]:
batch_size = 32
train_loader = DataLoader(data_train_norm, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(data_val_norm, batch_size=batch_size, shuffle=False)

In [None]:
epochs = 5
networks = [TestNet, LeNetVariant]

hyper_parameters = [
    {"lr": 0.001, "weight_decay": 0.0},#, "momentum": 0.0},
    {"lr": 0.001, "weight_decay": 0.5}
    # {"lr": 0.01, "weight_decay": 0.01, "momentum": 0.0},
    # {"lr": 0.01, "weight_decay": 0.9, "momentum": 0.0},
]

In [None]:
results = train_models(
    "localization",
    networks,
    hyper_parameters,
    batch_size,
    epochs,
    train_loader,
    val_loader,
    DEVICE,
    SEED)

In [None]:
best_model, best_model_idx = select_best_model(results["models"], results["strict_val"])
print(best_model)

In [None]:
bm_hyper_params = results["hyper_params"][best_model_idx]
bm_train_loss = results["loss_train"][best_model_idx]
bm_val_loss = results["loss_val"][best_model_idx]
bm_train_strict_acc = results["strict_train"][best_model_idx]
bm_val_strict_acc = results["strict_val"][best_model_idx]
bm_train_box_acc = results["box_train"][best_model_idx]
bm_val_box_acc = results["box_val"][best_model_idx]
bm_train_detect_acc = results["detection_train"][best_model_idx]
bm_val_detect_acc = results["detection_val"][best_model_idx]
bm_train_mean_acc = results["mean_perf_train"][best_model_idx]
bm_val_mean_acc = results["mean_perf_val"][best_model_idx]

In [None]:
# Hyper parameters
print("Selected hyper parameters")
print(bm_hyper_params)

# Loss
DataAnalysis.plot_performance_over_time(bm_train_loss, bm_val_loss, "Training and val loss", y_label="Loss",
                                        save_to_file="imgs/localization/performance/loss.png")

# Strict accuracy
DataAnalysis.plot_performance_over_time(bm_train_strict_acc, bm_val_strict_acc, "Training and validation strict accuracy over epochs", y_label="Strict Accuracy",
                                        save_to_file="imgs/localization/performance/strict.png") 

# Box accuracy
DataAnalysis.plot_performance_over_time(bm_train_box_acc, bm_val_box_acc, "Training and validation box accuracy over epochs", y_label="Box Accuracy",
                                        save_to_file="imgs/localization/performance/box.png") 
# Detection accuracy
DataAnalysis.plot_performance_over_time(bm_train_detect_acc, bm_val_detect_acc, "Training and validation detection accuracy over epochs", y_label="Detection Accuracy",
                                        save_to_file="imgs/localization/performance/detection.png")
# Mean accuracy
DataAnalysis.plot_performance_over_time(bm_train_mean_acc, bm_val_mean_acc, "Training and validation mean accuracy over epochs", y_label="Mean Accuracy",
                                        save_to_file="imgs/localization/performance/mean.png")

## Evaluation

In [None]:
test_loader = DataLoader(data_test_norm, batch_size=batch_size, shuffle=False)

perf, output = evaluate_performance("localization", best_model, test_loader, device=DEVICE)
print("--- Test performances ---")
print(f"Strict performance: {perf['strict']*100:.2f}%")
print(f"Box performance: {perf['box']*100:.2f}%")
print(f"Detection performance: {perf['detection']*100:.2f}%")
print(f"Mean performance: {perf['mean']*100:.2f}%")

In [None]:
for i in range(10):
    DataAnalysis.plot_instances_with_bounding_box(data_test, i, predictions=output, save_to_file=f"imgs/localization/predictions/class_{i}.png")
DataAnalysis.plot_instances_with_bounding_box(data_test, None, predictions=output, save_to_file="imgs/localization/predictions/class_none.png")