In [None]:
import os
os.chdir("../")

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from collections import Counter
from code_base.functions import train, train_models, select_best_model, evaluate_performance
from code_base.models import ObjectDetect_2x3
from code_base.DataAnalysis import DataAnalysis
from code_base.object_detection import get_converted_data

In [None]:
SEED = 265
torch.manual_seed(SEED)

DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Using device: {DEVICE}")

## Data exploration

In [None]:
grid_dimensions = (2,3)

In [None]:
object_train, object_val, object_test = get_converted_data(grid_dimensions=grid_dimensions) 

In [None]:
DataAnalysis.plot_detection_instances(object_train, (4,4), grid_dimensions=grid_dimensions, 
                                      title="Ground truth bounding boxes and class", save_to_file="imgs/detection/true.png")

## Training

In [None]:
# Normalize from training data
imgs = torch.stack([img for img, _ in object_train])

# Define normalizer
normalizer = transforms.Normalize(
    imgs.mean(dim=(0, 2, 3)), 
    imgs.std(dim=(0, 2, 3))
    )

object_train_norm = [(normalizer(img), label) for img, label in object_train]
object_val_norm = [(normalizer(img), label) for img, label in object_val]
object_test_norm = [(normalizer(img), label) for img, label in object_test]

In [None]:
batch_size = 32
train_loader = DataLoader(object_train_norm, batch_size=batch_size, shuffle=False, pin_memory=True)
val_loader = DataLoader(object_val_norm, batch_size=batch_size, shuffle=False, pin_memory=True)

In [None]:
epochs = 5
networks = [ObjectDetect_2x3]

hyper_parameters = [
    {"lr": 0.001, "weight_decay": 0.0},#, "momentum": 0.0},
    # {"lr": 0.05, "weight_decay": 0.9},
    # {"lr": 0.05, "weight_decay": 0.1},
    # {"lr": 0.01, "weight_decay": 0.01, "momentum": 0.0},
    # {"lr": 0.01, "weight_decay": 0.9, "momentum": 0.0},
]

In [None]:
results = train_models(
    "detection",
    networks,
    hyper_parameters,
    batch_size,
    epochs,
    train_loader,
    val_loader,
    DEVICE,
    SEED)

In [None]:
best_model, best_model_idx = select_best_model(results["models"], results["strict_val"])
print(best_model)

In [None]:
bm_hyper_params = results["hyper_params"][best_model_idx]
bm_train_loss = results["loss_train"][best_model_idx]
bm_val_loss = results["loss_val"][best_model_idx]
best_model_train_performance = results["strict_train"][best_model_idx]
best_model_val_performance = results["strict_val"][best_model_idx]

In [None]:
# Hyper parameters
print("Selected hyper parameters")
print(bm_hyper_params)

# Loss
DataAnalysis.plot_performance_over_time(bm_train_loss, bm_val_loss, "Training and val loss", y_label="Loss",
                                       save_to_file="imgs/detection/loss.png")

# Strict accuracy
DataAnalysis.plot_performance_over_time(best_model_train_performance, best_model_val_performance, "Training and validation strict accuracy over epochs", 
                                        y_label="Strict Accuracy", save_to_file="imgs/detection/strict.png")

## Evaluation

In [None]:
test_loader = DataLoader(object_test_norm, batch_size=batch_size, shuffle=True)

perf, output = evaluate_performance("detection", best_model, test_loader, device=DEVICE)
print(f"--- Test performance ---")
print(f"Strict accuracy: {perf['strict']*100:.2f}%")

In [None]:
DataAnalysis.plot_detection_instances(object_test, (4,4), predictions=output, grid_dimensions=grid_dimensions,
                                      title="Predicted bounding boxes and class", save_to_file="imgs/detection/pred.png")