In [None]:
import os
import sys

# Changes the current path to find the source files
current_dir = os.getcwd()
while current_dir != os.path.abspath("../src"):
    os.chdir("..")
    current_dir = os.getcwd()
# sys.path.append(os.path.abspath("Efficient-Computing/Detection/Gold-YOLO"))

In [None]:
from utils import create_all_folders, Folders

create_all_folders()

In [None]:
from layers import AMF_GD_YOLOv8
import torch
from preprocessing.data import ImageData
from training_parameters import (
    class_names,
    class_indices,
    class_colors,
    transform_pixel_rgb,
    transform_pixel_chm,
    transform_spatial,
    proba_drop_rgb,
    proba_drop_chm,
    labels_transformation_drop_chm,
    labels_transformation_drop_rgb,
)
from training import (
    train_and_validate,
    create_and_save_splitted_datasets,
    load_tree_datasets_from_split,
    compute_mean_and_std,
    compute_metrics,
    plot_sorted_ap,
    plot_sorted_ap_confs,
)
import multiprocessing as mp
from geojson_conversions import open_geojson_feature_collection
from preprocessing.rgb_cir import get_rgb_images_paths_from_polygon
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "garbage_collection_threshold:0.6,max_split_size_mb:512"
print(os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))

In [None]:
# Find and split the data

annotations_file_name = "122000_484000.geojson"

annotations_path = os.path.join(Folders.FULL_ANNOTS.value, annotations_file_name)
annotations = open_geojson_feature_collection(annotations_path)
full_image_path_tif = get_rgb_images_paths_from_polygon(annotations["bbox"])[0]

resolution = 0.08

image_data = ImageData(full_image_path_tif)

annotations_folder_path = os.path.join(Folders.CROPPED_ANNOTS.value, image_data.base_name)
rgb_cir_folder_path = os.path.join(Folders.IMAGES.value, "merged", "cropped", image_data.base_name)
chm_folder_path = os.path.join(
    Folders.CHM.value,
    f"{round(resolution*100)}cm",
    "filtered",
    "merged",
    "cropped",
    image_data.coord_name,
)

sets_ratios = [3, 1, 1]
sets_names = ["training", "validation", "test"]
data_split_file_path = os.path.join(Folders.OTHERS_DIR.value, "data_split.json")
dismissed_classes = []

create_and_save_splitted_datasets(
    rgb_cir_folder_path,
    chm_folder_path,
    annotations_folder_path,
    sets_ratios,
    sets_names,
    data_split_file_path,
    random_seed=0,
)

mean_rgb, std_rgb = compute_mean_and_std(
    rgb_cir_folder_path, per_channel=True, replace_no_data=False
)
no_data_new_value = -5  # TODO: Variable to add to the Dataset!
mean_chm, std_chm = compute_mean_and_std(
    chm_folder_path, per_channel=False, replace_no_data=True, no_data_new_value=no_data_new_value
)

datasets = load_tree_datasets_from_split(
    data_split_file_path,
    class_indices,
    mean_rgb=mean_rgb,
    std_rgb=std_rgb,
    mean_chm=mean_chm,
    std_chm=std_chm,
    proba_drop_rgb=proba_drop_rgb,
    labels_transformation_drop_rgb=labels_transformation_drop_rgb,
    proba_drop_chm=proba_drop_chm,
    labels_transformation_drop_chm=labels_transformation_drop_chm,
    dismissed_classes=dismissed_classes,
    transform_spatial_training=transform_spatial,
    transform_pixel_rgb_training=transform_pixel_rgb,
    transform_pixel_chm_training=transform_pixel_chm,
    no_data_new_value=no_data_new_value,
)

# Training parameters

lr = 1e-2
epochs = 1000

batch_size = 6
num_workers = mp.cpu_count()
accumulate = 12

In [None]:
postfix = "multi_chm"
model_name, model_path = AMF_GD_YOLOv8.get_new_model_name_and_path(epochs, postfix)


model = AMF_GD_YOLOv8(
    datasets["training"].rgb_channels,
    datasets["training"].chm_channels,
    device=device,
    scale="n",
    class_names=class_names,
    name=model_name,
).to(device)

print(f"{datasets['training'].rgb_channels = }")
print(f"{datasets['training'].chm_channels = }")

final_model = train_and_validate(
    model=model,
    datasets=datasets,
    lr=lr,
    epochs=epochs,
    batch_size=batch_size,
    num_workers=num_workers,
    accumulate=accumulate,
    device=device,
    save_outputs=False,
    show_training_metrics=True,
)

state_dict = final_model.state_dict()
torch.save(state_dict, model_path)

In [None]:
# from time import time
# import multiprocessing as mp
# from training import TreeDataLoader

# batch_size = 8

# for num_workers in range(2, mp.cpu_count() + 2, 2):
#     train_loader = TreeDataLoader(
#         datasets["training"],
#         batch_size=batch_size,
#         shuffle=True,
#         num_workers=num_workers,
#         pin_memory=True,
#     )
#     start = time()
#     for epoch in range(1, 3):
#         for i, data in enumerate(train_loader, 0):
#             pass
#     end = time()
#     print(f"Finish with: {end - start} second, num_workers={num_workers}")

In [None]:
from training import test_save_output_image, initialize_dataloaders
from tqdm.notebook import tqdm

model = AMF_GD_YOLOv8(
    datasets["training"].rgb_channels,
    datasets["training"].chm_channels,
    device=device,
    scale="n",
    class_names=class_names,
    name=model_name,
).to(device)

model_name, model_path = AMF_GD_YOLOv8.get_last_model_name_and_path(epochs, postfix)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

_, _, test_loader = initialize_dataloaders(
    datasets=datasets, batch_size=batch_size, num_workers=num_workers
)

best_sorted_ious_list = []
best_aps_list = []
best_sorted_ap_list = []
best_conf_threshold__list = []

sorted_ap_lists = []
conf_thresholds_list = []

legend_list = []

thresholds_low = np.power(10, np.linspace(-4, -1, 10))
thresholds_high = np.linspace(0.1, 1.0, 19)
conf_thresholds = np.hstack((thresholds_low, thresholds_high)).tolist()

no_rgbs = [False, False, True, True]
no_chms = [False, True, False, True]
test_names = ["all", "no_chm", "no_rgb", "no_chm_no_rgb"]

pbar = tqdm(zip(no_rgbs, no_chms, test_names), total=len(no_rgbs))
for no_rgb, no_chm, test_name in pbar:
    if no_rgb:
        if no_chm:
            legend = "No data"
        else:
            legend = "CHM"
    else:
        if no_chm:
            legend = "RGB"
        else:
            legend = "RGB and CHM"
    pbar.set_description(legend)
    pbar.refresh()
    test_save_output_image(
        model,
        test_loader,
        -1,
        device,
        no_rgb=no_rgb,
        no_chm=no_chm,
        save_path=os.path.join(Folders.OUTPUT_DIR.value, f"{model_name}_{test_name}.geojson"),
    )
    (
        best_sorted_ious,
        best_aps,
        best_sorted_ap,
        best_conf_threshold,
        sorted_ious_list,
        aps_list,
        sorted_ap_list_2,
    ) = compute_metrics(
        model,
        test_loader,
        device,
        conf_thresholds=conf_thresholds,
        no_rgb=no_rgb,
        no_chm=no_chm,
        save_path_ap_iou=os.path.join(
            Folders.OUTPUT_DIR.value, f"{model_name}_ap_iou_{test_name}.png"
        ),
        save_path_sap_conf=os.path.join(
            Folders.OUTPUT_DIR.value, f"{model_name}_sap_conf_{test_name}.png"
        ),
    )

    best_sorted_ious_list.append(best_sorted_ious)
    best_aps_list.append(best_aps)
    best_sorted_ap_list.append(best_sorted_ap)
    best_conf_threshold__list.append(best_conf_threshold)

    sorted_ap_lists.append(sorted_ap_list_2)
    conf_thresholds_list.append(conf_thresholds)

    legend_list.append(legend)

plot_sorted_ap(
    best_sorted_ious_list,
    best_aps_list,
    best_sorted_ap_list,
    conf_thresholds=best_conf_threshold__list,
    legend_list=legend_list,
    show=True,
    save_path=os.path.join(Folders.OUTPUT_DIR.value, f"{model_name}_ap_iou.png"),
)

plot_sorted_ap_confs(
    sorted_ap_lists=sorted_ap_lists,
    conf_thresholds_list=conf_thresholds_list,
    legend_list=legend_list,
    show=True,
    save_path=os.path.join(Folders.OUTPUT_DIR.value, f"{model_name}_sap_conf.png"),
)