In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
# %load_ext cudf.pandas
# import pandas as pd
# print(pd)


In [None]:
# Optimizations
# GDAL optimizations
import multiprocessing as mp
import os

cpu_count: int = mp.cpu_count()
num_cores: int = cpu_count - 2
os.environ["GDAL_NUM_THREADS"] = f"{num_cores}"
os.environ["GDAL_CACHEMAX"] = "1024"


## Libraries

In [None]:
# Imports
from pathlib import Path
import pandas as pd
import numpy as np
import riskmapjnr as rmj
from typing import Literal


In [None]:
# Add root to path
import sys

sys.path.append("..")
from component.script.utilities.file_filter import (
    list_files_by_extension,
    filter_files_by_keywords,
)


## Set user parameters

In [None]:
project_name = "test"


In [None]:
years: list[int] = [2015, 2020, 2024]
forest_source: Literal["gfc", "tmf"] = "gfc"
tree_cover_threshold: int = 10


In [None]:
static_variables = ["altitude", "slope", "pa", "subj", "dist_rivers", "dist_roads"]
dynamic_variables = ["forest", "deforestation", "forest_edge", "dist_towns"]


## Connect folders

In [None]:
root_folder: Path = Path.cwd().parent
downloads_folder: Path = root_folder / "data"
downloads_folder.mkdir(parents=True, exist_ok=True)


In [None]:
project_folder = downloads_folder / project_name
project_folder.mkdir(parents=True, exist_ok=True)
processed_data_folder = project_folder / "data"
processed_data_folder.mkdir(parents=True, exist_ok=True)
plots_folder = project_folder / "plots"
plots_folder.mkdir(parents=True, exist_ok=True)
rmj_bm = project_folder / "rmj_bm"
rmj_bm.mkdir(parents=True, exist_ok=True)


## Forest Files

In [None]:
# List all raster files in the processed data folder
input_raster_files = list_files_by_extension(processed_data_folder, [".tiff", ".tif"])
forest_change_file = filter_files_by_keywords(
    input_raster_files,
    ["defostack", forest_source],
    False,
    ["distance", "edge"],
)[0]
forest_yearly_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source], False, ["loss", "edge"]
)
forest_edge_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source, "edge"], False
)


In [None]:
forest_change_file = filter_files_by_keywords(
    input_raster_files,
    ["defostack", forest_source],
    False,
    ["distance", "edge"],
)[0]
forest_change_file


In [None]:
forest_yearly_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source], False, ["loss", "edge"]
)
forest_yearly_files


In [None]:
forest_edge_files = filter_files_by_keywords(
    input_raster_files, ["forest", forest_source, "edge"], False
)
forest_edge_files


In [None]:
raster_subj_file = filter_files_by_keywords(input_raster_files, ["subj"])[0]
raster_subj_file


## Periods dictionaries

In [None]:
import re
from pathlib import Path


def create_full_period_dict(
    years: list[int],
    period: str,
    processed_data_folder: Path,
    static_variables: list[str],
    dynamic_variables: list[str],
):
    """
    Create a comprehensive dictionary for a given modeling period.
    Handles period-independent and multi-temporal variables separately.
    """

    if len(years) < 3:
        raise ValueError("The 'years' list must contain at least three elements.")

    configs = {
        "calibration": {
            "train_period": "calibration",
            "initial_idx": 0,
            "final_idx": 1,
            "defor_value": 1,
            "var_idx": 0,
        },
        "validation": {
            "train_period": "calibration",
            "initial_idx": 1,
            "final_idx": 2,
            "defor_value": 1,
            "var_idx": 1,
        },
        "historical": {
            "train_period": "historical",
            "initial_idx": 0,
            "final_idx": 2,
            "defor_value": [1, 2],
            "var_idx": 0,
        },
        "forecast": {
            "train_period": "historical",
            "initial_idx": 0,
            "final_idx": 2,
            "defor_value": [1, 2],
            "var_idx": 2,
        },
    }

    if period not in configs:
        raise ValueError(f"Unknown period '{period}'. Must be one of: {list(configs)}.")

    c = configs[period]

    # --- Base period dictionary ---
    period_dict = {
        "period": period,
        "train_period": c["train_period"],
        "initial_year": years[c["initial_idx"]],
        "final_year": years[c["final_idx"]],
        "defor_value": c["defor_value"],
        "time_interval": years[c["final_idx"]] - years[c["initial_idx"]],
        "var_year": years[c["var_idx"]],
    }

    initial_year = str(period_dict["initial_year"])
    final_year = str(period_dict["final_year"])
    var_year = str(period_dict["var_year"])
    exclude_years = ", ".join(map(str, set(years) - {initial_year, final_year}))
    period_name = str(period_dict["period"])

    variable_file_mapping = {"period": period}
    input_raster_files = list_files_by_extension(
        processed_data_folder, [".tiff", ".tif"]
    )

    # --- Modular file search ---
    def _is_token_separate_in_name(token: str, name: str) -> bool:
        """
        Devuelve True si `token` aparece en `name` como 'palabra' separada por
        caracteres no alfanuméricos o en los límites (comportamiento similar a \b,
        pero \b considera "_" como no palabra; aquí queremos lo mismo).
        """
        if token.isdigit():  # años u otros números: buscar la secuencia directamente
            return token in name
        # construimos regex que asegura token no está pegado a letras o números
        pattern = rf"(?<![0-9A-Za-z]){re.escape(token)}(?![0-9A-Za-z])"
        return re.search(pattern, name) is not None

    def _strict_candidate_filter(candidates, tokens):
        """
        Filtra candidatos manteniendo sólo aquellos que contienen todos los tokens
        como 'palabras' separadas (ver _is_token_separate_in_name).
        """
        filtered = []
        for p in candidates:
            s = str(p).lower()
            if all(_is_token_separate_in_name(tok.lower(), s) for tok in tokens):
                filtered.append(p)
        return filtered

    def find_file(var_name, dynamic=False):
        """
        Busca un archivo que contenga los términos relevantes.
        Si es dinámico, incluye los años del periodo.
        """
        parts = var_name.split("_")
        include_terms = []
        if len(parts) == 1:
            exclude_terms = ["distance", "edge"]
        else:
            exclude_terms = None

        if dynamic:
            if period_name != "forecast":
                # Buscar archivos que incluyan los años del periodo
                if "deforestation" in parts:
                    include_terms = [*parts, initial_year, final_year]
                else:
                    include_terms = [*parts, initial_year]
            elif period_name == "forecast":
                include_terms = [*parts, var_year]
        else:
            include_terms = parts

        # Buscar distancias o bordes si el nombre lo indica
        if "dist" in parts and "distance" not in include_terms:
            include_terms.append("distance")

        files = filter_files_by_keywords(
            input_raster_files, include_terms, False, exclude_terms, True
        )
        # Si no hay archivos, devolvemos None
        if not files and period_name == "forecast":
            include_terms = [*parts, str(years[1])]
            files = filter_files_by_keywords(
                input_raster_files, include_terms, False, exclude_terms, True
            )
        if not files:
            return None

        # Si viene solo 1, ok
        if len(files) == 1:
            return files[0]
        strict = _strict_candidate_filter(files, parts)
        if strict:
            # si hay múltiplos aún, devolvemos el primero (heurística)
            return strict[0]

    # --- Buscar variables independientes ---
    for var in static_variables:
        variable_file_mapping[var] = find_file(var, dynamic=False)

    # --- Buscar variables multitemporales ---
    for var in dynamic_variables:
        variable_file_mapping[var] = find_file(var, dynamic=True)

    # --- Merge final ---
    period_dict.update(variable_file_mapping)
    return period_dict


In [None]:
calibration_dict = create_full_period_dict(
    years,
    "calibration",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
validation_dict = create_full_period_dict(
    years,
    "validation",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
historical_dict = create_full_period_dict(
    years,
    "historical",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)
forecast_dict = create_full_period_dict(
    years,
    "forecast",
    processed_data_folder,
    static_variables,
    dynamic_variables,
)


In [None]:
calibration_dict


## 1 Calculate distance to forest edge

In [None]:
deforestation_thresh = 99.5
deforestation_thresh = 99.5
max_dist1 = 5000
max_dist2 = 5000


In [None]:
def calculate_period_dist_edge_threshold(
    forest_change_file,
    period_dictionary,
    deforestation_thresh,
    max_dist,
    model_folder,
    plots_folder,
):
    period_output_folder = model_folder / period_dictionary["period"]
    if not os.path.exists(period_output_folder):
        os.makedirs(period_output_folder)
    dist_thresh = rmj.dist_edge_threshold(
        fcc_file=forest_change_file,
        defor_values=period_dictionary["defor_value"],
        defor_threshold=deforestation_thresh,
        dist_file=period_dictionary["forest_edge"],
        dist_bins=np.arange(0, max_dist, step=30),
        tab_file_dist=period_output_folder / "tab_dist.csv",
        fig_file_dist=plots_folder / f"perc_dist_{period_dictionary['period']}.png",
        blk_rows=128,
        dist_file_available=True,
        check_fcc=True,
        verbose=True,
    )
    # Save result
    dist_edge_data = pd.DataFrame(dist_thresh, index=[0])
    dist_edge_data.to_csv(
        period_output_folder / "dist_edge_threshold.csv",
        sep=",",
        header=True,
        index=False,
        index_label=False,
    )


In [None]:
calculate_period_dist_edge_threshold(
    forest_change_file,
    calibration_dict,
    deforestation_thresh,
    max_dist1,
    rmj_bm,
    plots_folder,
)


In [None]:
calculate_period_dist_edge_threshold(
    forest_change_file,
    historical_dict,
    deforestation_thresh,
    max_dist2,
    rmj_bm,
    plots_folder,
)


## 2 Compute bins

In [None]:
def get_dist_thresh(ifile):
    """Get distance to forest edge threshold."""
    dist_thresh_data = pd.read_csv(ifile)
    dist_thresh = dist_thresh_data.loc[0, "dist_thresh"]
    return dist_thresh


def save_dist_bins(dist_bins, output_file):
    dist_bins_str = [str(i) for i in dist_bins]
    with open(output_file, "w", encoding="utf-8") as f:
        f.write("\n".join(dist_bins_str))


def get_dist_bins(dist_bins_file):
    """Get distance bins."""
    with open(dist_bins_file, "r", encoding="utf-8") as f:
        dist_bins = [float(line.rstrip()) for line in f]
    return dist_bins


In [None]:
def calculate_period_dist_bins(period_dictionary, model_folder):
    period_output_folder = model_folder / period_dictionary["period"]
    period_dist_edge_file = period_output_folder / "dist_edge_threshold.csv"
    dist_bins = rmj.benchmark.compute_dist_bins(
        period_dictionary.get("forest_edge"),
        get_dist_thresh(period_dist_edge_file),
    )
    dist_thresh = get_dist_thresh(period_dist_edge_file)
    dist_bins_file = period_output_folder / "dist_bins.csv"
    save_dist_bins(dist_bins, dist_bins_file)


In [None]:
calculate_period_dist_bins(calibration_dict, rmj_bm)


In [None]:
calculate_period_dist_bins(historical_dict, rmj_bm)


## 3 Compute vulnerability map

In [None]:
def calculate_period_vulnerability_map(
    period_dictionary, raster_subj_file, model_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    if not os.path.exists(period_output_folder):
        os.makedirs(period_output_folder)
    trained_period_output_folder = model_folder / period_dictionary["train_period"]
    dist_bins_file = trained_period_output_folder / "dist_bins.csv"
    vulnerability_map_file = (
        period_output_folder / f"prob_bm_{period_dictionary['period']}.tif"
    )

    rmj.benchmark.vulnerability_map(
        forest_file=period_dictionary["forest"],
        dist_file=period_dictionary["forest_edge"],
        dist_bins=get_dist_bins(dist_bins_file),
        subj_file=raster_subj_file,
        output_file=vulnerability_map_file,
        blk_rows=128,
        verbose=False,
    )


In [None]:
calculate_period_vulnerability_map(calibration_dict, raster_subj_file, rmj_bm)


In [None]:
calculate_period_vulnerability_map(historical_dict, raster_subj_file, rmj_bm)


In [None]:
calculate_period_vulnerability_map(validation_dict, raster_subj_file, rmj_bm)


In [None]:
calculate_period_vulnerability_map(forecast_dict, raster_subj_file, rmj_bm)


## 4 Compute deforestation rate per vulnerability class

In [None]:
def calculate_period_vulnerability_classes(
    period_dictionary, forest_change_file, model_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    trained_period_output_folder = model_folder / period_dictionary["train_period"]
    vulnerability_file_path = (
        period_output_folder / f"prob_bm_{period_dictionary['period']}.tif"
    )
    time_interval = period_dictionary["time_interval"]
    if period_dictionary["period"] in ["validation", "forecast"]:
        deforate_model = (
            trained_period_output_folder
            / f"defrate_cat_bm_{period_dictionary['train_period']}.csv"
        )
    else:
        deforate_model = None
    output_file = (
        period_output_folder / f"defrate_cat_bm_{period_dictionary['period']}.csv"
    )
    rmj.benchmark.defrate_per_class(
        fcc_file=forest_change_file,
        vulnerability_file=vulnerability_file_path,
        time_interval=period_dictionary["time_interval"],
        period=period_dictionary["period"],
        deforate_model=deforate_model,
        tab_file_defrate=output_file,
        blk_rows=128,
        verbose=False,
    )


In [None]:
calculate_period_vulnerability_classes(calibration_dict, forest_change_file, rmj_bm)


In [None]:
calculate_period_vulnerability_classes(historical_dict, forest_change_file, rmj_bm)


In [None]:
calculate_period_vulnerability_classes(validation_dict, forest_change_file, rmj_bm)


In [None]:
calculate_period_vulnerability_classes(forecast_dict, forest_change_file, rmj_bm)


## 5 Plot Risk map of deforestation

In [None]:
input_vector_files = list_files_by_extension(processed_data_folder, [".shp"])
aoi_vector = filter_files_by_keywords(input_vector_files, ["aoi"])[0]


In [None]:
fcc_plot = os.path.join(plots_folder, "fcc123.png")
fig_fcc123 = rmj.plot.fcc123(
    input_fcc_raster=str(forest_change_file),
    maxpixels=1e8,
    output_file=fcc_plot,
    borders=aoi_vector,
    linewidth=0.2,
    figsize=(5, 4),
    dpi=100,
)


In [None]:
def calculate_period_plot_risk_map(
    period_dictionary, aoi_vector, model_folder, plots_folder
):
    period_output_folder = model_folder / period_dictionary["period"]
    ifile = str(period_output_folder / f"prob_bm_{period_dictionary['period']}.tif")
    ofile = str(plots_folder / f"prob_bm_{period_dictionary['period']}.png")
    riskmap_fig = rmj.benchmark.plot.vulnerability_map(
        input_map=ifile,
        maxpixels=1e8,
        output_file=ofile,
        borders=aoi_vector,
        legend=True,
        figsize=(6, 5),
        dpi=100,
        linewidth=0.3,
    )


In [None]:
calculate_period_plot_risk_map(calibration_dict, aoi_vector, rmj_bm, plots_folder)


In [None]:
calculate_period_plot_risk_map(historical_dict, aoi_vector, rmj_bm, plots_folder)


In [None]:
calculate_period_plot_risk_map(validation_dict, aoi_vector, rmj_bm, plots_folder)


In [None]:
calculate_period_plot_risk_map(forecast_dict, aoi_vector, rmj_bm, plots_folder)
