In [None]:
import os
import sys

# Changes the current path to find the source files
current_dir = os.getcwd()
while current_dir != os.path.abspath("../src"):
    os.chdir("..")
    current_dir = os.getcwd()
sys.path.append(os.path.abspath("Efficient-Computing/Detection/Gold-YOLO"))

In [None]:
from utils import create_all_folders, Folders

create_all_folders()

In [None]:
from layers import AMF_GD_YOLOv8
import torch
from preprocessing.data import ImageData
from training_parameters import (
    class_names,
    class_indices,
    class_colors,
    transform_pixel_rgb,
    transform_pixel_chm,
    transform_spatial,
    proba_drop_rgb,
    proba_drop_chm,
    labels_transformation_drop_chm,
    labels_transformation_drop_rgb,
)
from training import (
    train_and_validate,
    create_and_save_splitted_datasets,
    load_tree_datasets_from_split,
    compute_mean_and_std,
)
import multiprocessing as mp
from geojson_conversions import open_geojson_feature_collection
from preprocessing.rgb_cir import download_rgb_image_from_polygon

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "garbage_collection_threshold:0.6,max_split_size_mb:512"
print(os.environ.get("PYTORCH_CUDA_ALLOC_CONF"))

In [None]:
# Find and split the data

annotations_file_name = "122000_484000.geojson"

annotations_path = os.path.join(Folders.FULL_ANNOTS.value, annotations_file_name)
annotations = open_geojson_feature_collection(annotations_path)
full_image_path_tif = download_rgb_image_from_polygon(annotations["bbox"])[0]

resolution = 0.08

image_data = ImageData(full_image_path_tif)

annotations_folder_path = os.path.join(Folders.CROPPED_ANNOTS.value, image_data.base_name)
rgb_cir_folder_path = os.path.join(Folders.IMAGES.value, "merged", "cropped", image_data.base_name)
chm_folder_path = os.path.join(
    Folders.CHM.value,
    f"{round(resolution*100)}cm",
    "filtered",
    "merged",
    "cropped",
    image_data.coord_name,
)

sets_ratios = [3, 1, 1]
sets_names = ["training", "validation", "test"]
data_split_file_path = "../data/others/data_split.json"
dismissed_classes = []

create_and_save_splitted_datasets(
    rgb_cir_folder_path,
    chm_folder_path,
    annotations_folder_path,
    sets_ratios,
    sets_names,
    data_split_file_path,
    random_seed=0,
)

mean_rgb, std_rgb = compute_mean_and_std(rgb_cir_folder_path, per_channel=True)
mean_chm, std_chm = compute_mean_and_std(chm_folder_path, per_channel=False)
print(f"{mean_rgb = }")
print(f"{mean_chm = }")

datasets = load_tree_datasets_from_split(
    data_split_file_path,
    class_indices,
    class_colors,
    mean_rgb=mean_rgb,
    std_rgb=std_rgb,
    mean_chm=mean_chm,
    std_chm=std_chm,
    proba_drop_rgb=proba_drop_rgb,
    labels_transformation_drop_rgb=labels_transformation_drop_rgb,
    proba_drop_chm=proba_drop_chm,
    labels_transformation_drop_chm=labels_transformation_drop_chm,
    dismissed_classes=dismissed_classes,
    transform_spatial_training=transform_spatial,
    transform_pixel_rgb_training=transform_pixel_rgb,
    transform_pixel_chm_training=transform_pixel_chm,
)

# Training parameters

lr = 1e-2
epochs = 400

batch_size = 1
num_workers = mp.cpu_count()
accumulate = 12

In [None]:
index = 0
model_name = f"trained_model_{epochs}ep_{index}_multi_chm"
while os.path.exists(f"../models/amf_gd_yolov8/{model_name}.pt"):
    index += 1
    model_name = f"trained_model_{epochs}ep_{index}_multi_chm"

model = AMF_GD_YOLOv8(
    datasets["training"].rgb_channels,
    datasets["training"].chm_channels,
    device=device,
    scale="n",
    class_names=class_names,
    name=model_name,
).to(device)

print(f"{datasets['training'].rgb_channels = }")
print(f"{datasets['training'].chm_channels = }")

final_model = train_and_validate(
    model=model,
    datasets=datasets,
    lr=lr,
    epochs=epochs,
    batch_size=batch_size,
    num_workers=num_workers,
    accumulate=accumulate,
    device=device,
    save_outputs=True,
)
state_dict = final_model.state_dict()
state_dict_path = os.path.join(Folders.MODELS_AMF_GD_YOLOV8.value, f"{model_name}.pt")
torch.save(state_dict, state_dict_path)

In [None]:
# from time import time
# import multiprocessing as mp
# from training import TreeDataLoader

# batch_size = 8

# for num_workers in range(2, mp.cpu_count() + 2, 2):
#     train_loader = TreeDataLoader(
#         datasets["training"],
#         batch_size=batch_size,
#         shuffle=True,
#         num_workers=num_workers,
#         pin_memory=True,
#     )
#     start = time()
#     for epoch in range(1, 3):
#         for i, data in enumerate(train_loader, 0):
#             pass
#     end = time()
#     print(f"Finish with: {end - start} second, num_workers={num_workers}")

In [None]:
from training import test_save_output_image, initialize_dataloaders

index = 0
model_name = f"trained_model_{epochs}ep_{index}"

model = AMF_GD_YOLOv8(3, 1, device=device, scale="n", class_names=class_names, name=model_name).to(
    device
)

state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)

_, _, test_loader = initialize_dataloaders(
    datasets=datasets, batch_size=batch_size, num_workers=num_workers
)

test_save_output_image(
    model,
    test_loader,
    -1,
    device,
    no_rgb=False,
    no_chm=False,
    save_path=f"../data/others/model_output/{model_name}_all.geojson",
)

test_save_output_image(
    model,
    test_loader,
    -1,
    device,
    no_rgb=True,
    no_chm=False,
    save_path=f"../data/others/model_output/{model_name}_no_rgb.geojson",
)
test_save_output_image(
    model,
    test_loader,
    -1,
    device,
    no_rgb=False,
    no_chm=True,
    save_path=f"../data/others/model_output/{model_name}_no_chm.geojson",
)
test_save_output_image(
    model,
    test_loader,
    -1,
    device,
    no_rgb=True,
    no_chm=True,
    save_path=f"../data/others/model_output/{model_name}_no_rgb_no_chm.geojson",
)